refactor: serve/engine.rs => services/model.rs

extract-routes
Meng Zhang 2023-11-12 20:37:31 -08:00
parent 4359b0cc4b
commit bad87a99a2
3 changed files with 29 additions and 30 deletions

View File

@ -1,6 +1,5 @@
mod chat;
mod completions;
mod engine;
mod events;
mod health;
mod search;
@ -24,14 +23,15 @@ use tracing::info;
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
use self::{
engine::{create_engine, EngineInfo},
health::HealthState,
};
use self::health::HealthState;
use crate::{
api::{Hit, HitDocument, SearchResponse},
fatal,
services::{chat::ChatService, completions::CompletionService},
services::{
chat::ChatService,
completions::CompletionService,
model::{load_text_generation, PromptInfo},
},
};
#[derive(OpenApi)]
@ -93,17 +93,17 @@ pub enum Device {
impl Device {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn ggml_use_gpu(&self) -> bool {
pub fn ggml_use_gpu(&self) -> bool {
*self == Device::Metal
}
#[cfg(feature = "cuda")]
fn ggml_use_gpu(&self) -> bool {
pub fn ggml_use_gpu(&self) -> bool {
*self == Device::Cuda
}
#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
fn ggml_use_gpu(&self) -> bool {
pub fn ggml_use_gpu(&self) -> bool {
false
}
}
@ -178,16 +178,17 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
let completion_state = {
let (
engine,
EngineInfo {
PromptInfo {
prompt_template, ..
},
) = create_engine(&args.model, args).await;
) = load_text_generation(&args.model, &args.device, args.parallelism).await;
let state = CompletionService::new(engine.clone(), code.clone(), prompt_template);
Arc::new(state)
};
let chat_state = if let Some(chat_model) = &args.chat_model {
let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args).await;
let (engine, PromptInfo { chat_template, .. }) =
load_text_generation(chat_model, &args.device, args.parallelism).await;
let Some(chat_template) = chat_template else {
panic!("Chat model requires specifying prompt template");
};

View File

@ -1,3 +1,4 @@
pub mod chat;
pub mod code;
pub mod completions;
pub mod model;

View File

@ -4,18 +4,19 @@ use serde::Deserialize;
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
use tabby_inference::TextGeneration;
use crate::fatal;
use crate::{fatal, serve::Device};
pub async fn create_engine(
pub async fn load_text_generation(
model_id: &str,
args: &crate::serve::ServeArgs,
) -> (Arc<dyn TextGeneration>, EngineInfo) {
device: &Device,
parallelism: u8,
) -> (Arc<dyn TextGeneration>, PromptInfo) {
#[cfg(feature = "experimental-http")]
if args.device == crate::serve::Device::ExperimentalHttp {
let (engine, prompt_template) = http_api_bindings::create(model_id);
return (
engine,
EngineInfo {
PromptInfo {
prompt_template: Some(prompt_template),
chat_template: None,
},
@ -26,21 +27,21 @@ pub async fn create_engine(
let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine = create_ggml_engine(
&args.device,
device,
model_path.display().to_string().as_str(),
args.parallelism,
parallelism,
);
let engine_info = EngineInfo::read(path.join("tabby.json"));
let engine_info = PromptInfo::read(path.join("tabby.json"));
(Arc::new(engine), engine_info)
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(&args.device, &model_path, args.parallelism);
let engine = create_ggml_engine(device, &model_path, parallelism);
(
Arc::new(engine),
EngineInfo {
PromptInfo {
prompt_template: model_info.prompt_template.clone(),
chat_template: model_info.chat_template.clone(),
},
@ -49,23 +50,19 @@ pub async fn create_engine(
}
#[derive(Deserialize)]
pub struct EngineInfo {
pub struct PromptInfo {
pub prompt_template: Option<String>,
pub chat_template: Option<String>,
}
impl EngineInfo {
fn read(filepath: PathBuf) -> EngineInfo {
impl PromptInfo {
fn read(filepath: PathBuf) -> PromptInfo {
serdeconv::from_json_file(&filepath)
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", filepath.display()))
}
}
fn create_ggml_engine(
device: &super::Device,
model_path: &str,
parallelism: u8,
) -> impl TextGeneration {
fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl TextGeneration {
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
.model_path(model_path.to_owned())
.use_gpu(device.ggml_use_gpu())