feat: add http api bindings (#410)
* feat: add http-api-bindings * feat: add http-api-bindings * hand max_input_length * rename * update * update * add examples/simple.rs * update * add default value for stop words * update * fix lint * updaterelease-0.2
parent
ad3b974d5c
commit
17397c8c8c
|
|
@ -1233,6 +1233,18 @@ dependencies = [
|
|||
"itoa",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-api-bindings"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tabby-inference",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "0.4.5"
|
||||
|
|
@ -2660,9 +2672,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.163"
|
||||
version = "1.0.171"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2113ab51b87a539ae008b5c6c02dc020ffa39afd2d83cffcb3f4eb2722cebec2"
|
||||
checksum = "30e27d1e4fd7659406c492fd6cfaf2066ba8773de45ca75e855590f856dc34a9"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
|
@ -2679,9 +2691,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.163"
|
||||
version = "1.0.171"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e"
|
||||
checksum = "389894603bd18c46fa56231694f8d827779c0951a667087194cf9de94ed24682"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
|
@ -2690,9 +2702,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.96"
|
||||
version = "1.0.105"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
|
||||
checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ members = [
|
|||
"crates/rust-cxx-cmake-bridge",
|
||||
"crates/llama-cpp-bindings",
|
||||
"crates/stop-words",
|
||||
"crates/http-api-bindings",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "http-api-bindings"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
async-trait.workspace = true
|
||||
reqwest = { workspace = true, features = ["json"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = "1.0.105"
|
||||
tabby-inference = { version = "0.1.0", path = "../tabby-inference" }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
## Usage
|
||||
|
||||
```bash
|
||||
export MODEL_ID="code-gecko"
|
||||
export PROJECT_ID="$(gcloud config get project)"
|
||||
export API_ENDPOINT="https://us-central1-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:predict"
|
||||
export AUTHORIZATION="Bearer $(gcloud auth print-access-token)"
|
||||
|
||||
cargo run --example simple
|
||||
```
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
use std::env;
|
||||
|
||||
use http_api_bindings::vertex_ai::VertexAIEngine;
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let api_endpoint = env::var("API_ENDPOINT").expect("API_ENDPOINT not set");
|
||||
let authorization = env::var("AUTHORIZATION").expect("AUTHORIZATION not set");
|
||||
let engine = VertexAIEngine::create(&api_endpoint, &authorization);
|
||||
|
||||
let options = TextGenerationOptionsBuilder::default()
|
||||
.sampling_temperature(0.1)
|
||||
.max_decoding_length(32)
|
||||
.build()
|
||||
.unwrap();
|
||||
let prompt = "def fib(n)";
|
||||
let text = engine.generate(prompt, options).await;
|
||||
println!("{}{}", prompt, text);
|
||||
}
|
||||
|
|
@ -0,0 +1 @@
|
|||
pub mod vertex_ai;
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
use async_trait::async_trait;
|
||||
use reqwest::header;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptions};
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Request {
|
||||
instances: Vec<Instance>,
|
||||
parameters: Parameters,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Instance {
|
||||
prefix: String,
|
||||
suffix: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct Parameters {
|
||||
temperature: f32,
|
||||
max_output_tokens: usize,
|
||||
stop_sequences: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Response {
|
||||
predictions: Vec<Prediction>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Prediction {
|
||||
content: String,
|
||||
}
|
||||
|
||||
pub struct VertexAIEngine {
|
||||
client: reqwest::Client,
|
||||
api_endpoint: String,
|
||||
}
|
||||
|
||||
impl VertexAIEngine {
|
||||
pub fn create(api_endpoint: &str, authorization: &str) -> Self {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
header::HeaderValue::from_str(authorization)
|
||||
.expect("Failed to create authorization header"),
|
||||
);
|
||||
let client = reqwest::Client::builder()
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.expect("Failed to construct HTTP client");
|
||||
Self {
|
||||
api_endpoint: api_endpoint.to_owned(),
|
||||
client,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TextGeneration for VertexAIEngine {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let stop_sequences: Vec<String> =
|
||||
options.stop_words.iter().map(|x| x.to_string()).collect();
|
||||
|
||||
let request = Request {
|
||||
instances: vec![Instance {
|
||||
prefix: prompt.to_owned(),
|
||||
suffix: None,
|
||||
}],
|
||||
// options.max_input_length is ignored.
|
||||
parameters: Parameters {
|
||||
temperature: options.sampling_temperature,
|
||||
max_output_tokens: options.max_decoding_length,
|
||||
stop_sequences,
|
||||
},
|
||||
};
|
||||
|
||||
// API Documentation: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#code-completion-prompt-parameters
|
||||
let resp = self
|
||||
.client
|
||||
.post(&self.api_endpoint)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to making completion request");
|
||||
|
||||
if resp.status() != 200 {
|
||||
let err: Value = resp.json().await.expect("Failed to parse response");
|
||||
println!("Request failed: {}", err);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let resp: Response = resp.json().await.expect("Failed to parse response");
|
||||
|
||||
resp.predictions[0].content.clone()
|
||||
}
|
||||
}
|
||||
|
|
@ -12,9 +12,12 @@ pub struct TextGenerationOptions {
|
|||
#[builder(default = "1.0")]
|
||||
pub sampling_temperature: f32,
|
||||
|
||||
#[builder(default = "&EMPTY_STOP_WORDS")]
|
||||
pub stop_words: &'static Vec<&'static str>,
|
||||
}
|
||||
|
||||
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
|
||||
|
||||
#[async_trait]
|
||||
pub trait TextGeneration: Sync + Send {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
|
||||
|
|
|
|||
Loading…
Reference in New Issue