feat: add chat_template field in tabby.json

release-0.2
Meng Zhang 2023-10-03 11:46:05 -07:00
parent 7fc76228f7
commit 0e5128e8fb
3 changed files with 43 additions and 11 deletions

View File

@ -21,10 +21,10 @@ pub struct ChatState {
} }
impl ChatState { impl ChatState {
pub fn new(engine: Arc<Box<dyn TextGeneration>>, prompt_template: String) -> Self { pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
Self { Self {
engine, engine,
prompt_builder: ChatPromptBuilder::new(prompt_template), prompt_builder: ChatPromptBuilder::new(chat_template),
} }
} }
} }

View File

@ -21,12 +21,18 @@ fn get_param(params: &Value, key: &str) -> String {
pub fn create_engine( pub fn create_engine(
model: &str, model: &str,
args: &crate::serve::ServeArgs, args: &crate::serve::ServeArgs,
) -> (Box<dyn TextGeneration>, Option<String>) { ) -> (Box<dyn TextGeneration>, EngineInfo) {
if args.device != super::Device::ExperimentalHttp { if args.device != super::Device::ExperimentalHttp {
let model_dir = get_model_dir(model); let model_dir = get_model_dir(model);
let metadata = read_metadata(&model_dir); let metadata = read_metadata(&model_dir);
let engine = create_local_engine(args, &model_dir, &metadata); let engine = create_local_engine(args, &model_dir, &metadata);
(engine, metadata.prompt_template) (
engine,
EngineInfo {
prompt_template: metadata.prompt_template,
chat_template: metadata.chat_template,
},
)
} else { } else {
let params: Value = serdeconv::from_json_str(model).expect("Failed to parse model string"); let params: Value = serdeconv::from_json_str(model).expect("Failed to parse model string");
@ -39,7 +45,13 @@ pub fn create_engine(
api_endpoint.as_str(), api_endpoint.as_str(),
authorization.as_str(), authorization.as_str(),
)); ));
(engine, Some(VertexAIEngine::prompt_template())) (
engine,
EngineInfo {
prompt_template: Some(VertexAIEngine::prompt_template()),
chat_template: None,
},
)
} else if kind == "fastchat" { } else if kind == "fastchat" {
let model_name = get_param(&params, "model_name"); let model_name = get_param(&params, "model_name");
let api_endpoint = get_param(&params, "api_endpoint"); let api_endpoint = get_param(&params, "api_endpoint");
@ -49,13 +61,24 @@ pub fn create_engine(
model_name.as_str(), model_name.as_str(),
authorization.as_str(), authorization.as_str(),
)); ));
(engine, Some(FastChatEngine::prompt_template())) (
engine,
EngineInfo {
prompt_template: Some(FastChatEngine::prompt_template()),
chat_template: None,
},
)
} else { } else {
fatal!("Only vertex_ai and fastchat are supported for http backend"); fatal!("Only vertex_ai and fastchat are supported for http backend");
} }
} }
} }
pub struct EngineInfo {
pub prompt_template: Option<String>,
pub chat_template: Option<String>,
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn create_local_engine( fn create_local_engine(
args: &crate::serve::ServeArgs, args: &crate::serve::ServeArgs,
@ -121,6 +144,7 @@ fn get_model_dir(model: &str) -> ModelDir {
struct Metadata { struct Metadata {
auto_model: String, auto_model: String,
prompt_template: Option<String>, prompt_template: Option<String>,
chat_template: Option<String>,
} }
fn read_metadata(model_dir: &ModelDir) -> Metadata { fn read_metadata(model_dir: &ModelDir) -> Metadata {

View File

@ -25,7 +25,10 @@ use tracing::{info, warn};
use utoipa::{openapi::ServerBuilder, OpenApi}; use utoipa::{openapi::ServerBuilder, OpenApi};
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
use self::{engine::create_engine, health::HealthState}; use self::{
engine::{create_engine, EngineInfo},
health::HealthState,
};
use crate::fatal; use crate::fatal;
#[derive(OpenApi)] #[derive(OpenApi)]
@ -188,19 +191,24 @@ pub async fn main(config: &Config, args: &ServeArgs) {
fn api_router(args: &ServeArgs, config: &Config) -> Router { fn api_router(args: &ServeArgs, config: &Config) -> Router {
let completion_state = { let completion_state = {
let (engine, prompt_template) = create_engine(&args.model, args); let (
engine,
EngineInfo {
prompt_template, ..
},
) = create_engine(&args.model, args);
let engine = Arc::new(engine); let engine = Arc::new(engine);
let state = completions::CompletionState::new(engine.clone(), prompt_template, config); let state = completions::CompletionState::new(engine.clone(), prompt_template, config);
Arc::new(state) Arc::new(state)
}; };
let chat_state = if let Some(chat_model) = &args.chat_model { let chat_state = if let Some(chat_model) = &args.chat_model {
let (engine, prompt_template) = create_engine(chat_model, args); let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args);
let Some(prompt_template) = prompt_template else { let Some(chat_template) = chat_template else {
panic!("Chat model requires specifying prompt template"); panic!("Chat model requires specifying prompt template");
}; };
let engine = Arc::new(engine); let engine = Arc::new(engine);
let state = chat::ChatState::new(engine, prompt_template); let state = chat::ChatState::new(engine, chat_template);
Some(Arc::new(state)) Some(Arc::new(state))
} else { } else {
None None