refactor: serve/engine.rs => services/model.rs

extract-routes
Meng Zhang 2023-11-12 20:37:31 -08:00
parent 4359b0cc4b
commit bad87a99a2
3 changed files with 29 additions and 30 deletions

View File

@ -1,6 +1,5 @@
mod chat; mod chat;
mod completions; mod completions;
mod engine;
mod events; mod events;
mod health; mod health;
mod search; mod search;
@ -24,14 +23,15 @@ use tracing::info;
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
use self::{ use self::health::HealthState;
engine::{create_engine, EngineInfo},
health::HealthState,
};
use crate::{ use crate::{
api::{Hit, HitDocument, SearchResponse}, api::{Hit, HitDocument, SearchResponse},
fatal, fatal,
services::{chat::ChatService, completions::CompletionService}, services::{
chat::ChatService,
completions::CompletionService,
model::{load_text_generation, PromptInfo},
},
}; };
#[derive(OpenApi)] #[derive(OpenApi)]
@ -93,17 +93,17 @@ pub enum Device {
impl Device { impl Device {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))] #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn ggml_use_gpu(&self) -> bool { pub fn ggml_use_gpu(&self) -> bool {
*self == Device::Metal *self == Device::Metal
} }
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
fn ggml_use_gpu(&self) -> bool { pub fn ggml_use_gpu(&self) -> bool {
*self == Device::Cuda *self == Device::Cuda
} }
#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))] #[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
fn ggml_use_gpu(&self) -> bool { pub fn ggml_use_gpu(&self) -> bool {
false false
} }
} }
@ -178,16 +178,17 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
let completion_state = { let completion_state = {
let ( let (
engine, engine,
EngineInfo { PromptInfo {
prompt_template, .. prompt_template, ..
}, },
) = create_engine(&args.model, args).await; ) = load_text_generation(&args.model, &args.device, args.parallelism).await;
let state = CompletionService::new(engine.clone(), code.clone(), prompt_template); let state = CompletionService::new(engine.clone(), code.clone(), prompt_template);
Arc::new(state) Arc::new(state)
}; };
let chat_state = if let Some(chat_model) = &args.chat_model { let chat_state = if let Some(chat_model) = &args.chat_model {
let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args).await; let (engine, PromptInfo { chat_template, .. }) =
load_text_generation(chat_model, &args.device, args.parallelism).await;
let Some(chat_template) = chat_template else { let Some(chat_template) = chat_template else {
panic!("Chat model requires specifying prompt template"); panic!("Chat model requires specifying prompt template");
}; };

View File

@ -1,3 +1,4 @@
pub mod chat; pub mod chat;
pub mod code; pub mod code;
pub mod completions; pub mod completions;
pub mod model;

View File

@ -4,18 +4,19 @@ use serde::Deserialize;
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH}; use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
use tabby_inference::TextGeneration; use tabby_inference::TextGeneration;
use crate::fatal; use crate::{fatal, serve::Device};
pub async fn create_engine( pub async fn load_text_generation(
model_id: &str, model_id: &str,
args: &crate::serve::ServeArgs, device: &Device,
) -> (Arc<dyn TextGeneration>, EngineInfo) { parallelism: u8,
) -> (Arc<dyn TextGeneration>, PromptInfo) {
#[cfg(feature = "experimental-http")] #[cfg(feature = "experimental-http")]
if args.device == crate::serve::Device::ExperimentalHttp { if args.device == crate::serve::Device::ExperimentalHttp {
let (engine, prompt_template) = http_api_bindings::create(model_id); let (engine, prompt_template) = http_api_bindings::create(model_id);
return ( return (
engine, engine,
EngineInfo { PromptInfo {
prompt_template: Some(prompt_template), prompt_template: Some(prompt_template),
chat_template: None, chat_template: None,
}, },
@ -26,21 +27,21 @@ pub async fn create_engine(
let path = PathBuf::from(model_id); let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH); let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine = create_ggml_engine( let engine = create_ggml_engine(
&args.device, device,
model_path.display().to_string().as_str(), model_path.display().to_string().as_str(),
args.parallelism, parallelism,
); );
let engine_info = EngineInfo::read(path.join("tabby.json")); let engine_info = PromptInfo::read(path.join("tabby.json"));
(Arc::new(engine), engine_info) (Arc::new(engine), engine_info)
} else { } else {
let (registry, name) = parse_model_id(model_id); let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await; let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string(); let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name); let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(&args.device, &model_path, args.parallelism); let engine = create_ggml_engine(device, &model_path, parallelism);
( (
Arc::new(engine), Arc::new(engine),
EngineInfo { PromptInfo {
prompt_template: model_info.prompt_template.clone(), prompt_template: model_info.prompt_template.clone(),
chat_template: model_info.chat_template.clone(), chat_template: model_info.chat_template.clone(),
}, },
@ -49,23 +50,19 @@ pub async fn create_engine(
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct EngineInfo { pub struct PromptInfo {
pub prompt_template: Option<String>, pub prompt_template: Option<String>,
pub chat_template: Option<String>, pub chat_template: Option<String>,
} }
impl EngineInfo { impl PromptInfo {
fn read(filepath: PathBuf) -> EngineInfo { fn read(filepath: PathBuf) -> PromptInfo {
serdeconv::from_json_file(&filepath) serdeconv::from_json_file(&filepath)
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", filepath.display())) .unwrap_or_else(|_| fatal!("Invalid metadata file: {}", filepath.display()))
} }
} }
fn create_ggml_engine( fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl TextGeneration {
device: &super::Device,
model_path: &str,
parallelism: u8,
) -> impl TextGeneration {
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
.model_path(model_path.to_owned()) .model_path(model_path.to_owned())
.use_gpu(device.ggml_use_gpu()) .use_gpu(device.ggml_use_gpu())