refactor: serve/engine.rs => services/model.rs
parent
4359b0cc4b
commit
bad87a99a2
|
|
@ -1,6 +1,5 @@
|
||||||
mod chat;
|
mod chat;
|
||||||
mod completions;
|
mod completions;
|
||||||
mod engine;
|
|
||||||
mod events;
|
mod events;
|
||||||
mod health;
|
mod health;
|
||||||
mod search;
|
mod search;
|
||||||
|
|
@ -24,14 +23,15 @@ use tracing::info;
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
use self::{
|
use self::health::HealthState;
|
||||||
engine::{create_engine, EngineInfo},
|
|
||||||
health::HealthState,
|
|
||||||
};
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{Hit, HitDocument, SearchResponse},
|
api::{Hit, HitDocument, SearchResponse},
|
||||||
fatal,
|
fatal,
|
||||||
services::{chat::ChatService, completions::CompletionService},
|
services::{
|
||||||
|
chat::ChatService,
|
||||||
|
completions::CompletionService,
|
||||||
|
model::{load_text_generation, PromptInfo},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
|
|
@ -93,17 +93,17 @@ pub enum Device {
|
||||||
|
|
||||||
impl Device {
|
impl Device {
|
||||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||||
fn ggml_use_gpu(&self) -> bool {
|
pub fn ggml_use_gpu(&self) -> bool {
|
||||||
*self == Device::Metal
|
*self == Device::Metal
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
fn ggml_use_gpu(&self) -> bool {
|
pub fn ggml_use_gpu(&self) -> bool {
|
||||||
*self == Device::Cuda
|
*self == Device::Cuda
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
|
#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
|
||||||
fn ggml_use_gpu(&self) -> bool {
|
pub fn ggml_use_gpu(&self) -> bool {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -178,16 +178,17 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
let completion_state = {
|
let completion_state = {
|
||||||
let (
|
let (
|
||||||
engine,
|
engine,
|
||||||
EngineInfo {
|
PromptInfo {
|
||||||
prompt_template, ..
|
prompt_template, ..
|
||||||
},
|
},
|
||||||
) = create_engine(&args.model, args).await;
|
) = load_text_generation(&args.model, &args.device, args.parallelism).await;
|
||||||
let state = CompletionService::new(engine.clone(), code.clone(), prompt_template);
|
let state = CompletionService::new(engine.clone(), code.clone(), prompt_template);
|
||||||
Arc::new(state)
|
Arc::new(state)
|
||||||
};
|
};
|
||||||
|
|
||||||
let chat_state = if let Some(chat_model) = &args.chat_model {
|
let chat_state = if let Some(chat_model) = &args.chat_model {
|
||||||
let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args).await;
|
let (engine, PromptInfo { chat_template, .. }) =
|
||||||
|
load_text_generation(chat_model, &args.device, args.parallelism).await;
|
||||||
let Some(chat_template) = chat_template else {
|
let Some(chat_template) = chat_template else {
|
||||||
panic!("Chat model requires specifying prompt template");
|
panic!("Chat model requires specifying prompt template");
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
pub mod chat;
|
pub mod chat;
|
||||||
pub mod code;
|
pub mod code;
|
||||||
pub mod completions;
|
pub mod completions;
|
||||||
|
pub mod model;
|
||||||
|
|
|
||||||
|
|
@ -4,18 +4,19 @@ use serde::Deserialize;
|
||||||
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
|
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
|
||||||
use tabby_inference::TextGeneration;
|
use tabby_inference::TextGeneration;
|
||||||
|
|
||||||
use crate::fatal;
|
use crate::{fatal, serve::Device};
|
||||||
|
|
||||||
pub async fn create_engine(
|
pub async fn load_text_generation(
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
args: &crate::serve::ServeArgs,
|
device: &Device,
|
||||||
) -> (Arc<dyn TextGeneration>, EngineInfo) {
|
parallelism: u8,
|
||||||
|
) -> (Arc<dyn TextGeneration>, PromptInfo) {
|
||||||
#[cfg(feature = "experimental-http")]
|
#[cfg(feature = "experimental-http")]
|
||||||
if args.device == crate::serve::Device::ExperimentalHttp {
|
if args.device == crate::serve::Device::ExperimentalHttp {
|
||||||
let (engine, prompt_template) = http_api_bindings::create(model_id);
|
let (engine, prompt_template) = http_api_bindings::create(model_id);
|
||||||
return (
|
return (
|
||||||
engine,
|
engine,
|
||||||
EngineInfo {
|
PromptInfo {
|
||||||
prompt_template: Some(prompt_template),
|
prompt_template: Some(prompt_template),
|
||||||
chat_template: None,
|
chat_template: None,
|
||||||
},
|
},
|
||||||
|
|
@ -26,21 +27,21 @@ pub async fn create_engine(
|
||||||
let path = PathBuf::from(model_id);
|
let path = PathBuf::from(model_id);
|
||||||
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
|
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
|
||||||
let engine = create_ggml_engine(
|
let engine = create_ggml_engine(
|
||||||
&args.device,
|
device,
|
||||||
model_path.display().to_string().as_str(),
|
model_path.display().to_string().as_str(),
|
||||||
args.parallelism,
|
parallelism,
|
||||||
);
|
);
|
||||||
let engine_info = EngineInfo::read(path.join("tabby.json"));
|
let engine_info = PromptInfo::read(path.join("tabby.json"));
|
||||||
(Arc::new(engine), engine_info)
|
(Arc::new(engine), engine_info)
|
||||||
} else {
|
} else {
|
||||||
let (registry, name) = parse_model_id(model_id);
|
let (registry, name) = parse_model_id(model_id);
|
||||||
let registry = ModelRegistry::new(registry).await;
|
let registry = ModelRegistry::new(registry).await;
|
||||||
let model_path = registry.get_model_path(name).display().to_string();
|
let model_path = registry.get_model_path(name).display().to_string();
|
||||||
let model_info = registry.get_model_info(name);
|
let model_info = registry.get_model_info(name);
|
||||||
let engine = create_ggml_engine(&args.device, &model_path, args.parallelism);
|
let engine = create_ggml_engine(device, &model_path, parallelism);
|
||||||
(
|
(
|
||||||
Arc::new(engine),
|
Arc::new(engine),
|
||||||
EngineInfo {
|
PromptInfo {
|
||||||
prompt_template: model_info.prompt_template.clone(),
|
prompt_template: model_info.prompt_template.clone(),
|
||||||
chat_template: model_info.chat_template.clone(),
|
chat_template: model_info.chat_template.clone(),
|
||||||
},
|
},
|
||||||
|
|
@ -49,23 +50,19 @@ pub async fn create_engine(
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct EngineInfo {
|
pub struct PromptInfo {
|
||||||
pub prompt_template: Option<String>,
|
pub prompt_template: Option<String>,
|
||||||
pub chat_template: Option<String>,
|
pub chat_template: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EngineInfo {
|
impl PromptInfo {
|
||||||
fn read(filepath: PathBuf) -> EngineInfo {
|
fn read(filepath: PathBuf) -> PromptInfo {
|
||||||
serdeconv::from_json_file(&filepath)
|
serdeconv::from_json_file(&filepath)
|
||||||
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", filepath.display()))
|
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", filepath.display()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_ggml_engine(
|
fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl TextGeneration {
|
||||||
device: &super::Device,
|
|
||||||
model_path: &str,
|
|
||||||
parallelism: u8,
|
|
||||||
) -> impl TextGeneration {
|
|
||||||
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
|
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
|
||||||
.model_path(model_path.to_owned())
|
.model_path(model_path.to_owned())
|
||||||
.use_gpu(device.ggml_use_gpu())
|
.use_gpu(device.ggml_use_gpu())
|
||||||
Loading…
Reference in New Issue