diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index d46a450..b99b175 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -18,15 +18,17 @@ fn get_param(params: &Value, key: &str) -> String { .to_string() } -pub fn create_engine(args: &crate::serve::ServeArgs) -> (Box, Option) { +pub fn create_engine( + model: &str, + args: &crate::serve::ServeArgs, +) -> (Box, Option) { if args.device != super::Device::ExperimentalHttp { - let model_dir = get_model_dir(&args.model); + let model_dir = get_model_dir(model); let metadata = read_metadata(&model_dir); let engine = create_local_engine(args, &model_dir, &metadata); (engine, metadata.prompt_template) } else { - let params: Value = - serdeconv::from_json_str(&args.model).expect("Failed to parse model string"); + let params: Value = serdeconv::from_json_str(model).expect("Failed to parse model string"); let kind = get_param(¶ms, "kind"); diff --git a/crates/tabby/src/serve/health.rs b/crates/tabby/src/serve/health.rs index 2ef5e6c..401255a 100644 --- a/crates/tabby/src/serve/health.rs +++ b/crates/tabby/src/serve/health.rs @@ -10,6 +10,8 @@ use utoipa::ToSchema; #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] pub struct HealthState { model: String, + #[serde(skip_serializing_if = "Option::is_none")] + instruct_model: Option, device: String, compute_type: String, arch: String, @@ -30,6 +32,7 @@ impl HealthState { Self { model: args.model.clone(), + instruct_model: args.instruct_model.clone(), device: args.device.to_string(), compute_type: args.compute_type.to_string(), arch: ARCH.to_string(), diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 9d7cace..2ce45dc 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -105,10 +105,15 @@ pub enum ComputeType { #[derive(Args)] pub struct ServeArgs { - /// Model id for serving. + /// Model id for `/completion` API endpoint. #[clap(long)] model: String, + /// Model id for `/generate` and `/generate_stream` API endpoints. + /// If not set, `model` will be loaded for the purpose. + #[clap(long)] + instruct_model: Option, + #[clap(long, default_value_t = 8080)] port: u16, @@ -142,16 +147,11 @@ fn should_download_ggml_files(device: &Device) -> bool { pub async fn main(config: &Config, args: &ServeArgs) { valid_args(args); - let downloader = Downloader::new(&args.model, /* prefer_local_file= */ true); if args.device != Device::ExperimentalHttp { - let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err,); - let download_result = if should_download_ggml_files(&args.device) { - downloader.download_ggml_files().await - } else { - downloader.download_ctranslate2_files().await - }; - - download_result.unwrap_or_else(handler); + download_model(&args.model, &args.device).await; + if let Some(instruct_model) = &args.instruct_model { + download_model(instruct_model, &args.device).await; + } } else { warn!("HTTP device is unstable and does not comply with semver expectations.") } @@ -177,8 +177,14 @@ pub async fn main(config: &Config, args: &ServeArgs) { } fn api_router(args: &ServeArgs, config: &Config) -> Router { - let (engine, prompt_template) = create_engine(args); + let (engine, prompt_template) = create_engine(&args.model, args); let engine = Arc::new(engine); + let instruct_engine = if let Some(instruct_model) = &args.instruct_model { + Arc::new(create_engine(instruct_model, args).0) + } else { + engine.clone() + }; + Router::new() .route("/v1/events", routing::post(events::log_event)) .route( @@ -193,13 +199,15 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router { ) .route( "/v1beta/generate", - routing::post(generate::generate) - .with_state(Arc::new(generate::GenerateState::new(engine.clone()))), + routing::post(generate::generate).with_state(Arc::new(generate::GenerateState::new( + instruct_engine.clone(), + ))), ) .route( "/v1beta/generate_stream", - routing::post(generate::generate_stream) - .with_state(Arc::new(generate::GenerateState::new(engine.clone()))), + routing::post(generate::generate_stream).with_state(Arc::new( + generate::GenerateState::new(instruct_engine.clone()), + )), ) .layer(CorsLayer::permissive()) .layer(opentelemetry_tracing_layer()) @@ -272,3 +280,15 @@ fn add_proxy_server( doc } + +async fn download_model(model: &str, device: &Device) { + let downloader = Downloader::new(model, /* prefer_local_file= */ true); + let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", model, err,); + let download_result = if should_download_ggml_files(device) { + downloader.download_ggml_files().await + } else { + downloader.download_ctranslate2_files().await + }; + + download_result.unwrap_or_else(handler); +}