From e3c4a77fff5ed1aab4a42d3968e2e546549a53af Mon Sep 17 00:00:00 2001 From: leiwen83 Date: Sun, 10 Sep 2023 22:17:58 +0800 Subject: [PATCH] feat: add support fastchat http bindings (#421) * feat: add support fastchat http bindings Signed-off-by: Lei Wen Co-authored-by: Lei Wen --- crates/http-api-bindings/src/fastchat.rs | 90 ++++++++++++++++++++++++ crates/http-api-bindings/src/lib.rs | 1 + crates/tabby/src/serve/completions.rs | 52 ++++++++------ crates/tabby/src/serve/mod.rs | 2 +- 4 files changed, 123 insertions(+), 22 deletions(-) create mode 100644 crates/http-api-bindings/src/fastchat.rs diff --git a/crates/http-api-bindings/src/fastchat.rs b/crates/http-api-bindings/src/fastchat.rs new file mode 100644 index 0000000..a577b90 --- /dev/null +++ b/crates/http-api-bindings/src/fastchat.rs @@ -0,0 +1,90 @@ +use async_trait::async_trait; +use reqwest::header; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tabby_inference::{TextGeneration, TextGenerationOptions}; + +#[derive(Serialize)] +struct Request { + model: String, + prompt: Vec, + max_tokens: usize, + temperature: f32, +} + +#[derive(Deserialize)] +struct Response { + choices: Vec, +} + +#[derive(Deserialize)] +struct Prediction { + text: Vec, +} + +pub struct FastChatEngine { + client: reqwest::Client, + api_endpoint: String, + model_name: String, +} + +impl FastChatEngine { + pub fn create(api_endpoint: &str, model_name: &str, authorization: &str) -> Self { + let mut headers = reqwest::header::HeaderMap::new(); + if authorization.len() > 0 { + headers.insert( + "Authorization", + header::HeaderValue::from_str(authorization) + .expect("Failed to create authorization header"), + ); + } + let client = reqwest::Client::builder() + .default_headers(headers) + .build() + .expect("Failed to construct HTTP client"); + Self { + api_endpoint: api_endpoint.to_owned(), + model_name: model_name.to_owned(), + client, + } + } + + pub fn prompt_template() -> String { + "{prefix}{suffix}".to_owned() + } +} + +#[async_trait] +impl TextGeneration for FastChatEngine { + async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { + let _stop_sequences: Vec = + options.stop_words.iter().map(|x| x.to_string()).collect(); + + let tokens: Vec<&str> = prompt.split("").collect(); + let request = Request { + model: self.model_name.to_owned(), + prompt: vec![tokens[0].to_owned()], + max_tokens: options.max_decoding_length, + temperature: options.sampling_temperature, + }; + + // API Documentation: https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md + let resp = self + .client + .post(&self.api_endpoint) + .json(&request) + .send() + .await + .expect("Failed to making completion request"); + + if resp.status() != 200 { + let err: Value = resp.json().await.expect("Failed to parse response"); + println!("Request failed: {}", err); + std::process::exit(1); + } + + let resp: Response = resp.json().await.expect("Failed to parse response"); + + resp.choices[0].text[0].clone() + } +} diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs index 63d4c4f..f28bac6 100644 --- a/crates/http-api-bindings/src/lib.rs +++ b/crates/http-api-bindings/src/lib.rs @@ -1 +1,2 @@ +pub mod fastchat; pub mod vertex_ai; diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 10e558f..8b345f9 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -5,7 +5,7 @@ use std::{path::Path, sync::Arc}; use axum::{extract::State, Json}; use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder}; -use http_api_bindings::vertex_ai::VertexAIEngine; +use http_api_bindings::{fastchat::FastChatEngine, vertex_ai::VertexAIEngine}; use hyper::StatusCode; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -142,6 +142,15 @@ impl CompletionState { } } +fn get_param(params: &Value, key: &str) -> String { + params + .get(key) + .expect(format!("Missing {} field", key).as_str()) + .as_str() + .expect("Type unmatched") + .to_string() +} + fn create_engine(args: &crate::serve::ServeArgs) -> (Box, Option) { if args.device != super::Device::ExperimentalHttp { let model_dir = get_model_dir(&args.model); @@ -152,28 +161,29 @@ fn create_engine(args: &crate::serve::ServeArgs) -> (Box, Op let params: Value = serdeconv::from_json_str(&args.model).expect("Failed to parse model string"); - let kind = params - .get("kind") - .expect("Missing kind field") - .as_str() - .expect("Type unmatched"); + let kind = get_param(¶ms, "kind"); - if kind != "vertex-ai" { - fatal!("Only vertex_ai is supported for http backend"); + if kind == "vertex-ai" { + let api_endpoint = get_param(¶ms, "api_endpoint"); + let authorization = get_param(¶ms, "authorization"); + let engine = Box::new(VertexAIEngine::create( + api_endpoint.as_str(), + authorization.as_str(), + )); + (engine, Some(VertexAIEngine::prompt_template())) + } else if kind == "fastchat" { + let model_name = get_param(¶ms, "model_name"); + let api_endpoint = get_param(¶ms, "api_endpoint"); + let authorization = get_param(¶ms, "authorization"); + let engine = Box::new(FastChatEngine::create( + api_endpoint.as_str(), + model_name.as_str(), + authorization.as_str(), + )); + (engine, Some(FastChatEngine::prompt_template())) + } else { + fatal!("Only vertex_ai and fastchat are supported for http backend"); } - - let api_endpoint = params - .get("api_endpoint") - .expect("Missing api_endpoint field") - .as_str() - .expect("Type unmatched"); - let authorization = params - .get("authorization") - .expect("Missing authorization field") - .as_str() - .expect("Type unmatched"); - let engine = Box::new(VertexAIEngine::create(api_endpoint, authorization)); - (engine, Some(VertexAIEngine::prompt_template())) } } diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 2e0535d..64625ee 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -120,7 +120,7 @@ pub struct ServeArgs { } #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] -fn should_download_ggml_files(device: &Device) -> bool { +fn should_download_ggml_files(_device: &Device) -> bool { false }