refactor: move http engine creation to its sub crates (#524)

r0.3
Meng Zhang 2023-10-09 10:37:04 -07:00 committed by GitHub
parent 41e48dc119
commit 0f8ee7f589
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 70 deletions

View File

@ -1,20 +0,0 @@
use std::env;
use http_api_bindings::vertex_ai::VertexAIEngine;
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
#[tokio::main]
async fn main() {
let api_endpoint = env::var("API_ENDPOINT").expect("API_ENDPOINT not set");
let authorization = env::var("AUTHORIZATION").expect("AUTHORIZATION not set");
let engine = VertexAIEngine::create(&api_endpoint, &authorization);
let options = TextGenerationOptionsBuilder::default()
.sampling_temperature(0.1)
.max_decoding_length(32)
.build()
.unwrap();
let prompt = "def fib(n)";
let text = engine.generate(prompt, options).await;
println!("{}{}", prompt, text);
}

View File

@ -1,2 +1,42 @@
pub mod fastchat;
pub mod vertex_ai;
mod fastchat;
mod vertex_ai;
use fastchat::FastChatEngine;
use serde_json::Value;
use tabby_inference::TextGeneration;
use vertex_ai::VertexAIEngine;
pub fn create(model: &str) -> (Box<dyn TextGeneration>, String) {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "vertex-ai" {
let api_endpoint = get_param(&params, "api_endpoint");
let authorization = get_param(&params, "authorization");
let engine = Box::new(VertexAIEngine::create(
api_endpoint.as_str(),
authorization.as_str(),
));
(engine, VertexAIEngine::prompt_template())
} else if kind == "fastchat" {
let model_name = get_param(&params, "model_name");
let api_endpoint = get_param(&params, "api_endpoint");
let authorization = get_param(&params, "authorization");
let engine = Box::new(FastChatEngine::create(
api_endpoint.as_str(),
model_name.as_str(),
authorization.as_str(),
));
(engine, FastChatEngine::prompt_template())
} else {
panic!("Only vertex_ai and fastchat are supported for http backend");
}
}
fn get_param(params: &Value, key: &str) -> String {
params
.get(key)
.unwrap_or_else(|| panic!("Missing {} field", key))
.as_str()
.expect("Type unmatched")
.to_string()
}

View File

@ -1,23 +1,12 @@
use std::path::Path;
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
use http_api_bindings::{fastchat::FastChatEngine, vertex_ai::VertexAIEngine};
use serde::Deserialize;
use serde_json::Value;
use tabby_common::path::ModelDir;
use tabby_inference::TextGeneration;
use crate::fatal;
fn get_param(params: &Value, key: &str) -> String {
params
.get(key)
.unwrap_or_else(|| panic!("Missing {} field", key))
.as_str()
.expect("Type unmatched")
.to_string()
}
pub fn create_engine(
model: &str,
args: &crate::serve::ServeArgs,
@ -34,43 +23,14 @@ pub fn create_engine(
},
)
} else {
let params: Value = serdeconv::from_json_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "vertex-ai" {
let api_endpoint = get_param(&params, "api_endpoint");
let authorization = get_param(&params, "authorization");
let engine = Box::new(VertexAIEngine::create(
api_endpoint.as_str(),
authorization.as_str(),
));
let (engine, prompt_template) = http_api_bindings::create(model);
(
engine,
EngineInfo {
prompt_template: Some(VertexAIEngine::prompt_template()),
prompt_template: Some(prompt_template),
chat_template: None,
},
)
} else if kind == "fastchat" {
let model_name = get_param(&params, "model_name");
let api_endpoint = get_param(&params, "api_endpoint");
let authorization = get_param(&params, "authorization");
let engine = Box::new(FastChatEngine::create(
api_endpoint.as_str(),
model_name.as_str(),
authorization.as_str(),
));
(
engine,
EngineInfo {
prompt_template: Some(FastChatEngine::prompt_template()),
chat_template: None,
},
)
} else {
fatal!("Only vertex_ai and fastchat are supported for http backend");
}
}
}