refactor: serve/engine.rs => services/model.rs
parent
4359b0cc4b
commit
bad87a99a2
|
|
@ -1,6 +1,5 @@
|
|||
mod chat;
|
||||
mod completions;
|
||||
mod engine;
|
||||
mod events;
|
||||
mod health;
|
||||
mod search;
|
||||
|
|
@ -24,14 +23,15 @@ use tracing::info;
|
|||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
use self::{
|
||||
engine::{create_engine, EngineInfo},
|
||||
health::HealthState,
|
||||
};
|
||||
use self::health::HealthState;
|
||||
use crate::{
|
||||
api::{Hit, HitDocument, SearchResponse},
|
||||
fatal,
|
||||
services::{chat::ChatService, completions::CompletionService},
|
||||
services::{
|
||||
chat::ChatService,
|
||||
completions::CompletionService,
|
||||
model::{load_text_generation, PromptInfo},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(OpenApi)]
|
||||
|
|
@ -93,17 +93,17 @@ pub enum Device {
|
|||
|
||||
impl Device {
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
fn ggml_use_gpu(&self) -> bool {
|
||||
pub fn ggml_use_gpu(&self) -> bool {
|
||||
*self == Device::Metal
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn ggml_use_gpu(&self) -> bool {
|
||||
pub fn ggml_use_gpu(&self) -> bool {
|
||||
*self == Device::Cuda
|
||||
}
|
||||
|
||||
#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
|
||||
fn ggml_use_gpu(&self) -> bool {
|
||||
pub fn ggml_use_gpu(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
|
@ -178,16 +178,17 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
let completion_state = {
|
||||
let (
|
||||
engine,
|
||||
EngineInfo {
|
||||
PromptInfo {
|
||||
prompt_template, ..
|
||||
},
|
||||
) = create_engine(&args.model, args).await;
|
||||
) = load_text_generation(&args.model, &args.device, args.parallelism).await;
|
||||
let state = CompletionService::new(engine.clone(), code.clone(), prompt_template);
|
||||
Arc::new(state)
|
||||
};
|
||||
|
||||
let chat_state = if let Some(chat_model) = &args.chat_model {
|
||||
let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args).await;
|
||||
let (engine, PromptInfo { chat_template, .. }) =
|
||||
load_text_generation(chat_model, &args.device, args.parallelism).await;
|
||||
let Some(chat_template) = chat_template else {
|
||||
panic!("Chat model requires specifying prompt template");
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
pub mod chat;
|
||||
pub mod code;
|
||||
pub mod completions;
|
||||
pub mod model;
|
||||
|
|
|
|||
|
|
@ -4,18 +4,19 @@ use serde::Deserialize;
|
|||
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
|
||||
use tabby_inference::TextGeneration;
|
||||
|
||||
use crate::fatal;
|
||||
use crate::{fatal, serve::Device};
|
||||
|
||||
pub async fn create_engine(
|
||||
pub async fn load_text_generation(
|
||||
model_id: &str,
|
||||
args: &crate::serve::ServeArgs,
|
||||
) -> (Arc<dyn TextGeneration>, EngineInfo) {
|
||||
device: &Device,
|
||||
parallelism: u8,
|
||||
) -> (Arc<dyn TextGeneration>, PromptInfo) {
|
||||
#[cfg(feature = "experimental-http")]
|
||||
if args.device == crate::serve::Device::ExperimentalHttp {
|
||||
let (engine, prompt_template) = http_api_bindings::create(model_id);
|
||||
return (
|
||||
engine,
|
||||
EngineInfo {
|
||||
PromptInfo {
|
||||
prompt_template: Some(prompt_template),
|
||||
chat_template: None,
|
||||
},
|
||||
|
|
@ -26,21 +27,21 @@ pub async fn create_engine(
|
|||
let path = PathBuf::from(model_id);
|
||||
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
|
||||
let engine = create_ggml_engine(
|
||||
&args.device,
|
||||
device,
|
||||
model_path.display().to_string().as_str(),
|
||||
args.parallelism,
|
||||
parallelism,
|
||||
);
|
||||
let engine_info = EngineInfo::read(path.join("tabby.json"));
|
||||
let engine_info = PromptInfo::read(path.join("tabby.json"));
|
||||
(Arc::new(engine), engine_info)
|
||||
} else {
|
||||
let (registry, name) = parse_model_id(model_id);
|
||||
let registry = ModelRegistry::new(registry).await;
|
||||
let model_path = registry.get_model_path(name).display().to_string();
|
||||
let model_info = registry.get_model_info(name);
|
||||
let engine = create_ggml_engine(&args.device, &model_path, args.parallelism);
|
||||
let engine = create_ggml_engine(device, &model_path, parallelism);
|
||||
(
|
||||
Arc::new(engine),
|
||||
EngineInfo {
|
||||
PromptInfo {
|
||||
prompt_template: model_info.prompt_template.clone(),
|
||||
chat_template: model_info.chat_template.clone(),
|
||||
},
|
||||
|
|
@ -49,23 +50,19 @@ pub async fn create_engine(
|
|||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct EngineInfo {
|
||||
pub struct PromptInfo {
|
||||
pub prompt_template: Option<String>,
|
||||
pub chat_template: Option<String>,
|
||||
}
|
||||
|
||||
impl EngineInfo {
|
||||
fn read(filepath: PathBuf) -> EngineInfo {
|
||||
impl PromptInfo {
|
||||
fn read(filepath: PathBuf) -> PromptInfo {
|
||||
serdeconv::from_json_file(&filepath)
|
||||
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", filepath.display()))
|
||||
}
|
||||
}
|
||||
|
||||
fn create_ggml_engine(
|
||||
device: &super::Device,
|
||||
model_path: &str,
|
||||
parallelism: u8,
|
||||
) -> impl TextGeneration {
|
||||
fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl TextGeneration {
|
||||
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
|
||||
.model_path(model_path.to_owned())
|
||||
.use_gpu(device.ggml_use_gpu())
|
||||
Loading…
Reference in New Issue