feat: add chat_template field in tabby.json

release-0.2
Meng Zhang 2023-10-03 11:46:05 -07:00
parent 7fc76228f7
commit 0e5128e8fb
3 changed files with 43 additions and 11 deletions

View File

@ -21,10 +21,10 @@ pub struct ChatState {
}
impl ChatState {
pub fn new(engine: Arc<Box<dyn TextGeneration>>, prompt_template: String) -> Self {
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
Self {
engine,
prompt_builder: ChatPromptBuilder::new(prompt_template),
prompt_builder: ChatPromptBuilder::new(chat_template),
}
}
}

View File

@ -21,12 +21,18 @@ fn get_param(params: &Value, key: &str) -> String {
pub fn create_engine(
model: &str,
args: &crate::serve::ServeArgs,
) -> (Box<dyn TextGeneration>, Option<String>) {
) -> (Box<dyn TextGeneration>, EngineInfo) {
if args.device != super::Device::ExperimentalHttp {
let model_dir = get_model_dir(model);
let metadata = read_metadata(&model_dir);
let engine = create_local_engine(args, &model_dir, &metadata);
(engine, metadata.prompt_template)
(
engine,
EngineInfo {
prompt_template: metadata.prompt_template,
chat_template: metadata.chat_template,
},
)
} else {
let params: Value = serdeconv::from_json_str(model).expect("Failed to parse model string");
@ -39,7 +45,13 @@ pub fn create_engine(
api_endpoint.as_str(),
authorization.as_str(),
));
(engine, Some(VertexAIEngine::prompt_template()))
(
engine,
EngineInfo {
prompt_template: Some(VertexAIEngine::prompt_template()),
chat_template: None,
},
)
} else if kind == "fastchat" {
let model_name = get_param(&params, "model_name");
let api_endpoint = get_param(&params, "api_endpoint");
@ -49,13 +61,24 @@ pub fn create_engine(
model_name.as_str(),
authorization.as_str(),
));
(engine, Some(FastChatEngine::prompt_template()))
(
engine,
EngineInfo {
prompt_template: Some(FastChatEngine::prompt_template()),
chat_template: None,
},
)
} else {
fatal!("Only vertex_ai and fastchat are supported for http backend");
}
}
}
pub struct EngineInfo {
pub prompt_template: Option<String>,
pub chat_template: Option<String>,
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn create_local_engine(
args: &crate::serve::ServeArgs,
@ -121,6 +144,7 @@ fn get_model_dir(model: &str) -> ModelDir {
struct Metadata {
auto_model: String,
prompt_template: Option<String>,
chat_template: Option<String>,
}
fn read_metadata(model_dir: &ModelDir) -> Metadata {

View File

@ -25,7 +25,10 @@ use tracing::{info, warn};
use utoipa::{openapi::ServerBuilder, OpenApi};
use utoipa_swagger_ui::SwaggerUi;
use self::{engine::create_engine, health::HealthState};
use self::{
engine::{create_engine, EngineInfo},
health::HealthState,
};
use crate::fatal;
#[derive(OpenApi)]
@ -188,19 +191,24 @@ pub async fn main(config: &Config, args: &ServeArgs) {
fn api_router(args: &ServeArgs, config: &Config) -> Router {
let completion_state = {
let (engine, prompt_template) = create_engine(&args.model, args);
let (
engine,
EngineInfo {
prompt_template, ..
},
) = create_engine(&args.model, args);
let engine = Arc::new(engine);
let state = completions::CompletionState::new(engine.clone(), prompt_template, config);
Arc::new(state)
};
let chat_state = if let Some(chat_model) = &args.chat_model {
let (engine, prompt_template) = create_engine(chat_model, args);
let Some(prompt_template) = prompt_template else {
let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args);
let Some(chat_template) = chat_template else {
panic!("Chat model requires specifying prompt template");
};
let engine = Arc::new(engine);
let state = chat::ChatState::new(engine, prompt_template);
let state = chat::ChatState::new(engine, chat_template);
Some(Arc::new(state))
} else {
None