refactor: extract TextGeneration trait (#324)

* add tabby-inference

* extract TextGeneration trait

* format

* Rename TextInferenceEngine to CTranslate2Engine
release-0.0
Meng Zhang 2023-08-02 14:12:51 +08:00 committed by GitHub
parent 83e1cf76d8
commit b8308b7118
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 89 additions and 53 deletions

61
Cargo.lock generated
View File

@ -161,18 +161,18 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.68" version = "0.1.72"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" checksum = "cc6dde6e4ed435a4c1ee4e73592f5ba9da2151af10076cc04858746af9352d09"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -443,7 +443,7 @@ dependencies = [
"heck", "heck",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -613,6 +613,7 @@ dependencies = [
name = "ctranslate2-bindings" name = "ctranslate2-bindings"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"async-trait",
"cmake", "cmake",
"cxx", "cxx",
"cxx-build", "cxx-build",
@ -620,6 +621,7 @@ dependencies = [
"derive_builder", "derive_builder",
"regex", "regex",
"rust-cxx-cmake-bridge", "rust-cxx-cmake-bridge",
"tabby-inference",
"tokenizers", "tokenizers",
"tokio", "tokio",
"tokio-util", "tokio-util",
@ -649,7 +651,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"scratch", "scratch",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -666,7 +668,7 @@ checksum = "4a076022ece33e7686fb76513518e219cca4fce5750a8ae6d1ce6c0f48fd1af9"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -1028,7 +1030,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -1683,7 +1685,7 @@ checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -1848,7 +1850,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -2049,7 +2051,7 @@ checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -2127,9 +2129,9 @@ dependencies = [
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.59" version = "1.0.66"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6aeca18b86b413c660b781aa319e4e2648a3e6f9eadc9b47e9038e6fe9f3451b" checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9"
dependencies = [ dependencies = [
"unicode-ident", "unicode-ident",
] ]
@ -2190,9 +2192,9 @@ dependencies = [
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.28" version = "1.0.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" checksum = "50f3b39ccfb720540debaa0164757101c08ecb8d326b15358ce76a62c7e85965"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
] ]
@ -2555,7 +2557,7 @@ checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -2757,9 +2759,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.18" version = "2.0.28"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32d41677bcbe24c20c52e7c70b0d8db04134c5d1066bf98662e2871ad200ea3e" checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -2794,6 +2796,7 @@ dependencies = [
"strum", "strum",
"tabby-common", "tabby-common",
"tabby-download", "tabby-download",
"tabby-inference",
"tabby-scheduler", "tabby-scheduler",
"tantivy", "tantivy",
"tokio", "tokio",
@ -2833,6 +2836,14 @@ dependencies = [
"tokio-retry", "tokio-retry",
] ]
[[package]]
name = "tabby-inference"
version = "0.1.0"
dependencies = [
"async-trait",
"derive_builder",
]
[[package]] [[package]]
name = "tabby-scheduler" name = "tabby-scheduler"
version = "0.1.0" version = "0.1.0"
@ -3002,7 +3013,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -3141,7 +3152,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -3359,7 +3370,7 @@ checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -3664,7 +3675,7 @@ dependencies = [
"proc-macro-error", "proc-macro-error",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -3712,7 +3723,7 @@ checksum = "3f67b459f42af2e6e1ee213cb9da4dbd022d3320788c3fb3e1b893093f1e45da"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -3786,7 +3797,7 @@ dependencies = [
"once_cell", "once_cell",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@ -3820,7 +3831,7 @@ checksum = "e128beba882dd1eb6200e1dc92ae6c5dbaa4311aa7bb211ca035779e5efc39f8"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.18", "syn 2.0.28",
"wasm-bindgen-backend", "wasm-bindgen-backend",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]

View File

@ -4,6 +4,7 @@ members = [
"crates/tabby-common", "crates/tabby-common",
"crates/tabby-scheduler", "crates/tabby-scheduler",
"crates/tabby-download", "crates/tabby-download",
"crates/tabby-inference",
"crates/ctranslate2-bindings", "crates/ctranslate2-bindings",
"crates/rust-cxx-cmake-bridge", "crates/rust-cxx-cmake-bridge",
] ]
@ -25,3 +26,4 @@ tracing-subscriber = "0.3"
anyhow = "1.0.71" anyhow = "1.0.71"
serde-jsonlines = "0.4.0" serde-jsonlines = "0.4.0"
tantivy = "0.19.2" tantivy = "0.19.2"
async-trait = "0.1.72"

View File

@ -11,6 +11,8 @@ regex = "1.8.4"
tokenizers = "0.13.3" tokenizers = "0.13.3"
tokio = { workspace = true, features = ["rt"] } tokio = { workspace = true, features = ["rt"] }
tokio-util = { workspace = true } tokio-util = { workspace = true }
tabby-inference = { path = "../tabby-inference" }
async-trait = { workspace = true }
[build-dependencies] [build-dependencies]
cxx-build = "1.0" cxx-build = "1.0"

View File

@ -1,11 +1,11 @@
use async_trait::async_trait;
use dashmap::DashMap; use dashmap::DashMap;
use derive_builder::Builder;
use regex::Regex; use regex::Regex;
use tabby_inference::{TextGeneration, TextGenerationOptions};
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
#[macro_use]
extern crate derive_builder;
#[cxx::bridge(namespace = "tabby")] #[cxx::bridge(namespace = "tabby")]
mod ffi { mod ffi {
extern "Rust" { extern "Rust" {
@ -49,7 +49,7 @@ unsafe impl Send for ffi::TextInferenceEngine {}
unsafe impl Sync for ffi::TextInferenceEngine {} unsafe impl Sync for ffi::TextInferenceEngine {}
#[derive(Builder, Debug)] #[derive(Builder, Debug)]
pub struct TextInferenceEngineCreateOptions { pub struct CTranslate2EngineOptions {
model_path: String, model_path: String,
model_type: String, model_type: String,
@ -65,17 +65,6 @@ pub struct TextInferenceEngineCreateOptions {
compute_type: String, compute_type: String,
} }
#[derive(Builder, Debug)]
pub struct TextInferenceOptions {
#[builder(default = "256")]
max_decoding_length: usize,
#[builder(default = "1.0")]
sampling_temperature: f32,
stop_words: &'static Vec<&'static str>,
}
pub struct InferenceContext { pub struct InferenceContext {
stop_re: Option<Regex>, stop_re: Option<Regex>,
cancel: CancellationToken, cancel: CancellationToken,
@ -92,14 +81,14 @@ impl InferenceContext {
} }
} }
pub struct TextInferenceEngine { pub struct CTranslate2Engine {
engine: cxx::SharedPtr<ffi::TextInferenceEngine>, engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
tokenizer: Tokenizer, tokenizer: Tokenizer,
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
} }
impl TextInferenceEngine { impl CTranslate2Engine {
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where { pub fn create(options: CTranslate2EngineOptions) -> Self where {
let engine = ffi::create_engine( let engine = ffi::create_engine(
&options.model_path, &options.model_path,
&options.model_type, &options.model_type,
@ -108,14 +97,18 @@ impl TextInferenceEngine {
&options.device_indices, &options.device_indices,
options.num_replicas_per_device, options.num_replicas_per_device,
); );
return TextInferenceEngine {
return Self {
engine, engine,
stop_regex_cache: DashMap::new(), stop_regex_cache: DashMap::new(),
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(), tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(),
}; };
} }
}
pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String { #[async_trait]
impl TextGeneration for CTranslate2Engine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let encoding = self.tokenizer.encode(prompt, true).unwrap(); let encoding = self.tokenizer.encode(prompt, true).unwrap();
let engine = self.engine.clone(); let engine = self.engine.clone();

View File

@ -0,0 +1,10 @@
[package]
name = "tabby-inference"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
async-trait = { workspace = true }
derive_builder = "0.12.0"

View File

@ -0,0 +1,18 @@
use async_trait::async_trait;
use derive_builder::Builder;
#[derive(Builder, Debug)]
pub struct TextGenerationOptions {
#[builder(default = "256")]
pub max_decoding_length: usize,
#[builder(default = "1.0")]
pub sampling_temperature: f32,
pub stop_words: &'static Vec<&'static str>,
}
#[async_trait]
pub trait TextGeneration {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
}

View File

@ -8,6 +8,7 @@ ctranslate2-bindings = { path = "../ctranslate2-bindings" }
tabby-common = { path = "../tabby-common" } tabby-common = { path = "../tabby-common" }
tabby-scheduler = { path = "../tabby-scheduler", optional = true } tabby-scheduler = { path = "../tabby-scheduler", optional = true }
tabby-download = { path = "../tabby-download" } tabby-download = { path = "../tabby-download" }
tabby-inference = { path = "../tabby-inference" }
axum = "0.6" axum = "0.6"
hyper = { version = "0.14", features = ["full"] } hyper = { version = "0.14", features = ["full"] }
tokio = { workspace = true, features = ["full"] } tokio = { workspace = true, features = ["full"] }

View File

@ -4,12 +4,11 @@ mod prompt;
use std::{path::Path, sync::Arc}; use std::{path::Path, sync::Arc};
use axum::{extract::State, Json}; use axum::{extract::State, Json};
use ctranslate2_bindings::{ use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
TextInferenceEngine, TextInferenceEngineCreateOptionsBuilder, TextInferenceOptionsBuilder,
};
use hyper::StatusCode; use hyper::StatusCode;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tabby_common::{config::Config, events, path::ModelDir}; use tabby_common::{config::Config, events, path::ModelDir};
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument}; use tracing::{debug, instrument};
use utoipa::ToSchema; use utoipa::ToSchema;
@ -79,7 +78,7 @@ pub async fn completion(
Json(request): Json<CompletionRequest>, Json(request): Json<CompletionRequest>,
) -> Result<Json<CompletionResponse>, StatusCode> { ) -> Result<Json<CompletionResponse>, StatusCode> {
let language = request.language.unwrap_or("unknown".to_string()); let language = request.language.unwrap_or("unknown".to_string());
let options = TextInferenceOptionsBuilder::default() let options = TextGenerationOptionsBuilder::default()
.max_decoding_length(128) .max_decoding_length(128)
.sampling_temperature(0.1) .sampling_temperature(0.1)
.stop_words(get_stop_words(&language)) .stop_words(get_stop_words(&language))
@ -101,7 +100,7 @@ pub async fn completion(
let prompt = state.prompt_builder.build(&language, segments); let prompt = state.prompt_builder.build(&language, segments);
debug!("PROMPT: {}", prompt); debug!("PROMPT: {}", prompt);
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
let text = state.engine.inference(&prompt, options).await; let text = state.engine.generate(&prompt, options).await;
events::Event::Completion { events::Event::Completion {
completion_id: &completion_id, completion_id: &completion_id,
@ -122,7 +121,7 @@ pub async fn completion(
} }
pub struct CompletionState { pub struct CompletionState {
engine: TextInferenceEngine, engine: CTranslate2Engine,
prompt_builder: prompt::PromptBuilder, prompt_builder: prompt::PromptBuilder,
} }
@ -133,7 +132,7 @@ impl CompletionState {
let device = format!("{}", args.device); let device = format!("{}", args.device);
let compute_type = format!("{}", args.compute_type); let compute_type = format!("{}", args.compute_type);
let options = TextInferenceEngineCreateOptionsBuilder::default() let options = CTranslate2EngineOptionsBuilder::default()
.model_path(model_dir.ctranslate2_dir()) .model_path(model_dir.ctranslate2_dir())
.tokenizer_path(model_dir.tokenizer_file()) .tokenizer_path(model_dir.tokenizer_file())
.device(device) .device(device)
@ -143,7 +142,7 @@ impl CompletionState {
.compute_type(compute_type) .compute_type(compute_type)
.build() .build()
.unwrap(); .unwrap();
let engine = TextInferenceEngine::create(options); let engine = CTranslate2Engine::create(options);
Self { Self {
engine, engine,
prompt_builder: prompt::PromptBuilder::new( prompt_builder: prompt::PromptBuilder::new(