feat: add http api bindings (#410)

* feat: add http-api-bindings

* feat: add http-api-bindings

* hand max_input_length

* rename

* update

* update

* add examples/simple.rs

* update

* add default value for stop words

* update

* fix lint

* update
release-0.2
Meng Zhang 2023-09-09 11:59:42 +08:00 committed by GitHub
parent ad3b974d5c
commit 17397c8c8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 166 additions and 6 deletions

24
Cargo.lock generated
View File

@ -1233,6 +1233,18 @@ dependencies = [
"itoa",
]
[[package]]
name = "http-api-bindings"
version = "0.1.0"
dependencies = [
"async-trait",
"reqwest",
"serde",
"serde_json",
"tabby-inference",
"tokio",
]
[[package]]
name = "http-body"
version = "0.4.5"
@ -2660,9 +2672,9 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.163"
version = "1.0.171"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2113ab51b87a539ae008b5c6c02dc020ffa39afd2d83cffcb3f4eb2722cebec2"
checksum = "30e27d1e4fd7659406c492fd6cfaf2066ba8773de45ca75e855590f856dc34a9"
dependencies = [
"serde_derive",
]
@ -2679,9 +2691,9 @@ dependencies = [
[[package]]
name = "serde_derive"
version = "1.0.163"
version = "1.0.171"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e"
checksum = "389894603bd18c46fa56231694f8d827779c0951a667087194cf9de94ed24682"
dependencies = [
"proc-macro2",
"quote",
@ -2690,9 +2702,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.96"
version = "1.0.105"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360"
dependencies = [
"itoa",
"ryu",

View File

@ -9,6 +9,7 @@ members = [
"crates/rust-cxx-cmake-bridge",
"crates/llama-cpp-bindings",
"crates/stop-words",
"crates/http-api-bindings",
]
[workspace.package]

View File

@ -0,0 +1,14 @@
[package]
name = "http-api-bindings"
version = "0.1.0"
edition = "2021"
[dependencies]
async-trait.workspace = true
reqwest = { workspace = true, features = ["json"] }
serde = { workspace = true, features = ["derive"] }
serde_json = "1.0.105"
tabby-inference = { version = "0.1.0", path = "../tabby-inference" }
[dev-dependencies]
tokio = { workspace = true, features = ["full"] }

View File

@ -0,0 +1,10 @@
## Usage
```bash
export MODEL_ID="code-gecko"
export PROJECT_ID="$(gcloud config get project)"
export API_ENDPOINT="https://us-central1-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:predict"
export AUTHORIZATION="Bearer $(gcloud auth print-access-token)"
cargo run --example simple
```

View File

@ -0,0 +1,20 @@
use std::env;
use http_api_bindings::vertex_ai::VertexAIEngine;
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
#[tokio::main]
async fn main() {
let api_endpoint = env::var("API_ENDPOINT").expect("API_ENDPOINT not set");
let authorization = env::var("AUTHORIZATION").expect("AUTHORIZATION not set");
let engine = VertexAIEngine::create(&api_endpoint, &authorization);
let options = TextGenerationOptionsBuilder::default()
.sampling_temperature(0.1)
.max_decoding_length(32)
.build()
.unwrap();
let prompt = "def fib(n)";
let text = engine.generate(prompt, options).await;
println!("{}{}", prompt, text);
}

View File

@ -0,0 +1 @@
pub mod vertex_ai;

View File

@ -0,0 +1,99 @@
use async_trait::async_trait;
use reqwest::header;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tabby_inference::{TextGeneration, TextGenerationOptions};
#[derive(Serialize)]
struct Request {
instances: Vec<Instance>,
parameters: Parameters,
}
#[derive(Serialize)]
struct Instance {
prefix: String,
suffix: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Parameters {
temperature: f32,
max_output_tokens: usize,
stop_sequences: Vec<String>,
}
#[derive(Deserialize)]
struct Response {
predictions: Vec<Prediction>,
}
#[derive(Deserialize)]
struct Prediction {
content: String,
}
pub struct VertexAIEngine {
client: reqwest::Client,
api_endpoint: String,
}
impl VertexAIEngine {
pub fn create(api_endpoint: &str, authorization: &str) -> Self {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
header::HeaderValue::from_str(authorization)
.expect("Failed to create authorization header"),
);
let client = reqwest::Client::builder()
.default_headers(headers)
.build()
.expect("Failed to construct HTTP client");
Self {
api_endpoint: api_endpoint.to_owned(),
client,
}
}
}
#[async_trait]
impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> =
options.stop_words.iter().map(|x| x.to_string()).collect();
let request = Request {
instances: vec![Instance {
prefix: prompt.to_owned(),
suffix: None,
}],
// options.max_input_length is ignored.
parameters: Parameters {
temperature: options.sampling_temperature,
max_output_tokens: options.max_decoding_length,
stop_sequences,
},
};
// API Documentation: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#code-completion-prompt-parameters
let resp = self
.client
.post(&self.api_endpoint)
.json(&request)
.send()
.await
.expect("Failed to making completion request");
if resp.status() != 200 {
let err: Value = resp.json().await.expect("Failed to parse response");
println!("Request failed: {}", err);
std::process::exit(1);
}
let resp: Response = resp.json().await.expect("Failed to parse response");
resp.predictions[0].content.clone()
}
}

View File

@ -12,9 +12,12 @@ pub struct TextGenerationOptions {
#[builder(default = "1.0")]
pub sampling_temperature: f32,
#[builder(default = "&EMPTY_STOP_WORDS")]
pub stop_words: &'static Vec<&'static str>,
}
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
#[async_trait]
pub trait TextGeneration: Sync + Send {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;