diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs index fc743ca..83c09cd 100644 --- a/crates/http-api-bindings/src/lib.rs +++ b/crates/http-api-bindings/src/lib.rs @@ -1,32 +1,31 @@ mod fastchat; mod vertex_ai; +use std::sync::Arc; + use fastchat::FastChatEngine; use serde_json::Value; use tabby_inference::TextGeneration; use vertex_ai::VertexAIEngine; -pub fn create(model: &str) -> (Box, String) { +pub fn create(model: &str) -> (Arc, String) { let params = serde_json::from_str(model).expect("Failed to parse model string"); let kind = get_param(¶ms, "kind"); if kind == "vertex-ai" { let api_endpoint = get_param(¶ms, "api_endpoint"); let authorization = get_param(¶ms, "authorization"); - let engine = Box::new(VertexAIEngine::create( - api_endpoint.as_str(), - authorization.as_str(), - )); - (engine, VertexAIEngine::prompt_template()) + let engine = VertexAIEngine::create(api_endpoint.as_str(), authorization.as_str()); + (Arc::new(engine), VertexAIEngine::prompt_template()) } else if kind == "fastchat" { let model_name = get_param(¶ms, "model_name"); let api_endpoint = get_param(¶ms, "api_endpoint"); let authorization = get_param(¶ms, "authorization"); - let engine = Box::new(FastChatEngine::create( + let engine = FastChatEngine::create( api_endpoint.as_str(), model_name.as_str(), authorization.as_str(), - )); - (engine, FastChatEngine::prompt_template()) + ); + (Arc::new(engine), FastChatEngine::prompt_template()) } else { panic!("Only vertex_ai and fastchat are supported for http backend"); } diff --git a/crates/tabby-common/src/api/code.rs b/crates/tabby-common/src/api/code.rs index 47cbe8e..2f07329 100644 --- a/crates/tabby-common/src/api/code.rs +++ b/crates/tabby-common/src/api/code.rs @@ -55,5 +55,3 @@ pub trait CodeSearch: Send + Sync { offset: usize, ) -> Result; } - -pub type BoxCodeSearch = Box; diff --git a/crates/tabby/src/chat.rs b/crates/tabby/src/chat.rs index 0d1d778..1f2c91a 100644 --- a/crates/tabby/src/chat.rs +++ b/crates/tabby/src/chat.rs @@ -35,12 +35,12 @@ pub struct ChatCompletionChunk { } pub struct ChatService { - engine: Arc>, + engine: Arc, prompt_builder: ChatPromptBuilder, } impl ChatService { - pub fn new(engine: Arc>, chat_template: String) -> Self { + pub fn new(engine: Arc, chat_template: String) -> Self { Self { engine, prompt_builder: ChatPromptBuilder::new(chat_template), diff --git a/crates/tabby/src/search.rs b/crates/tabby/src/search.rs index b7de2da..7ca3188 100644 --- a/crates/tabby/src/search.rs +++ b/crates/tabby/src/search.rs @@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration}; use anyhow::Result; use async_trait::async_trait; use tabby_common::{ - api::code::{BoxCodeSearch, CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse}, + api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse}, index::{self, register_tokenizers, CodeSearchSchema}, path, }; @@ -156,8 +156,8 @@ impl CodeSearchService { } } -pub fn create_code_search() -> BoxCodeSearch { - Box::new(CodeSearchService::new()) +pub fn create_code_search() -> impl CodeSearch { + CodeSearchService::new() } #[async_trait] diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 4a66220..f630b6d 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use axum::{extract::State, Json}; use hyper::StatusCode; use serde::{Deserialize, Serialize}; -use tabby_common::{api::code::BoxCodeSearch, events, languages::get_language}; +use tabby_common::{api::code::CodeSearch, events, languages::get_language}; use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; use tracing::{debug, instrument}; use utoipa::ToSchema; @@ -202,14 +202,14 @@ async fn build_prompt( } pub struct CompletionState { - engine: Arc>, + engine: Arc, prompt_builder: prompt::PromptBuilder, } impl CompletionState { pub fn new( - engine: Arc>, - code: Arc, + engine: Arc, + code: Arc, prompt_template: Option, ) -> Self { Self { diff --git a/crates/tabby/src/serve/completions/prompt.rs b/crates/tabby/src/serve/completions/prompt.rs index a863d73..742e5bf 100644 --- a/crates/tabby/src/serve/completions/prompt.rs +++ b/crates/tabby/src/serve/completions/prompt.rs @@ -4,7 +4,7 @@ use lazy_static::lazy_static; use regex::Regex; use strfmt::strfmt; use tabby_common::{ - api::code::{BoxCodeSearch, CodeSearchError}, + api::code::{CodeSearch, CodeSearchError}, languages::get_language, }; use textdistance::Algorithm; @@ -18,11 +18,11 @@ static MAX_SIMILARITY_THRESHOLD: f32 = 0.9; pub struct PromptBuilder { prompt_template: Option, - code: Option>, + code: Option>, } impl PromptBuilder { - pub fn new(prompt_template: Option, code: Option>) -> Self { + pub fn new(prompt_template: Option, code: Option>) -> Self { PromptBuilder { prompt_template, code, @@ -106,7 +106,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String { format!("{}\n{}", comments, prefix) } -async fn collect_snippets(code: &BoxCodeSearch, language: &str, text: &str) -> Vec { +async fn collect_snippets(code: &dyn CodeSearch, language: &str, text: &str) -> Vec { let mut ret = Vec::new(); let mut tokens = tokenize_text(text); diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index 8e28d12..ce35914 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -1,4 +1,4 @@ -use std::{fs, path::PathBuf}; +use std::{fs, path::PathBuf, sync::Arc}; use serde::Deserialize; use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH}; @@ -9,7 +9,7 @@ use crate::fatal; pub async fn create_engine( model_id: &str, args: &crate::serve::ServeArgs, -) -> (Box, EngineInfo) { +) -> (Arc, EngineInfo) { #[cfg(feature = "experimental-http")] if args.device == crate::serve::Device::ExperimentalHttp { let (engine, prompt_template) = http_api_bindings::create(model_id); @@ -31,7 +31,7 @@ pub async fn create_engine( args.parallelism, ); let engine_info = EngineInfo::read(path.join("tabby.json")); - (engine, engine_info) + (Arc::new(engine), engine_info) } else { let (registry, name) = parse_model_id(model_id); let registry = ModelRegistry::new(registry).await; @@ -39,7 +39,7 @@ pub async fn create_engine( let model_info = registry.get_model_info(name); let engine = create_ggml_engine(&args.device, &model_path, args.parallelism); ( - engine, + Arc::new(engine), EngineInfo { prompt_template: model_info.prompt_template.clone(), chat_template: model_info.chat_template.clone(), @@ -65,7 +65,7 @@ fn create_ggml_engine( device: &super::Device, model_path: &str, parallelism: u8, -) -> Box { +) -> impl TextGeneration { let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() .model_path(model_path.to_owned()) .use_gpu(device.ggml_use_gpu()) @@ -73,5 +73,5 @@ fn create_ggml_engine( .build() .unwrap(); - Box::new(llama_cpp_bindings::LlamaTextGeneration::new(options)) + llama_cpp_bindings::LlamaTextGeneration::new(options) } diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 7bee9df..9298814 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -189,7 +189,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router { prompt_template, .. }, ) = create_engine(&args.model, args).await; - let engine = Arc::new(engine); let state = completions::CompletionState::new(engine.clone(), code.clone(), prompt_template); Arc::new(state) @@ -200,7 +199,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router { let Some(chat_template) = chat_template else { panic!("Chat model requires specifying prompt template"); }; - let engine = Arc::new(engine); let state = ChatService::new(engine, chat_template); Some(Arc::new(state)) } else { diff --git a/crates/tabby/src/serve/search.rs b/crates/tabby/src/serve/search.rs index 28a2a67..99edade 100644 --- a/crates/tabby/src/serve/search.rs +++ b/crates/tabby/src/serve/search.rs @@ -7,7 +7,7 @@ use axum::{ }; use hyper::StatusCode; use serde::Deserialize; -use tabby_common::api::code::{BoxCodeSearch, CodeSearchError, SearchResponse}; +use tabby_common::api::code::{CodeSearch, CodeSearchError, SearchResponse}; use tracing::{instrument, warn}; use utoipa::IntoParams; @@ -36,7 +36,7 @@ pub struct SearchQuery { )] #[instrument(skip(state, query))] pub async fn search( - State(state): State>, + State(state): State>, query: Query, ) -> Result, StatusCode> { match state