diff --git a/Cargo.lock b/Cargo.lock index 52e8e15..9999339 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1233,6 +1233,18 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-api-bindings" +version = "0.1.0" +dependencies = [ + "async-trait", + "reqwest", + "serde", + "serde_json", + "tabby-inference", + "tokio", +] + [[package]] name = "http-body" version = "0.4.5" @@ -2660,9 +2672,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.163" +version = "1.0.171" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2113ab51b87a539ae008b5c6c02dc020ffa39afd2d83cffcb3f4eb2722cebec2" +checksum = "30e27d1e4fd7659406c492fd6cfaf2066ba8773de45ca75e855590f856dc34a9" dependencies = [ "serde_derive", ] @@ -2679,9 +2691,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.163" +version = "1.0.171" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e" +checksum = "389894603bd18c46fa56231694f8d827779c0951a667087194cf9de94ed24682" dependencies = [ "proc-macro2", "quote", @@ -2690,9 +2702,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.96" +version = "1.0.105" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" dependencies = [ "itoa", "ryu", diff --git a/Cargo.toml b/Cargo.toml index be7706a..7e04bd2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "crates/rust-cxx-cmake-bridge", "crates/llama-cpp-bindings", "crates/stop-words", + "crates/http-api-bindings", ] [workspace.package] diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml new file mode 100644 index 0000000..492e046 --- /dev/null +++ b/crates/http-api-bindings/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "http-api-bindings" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-trait.workspace = true +reqwest = { workspace = true, features = ["json"] } +serde = { workspace = true, features = ["derive"] } +serde_json = "1.0.105" +tabby-inference = { version = "0.1.0", path = "../tabby-inference" } + +[dev-dependencies] +tokio = { workspace = true, features = ["full"] } diff --git a/crates/http-api-bindings/README.md b/crates/http-api-bindings/README.md new file mode 100644 index 0000000..81ad900 --- /dev/null +++ b/crates/http-api-bindings/README.md @@ -0,0 +1,10 @@ +## Usage + +```bash +export MODEL_ID="code-gecko" +export PROJECT_ID="$(gcloud config get project)" +export API_ENDPOINT="https://us-central1-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:predict" +export AUTHORIZATION="Bearer $(gcloud auth print-access-token)" + +cargo run --example simple +``` diff --git a/crates/http-api-bindings/examples/simple.rs b/crates/http-api-bindings/examples/simple.rs new file mode 100644 index 0000000..2c357e1 --- /dev/null +++ b/crates/http-api-bindings/examples/simple.rs @@ -0,0 +1,20 @@ +use std::env; + +use http_api_bindings::vertex_ai::VertexAIEngine; +use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; + +#[tokio::main] +async fn main() { + let api_endpoint = env::var("API_ENDPOINT").expect("API_ENDPOINT not set"); + let authorization = env::var("AUTHORIZATION").expect("AUTHORIZATION not set"); + let engine = VertexAIEngine::create(&api_endpoint, &authorization); + + let options = TextGenerationOptionsBuilder::default() + .sampling_temperature(0.1) + .max_decoding_length(32) + .build() + .unwrap(); + let prompt = "def fib(n)"; + let text = engine.generate(prompt, options).await; + println!("{}{}", prompt, text); +} diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs new file mode 100644 index 0000000..63d4c4f --- /dev/null +++ b/crates/http-api-bindings/src/lib.rs @@ -0,0 +1 @@ +pub mod vertex_ai; diff --git a/crates/http-api-bindings/src/vertex_ai.rs b/crates/http-api-bindings/src/vertex_ai.rs new file mode 100644 index 0000000..e1f7e8a --- /dev/null +++ b/crates/http-api-bindings/src/vertex_ai.rs @@ -0,0 +1,99 @@ +use async_trait::async_trait; +use reqwest::header; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tabby_inference::{TextGeneration, TextGenerationOptions}; + +#[derive(Serialize)] +struct Request { + instances: Vec, + parameters: Parameters, +} + +#[derive(Serialize)] +struct Instance { + prefix: String, + suffix: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct Parameters { + temperature: f32, + max_output_tokens: usize, + stop_sequences: Vec, +} + +#[derive(Deserialize)] +struct Response { + predictions: Vec, +} + +#[derive(Deserialize)] +struct Prediction { + content: String, +} + +pub struct VertexAIEngine { + client: reqwest::Client, + api_endpoint: String, +} + +impl VertexAIEngine { + pub fn create(api_endpoint: &str, authorization: &str) -> Self { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Authorization", + header::HeaderValue::from_str(authorization) + .expect("Failed to create authorization header"), + ); + let client = reqwest::Client::builder() + .default_headers(headers) + .build() + .expect("Failed to construct HTTP client"); + Self { + api_endpoint: api_endpoint.to_owned(), + client, + } + } +} + +#[async_trait] +impl TextGeneration for VertexAIEngine { + async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { + let stop_sequences: Vec = + options.stop_words.iter().map(|x| x.to_string()).collect(); + + let request = Request { + instances: vec![Instance { + prefix: prompt.to_owned(), + suffix: None, + }], + // options.max_input_length is ignored. + parameters: Parameters { + temperature: options.sampling_temperature, + max_output_tokens: options.max_decoding_length, + stop_sequences, + }, + }; + + // API Documentation: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#code-completion-prompt-parameters + let resp = self + .client + .post(&self.api_endpoint) + .json(&request) + .send() + .await + .expect("Failed to making completion request"); + + if resp.status() != 200 { + let err: Value = resp.json().await.expect("Failed to parse response"); + println!("Request failed: {}", err); + std::process::exit(1); + } + + let resp: Response = resp.json().await.expect("Failed to parse response"); + + resp.predictions[0].content.clone() + } +} diff --git a/crates/tabby-inference/src/lib.rs b/crates/tabby-inference/src/lib.rs index c822034..1622dc0 100644 --- a/crates/tabby-inference/src/lib.rs +++ b/crates/tabby-inference/src/lib.rs @@ -12,9 +12,12 @@ pub struct TextGenerationOptions { #[builder(default = "1.0")] pub sampling_temperature: f32, + #[builder(default = "&EMPTY_STOP_WORDS")] pub stop_words: &'static Vec<&'static str>, } +static EMPTY_STOP_WORDS: Vec<&'static str> = vec![]; + #[async_trait] pub trait TextGeneration: Sync + Send { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;