refactor: use Arc<dyn TextGeneration> and Arc<dyn CodeSearch>
parent
22592374c1
commit
fce94f622b
|
|
@ -1,32 +1,31 @@
|
|||
mod fastchat;
|
||||
mod vertex_ai;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use fastchat::FastChatEngine;
|
||||
use serde_json::Value;
|
||||
use tabby_inference::TextGeneration;
|
||||
use vertex_ai::VertexAIEngine;
|
||||
|
||||
pub fn create(model: &str) -> (Box<dyn TextGeneration>, String) {
|
||||
pub fn create(model: &str) -> (Arc<dyn TextGeneration>, String) {
|
||||
let params = serde_json::from_str(model).expect("Failed to parse model string");
|
||||
let kind = get_param(¶ms, "kind");
|
||||
if kind == "vertex-ai" {
|
||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||
let authorization = get_param(¶ms, "authorization");
|
||||
let engine = Box::new(VertexAIEngine::create(
|
||||
api_endpoint.as_str(),
|
||||
authorization.as_str(),
|
||||
));
|
||||
(engine, VertexAIEngine::prompt_template())
|
||||
let engine = VertexAIEngine::create(api_endpoint.as_str(), authorization.as_str());
|
||||
(Arc::new(engine), VertexAIEngine::prompt_template())
|
||||
} else if kind == "fastchat" {
|
||||
let model_name = get_param(¶ms, "model_name");
|
||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||
let authorization = get_param(¶ms, "authorization");
|
||||
let engine = Box::new(FastChatEngine::create(
|
||||
let engine = FastChatEngine::create(
|
||||
api_endpoint.as_str(),
|
||||
model_name.as_str(),
|
||||
authorization.as_str(),
|
||||
));
|
||||
(engine, FastChatEngine::prompt_template())
|
||||
);
|
||||
(Arc::new(engine), FastChatEngine::prompt_template())
|
||||
} else {
|
||||
panic!("Only vertex_ai and fastchat are supported for http backend");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,5 +55,3 @@ pub trait CodeSearch: Send + Sync {
|
|||
offset: usize,
|
||||
) -> Result<SearchResponse, CodeSearchError>;
|
||||
}
|
||||
|
||||
pub type BoxCodeSearch = Box<dyn CodeSearch>;
|
||||
|
|
|
|||
|
|
@ -35,12 +35,12 @@ pub struct ChatCompletionChunk {
|
|||
}
|
||||
|
||||
pub struct ChatService {
|
||||
engine: Arc<Box<dyn TextGeneration>>,
|
||||
engine: Arc<dyn TextGeneration>,
|
||||
prompt_builder: ChatPromptBuilder,
|
||||
}
|
||||
|
||||
impl ChatService {
|
||||
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
|
||||
pub fn new(engine: Arc<dyn TextGeneration>, chat_template: String) -> Self {
|
||||
Self {
|
||||
engine,
|
||||
prompt_builder: ChatPromptBuilder::new(chat_template),
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration};
|
|||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use tabby_common::{
|
||||
api::code::{BoxCodeSearch, CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse},
|
||||
api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse},
|
||||
index::{self, register_tokenizers, CodeSearchSchema},
|
||||
path,
|
||||
};
|
||||
|
|
@ -156,8 +156,8 @@ impl CodeSearchService {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn create_code_search() -> BoxCodeSearch {
|
||||
Box::new(CodeSearchService::new())
|
||||
pub fn create_code_search() -> impl CodeSearch {
|
||||
CodeSearchService::new()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use std::sync::Arc;
|
|||
use axum::{extract::State, Json};
|
||||
use hyper::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::{api::code::BoxCodeSearch, events, languages::get_language};
|
||||
use tabby_common::{api::code::CodeSearch, events, languages::get_language};
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
||||
use tracing::{debug, instrument};
|
||||
use utoipa::ToSchema;
|
||||
|
|
@ -202,14 +202,14 @@ async fn build_prompt(
|
|||
}
|
||||
|
||||
pub struct CompletionState {
|
||||
engine: Arc<Box<dyn TextGeneration>>,
|
||||
engine: Arc<dyn TextGeneration>,
|
||||
prompt_builder: prompt::PromptBuilder,
|
||||
}
|
||||
|
||||
impl CompletionState {
|
||||
pub fn new(
|
||||
engine: Arc<Box<dyn TextGeneration>>,
|
||||
code: Arc<BoxCodeSearch>,
|
||||
engine: Arc<dyn TextGeneration>,
|
||||
code: Arc<dyn CodeSearch>,
|
||||
prompt_template: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use lazy_static::lazy_static;
|
|||
use regex::Regex;
|
||||
use strfmt::strfmt;
|
||||
use tabby_common::{
|
||||
api::code::{BoxCodeSearch, CodeSearchError},
|
||||
api::code::{CodeSearch, CodeSearchError},
|
||||
languages::get_language,
|
||||
};
|
||||
use textdistance::Algorithm;
|
||||
|
|
@ -18,11 +18,11 @@ static MAX_SIMILARITY_THRESHOLD: f32 = 0.9;
|
|||
|
||||
pub struct PromptBuilder {
|
||||
prompt_template: Option<String>,
|
||||
code: Option<Arc<BoxCodeSearch>>,
|
||||
code: Option<Arc<dyn CodeSearch>>,
|
||||
}
|
||||
|
||||
impl PromptBuilder {
|
||||
pub fn new(prompt_template: Option<String>, code: Option<Arc<BoxCodeSearch>>) -> Self {
|
||||
pub fn new(prompt_template: Option<String>, code: Option<Arc<dyn CodeSearch>>) -> Self {
|
||||
PromptBuilder {
|
||||
prompt_template,
|
||||
code,
|
||||
|
|
@ -106,7 +106,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
|
|||
format!("{}\n{}", comments, prefix)
|
||||
}
|
||||
|
||||
async fn collect_snippets(code: &BoxCodeSearch, language: &str, text: &str) -> Vec<Snippet> {
|
||||
async fn collect_snippets(code: &dyn CodeSearch, language: &str, text: &str) -> Vec<Snippet> {
|
||||
let mut ret = Vec::new();
|
||||
let mut tokens = tokenize_text(text);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use std::{fs, path::PathBuf};
|
||||
use std::{fs, path::PathBuf, sync::Arc};
|
||||
|
||||
use serde::Deserialize;
|
||||
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
|
||||
|
|
@ -9,7 +9,7 @@ use crate::fatal;
|
|||
pub async fn create_engine(
|
||||
model_id: &str,
|
||||
args: &crate::serve::ServeArgs,
|
||||
) -> (Box<dyn TextGeneration>, EngineInfo) {
|
||||
) -> (Arc<dyn TextGeneration>, EngineInfo) {
|
||||
#[cfg(feature = "experimental-http")]
|
||||
if args.device == crate::serve::Device::ExperimentalHttp {
|
||||
let (engine, prompt_template) = http_api_bindings::create(model_id);
|
||||
|
|
@ -31,7 +31,7 @@ pub async fn create_engine(
|
|||
args.parallelism,
|
||||
);
|
||||
let engine_info = EngineInfo::read(path.join("tabby.json"));
|
||||
(engine, engine_info)
|
||||
(Arc::new(engine), engine_info)
|
||||
} else {
|
||||
let (registry, name) = parse_model_id(model_id);
|
||||
let registry = ModelRegistry::new(registry).await;
|
||||
|
|
@ -39,7 +39,7 @@ pub async fn create_engine(
|
|||
let model_info = registry.get_model_info(name);
|
||||
let engine = create_ggml_engine(&args.device, &model_path, args.parallelism);
|
||||
(
|
||||
engine,
|
||||
Arc::new(engine),
|
||||
EngineInfo {
|
||||
prompt_template: model_info.prompt_template.clone(),
|
||||
chat_template: model_info.chat_template.clone(),
|
||||
|
|
@ -65,7 +65,7 @@ fn create_ggml_engine(
|
|||
device: &super::Device,
|
||||
model_path: &str,
|
||||
parallelism: u8,
|
||||
) -> Box<dyn TextGeneration> {
|
||||
) -> impl TextGeneration {
|
||||
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
|
||||
.model_path(model_path.to_owned())
|
||||
.use_gpu(device.ggml_use_gpu())
|
||||
|
|
@ -73,5 +73,5 @@ fn create_ggml_engine(
|
|||
.build()
|
||||
.unwrap();
|
||||
|
||||
Box::new(llama_cpp_bindings::LlamaTextGeneration::new(options))
|
||||
llama_cpp_bindings::LlamaTextGeneration::new(options)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -189,7 +189,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
prompt_template, ..
|
||||
},
|
||||
) = create_engine(&args.model, args).await;
|
||||
let engine = Arc::new(engine);
|
||||
let state =
|
||||
completions::CompletionState::new(engine.clone(), code.clone(), prompt_template);
|
||||
Arc::new(state)
|
||||
|
|
@ -200,7 +199,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
let Some(chat_template) = chat_template else {
|
||||
panic!("Chat model requires specifying prompt template");
|
||||
};
|
||||
let engine = Arc::new(engine);
|
||||
let state = ChatService::new(engine, chat_template);
|
||||
Some(Arc::new(state))
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ use axum::{
|
|||
};
|
||||
use hyper::StatusCode;
|
||||
use serde::Deserialize;
|
||||
use tabby_common::api::code::{BoxCodeSearch, CodeSearchError, SearchResponse};
|
||||
use tabby_common::api::code::{CodeSearch, CodeSearchError, SearchResponse};
|
||||
use tracing::{instrument, warn};
|
||||
use utoipa::IntoParams;
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ pub struct SearchQuery {
|
|||
)]
|
||||
#[instrument(skip(state, query))]
|
||||
pub async fn search(
|
||||
State(state): State<Arc<BoxCodeSearch>>,
|
||||
State(state): State<Arc<dyn CodeSearch>>,
|
||||
query: Query<SearchQuery>,
|
||||
) -> Result<Json<SearchResponse>, StatusCode> {
|
||||
match state
|
||||
|
|
|
|||
Loading…
Reference in New Issue