feat: add chat_template field in tabby.json
parent
7fc76228f7
commit
0e5128e8fb
|
|
@ -21,10 +21,10 @@ pub struct ChatState {
|
|||
}
|
||||
|
||||
impl ChatState {
|
||||
pub fn new(engine: Arc<Box<dyn TextGeneration>>, prompt_template: String) -> Self {
|
||||
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
|
||||
Self {
|
||||
engine,
|
||||
prompt_builder: ChatPromptBuilder::new(prompt_template),
|
||||
prompt_builder: ChatPromptBuilder::new(chat_template),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,12 +21,18 @@ fn get_param(params: &Value, key: &str) -> String {
|
|||
pub fn create_engine(
|
||||
model: &str,
|
||||
args: &crate::serve::ServeArgs,
|
||||
) -> (Box<dyn TextGeneration>, Option<String>) {
|
||||
) -> (Box<dyn TextGeneration>, EngineInfo) {
|
||||
if args.device != super::Device::ExperimentalHttp {
|
||||
let model_dir = get_model_dir(model);
|
||||
let metadata = read_metadata(&model_dir);
|
||||
let engine = create_local_engine(args, &model_dir, &metadata);
|
||||
(engine, metadata.prompt_template)
|
||||
(
|
||||
engine,
|
||||
EngineInfo {
|
||||
prompt_template: metadata.prompt_template,
|
||||
chat_template: metadata.chat_template,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
let params: Value = serdeconv::from_json_str(model).expect("Failed to parse model string");
|
||||
|
||||
|
|
@ -39,7 +45,13 @@ pub fn create_engine(
|
|||
api_endpoint.as_str(),
|
||||
authorization.as_str(),
|
||||
));
|
||||
(engine, Some(VertexAIEngine::prompt_template()))
|
||||
(
|
||||
engine,
|
||||
EngineInfo {
|
||||
prompt_template: Some(VertexAIEngine::prompt_template()),
|
||||
chat_template: None,
|
||||
},
|
||||
)
|
||||
} else if kind == "fastchat" {
|
||||
let model_name = get_param(¶ms, "model_name");
|
||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||
|
|
@ -49,13 +61,24 @@ pub fn create_engine(
|
|||
model_name.as_str(),
|
||||
authorization.as_str(),
|
||||
));
|
||||
(engine, Some(FastChatEngine::prompt_template()))
|
||||
(
|
||||
engine,
|
||||
EngineInfo {
|
||||
prompt_template: Some(FastChatEngine::prompt_template()),
|
||||
chat_template: None,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
fatal!("Only vertex_ai and fastchat are supported for http backend");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EngineInfo {
|
||||
pub prompt_template: Option<String>,
|
||||
pub chat_template: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||
fn create_local_engine(
|
||||
args: &crate::serve::ServeArgs,
|
||||
|
|
@ -121,6 +144,7 @@ fn get_model_dir(model: &str) -> ModelDir {
|
|||
struct Metadata {
|
||||
auto_model: String,
|
||||
prompt_template: Option<String>,
|
||||
chat_template: Option<String>,
|
||||
}
|
||||
|
||||
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,10 @@ use tracing::{info, warn};
|
|||
use utoipa::{openapi::ServerBuilder, OpenApi};
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
use self::{engine::create_engine, health::HealthState};
|
||||
use self::{
|
||||
engine::{create_engine, EngineInfo},
|
||||
health::HealthState,
|
||||
};
|
||||
use crate::fatal;
|
||||
|
||||
#[derive(OpenApi)]
|
||||
|
|
@ -188,19 +191,24 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
|||
|
||||
fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||
let completion_state = {
|
||||
let (engine, prompt_template) = create_engine(&args.model, args);
|
||||
let (
|
||||
engine,
|
||||
EngineInfo {
|
||||
prompt_template, ..
|
||||
},
|
||||
) = create_engine(&args.model, args);
|
||||
let engine = Arc::new(engine);
|
||||
let state = completions::CompletionState::new(engine.clone(), prompt_template, config);
|
||||
Arc::new(state)
|
||||
};
|
||||
|
||||
let chat_state = if let Some(chat_model) = &args.chat_model {
|
||||
let (engine, prompt_template) = create_engine(chat_model, args);
|
||||
let Some(prompt_template) = prompt_template else {
|
||||
let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args);
|
||||
let Some(chat_template) = chat_template else {
|
||||
panic!("Chat model requires specifying prompt template");
|
||||
};
|
||||
let engine = Arc::new(engine);
|
||||
let state = chat::ChatState::new(engine, prompt_template);
|
||||
let state = chat::ChatState::new(engine, chat_template);
|
||||
Some(Arc::new(state))
|
||||
} else {
|
||||
None
|
||||
|
|
|
|||
Loading…
Reference in New Issue