diff --git a/crates/tabby/src/serve/chat.rs b/crates/tabby/src/serve/chat.rs index c6a7ea0..9bd98d1 100644 --- a/crates/tabby/src/serve/chat.rs +++ b/crates/tabby/src/serve/chat.rs @@ -21,10 +21,10 @@ pub struct ChatState { } impl ChatState { - pub fn new(engine: Arc>, prompt_template: String) -> Self { + pub fn new(engine: Arc>, chat_template: String) -> Self { Self { engine, - prompt_builder: ChatPromptBuilder::new(prompt_template), + prompt_builder: ChatPromptBuilder::new(chat_template), } } } diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index b99b175..665e2f6 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -21,12 +21,18 @@ fn get_param(params: &Value, key: &str) -> String { pub fn create_engine( model: &str, args: &crate::serve::ServeArgs, -) -> (Box, Option) { +) -> (Box, EngineInfo) { if args.device != super::Device::ExperimentalHttp { let model_dir = get_model_dir(model); let metadata = read_metadata(&model_dir); let engine = create_local_engine(args, &model_dir, &metadata); - (engine, metadata.prompt_template) + ( + engine, + EngineInfo { + prompt_template: metadata.prompt_template, + chat_template: metadata.chat_template, + }, + ) } else { let params: Value = serdeconv::from_json_str(model).expect("Failed to parse model string"); @@ -39,7 +45,13 @@ pub fn create_engine( api_endpoint.as_str(), authorization.as_str(), )); - (engine, Some(VertexAIEngine::prompt_template())) + ( + engine, + EngineInfo { + prompt_template: Some(VertexAIEngine::prompt_template()), + chat_template: None, + }, + ) } else if kind == "fastchat" { let model_name = get_param(¶ms, "model_name"); let api_endpoint = get_param(¶ms, "api_endpoint"); @@ -49,13 +61,24 @@ pub fn create_engine( model_name.as_str(), authorization.as_str(), )); - (engine, Some(FastChatEngine::prompt_template())) + ( + engine, + EngineInfo { + prompt_template: Some(FastChatEngine::prompt_template()), + chat_template: None, + }, + ) } else { fatal!("Only vertex_ai and fastchat are supported for http backend"); } } } +pub struct EngineInfo { + pub prompt_template: Option, + pub chat_template: Option, +} + #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] fn create_local_engine( args: &crate::serve::ServeArgs, @@ -121,6 +144,7 @@ fn get_model_dir(model: &str) -> ModelDir { struct Metadata { auto_model: String, prompt_template: Option, + chat_template: Option, } fn read_metadata(model_dir: &ModelDir) -> Metadata { diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index f975875..40be980 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -25,7 +25,10 @@ use tracing::{info, warn}; use utoipa::{openapi::ServerBuilder, OpenApi}; use utoipa_swagger_ui::SwaggerUi; -use self::{engine::create_engine, health::HealthState}; +use self::{ + engine::{create_engine, EngineInfo}, + health::HealthState, +}; use crate::fatal; #[derive(OpenApi)] @@ -188,19 +191,24 @@ pub async fn main(config: &Config, args: &ServeArgs) { fn api_router(args: &ServeArgs, config: &Config) -> Router { let completion_state = { - let (engine, prompt_template) = create_engine(&args.model, args); + let ( + engine, + EngineInfo { + prompt_template, .. + }, + ) = create_engine(&args.model, args); let engine = Arc::new(engine); let state = completions::CompletionState::new(engine.clone(), prompt_template, config); Arc::new(state) }; let chat_state = if let Some(chat_model) = &args.chat_model { - let (engine, prompt_template) = create_engine(chat_model, args); - let Some(prompt_template) = prompt_template else { + let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args); + let Some(chat_template) = chat_template else { panic!("Chat model requires specifying prompt template"); }; let engine = Arc::new(engine); - let state = chat::ChatState::new(engine, prompt_template); + let state = chat::ChatState::new(engine, chat_template); Some(Arc::new(state)) } else { None