diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 6c4c184..2a5d3ae 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -1,6 +1,5 @@ mod chat; mod completions; -mod engine; mod events; mod health; mod search; @@ -24,14 +23,15 @@ use tracing::info; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; -use self::{ - engine::{create_engine, EngineInfo}, - health::HealthState, -}; +use self::health::HealthState; use crate::{ api::{Hit, HitDocument, SearchResponse}, fatal, - services::{chat::ChatService, completions::CompletionService}, + services::{ + chat::ChatService, + completions::CompletionService, + model::{load_text_generation, PromptInfo}, + }, }; #[derive(OpenApi)] @@ -93,17 +93,17 @@ pub enum Device { impl Device { #[cfg(all(target_os = "macos", target_arch = "aarch64"))] - fn ggml_use_gpu(&self) -> bool { + pub fn ggml_use_gpu(&self) -> bool { *self == Device::Metal } #[cfg(feature = "cuda")] - fn ggml_use_gpu(&self) -> bool { + pub fn ggml_use_gpu(&self) -> bool { *self == Device::Cuda } #[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))] - fn ggml_use_gpu(&self) -> bool { + pub fn ggml_use_gpu(&self) -> bool { false } } @@ -178,16 +178,17 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router { let completion_state = { let ( engine, - EngineInfo { + PromptInfo { prompt_template, .. }, - ) = create_engine(&args.model, args).await; + ) = load_text_generation(&args.model, &args.device, args.parallelism).await; let state = CompletionService::new(engine.clone(), code.clone(), prompt_template); Arc::new(state) }; let chat_state = if let Some(chat_model) = &args.chat_model { - let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args).await; + let (engine, PromptInfo { chat_template, .. }) = + load_text_generation(chat_model, &args.device, args.parallelism).await; let Some(chat_template) = chat_template else { panic!("Chat model requires specifying prompt template"); }; diff --git a/crates/tabby/src/services/mod.rs b/crates/tabby/src/services/mod.rs index a0bba28..9ef0b4d 100644 --- a/crates/tabby/src/services/mod.rs +++ b/crates/tabby/src/services/mod.rs @@ -1,3 +1,4 @@ pub mod chat; pub mod code; pub mod completions; +pub mod model; diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/services/model.rs similarity index 76% rename from crates/tabby/src/serve/engine.rs rename to crates/tabby/src/services/model.rs index ce35914..21f9545 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/services/model.rs @@ -4,18 +4,19 @@ use serde::Deserialize; use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH}; use tabby_inference::TextGeneration; -use crate::fatal; +use crate::{fatal, serve::Device}; -pub async fn create_engine( +pub async fn load_text_generation( model_id: &str, - args: &crate::serve::ServeArgs, -) -> (Arc, EngineInfo) { + device: &Device, + parallelism: u8, +) -> (Arc, PromptInfo) { #[cfg(feature = "experimental-http")] if args.device == crate::serve::Device::ExperimentalHttp { let (engine, prompt_template) = http_api_bindings::create(model_id); return ( engine, - EngineInfo { + PromptInfo { prompt_template: Some(prompt_template), chat_template: None, }, @@ -26,21 +27,21 @@ pub async fn create_engine( let path = PathBuf::from(model_id); let model_path = path.join(GGML_MODEL_RELATIVE_PATH); let engine = create_ggml_engine( - &args.device, + device, model_path.display().to_string().as_str(), - args.parallelism, + parallelism, ); - let engine_info = EngineInfo::read(path.join("tabby.json")); + let engine_info = PromptInfo::read(path.join("tabby.json")); (Arc::new(engine), engine_info) } else { let (registry, name) = parse_model_id(model_id); let registry = ModelRegistry::new(registry).await; let model_path = registry.get_model_path(name).display().to_string(); let model_info = registry.get_model_info(name); - let engine = create_ggml_engine(&args.device, &model_path, args.parallelism); + let engine = create_ggml_engine(device, &model_path, parallelism); ( Arc::new(engine), - EngineInfo { + PromptInfo { prompt_template: model_info.prompt_template.clone(), chat_template: model_info.chat_template.clone(), }, @@ -49,23 +50,19 @@ pub async fn create_engine( } #[derive(Deserialize)] -pub struct EngineInfo { +pub struct PromptInfo { pub prompt_template: Option, pub chat_template: Option, } -impl EngineInfo { - fn read(filepath: PathBuf) -> EngineInfo { +impl PromptInfo { + fn read(filepath: PathBuf) -> PromptInfo { serdeconv::from_json_file(&filepath) .unwrap_or_else(|_| fatal!("Invalid metadata file: {}", filepath.display())) } } -fn create_ggml_engine( - device: &super::Device, - model_path: &str, - parallelism: u8, -) -> impl TextGeneration { +fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl TextGeneration { let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() .model_path(model_path.to_owned()) .use_gpu(device.ggml_use_gpu())