refactor: move http engine creation to its sub crates (#524)
parent
41e48dc119
commit
0f8ee7f589
|
|
@ -1,20 +0,0 @@
|
||||||
use std::env;
|
|
||||||
|
|
||||||
use http_api_bindings::vertex_ai::VertexAIEngine;
|
|
||||||
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() {
|
|
||||||
let api_endpoint = env::var("API_ENDPOINT").expect("API_ENDPOINT not set");
|
|
||||||
let authorization = env::var("AUTHORIZATION").expect("AUTHORIZATION not set");
|
|
||||||
let engine = VertexAIEngine::create(&api_endpoint, &authorization);
|
|
||||||
|
|
||||||
let options = TextGenerationOptionsBuilder::default()
|
|
||||||
.sampling_temperature(0.1)
|
|
||||||
.max_decoding_length(32)
|
|
||||||
.build()
|
|
||||||
.unwrap();
|
|
||||||
let prompt = "def fib(n)";
|
|
||||||
let text = engine.generate(prompt, options).await;
|
|
||||||
println!("{}{}", prompt, text);
|
|
||||||
}
|
|
||||||
|
|
@ -1,2 +1,42 @@
|
||||||
pub mod fastchat;
|
mod fastchat;
|
||||||
pub mod vertex_ai;
|
mod vertex_ai;
|
||||||
|
|
||||||
|
use fastchat::FastChatEngine;
|
||||||
|
use serde_json::Value;
|
||||||
|
use tabby_inference::TextGeneration;
|
||||||
|
use vertex_ai::VertexAIEngine;
|
||||||
|
|
||||||
|
pub fn create(model: &str) -> (Box<dyn TextGeneration>, String) {
|
||||||
|
let params = serde_json::from_str(model).expect("Failed to parse model string");
|
||||||
|
let kind = get_param(¶ms, "kind");
|
||||||
|
if kind == "vertex-ai" {
|
||||||
|
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||||
|
let authorization = get_param(¶ms, "authorization");
|
||||||
|
let engine = Box::new(VertexAIEngine::create(
|
||||||
|
api_endpoint.as_str(),
|
||||||
|
authorization.as_str(),
|
||||||
|
));
|
||||||
|
(engine, VertexAIEngine::prompt_template())
|
||||||
|
} else if kind == "fastchat" {
|
||||||
|
let model_name = get_param(¶ms, "model_name");
|
||||||
|
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||||
|
let authorization = get_param(¶ms, "authorization");
|
||||||
|
let engine = Box::new(FastChatEngine::create(
|
||||||
|
api_endpoint.as_str(),
|
||||||
|
model_name.as_str(),
|
||||||
|
authorization.as_str(),
|
||||||
|
));
|
||||||
|
(engine, FastChatEngine::prompt_template())
|
||||||
|
} else {
|
||||||
|
panic!("Only vertex_ai and fastchat are supported for http backend");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_param(params: &Value, key: &str) -> String {
|
||||||
|
params
|
||||||
|
.get(key)
|
||||||
|
.unwrap_or_else(|| panic!("Missing {} field", key))
|
||||||
|
.as_str()
|
||||||
|
.expect("Type unmatched")
|
||||||
|
.to_string()
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,12 @@
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
|
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
|
||||||
use http_api_bindings::{fastchat::FastChatEngine, vertex_ai::VertexAIEngine};
|
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::Value;
|
|
||||||
use tabby_common::path::ModelDir;
|
use tabby_common::path::ModelDir;
|
||||||
use tabby_inference::TextGeneration;
|
use tabby_inference::TextGeneration;
|
||||||
|
|
||||||
use crate::fatal;
|
use crate::fatal;
|
||||||
|
|
||||||
fn get_param(params: &Value, key: &str) -> String {
|
|
||||||
params
|
|
||||||
.get(key)
|
|
||||||
.unwrap_or_else(|| panic!("Missing {} field", key))
|
|
||||||
.as_str()
|
|
||||||
.expect("Type unmatched")
|
|
||||||
.to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn create_engine(
|
pub fn create_engine(
|
||||||
model: &str,
|
model: &str,
|
||||||
args: &crate::serve::ServeArgs,
|
args: &crate::serve::ServeArgs,
|
||||||
|
|
@ -34,43 +23,14 @@ pub fn create_engine(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
let params: Value = serdeconv::from_json_str(model).expect("Failed to parse model string");
|
let (engine, prompt_template) = http_api_bindings::create(model);
|
||||||
|
(
|
||||||
let kind = get_param(¶ms, "kind");
|
engine,
|
||||||
|
EngineInfo {
|
||||||
if kind == "vertex-ai" {
|
prompt_template: Some(prompt_template),
|
||||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
chat_template: None,
|
||||||
let authorization = get_param(¶ms, "authorization");
|
},
|
||||||
let engine = Box::new(VertexAIEngine::create(
|
)
|
||||||
api_endpoint.as_str(),
|
|
||||||
authorization.as_str(),
|
|
||||||
));
|
|
||||||
(
|
|
||||||
engine,
|
|
||||||
EngineInfo {
|
|
||||||
prompt_template: Some(VertexAIEngine::prompt_template()),
|
|
||||||
chat_template: None,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else if kind == "fastchat" {
|
|
||||||
let model_name = get_param(¶ms, "model_name");
|
|
||||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
|
||||||
let authorization = get_param(¶ms, "authorization");
|
|
||||||
let engine = Box::new(FastChatEngine::create(
|
|
||||||
api_endpoint.as_str(),
|
|
||||||
model_name.as_str(),
|
|
||||||
authorization.as_str(),
|
|
||||||
));
|
|
||||||
(
|
|
||||||
engine,
|
|
||||||
EngineInfo {
|
|
||||||
prompt_template: Some(FastChatEngine::prompt_template()),
|
|
||||||
chat_template: None,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
fatal!("Only vertex_ai and fastchat are supported for http backend");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue