feat: add chat_template field in tabby.json
parent
7fc76228f7
commit
0e5128e8fb
|
|
@ -21,10 +21,10 @@ pub struct ChatState {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatState {
|
impl ChatState {
|
||||||
pub fn new(engine: Arc<Box<dyn TextGeneration>>, prompt_template: String) -> Self {
|
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
|
||||||
Self {
|
Self {
|
||||||
engine,
|
engine,
|
||||||
prompt_builder: ChatPromptBuilder::new(prompt_template),
|
prompt_builder: ChatPromptBuilder::new(chat_template),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,12 +21,18 @@ fn get_param(params: &Value, key: &str) -> String {
|
||||||
pub fn create_engine(
|
pub fn create_engine(
|
||||||
model: &str,
|
model: &str,
|
||||||
args: &crate::serve::ServeArgs,
|
args: &crate::serve::ServeArgs,
|
||||||
) -> (Box<dyn TextGeneration>, Option<String>) {
|
) -> (Box<dyn TextGeneration>, EngineInfo) {
|
||||||
if args.device != super::Device::ExperimentalHttp {
|
if args.device != super::Device::ExperimentalHttp {
|
||||||
let model_dir = get_model_dir(model);
|
let model_dir = get_model_dir(model);
|
||||||
let metadata = read_metadata(&model_dir);
|
let metadata = read_metadata(&model_dir);
|
||||||
let engine = create_local_engine(args, &model_dir, &metadata);
|
let engine = create_local_engine(args, &model_dir, &metadata);
|
||||||
(engine, metadata.prompt_template)
|
(
|
||||||
|
engine,
|
||||||
|
EngineInfo {
|
||||||
|
prompt_template: metadata.prompt_template,
|
||||||
|
chat_template: metadata.chat_template,
|
||||||
|
},
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
let params: Value = serdeconv::from_json_str(model).expect("Failed to parse model string");
|
let params: Value = serdeconv::from_json_str(model).expect("Failed to parse model string");
|
||||||
|
|
||||||
|
|
@ -39,7 +45,13 @@ pub fn create_engine(
|
||||||
api_endpoint.as_str(),
|
api_endpoint.as_str(),
|
||||||
authorization.as_str(),
|
authorization.as_str(),
|
||||||
));
|
));
|
||||||
(engine, Some(VertexAIEngine::prompt_template()))
|
(
|
||||||
|
engine,
|
||||||
|
EngineInfo {
|
||||||
|
prompt_template: Some(VertexAIEngine::prompt_template()),
|
||||||
|
chat_template: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
} else if kind == "fastchat" {
|
} else if kind == "fastchat" {
|
||||||
let model_name = get_param(¶ms, "model_name");
|
let model_name = get_param(¶ms, "model_name");
|
||||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||||
|
|
@ -49,13 +61,24 @@ pub fn create_engine(
|
||||||
model_name.as_str(),
|
model_name.as_str(),
|
||||||
authorization.as_str(),
|
authorization.as_str(),
|
||||||
));
|
));
|
||||||
(engine, Some(FastChatEngine::prompt_template()))
|
(
|
||||||
|
engine,
|
||||||
|
EngineInfo {
|
||||||
|
prompt_template: Some(FastChatEngine::prompt_template()),
|
||||||
|
chat_template: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
fatal!("Only vertex_ai and fastchat are supported for http backend");
|
fatal!("Only vertex_ai and fastchat are supported for http backend");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct EngineInfo {
|
||||||
|
pub prompt_template: Option<String>,
|
||||||
|
pub chat_template: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||||
fn create_local_engine(
|
fn create_local_engine(
|
||||||
args: &crate::serve::ServeArgs,
|
args: &crate::serve::ServeArgs,
|
||||||
|
|
@ -121,6 +144,7 @@ fn get_model_dir(model: &str) -> ModelDir {
|
||||||
struct Metadata {
|
struct Metadata {
|
||||||
auto_model: String,
|
auto_model: String,
|
||||||
prompt_template: Option<String>,
|
prompt_template: Option<String>,
|
||||||
|
chat_template: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,10 @@ use tracing::{info, warn};
|
||||||
use utoipa::{openapi::ServerBuilder, OpenApi};
|
use utoipa::{openapi::ServerBuilder, OpenApi};
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
use self::{engine::create_engine, health::HealthState};
|
use self::{
|
||||||
|
engine::{create_engine, EngineInfo},
|
||||||
|
health::HealthState,
|
||||||
|
};
|
||||||
use crate::fatal;
|
use crate::fatal;
|
||||||
|
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
|
|
@ -188,19 +191,24 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
||||||
|
|
||||||
fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
let completion_state = {
|
let completion_state = {
|
||||||
let (engine, prompt_template) = create_engine(&args.model, args);
|
let (
|
||||||
|
engine,
|
||||||
|
EngineInfo {
|
||||||
|
prompt_template, ..
|
||||||
|
},
|
||||||
|
) = create_engine(&args.model, args);
|
||||||
let engine = Arc::new(engine);
|
let engine = Arc::new(engine);
|
||||||
let state = completions::CompletionState::new(engine.clone(), prompt_template, config);
|
let state = completions::CompletionState::new(engine.clone(), prompt_template, config);
|
||||||
Arc::new(state)
|
Arc::new(state)
|
||||||
};
|
};
|
||||||
|
|
||||||
let chat_state = if let Some(chat_model) = &args.chat_model {
|
let chat_state = if let Some(chat_model) = &args.chat_model {
|
||||||
let (engine, prompt_template) = create_engine(chat_model, args);
|
let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args);
|
||||||
let Some(prompt_template) = prompt_template else {
|
let Some(chat_template) = chat_template else {
|
||||||
panic!("Chat model requires specifying prompt template");
|
panic!("Chat model requires specifying prompt template");
|
||||||
};
|
};
|
||||||
let engine = Arc::new(engine);
|
let engine = Arc::new(engine);
|
||||||
let state = chat::ChatState::new(engine, prompt_template);
|
let state = chat::ChatState::new(engine, chat_template);
|
||||||
Some(Arc::new(state))
|
Some(Arc::new(state))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue