refactor: use Arc<dyn TextGeneration> and Arc<dyn CodeSearch>
parent
22592374c1
commit
fce94f622b
|
|
@ -1,32 +1,31 @@
|
||||||
mod fastchat;
|
mod fastchat;
|
||||||
mod vertex_ai;
|
mod vertex_ai;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use fastchat::FastChatEngine;
|
use fastchat::FastChatEngine;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use tabby_inference::TextGeneration;
|
use tabby_inference::TextGeneration;
|
||||||
use vertex_ai::VertexAIEngine;
|
use vertex_ai::VertexAIEngine;
|
||||||
|
|
||||||
pub fn create(model: &str) -> (Box<dyn TextGeneration>, String) {
|
pub fn create(model: &str) -> (Arc<dyn TextGeneration>, String) {
|
||||||
let params = serde_json::from_str(model).expect("Failed to parse model string");
|
let params = serde_json::from_str(model).expect("Failed to parse model string");
|
||||||
let kind = get_param(¶ms, "kind");
|
let kind = get_param(¶ms, "kind");
|
||||||
if kind == "vertex-ai" {
|
if kind == "vertex-ai" {
|
||||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||||
let authorization = get_param(¶ms, "authorization");
|
let authorization = get_param(¶ms, "authorization");
|
||||||
let engine = Box::new(VertexAIEngine::create(
|
let engine = VertexAIEngine::create(api_endpoint.as_str(), authorization.as_str());
|
||||||
api_endpoint.as_str(),
|
(Arc::new(engine), VertexAIEngine::prompt_template())
|
||||||
authorization.as_str(),
|
|
||||||
));
|
|
||||||
(engine, VertexAIEngine::prompt_template())
|
|
||||||
} else if kind == "fastchat" {
|
} else if kind == "fastchat" {
|
||||||
let model_name = get_param(¶ms, "model_name");
|
let model_name = get_param(¶ms, "model_name");
|
||||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||||
let authorization = get_param(¶ms, "authorization");
|
let authorization = get_param(¶ms, "authorization");
|
||||||
let engine = Box::new(FastChatEngine::create(
|
let engine = FastChatEngine::create(
|
||||||
api_endpoint.as_str(),
|
api_endpoint.as_str(),
|
||||||
model_name.as_str(),
|
model_name.as_str(),
|
||||||
authorization.as_str(),
|
authorization.as_str(),
|
||||||
));
|
);
|
||||||
(engine, FastChatEngine::prompt_template())
|
(Arc::new(engine), FastChatEngine::prompt_template())
|
||||||
} else {
|
} else {
|
||||||
panic!("Only vertex_ai and fastchat are supported for http backend");
|
panic!("Only vertex_ai and fastchat are supported for http backend");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -55,5 +55,3 @@ pub trait CodeSearch: Send + Sync {
|
||||||
offset: usize,
|
offset: usize,
|
||||||
) -> Result<SearchResponse, CodeSearchError>;
|
) -> Result<SearchResponse, CodeSearchError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type BoxCodeSearch = Box<dyn CodeSearch>;
|
|
||||||
|
|
|
||||||
|
|
@ -35,12 +35,12 @@ pub struct ChatCompletionChunk {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ChatService {
|
pub struct ChatService {
|
||||||
engine: Arc<Box<dyn TextGeneration>>,
|
engine: Arc<dyn TextGeneration>,
|
||||||
prompt_builder: ChatPromptBuilder,
|
prompt_builder: ChatPromptBuilder,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatService {
|
impl ChatService {
|
||||||
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
|
pub fn new(engine: Arc<dyn TextGeneration>, chat_template: String) -> Self {
|
||||||
Self {
|
Self {
|
||||||
engine,
|
engine,
|
||||||
prompt_builder: ChatPromptBuilder::new(chat_template),
|
prompt_builder: ChatPromptBuilder::new(chat_template),
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use tabby_common::{
|
use tabby_common::{
|
||||||
api::code::{BoxCodeSearch, CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse},
|
api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse},
|
||||||
index::{self, register_tokenizers, CodeSearchSchema},
|
index::{self, register_tokenizers, CodeSearchSchema},
|
||||||
path,
|
path,
|
||||||
};
|
};
|
||||||
|
|
@ -156,8 +156,8 @@ impl CodeSearchService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_code_search() -> BoxCodeSearch {
|
pub fn create_code_search() -> impl CodeSearch {
|
||||||
Box::new(CodeSearchService::new())
|
CodeSearchService::new()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ use std::sync::Arc;
|
||||||
use axum::{extract::State, Json};
|
use axum::{extract::State, Json};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tabby_common::{api::code::BoxCodeSearch, events, languages::get_language};
|
use tabby_common::{api::code::CodeSearch, events, languages::get_language};
|
||||||
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
||||||
use tracing::{debug, instrument};
|
use tracing::{debug, instrument};
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
@ -202,14 +202,14 @@ async fn build_prompt(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CompletionState {
|
pub struct CompletionState {
|
||||||
engine: Arc<Box<dyn TextGeneration>>,
|
engine: Arc<dyn TextGeneration>,
|
||||||
prompt_builder: prompt::PromptBuilder,
|
prompt_builder: prompt::PromptBuilder,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CompletionState {
|
impl CompletionState {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
engine: Arc<Box<dyn TextGeneration>>,
|
engine: Arc<dyn TextGeneration>,
|
||||||
code: Arc<BoxCodeSearch>,
|
code: Arc<dyn CodeSearch>,
|
||||||
prompt_template: Option<String>,
|
prompt_template: Option<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ use lazy_static::lazy_static;
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use strfmt::strfmt;
|
use strfmt::strfmt;
|
||||||
use tabby_common::{
|
use tabby_common::{
|
||||||
api::code::{BoxCodeSearch, CodeSearchError},
|
api::code::{CodeSearch, CodeSearchError},
|
||||||
languages::get_language,
|
languages::get_language,
|
||||||
};
|
};
|
||||||
use textdistance::Algorithm;
|
use textdistance::Algorithm;
|
||||||
|
|
@ -18,11 +18,11 @@ static MAX_SIMILARITY_THRESHOLD: f32 = 0.9;
|
||||||
|
|
||||||
pub struct PromptBuilder {
|
pub struct PromptBuilder {
|
||||||
prompt_template: Option<String>,
|
prompt_template: Option<String>,
|
||||||
code: Option<Arc<BoxCodeSearch>>,
|
code: Option<Arc<dyn CodeSearch>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PromptBuilder {
|
impl PromptBuilder {
|
||||||
pub fn new(prompt_template: Option<String>, code: Option<Arc<BoxCodeSearch>>) -> Self {
|
pub fn new(prompt_template: Option<String>, code: Option<Arc<dyn CodeSearch>>) -> Self {
|
||||||
PromptBuilder {
|
PromptBuilder {
|
||||||
prompt_template,
|
prompt_template,
|
||||||
code,
|
code,
|
||||||
|
|
@ -106,7 +106,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
|
||||||
format!("{}\n{}", comments, prefix)
|
format!("{}\n{}", comments, prefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn collect_snippets(code: &BoxCodeSearch, language: &str, text: &str) -> Vec<Snippet> {
|
async fn collect_snippets(code: &dyn CodeSearch, language: &str, text: &str) -> Vec<Snippet> {
|
||||||
let mut ret = Vec::new();
|
let mut ret = Vec::new();
|
||||||
let mut tokens = tokenize_text(text);
|
let mut tokens = tokenize_text(text);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use std::{fs, path::PathBuf};
|
use std::{fs, path::PathBuf, sync::Arc};
|
||||||
|
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
|
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
|
||||||
|
|
@ -9,7 +9,7 @@ use crate::fatal;
|
||||||
pub async fn create_engine(
|
pub async fn create_engine(
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
args: &crate::serve::ServeArgs,
|
args: &crate::serve::ServeArgs,
|
||||||
) -> (Box<dyn TextGeneration>, EngineInfo) {
|
) -> (Arc<dyn TextGeneration>, EngineInfo) {
|
||||||
#[cfg(feature = "experimental-http")]
|
#[cfg(feature = "experimental-http")]
|
||||||
if args.device == crate::serve::Device::ExperimentalHttp {
|
if args.device == crate::serve::Device::ExperimentalHttp {
|
||||||
let (engine, prompt_template) = http_api_bindings::create(model_id);
|
let (engine, prompt_template) = http_api_bindings::create(model_id);
|
||||||
|
|
@ -31,7 +31,7 @@ pub async fn create_engine(
|
||||||
args.parallelism,
|
args.parallelism,
|
||||||
);
|
);
|
||||||
let engine_info = EngineInfo::read(path.join("tabby.json"));
|
let engine_info = EngineInfo::read(path.join("tabby.json"));
|
||||||
(engine, engine_info)
|
(Arc::new(engine), engine_info)
|
||||||
} else {
|
} else {
|
||||||
let (registry, name) = parse_model_id(model_id);
|
let (registry, name) = parse_model_id(model_id);
|
||||||
let registry = ModelRegistry::new(registry).await;
|
let registry = ModelRegistry::new(registry).await;
|
||||||
|
|
@ -39,7 +39,7 @@ pub async fn create_engine(
|
||||||
let model_info = registry.get_model_info(name);
|
let model_info = registry.get_model_info(name);
|
||||||
let engine = create_ggml_engine(&args.device, &model_path, args.parallelism);
|
let engine = create_ggml_engine(&args.device, &model_path, args.parallelism);
|
||||||
(
|
(
|
||||||
engine,
|
Arc::new(engine),
|
||||||
EngineInfo {
|
EngineInfo {
|
||||||
prompt_template: model_info.prompt_template.clone(),
|
prompt_template: model_info.prompt_template.clone(),
|
||||||
chat_template: model_info.chat_template.clone(),
|
chat_template: model_info.chat_template.clone(),
|
||||||
|
|
@ -65,7 +65,7 @@ fn create_ggml_engine(
|
||||||
device: &super::Device,
|
device: &super::Device,
|
||||||
model_path: &str,
|
model_path: &str,
|
||||||
parallelism: u8,
|
parallelism: u8,
|
||||||
) -> Box<dyn TextGeneration> {
|
) -> impl TextGeneration {
|
||||||
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
|
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
|
||||||
.model_path(model_path.to_owned())
|
.model_path(model_path.to_owned())
|
||||||
.use_gpu(device.ggml_use_gpu())
|
.use_gpu(device.ggml_use_gpu())
|
||||||
|
|
@ -73,5 +73,5 @@ fn create_ggml_engine(
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
Box::new(llama_cpp_bindings::LlamaTextGeneration::new(options))
|
llama_cpp_bindings::LlamaTextGeneration::new(options)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -189,7 +189,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
prompt_template, ..
|
prompt_template, ..
|
||||||
},
|
},
|
||||||
) = create_engine(&args.model, args).await;
|
) = create_engine(&args.model, args).await;
|
||||||
let engine = Arc::new(engine);
|
|
||||||
let state =
|
let state =
|
||||||
completions::CompletionState::new(engine.clone(), code.clone(), prompt_template);
|
completions::CompletionState::new(engine.clone(), code.clone(), prompt_template);
|
||||||
Arc::new(state)
|
Arc::new(state)
|
||||||
|
|
@ -200,7 +199,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
let Some(chat_template) = chat_template else {
|
let Some(chat_template) = chat_template else {
|
||||||
panic!("Chat model requires specifying prompt template");
|
panic!("Chat model requires specifying prompt template");
|
||||||
};
|
};
|
||||||
let engine = Arc::new(engine);
|
|
||||||
let state = ChatService::new(engine, chat_template);
|
let state = ChatService::new(engine, chat_template);
|
||||||
Some(Arc::new(state))
|
Some(Arc::new(state))
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ use axum::{
|
||||||
};
|
};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use tabby_common::api::code::{BoxCodeSearch, CodeSearchError, SearchResponse};
|
use tabby_common::api::code::{CodeSearch, CodeSearchError, SearchResponse};
|
||||||
use tracing::{instrument, warn};
|
use tracing::{instrument, warn};
|
||||||
use utoipa::IntoParams;
|
use utoipa::IntoParams;
|
||||||
|
|
||||||
|
|
@ -36,7 +36,7 @@ pub struct SearchQuery {
|
||||||
)]
|
)]
|
||||||
#[instrument(skip(state, query))]
|
#[instrument(skip(state, query))]
|
||||||
pub async fn search(
|
pub async fn search(
|
||||||
State(state): State<Arc<BoxCodeSearch>>,
|
State(state): State<Arc<dyn CodeSearch>>,
|
||||||
query: Query<SearchQuery>,
|
query: Query<SearchQuery>,
|
||||||
) -> Result<Json<SearchResponse>, StatusCode> {
|
) -> Result<Json<SearchResponse>, StatusCode> {
|
||||||
match state
|
match state
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue