refactor: extract TextGeneration trait (#324)
* add tabby-inference * extract TextGeneration trait * format * Rename TextInferenceEngine to CTranslate2Enginerelease-0.0
parent
83e1cf76d8
commit
b8308b7118
|
|
@ -161,18 +161,18 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.68"
|
||||
version = "0.1.72"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842"
|
||||
checksum = "cc6dde6e4ed435a4c1ee4e73592f5ba9da2151af10076cc04858746af9352d09"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -443,7 +443,7 @@ dependencies = [
|
|||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -613,6 +613,7 @@ dependencies = [
|
|||
name = "ctranslate2-bindings"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"cmake",
|
||||
"cxx",
|
||||
"cxx-build",
|
||||
|
|
@ -620,6 +621,7 @@ dependencies = [
|
|||
"derive_builder",
|
||||
"regex",
|
||||
"rust-cxx-cmake-bridge",
|
||||
"tabby-inference",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
|
|
@ -649,7 +651,7 @@ dependencies = [
|
|||
"proc-macro2",
|
||||
"quote",
|
||||
"scratch",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -666,7 +668,7 @@ checksum = "4a076022ece33e7686fb76513518e219cca4fce5750a8ae6d1ce6c0f48fd1af9"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1028,7 +1030,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1683,7 +1685,7 @@ checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1848,7 +1850,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2049,7 +2051,7 @@ checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2127,9 +2129,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.59"
|
||||
version = "1.0.66"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6aeca18b86b413c660b781aa319e4e2648a3e6f9eadc9b47e9038e6fe9f3451b"
|
||||
checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
|
@ -2190,9 +2192,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.28"
|
||||
version = "1.0.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488"
|
||||
checksum = "50f3b39ccfb720540debaa0164757101c08ecb8d326b15358ce76a62c7e85965"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
|
@ -2555,7 +2557,7 @@ checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2757,9 +2759,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.18"
|
||||
version = "2.0.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32d41677bcbe24c20c52e7c70b0d8db04134c5d1066bf98662e2871ad200ea3e"
|
||||
checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
|
@ -2794,6 +2796,7 @@ dependencies = [
|
|||
"strum",
|
||||
"tabby-common",
|
||||
"tabby-download",
|
||||
"tabby-inference",
|
||||
"tabby-scheduler",
|
||||
"tantivy",
|
||||
"tokio",
|
||||
|
|
@ -2833,6 +2836,14 @@ dependencies = [
|
|||
"tokio-retry",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tabby-inference"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"derive_builder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tabby-scheduler"
|
||||
version = "0.1.0"
|
||||
|
|
@ -3002,7 +3013,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3141,7 +3152,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3359,7 +3370,7 @@ checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3664,7 +3675,7 @@ dependencies = [
|
|||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3712,7 +3723,7 @@ checksum = "3f67b459f42af2e6e1ee213cb9da4dbd022d3320788c3fb3e1b893093f1e45da"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3786,7 +3797,7 @@ dependencies = [
|
|||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
|
|
@ -3820,7 +3831,7 @@ checksum = "e128beba882dd1eb6200e1dc92ae6c5dbaa4311aa7bb211ca035779e5efc39f8"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
"syn 2.0.28",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ members = [
|
|||
"crates/tabby-common",
|
||||
"crates/tabby-scheduler",
|
||||
"crates/tabby-download",
|
||||
"crates/tabby-inference",
|
||||
"crates/ctranslate2-bindings",
|
||||
"crates/rust-cxx-cmake-bridge",
|
||||
]
|
||||
|
|
@ -25,3 +26,4 @@ tracing-subscriber = "0.3"
|
|||
anyhow = "1.0.71"
|
||||
serde-jsonlines = "0.4.0"
|
||||
tantivy = "0.19.2"
|
||||
async-trait = "0.1.72"
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ regex = "1.8.4"
|
|||
tokenizers = "0.13.3"
|
||||
tokio = { workspace = true, features = ["rt"] }
|
||||
tokio-util = { workspace = true }
|
||||
tabby-inference = { path = "../tabby-inference" }
|
||||
async-trait = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
cxx-build = "1.0"
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use derive_builder::Builder;
|
||||
use regex::Regex;
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptions};
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_builder;
|
||||
|
||||
#[cxx::bridge(namespace = "tabby")]
|
||||
mod ffi {
|
||||
extern "Rust" {
|
||||
|
|
@ -49,7 +49,7 @@ unsafe impl Send for ffi::TextInferenceEngine {}
|
|||
unsafe impl Sync for ffi::TextInferenceEngine {}
|
||||
|
||||
#[derive(Builder, Debug)]
|
||||
pub struct TextInferenceEngineCreateOptions {
|
||||
pub struct CTranslate2EngineOptions {
|
||||
model_path: String,
|
||||
|
||||
model_type: String,
|
||||
|
|
@ -65,17 +65,6 @@ pub struct TextInferenceEngineCreateOptions {
|
|||
compute_type: String,
|
||||
}
|
||||
|
||||
#[derive(Builder, Debug)]
|
||||
pub struct TextInferenceOptions {
|
||||
#[builder(default = "256")]
|
||||
max_decoding_length: usize,
|
||||
|
||||
#[builder(default = "1.0")]
|
||||
sampling_temperature: f32,
|
||||
|
||||
stop_words: &'static Vec<&'static str>,
|
||||
}
|
||||
|
||||
pub struct InferenceContext {
|
||||
stop_re: Option<Regex>,
|
||||
cancel: CancellationToken,
|
||||
|
|
@ -92,14 +81,14 @@ impl InferenceContext {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct TextInferenceEngine {
|
||||
pub struct CTranslate2Engine {
|
||||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||
tokenizer: Tokenizer,
|
||||
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
||||
}
|
||||
|
||||
impl TextInferenceEngine {
|
||||
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
|
||||
impl CTranslate2Engine {
|
||||
pub fn create(options: CTranslate2EngineOptions) -> Self where {
|
||||
let engine = ffi::create_engine(
|
||||
&options.model_path,
|
||||
&options.model_type,
|
||||
|
|
@ -108,14 +97,18 @@ impl TextInferenceEngine {
|
|||
&options.device_indices,
|
||||
options.num_replicas_per_device,
|
||||
);
|
||||
return TextInferenceEngine {
|
||||
|
||||
return Self {
|
||||
engine,
|
||||
stop_regex_cache: DashMap::new(),
|
||||
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
|
||||
#[async_trait]
|
||||
impl TextGeneration for CTranslate2Engine {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||
let engine = self.engine.clone();
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
[package]
|
||||
name = "tabby-inference"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
async-trait = { workspace = true }
|
||||
derive_builder = "0.12.0"
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
use async_trait::async_trait;
|
||||
use derive_builder::Builder;
|
||||
|
||||
#[derive(Builder, Debug)]
|
||||
pub struct TextGenerationOptions {
|
||||
#[builder(default = "256")]
|
||||
pub max_decoding_length: usize,
|
||||
|
||||
#[builder(default = "1.0")]
|
||||
pub sampling_temperature: f32,
|
||||
|
||||
pub stop_words: &'static Vec<&'static str>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait TextGeneration {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
|
||||
}
|
||||
|
|
@ -8,6 +8,7 @@ ctranslate2-bindings = { path = "../ctranslate2-bindings" }
|
|||
tabby-common = { path = "../tabby-common" }
|
||||
tabby-scheduler = { path = "../tabby-scheduler", optional = true }
|
||||
tabby-download = { path = "../tabby-download" }
|
||||
tabby-inference = { path = "../tabby-inference" }
|
||||
axum = "0.6"
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
|
|
|
|||
|
|
@ -4,12 +4,11 @@ mod prompt;
|
|||
use std::{path::Path, sync::Arc};
|
||||
|
||||
use axum::{extract::State, Json};
|
||||
use ctranslate2_bindings::{
|
||||
TextInferenceEngine, TextInferenceEngineCreateOptionsBuilder, TextInferenceOptionsBuilder,
|
||||
};
|
||||
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
|
||||
use hyper::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::{config::Config, events, path::ModelDir};
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
||||
use tracing::{debug, instrument};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
|
|
@ -79,7 +78,7 @@ pub async fn completion(
|
|||
Json(request): Json<CompletionRequest>,
|
||||
) -> Result<Json<CompletionResponse>, StatusCode> {
|
||||
let language = request.language.unwrap_or("unknown".to_string());
|
||||
let options = TextInferenceOptionsBuilder::default()
|
||||
let options = TextGenerationOptionsBuilder::default()
|
||||
.max_decoding_length(128)
|
||||
.sampling_temperature(0.1)
|
||||
.stop_words(get_stop_words(&language))
|
||||
|
|
@ -101,7 +100,7 @@ pub async fn completion(
|
|||
let prompt = state.prompt_builder.build(&language, segments);
|
||||
debug!("PROMPT: {}", prompt);
|
||||
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
|
||||
let text = state.engine.inference(&prompt, options).await;
|
||||
let text = state.engine.generate(&prompt, options).await;
|
||||
|
||||
events::Event::Completion {
|
||||
completion_id: &completion_id,
|
||||
|
|
@ -122,7 +121,7 @@ pub async fn completion(
|
|||
}
|
||||
|
||||
pub struct CompletionState {
|
||||
engine: TextInferenceEngine,
|
||||
engine: CTranslate2Engine,
|
||||
prompt_builder: prompt::PromptBuilder,
|
||||
}
|
||||
|
||||
|
|
@ -133,7 +132,7 @@ impl CompletionState {
|
|||
|
||||
let device = format!("{}", args.device);
|
||||
let compute_type = format!("{}", args.compute_type);
|
||||
let options = TextInferenceEngineCreateOptionsBuilder::default()
|
||||
let options = CTranslate2EngineOptionsBuilder::default()
|
||||
.model_path(model_dir.ctranslate2_dir())
|
||||
.tokenizer_path(model_dir.tokenizer_file())
|
||||
.device(device)
|
||||
|
|
@ -143,7 +142,7 @@ impl CompletionState {
|
|||
.compute_type(compute_type)
|
||||
.build()
|
||||
.unwrap();
|
||||
let engine = TextInferenceEngine::create(options);
|
||||
let engine = CTranslate2Engine::create(options);
|
||||
Self {
|
||||
engine,
|
||||
prompt_builder: prompt::PromptBuilder::new(
|
||||
|
|
|
|||
Loading…
Reference in New Issue