feat: add http api bindings (#410)
* feat: add http-api-bindings * feat: add http-api-bindings * hand max_input_length * rename * update * update * add examples/simple.rs * update * add default value for stop words * update * fix lint * updaterelease-0.2
parent
ad3b974d5c
commit
17397c8c8c
|
|
@ -1233,6 +1233,18 @@ dependencies = [
|
||||||
"itoa",
|
"itoa",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "http-api-bindings"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"reqwest",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"tabby-inference",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "http-body"
|
name = "http-body"
|
||||||
version = "0.4.5"
|
version = "0.4.5"
|
||||||
|
|
@ -2660,9 +2672,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde"
|
name = "serde"
|
||||||
version = "1.0.163"
|
version = "1.0.171"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2113ab51b87a539ae008b5c6c02dc020ffa39afd2d83cffcb3f4eb2722cebec2"
|
checksum = "30e27d1e4fd7659406c492fd6cfaf2066ba8773de45ca75e855590f856dc34a9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
]
|
]
|
||||||
|
|
@ -2679,9 +2691,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_derive"
|
name = "serde_derive"
|
||||||
version = "1.0.163"
|
version = "1.0.171"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e"
|
checksum = "389894603bd18c46fa56231694f8d827779c0951a667087194cf9de94ed24682"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
|
@ -2690,9 +2702,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_json"
|
name = "serde_json"
|
||||||
version = "1.0.96"
|
version = "1.0.105"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
|
checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"itoa",
|
"itoa",
|
||||||
"ryu",
|
"ryu",
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ members = [
|
||||||
"crates/rust-cxx-cmake-bridge",
|
"crates/rust-cxx-cmake-bridge",
|
||||||
"crates/llama-cpp-bindings",
|
"crates/llama-cpp-bindings",
|
||||||
"crates/stop-words",
|
"crates/stop-words",
|
||||||
|
"crates/http-api-bindings",
|
||||||
]
|
]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
[package]
|
||||||
|
name = "http-api-bindings"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
async-trait.workspace = true
|
||||||
|
reqwest = { workspace = true, features = ["json"] }
|
||||||
|
serde = { workspace = true, features = ["derive"] }
|
||||||
|
serde_json = "1.0.105"
|
||||||
|
tabby-inference = { version = "0.1.0", path = "../tabby-inference" }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
tokio = { workspace = true, features = ["full"] }
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export MODEL_ID="code-gecko"
|
||||||
|
export PROJECT_ID="$(gcloud config get project)"
|
||||||
|
export API_ENDPOINT="https://us-central1-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:predict"
|
||||||
|
export AUTHORIZATION="Bearer $(gcloud auth print-access-token)"
|
||||||
|
|
||||||
|
cargo run --example simple
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
use std::env;
|
||||||
|
|
||||||
|
use http_api_bindings::vertex_ai::VertexAIEngine;
|
||||||
|
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
let api_endpoint = env::var("API_ENDPOINT").expect("API_ENDPOINT not set");
|
||||||
|
let authorization = env::var("AUTHORIZATION").expect("AUTHORIZATION not set");
|
||||||
|
let engine = VertexAIEngine::create(&api_endpoint, &authorization);
|
||||||
|
|
||||||
|
let options = TextGenerationOptionsBuilder::default()
|
||||||
|
.sampling_temperature(0.1)
|
||||||
|
.max_decoding_length(32)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
let prompt = "def fib(n)";
|
||||||
|
let text = engine.generate(prompt, options).await;
|
||||||
|
println!("{}{}", prompt, text);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
pub mod vertex_ai;
|
||||||
|
|
@ -0,0 +1,99 @@
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use reqwest::header;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use tabby_inference::{TextGeneration, TextGenerationOptions};
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct Request {
|
||||||
|
instances: Vec<Instance>,
|
||||||
|
parameters: Parameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct Instance {
|
||||||
|
prefix: String,
|
||||||
|
suffix: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
struct Parameters {
|
||||||
|
temperature: f32,
|
||||||
|
max_output_tokens: usize,
|
||||||
|
stop_sequences: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Response {
|
||||||
|
predictions: Vec<Prediction>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Prediction {
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct VertexAIEngine {
|
||||||
|
client: reqwest::Client,
|
||||||
|
api_endpoint: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VertexAIEngine {
|
||||||
|
pub fn create(api_endpoint: &str, authorization: &str) -> Self {
|
||||||
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
"Authorization",
|
||||||
|
header::HeaderValue::from_str(authorization)
|
||||||
|
.expect("Failed to create authorization header"),
|
||||||
|
);
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.default_headers(headers)
|
||||||
|
.build()
|
||||||
|
.expect("Failed to construct HTTP client");
|
||||||
|
Self {
|
||||||
|
api_endpoint: api_endpoint.to_owned(),
|
||||||
|
client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl TextGeneration for VertexAIEngine {
|
||||||
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||||
|
let stop_sequences: Vec<String> =
|
||||||
|
options.stop_words.iter().map(|x| x.to_string()).collect();
|
||||||
|
|
||||||
|
let request = Request {
|
||||||
|
instances: vec![Instance {
|
||||||
|
prefix: prompt.to_owned(),
|
||||||
|
suffix: None,
|
||||||
|
}],
|
||||||
|
// options.max_input_length is ignored.
|
||||||
|
parameters: Parameters {
|
||||||
|
temperature: options.sampling_temperature,
|
||||||
|
max_output_tokens: options.max_decoding_length,
|
||||||
|
stop_sequences,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// API Documentation: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#code-completion-prompt-parameters
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.post(&self.api_endpoint)
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("Failed to making completion request");
|
||||||
|
|
||||||
|
if resp.status() != 200 {
|
||||||
|
let err: Value = resp.json().await.expect("Failed to parse response");
|
||||||
|
println!("Request failed: {}", err);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp: Response = resp.json().await.expect("Failed to parse response");
|
||||||
|
|
||||||
|
resp.predictions[0].content.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -12,9 +12,12 @@ pub struct TextGenerationOptions {
|
||||||
#[builder(default = "1.0")]
|
#[builder(default = "1.0")]
|
||||||
pub sampling_temperature: f32,
|
pub sampling_temperature: f32,
|
||||||
|
|
||||||
|
#[builder(default = "&EMPTY_STOP_WORDS")]
|
||||||
pub stop_words: &'static Vec<&'static str>,
|
pub stop_words: &'static Vec<&'static str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait TextGeneration: Sync + Send {
|
pub trait TextGeneration: Sync + Send {
|
||||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue