refactor: use Arc<dyn TextGeneration> and Arc<dyn CodeSearch>

extract-routes
Meng Zhang 2023-11-11 13:56:01 -08:00
parent 22592374c1
commit fce94f622b
9 changed files with 29 additions and 34 deletions

View File

@ -1,32 +1,31 @@
mod fastchat; mod fastchat;
mod vertex_ai; mod vertex_ai;
use std::sync::Arc;
use fastchat::FastChatEngine; use fastchat::FastChatEngine;
use serde_json::Value; use serde_json::Value;
use tabby_inference::TextGeneration; use tabby_inference::TextGeneration;
use vertex_ai::VertexAIEngine; use vertex_ai::VertexAIEngine;
pub fn create(model: &str) -> (Box<dyn TextGeneration>, String) { pub fn create(model: &str) -> (Arc<dyn TextGeneration>, String) {
let params = serde_json::from_str(model).expect("Failed to parse model string"); let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind"); let kind = get_param(&params, "kind");
if kind == "vertex-ai" { if kind == "vertex-ai" {
let api_endpoint = get_param(&params, "api_endpoint"); let api_endpoint = get_param(&params, "api_endpoint");
let authorization = get_param(&params, "authorization"); let authorization = get_param(&params, "authorization");
let engine = Box::new(VertexAIEngine::create( let engine = VertexAIEngine::create(api_endpoint.as_str(), authorization.as_str());
api_endpoint.as_str(), (Arc::new(engine), VertexAIEngine::prompt_template())
authorization.as_str(),
));
(engine, VertexAIEngine::prompt_template())
} else if kind == "fastchat" { } else if kind == "fastchat" {
let model_name = get_param(&params, "model_name"); let model_name = get_param(&params, "model_name");
let api_endpoint = get_param(&params, "api_endpoint"); let api_endpoint = get_param(&params, "api_endpoint");
let authorization = get_param(&params, "authorization"); let authorization = get_param(&params, "authorization");
let engine = Box::new(FastChatEngine::create( let engine = FastChatEngine::create(
api_endpoint.as_str(), api_endpoint.as_str(),
model_name.as_str(), model_name.as_str(),
authorization.as_str(), authorization.as_str(),
)); );
(engine, FastChatEngine::prompt_template()) (Arc::new(engine), FastChatEngine::prompt_template())
} else { } else {
panic!("Only vertex_ai and fastchat are supported for http backend"); panic!("Only vertex_ai and fastchat are supported for http backend");
} }

View File

@ -55,5 +55,3 @@ pub trait CodeSearch: Send + Sync {
offset: usize, offset: usize,
) -> Result<SearchResponse, CodeSearchError>; ) -> Result<SearchResponse, CodeSearchError>;
} }
pub type BoxCodeSearch = Box<dyn CodeSearch>;

View File

@ -35,12 +35,12 @@ pub struct ChatCompletionChunk {
} }
pub struct ChatService { pub struct ChatService {
engine: Arc<Box<dyn TextGeneration>>, engine: Arc<dyn TextGeneration>,
prompt_builder: ChatPromptBuilder, prompt_builder: ChatPromptBuilder,
} }
impl ChatService { impl ChatService {
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self { pub fn new(engine: Arc<dyn TextGeneration>, chat_template: String) -> Self {
Self { Self {
engine, engine,
prompt_builder: ChatPromptBuilder::new(chat_template), prompt_builder: ChatPromptBuilder::new(chat_template),

View File

@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration};
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use tabby_common::{ use tabby_common::{
api::code::{BoxCodeSearch, CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse}, api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse},
index::{self, register_tokenizers, CodeSearchSchema}, index::{self, register_tokenizers, CodeSearchSchema},
path, path,
}; };
@ -156,8 +156,8 @@ impl CodeSearchService {
} }
} }
pub fn create_code_search() -> BoxCodeSearch { pub fn create_code_search() -> impl CodeSearch {
Box::new(CodeSearchService::new()) CodeSearchService::new()
} }
#[async_trait] #[async_trait]

View File

@ -5,7 +5,7 @@ use std::sync::Arc;
use axum::{extract::State, Json}; use axum::{extract::State, Json};
use hyper::StatusCode; use hyper::StatusCode;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tabby_common::{api::code::BoxCodeSearch, events, languages::get_language}; use tabby_common::{api::code::CodeSearch, events, languages::get_language};
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument}; use tracing::{debug, instrument};
use utoipa::ToSchema; use utoipa::ToSchema;
@ -202,14 +202,14 @@ async fn build_prompt(
} }
pub struct CompletionState { pub struct CompletionState {
engine: Arc<Box<dyn TextGeneration>>, engine: Arc<dyn TextGeneration>,
prompt_builder: prompt::PromptBuilder, prompt_builder: prompt::PromptBuilder,
} }
impl CompletionState { impl CompletionState {
pub fn new( pub fn new(
engine: Arc<Box<dyn TextGeneration>>, engine: Arc<dyn TextGeneration>,
code: Arc<BoxCodeSearch>, code: Arc<dyn CodeSearch>,
prompt_template: Option<String>, prompt_template: Option<String>,
) -> Self { ) -> Self {
Self { Self {

View File

@ -4,7 +4,7 @@ use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
use strfmt::strfmt; use strfmt::strfmt;
use tabby_common::{ use tabby_common::{
api::code::{BoxCodeSearch, CodeSearchError}, api::code::{CodeSearch, CodeSearchError},
languages::get_language, languages::get_language,
}; };
use textdistance::Algorithm; use textdistance::Algorithm;
@ -18,11 +18,11 @@ static MAX_SIMILARITY_THRESHOLD: f32 = 0.9;
pub struct PromptBuilder { pub struct PromptBuilder {
prompt_template: Option<String>, prompt_template: Option<String>,
code: Option<Arc<BoxCodeSearch>>, code: Option<Arc<dyn CodeSearch>>,
} }
impl PromptBuilder { impl PromptBuilder {
pub fn new(prompt_template: Option<String>, code: Option<Arc<BoxCodeSearch>>) -> Self { pub fn new(prompt_template: Option<String>, code: Option<Arc<dyn CodeSearch>>) -> Self {
PromptBuilder { PromptBuilder {
prompt_template, prompt_template,
code, code,
@ -106,7 +106,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
format!("{}\n{}", comments, prefix) format!("{}\n{}", comments, prefix)
} }
async fn collect_snippets(code: &BoxCodeSearch, language: &str, text: &str) -> Vec<Snippet> { async fn collect_snippets(code: &dyn CodeSearch, language: &str, text: &str) -> Vec<Snippet> {
let mut ret = Vec::new(); let mut ret = Vec::new();
let mut tokens = tokenize_text(text); let mut tokens = tokenize_text(text);

View File

@ -1,4 +1,4 @@
use std::{fs, path::PathBuf}; use std::{fs, path::PathBuf, sync::Arc};
use serde::Deserialize; use serde::Deserialize;
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH}; use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
@ -9,7 +9,7 @@ use crate::fatal;
pub async fn create_engine( pub async fn create_engine(
model_id: &str, model_id: &str,
args: &crate::serve::ServeArgs, args: &crate::serve::ServeArgs,
) -> (Box<dyn TextGeneration>, EngineInfo) { ) -> (Arc<dyn TextGeneration>, EngineInfo) {
#[cfg(feature = "experimental-http")] #[cfg(feature = "experimental-http")]
if args.device == crate::serve::Device::ExperimentalHttp { if args.device == crate::serve::Device::ExperimentalHttp {
let (engine, prompt_template) = http_api_bindings::create(model_id); let (engine, prompt_template) = http_api_bindings::create(model_id);
@ -31,7 +31,7 @@ pub async fn create_engine(
args.parallelism, args.parallelism,
); );
let engine_info = EngineInfo::read(path.join("tabby.json")); let engine_info = EngineInfo::read(path.join("tabby.json"));
(engine, engine_info) (Arc::new(engine), engine_info)
} else { } else {
let (registry, name) = parse_model_id(model_id); let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await; let registry = ModelRegistry::new(registry).await;
@ -39,7 +39,7 @@ pub async fn create_engine(
let model_info = registry.get_model_info(name); let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(&args.device, &model_path, args.parallelism); let engine = create_ggml_engine(&args.device, &model_path, args.parallelism);
( (
engine, Arc::new(engine),
EngineInfo { EngineInfo {
prompt_template: model_info.prompt_template.clone(), prompt_template: model_info.prompt_template.clone(),
chat_template: model_info.chat_template.clone(), chat_template: model_info.chat_template.clone(),
@ -65,7 +65,7 @@ fn create_ggml_engine(
device: &super::Device, device: &super::Device,
model_path: &str, model_path: &str,
parallelism: u8, parallelism: u8,
) -> Box<dyn TextGeneration> { ) -> impl TextGeneration {
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
.model_path(model_path.to_owned()) .model_path(model_path.to_owned())
.use_gpu(device.ggml_use_gpu()) .use_gpu(device.ggml_use_gpu())
@ -73,5 +73,5 @@ fn create_ggml_engine(
.build() .build()
.unwrap(); .unwrap();
Box::new(llama_cpp_bindings::LlamaTextGeneration::new(options)) llama_cpp_bindings::LlamaTextGeneration::new(options)
} }

View File

@ -189,7 +189,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
prompt_template, .. prompt_template, ..
}, },
) = create_engine(&args.model, args).await; ) = create_engine(&args.model, args).await;
let engine = Arc::new(engine);
let state = let state =
completions::CompletionState::new(engine.clone(), code.clone(), prompt_template); completions::CompletionState::new(engine.clone(), code.clone(), prompt_template);
Arc::new(state) Arc::new(state)
@ -200,7 +199,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
let Some(chat_template) = chat_template else { let Some(chat_template) = chat_template else {
panic!("Chat model requires specifying prompt template"); panic!("Chat model requires specifying prompt template");
}; };
let engine = Arc::new(engine);
let state = ChatService::new(engine, chat_template); let state = ChatService::new(engine, chat_template);
Some(Arc::new(state)) Some(Arc::new(state))
} else { } else {

View File

@ -7,7 +7,7 @@ use axum::{
}; };
use hyper::StatusCode; use hyper::StatusCode;
use serde::Deserialize; use serde::Deserialize;
use tabby_common::api::code::{BoxCodeSearch, CodeSearchError, SearchResponse}; use tabby_common::api::code::{CodeSearch, CodeSearchError, SearchResponse};
use tracing::{instrument, warn}; use tracing::{instrument, warn};
use utoipa::IntoParams; use utoipa::IntoParams;
@ -36,7 +36,7 @@ pub struct SearchQuery {
)] )]
#[instrument(skip(state, query))] #[instrument(skip(state, query))]
pub async fn search( pub async fn search(
State(state): State<Arc<BoxCodeSearch>>, State(state): State<Arc<dyn CodeSearch>>,
query: Query<SearchQuery>, query: Query<SearchQuery>,
) -> Result<Json<SearchResponse>, StatusCode> { ) -> Result<Json<SearchResponse>, StatusCode> {
match state match state