refactor: use Arc<dyn TextGeneration> and Arc<dyn CodeSearch>

extract-routes
Meng Zhang 2023-11-11 13:56:01 -08:00
parent 22592374c1
commit fce94f622b
9 changed files with 29 additions and 34 deletions

View File

@ -1,32 +1,31 @@
mod fastchat;
mod vertex_ai;
use std::sync::Arc;
use fastchat::FastChatEngine;
use serde_json::Value;
use tabby_inference::TextGeneration;
use vertex_ai::VertexAIEngine;
pub fn create(model: &str) -> (Box<dyn TextGeneration>, String) {
pub fn create(model: &str) -> (Arc<dyn TextGeneration>, String) {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "vertex-ai" {
let api_endpoint = get_param(&params, "api_endpoint");
let authorization = get_param(&params, "authorization");
let engine = Box::new(VertexAIEngine::create(
api_endpoint.as_str(),
authorization.as_str(),
));
(engine, VertexAIEngine::prompt_template())
let engine = VertexAIEngine::create(api_endpoint.as_str(), authorization.as_str());
(Arc::new(engine), VertexAIEngine::prompt_template())
} else if kind == "fastchat" {
let model_name = get_param(&params, "model_name");
let api_endpoint = get_param(&params, "api_endpoint");
let authorization = get_param(&params, "authorization");
let engine = Box::new(FastChatEngine::create(
let engine = FastChatEngine::create(
api_endpoint.as_str(),
model_name.as_str(),
authorization.as_str(),
));
(engine, FastChatEngine::prompt_template())
);
(Arc::new(engine), FastChatEngine::prompt_template())
} else {
panic!("Only vertex_ai and fastchat are supported for http backend");
}

View File

@ -55,5 +55,3 @@ pub trait CodeSearch: Send + Sync {
offset: usize,
) -> Result<SearchResponse, CodeSearchError>;
}
pub type BoxCodeSearch = Box<dyn CodeSearch>;

View File

@ -35,12 +35,12 @@ pub struct ChatCompletionChunk {
}
pub struct ChatService {
engine: Arc<Box<dyn TextGeneration>>,
engine: Arc<dyn TextGeneration>,
prompt_builder: ChatPromptBuilder,
}
impl ChatService {
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
pub fn new(engine: Arc<dyn TextGeneration>, chat_template: String) -> Self {
Self {
engine,
prompt_builder: ChatPromptBuilder::new(chat_template),

View File

@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration};
use anyhow::Result;
use async_trait::async_trait;
use tabby_common::{
api::code::{BoxCodeSearch, CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse},
api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse},
index::{self, register_tokenizers, CodeSearchSchema},
path,
};
@ -156,8 +156,8 @@ impl CodeSearchService {
}
}
pub fn create_code_search() -> BoxCodeSearch {
Box::new(CodeSearchService::new())
pub fn create_code_search() -> impl CodeSearch {
CodeSearchService::new()
}
#[async_trait]

View File

@ -5,7 +5,7 @@ use std::sync::Arc;
use axum::{extract::State, Json};
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use tabby_common::{api::code::BoxCodeSearch, events, languages::get_language};
use tabby_common::{api::code::CodeSearch, events, languages::get_language};
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument};
use utoipa::ToSchema;
@ -202,14 +202,14 @@ async fn build_prompt(
}
pub struct CompletionState {
engine: Arc<Box<dyn TextGeneration>>,
engine: Arc<dyn TextGeneration>,
prompt_builder: prompt::PromptBuilder,
}
impl CompletionState {
pub fn new(
engine: Arc<Box<dyn TextGeneration>>,
code: Arc<BoxCodeSearch>,
engine: Arc<dyn TextGeneration>,
code: Arc<dyn CodeSearch>,
prompt_template: Option<String>,
) -> Self {
Self {

View File

@ -4,7 +4,7 @@ use lazy_static::lazy_static;
use regex::Regex;
use strfmt::strfmt;
use tabby_common::{
api::code::{BoxCodeSearch, CodeSearchError},
api::code::{CodeSearch, CodeSearchError},
languages::get_language,
};
use textdistance::Algorithm;
@ -18,11 +18,11 @@ static MAX_SIMILARITY_THRESHOLD: f32 = 0.9;
pub struct PromptBuilder {
prompt_template: Option<String>,
code: Option<Arc<BoxCodeSearch>>,
code: Option<Arc<dyn CodeSearch>>,
}
impl PromptBuilder {
pub fn new(prompt_template: Option<String>, code: Option<Arc<BoxCodeSearch>>) -> Self {
pub fn new(prompt_template: Option<String>, code: Option<Arc<dyn CodeSearch>>) -> Self {
PromptBuilder {
prompt_template,
code,
@ -106,7 +106,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
format!("{}\n{}", comments, prefix)
}
async fn collect_snippets(code: &BoxCodeSearch, language: &str, text: &str) -> Vec<Snippet> {
async fn collect_snippets(code: &dyn CodeSearch, language: &str, text: &str) -> Vec<Snippet> {
let mut ret = Vec::new();
let mut tokens = tokenize_text(text);

View File

@ -1,4 +1,4 @@
use std::{fs, path::PathBuf};
use std::{fs, path::PathBuf, sync::Arc};
use serde::Deserialize;
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
@ -9,7 +9,7 @@ use crate::fatal;
pub async fn create_engine(
model_id: &str,
args: &crate::serve::ServeArgs,
) -> (Box<dyn TextGeneration>, EngineInfo) {
) -> (Arc<dyn TextGeneration>, EngineInfo) {
#[cfg(feature = "experimental-http")]
if args.device == crate::serve::Device::ExperimentalHttp {
let (engine, prompt_template) = http_api_bindings::create(model_id);
@ -31,7 +31,7 @@ pub async fn create_engine(
args.parallelism,
);
let engine_info = EngineInfo::read(path.join("tabby.json"));
(engine, engine_info)
(Arc::new(engine), engine_info)
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
@ -39,7 +39,7 @@ pub async fn create_engine(
let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(&args.device, &model_path, args.parallelism);
(
engine,
Arc::new(engine),
EngineInfo {
prompt_template: model_info.prompt_template.clone(),
chat_template: model_info.chat_template.clone(),
@ -65,7 +65,7 @@ fn create_ggml_engine(
device: &super::Device,
model_path: &str,
parallelism: u8,
) -> Box<dyn TextGeneration> {
) -> impl TextGeneration {
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
.model_path(model_path.to_owned())
.use_gpu(device.ggml_use_gpu())
@ -73,5 +73,5 @@ fn create_ggml_engine(
.build()
.unwrap();
Box::new(llama_cpp_bindings::LlamaTextGeneration::new(options))
llama_cpp_bindings::LlamaTextGeneration::new(options)
}

View File

@ -189,7 +189,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
prompt_template, ..
},
) = create_engine(&args.model, args).await;
let engine = Arc::new(engine);
let state =
completions::CompletionState::new(engine.clone(), code.clone(), prompt_template);
Arc::new(state)
@ -200,7 +199,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
let Some(chat_template) = chat_template else {
panic!("Chat model requires specifying prompt template");
};
let engine = Arc::new(engine);
let state = ChatService::new(engine, chat_template);
Some(Arc::new(state))
} else {

View File

@ -7,7 +7,7 @@ use axum::{
};
use hyper::StatusCode;
use serde::Deserialize;
use tabby_common::api::code::{BoxCodeSearch, CodeSearchError, SearchResponse};
use tabby_common::api::code::{CodeSearch, CodeSearchError, SearchResponse};
use tracing::{instrument, warn};
use utoipa::IntoParams;
@ -36,7 +36,7 @@ pub struct SearchQuery {
)]
#[instrument(skip(state, query))]
pub async fn search(
State(state): State<Arc<BoxCodeSearch>>,
State(state): State<Arc<dyn CodeSearch>>,
query: Query<SearchQuery>,
) -> Result<Json<SearchResponse>, StatusCode> {
match state