feat: add support fastchat http bindings (#421)
* feat: add support fastchat http bindings Signed-off-by: Lei Wen <wenlei03@qiyi.com> Co-authored-by: Lei Wen <wenlei03@qiyi.com>release-0.2
parent
491e295a48
commit
e3c4a77fff
|
|
@ -0,0 +1,90 @@
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use reqwest::header;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use tabby_inference::{TextGeneration, TextGenerationOptions};
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct Request {
|
||||||
|
model: String,
|
||||||
|
prompt: Vec<String>,
|
||||||
|
max_tokens: usize,
|
||||||
|
temperature: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Response {
|
||||||
|
choices: Vec<Prediction>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Prediction {
|
||||||
|
text: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FastChatEngine {
|
||||||
|
client: reqwest::Client,
|
||||||
|
api_endpoint: String,
|
||||||
|
model_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FastChatEngine {
|
||||||
|
pub fn create(api_endpoint: &str, model_name: &str, authorization: &str) -> Self {
|
||||||
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
|
if authorization.len() > 0 {
|
||||||
|
headers.insert(
|
||||||
|
"Authorization",
|
||||||
|
header::HeaderValue::from_str(authorization)
|
||||||
|
.expect("Failed to create authorization header"),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.default_headers(headers)
|
||||||
|
.build()
|
||||||
|
.expect("Failed to construct HTTP client");
|
||||||
|
Self {
|
||||||
|
api_endpoint: api_endpoint.to_owned(),
|
||||||
|
model_name: model_name.to_owned(),
|
||||||
|
client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn prompt_template() -> String {
|
||||||
|
"{prefix}<MID>{suffix}".to_owned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl TextGeneration for FastChatEngine {
|
||||||
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||||
|
let _stop_sequences: Vec<String> =
|
||||||
|
options.stop_words.iter().map(|x| x.to_string()).collect();
|
||||||
|
|
||||||
|
let tokens: Vec<&str> = prompt.split("<MID>").collect();
|
||||||
|
let request = Request {
|
||||||
|
model: self.model_name.to_owned(),
|
||||||
|
prompt: vec![tokens[0].to_owned()],
|
||||||
|
max_tokens: options.max_decoding_length,
|
||||||
|
temperature: options.sampling_temperature,
|
||||||
|
};
|
||||||
|
|
||||||
|
// API Documentation: https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.post(&self.api_endpoint)
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("Failed to making completion request");
|
||||||
|
|
||||||
|
if resp.status() != 200 {
|
||||||
|
let err: Value = resp.json().await.expect("Failed to parse response");
|
||||||
|
println!("Request failed: {}", err);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp: Response = resp.json().await.expect("Failed to parse response");
|
||||||
|
|
||||||
|
resp.choices[0].text[0].clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1 +1,2 @@
|
||||||
|
pub mod fastchat;
|
||||||
pub mod vertex_ai;
|
pub mod vertex_ai;
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ use std::{path::Path, sync::Arc};
|
||||||
|
|
||||||
use axum::{extract::State, Json};
|
use axum::{extract::State, Json};
|
||||||
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
|
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
|
||||||
use http_api_bindings::vertex_ai::VertexAIEngine;
|
use http_api_bindings::{fastchat::FastChatEngine, vertex_ai::VertexAIEngine};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
@ -142,6 +142,15 @@ impl CompletionState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_param(params: &Value, key: &str) -> String {
|
||||||
|
params
|
||||||
|
.get(key)
|
||||||
|
.expect(format!("Missing {} field", key).as_str())
|
||||||
|
.as_str()
|
||||||
|
.expect("Type unmatched")
|
||||||
|
.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Option<String>) {
|
fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Option<String>) {
|
||||||
if args.device != super::Device::ExperimentalHttp {
|
if args.device != super::Device::ExperimentalHttp {
|
||||||
let model_dir = get_model_dir(&args.model);
|
let model_dir = get_model_dir(&args.model);
|
||||||
|
|
@ -152,28 +161,29 @@ fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Op
|
||||||
let params: Value =
|
let params: Value =
|
||||||
serdeconv::from_json_str(&args.model).expect("Failed to parse model string");
|
serdeconv::from_json_str(&args.model).expect("Failed to parse model string");
|
||||||
|
|
||||||
let kind = params
|
let kind = get_param(¶ms, "kind");
|
||||||
.get("kind")
|
|
||||||
.expect("Missing kind field")
|
|
||||||
.as_str()
|
|
||||||
.expect("Type unmatched");
|
|
||||||
|
|
||||||
if kind != "vertex-ai" {
|
if kind == "vertex-ai" {
|
||||||
fatal!("Only vertex_ai is supported for http backend");
|
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||||
}
|
let authorization = get_param(¶ms, "authorization");
|
||||||
|
let engine = Box::new(VertexAIEngine::create(
|
||||||
let api_endpoint = params
|
api_endpoint.as_str(),
|
||||||
.get("api_endpoint")
|
authorization.as_str(),
|
||||||
.expect("Missing api_endpoint field")
|
));
|
||||||
.as_str()
|
|
||||||
.expect("Type unmatched");
|
|
||||||
let authorization = params
|
|
||||||
.get("authorization")
|
|
||||||
.expect("Missing authorization field")
|
|
||||||
.as_str()
|
|
||||||
.expect("Type unmatched");
|
|
||||||
let engine = Box::new(VertexAIEngine::create(api_endpoint, authorization));
|
|
||||||
(engine, Some(VertexAIEngine::prompt_template()))
|
(engine, Some(VertexAIEngine::prompt_template()))
|
||||||
|
} else if kind == "fastchat" {
|
||||||
|
let model_name = get_param(¶ms, "model_name");
|
||||||
|
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||||
|
let authorization = get_param(¶ms, "authorization");
|
||||||
|
let engine = Box::new(FastChatEngine::create(
|
||||||
|
api_endpoint.as_str(),
|
||||||
|
model_name.as_str(),
|
||||||
|
authorization.as_str(),
|
||||||
|
));
|
||||||
|
(engine, Some(FastChatEngine::prompt_template()))
|
||||||
|
} else {
|
||||||
|
fatal!("Only vertex_ai and fastchat are supported for http backend");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,7 @@ pub struct ServeArgs {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||||
fn should_download_ggml_files(device: &Device) -> bool {
|
fn should_download_ggml_files(_device: &Device) -> bool {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue