add ctranslate2-bindings / tabby rust packages (#146)
* add ctranslate2-bindings * add fixme for linux build * turn off shared lib * add tabby-cliadd-tracing
parent
c08f5acf26
commit
a2476af373
|
|
@ -0,0 +1,3 @@
|
||||||
|
[submodule "crates/ctranslate2-bindings/CTranslate2"]
|
||||||
|
path = crates/ctranslate2-bindings/CTranslate2
|
||||||
|
url = https://github.com/OpenNMT/CTranslate2.git
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
/target
|
||||||
|
/Cargo.lock
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
[package]
|
||||||
|
name = "ctranslate2-bindings"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
cxx = "1.0"
|
||||||
|
derive_builder = "0.12.0"
|
||||||
|
tokenizers = "0.13.3"
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
bindgen = "0.53.1"
|
||||||
|
cxx-build = "1.0"
|
||||||
|
cmake = "0.1"
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
use cmake::Config;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let dst = Config::new("CTranslate2")
|
||||||
|
// Default flags.
|
||||||
|
.define("CMAKE_BUILD_TYPE", "Release")
|
||||||
|
.define("BUILD_CLI", "OFF")
|
||||||
|
.define("CMAKE_INSTALL_RPATH_USE_LINK_PATH", "ON")
|
||||||
|
|
||||||
|
// FIXME(meng): support linux build.
|
||||||
|
// OSX flags.
|
||||||
|
.define("CMAKE_OSX_ARCHITECTURES", "arm64")
|
||||||
|
.define("WITH_ACCELERATE", "ON")
|
||||||
|
.define("WITH_MKL", "OFF")
|
||||||
|
.define("OPENMP_RUNTIME", "NONE")
|
||||||
|
.define("WITH_RUY", "ON")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
println!("cargo:rustc-link-search=native={}", dst.join("lib").display());
|
||||||
|
println!("cargo:rustc-link-lib=ctranslate2");
|
||||||
|
|
||||||
|
// Tell cargo to invalidate the built crate whenever the wrapper changes
|
||||||
|
println!("cargo:rerun-if-changed=include/ctranslate2.h");
|
||||||
|
println!("cargo:rerun-if-changed=src/ctranslate2.cc");
|
||||||
|
println!("cargo:rerun-if-changed=src/lib.rs");
|
||||||
|
|
||||||
|
cxx_build::bridge("src/lib.rs")
|
||||||
|
.file("src/ctranslate2.cc")
|
||||||
|
.flag_if_supported("-std=c++17")
|
||||||
|
.flag_if_supported(&format!("-I{}", dst.join("include").display()))
|
||||||
|
.compile("cxxbridge");
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 692fb607ab67573fa5cf6e410aec24e8655844f8
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "rust/cxx.h"
|
||||||
|
|
||||||
|
namespace tabby {
|
||||||
|
|
||||||
|
class TextInferenceEngine {
|
||||||
|
public:
|
||||||
|
virtual ~TextInferenceEngine();
|
||||||
|
virtual rust::Vec<rust::String> inference(
|
||||||
|
rust::Slice<const rust::String> tokens,
|
||||||
|
size_t max_decoding_length,
|
||||||
|
float sampling_temperature,
|
||||||
|
size_t beam_size
|
||||||
|
) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
|
||||||
|
} // namespace
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
#include "ctranslate2-bindings/include/ctranslate2.h"
|
||||||
|
|
||||||
|
#include "ctranslate2/translator.h"
|
||||||
|
|
||||||
|
namespace tabby {
|
||||||
|
TextInferenceEngine::~TextInferenceEngine() {}
|
||||||
|
|
||||||
|
class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
|
public:
|
||||||
|
TextInferenceEngineImpl(const std::string& model_path) {
|
||||||
|
ctranslate2::models::ModelLoader loader(model_path);
|
||||||
|
translator_ = std::make_unique<ctranslate2::Translator>(loader);
|
||||||
|
}
|
||||||
|
|
||||||
|
~TextInferenceEngineImpl() {}
|
||||||
|
|
||||||
|
rust::Vec<rust::String> inference(
|
||||||
|
rust::Slice<const rust::String> tokens,
|
||||||
|
size_t max_decoding_length,
|
||||||
|
float sampling_temperature,
|
||||||
|
size_t beam_size
|
||||||
|
) const {
|
||||||
|
// Create options.
|
||||||
|
ctranslate2::TranslationOptions options;
|
||||||
|
options.max_decoding_length = max_decoding_length;
|
||||||
|
options.sampling_temperature = sampling_temperature;
|
||||||
|
options.beam_size = beam_size;
|
||||||
|
|
||||||
|
// Inference.
|
||||||
|
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
||||||
|
ctranslate2::TranslationResult result = translator_->translate_batch({ input_tokens }, options)[0];
|
||||||
|
const auto& output_tokens = result.output();
|
||||||
|
|
||||||
|
// Convert to rust vec.
|
||||||
|
rust::Vec<rust::String> output;
|
||||||
|
output.reserve(output_tokens.size());
|
||||||
|
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
std::unique_ptr<ctranslate2::Translator> translator_;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
||||||
|
return std::make_unique<TextInferenceEngineImpl>(std::string(model_path));
|
||||||
|
}
|
||||||
|
} // namespace tabby
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use tokenizers::tokenizer::{Model, Tokenizer};
|
||||||
|
|
||||||
|
#[macro_use]
|
||||||
|
extern crate derive_builder;
|
||||||
|
|
||||||
|
#[cxx::bridge(namespace = "tabby")]
|
||||||
|
mod ffi {
|
||||||
|
unsafe extern "C++" {
|
||||||
|
include!("ctranslate2-bindings/include/ctranslate2.h");
|
||||||
|
|
||||||
|
type TextInferenceEngine;
|
||||||
|
|
||||||
|
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
|
||||||
|
fn inference(
|
||||||
|
&self,
|
||||||
|
tokens: &[String],
|
||||||
|
max_decoding_length: usize,
|
||||||
|
sampling_temperature: f32,
|
||||||
|
beam_size: usize,
|
||||||
|
) -> Vec<String>;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Builder, Debug)]
|
||||||
|
pub struct TextInferenceOptions {
|
||||||
|
#[builder(default = "256")]
|
||||||
|
max_decoding_length: usize,
|
||||||
|
|
||||||
|
#[builder(default = "1.0")]
|
||||||
|
sampling_temperature: f32,
|
||||||
|
|
||||||
|
#[builder(default = "2")]
|
||||||
|
beam_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TextInferenceEngine {
|
||||||
|
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl Send for TextInferenceEngine {}
|
||||||
|
unsafe impl Sync for TextInferenceEngine {}
|
||||||
|
|
||||||
|
impl TextInferenceEngine {
|
||||||
|
pub fn create(model_path: &str, tokenizer_path: &str) -> Self where {
|
||||||
|
return TextInferenceEngine {
|
||||||
|
engine: Mutex::new(ffi::create_engine(model_path)),
|
||||||
|
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
|
||||||
|
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||||
|
let output_tokens = self.engine.lock().unwrap().inference(
|
||||||
|
encoding.get_tokens(),
|
||||||
|
options.max_decoding_length,
|
||||||
|
options.sampling_temperature,
|
||||||
|
options.beam_size,
|
||||||
|
);
|
||||||
|
|
||||||
|
let model = self.tokenizer.get_model();
|
||||||
|
let output_ids: Vec<u32> = output_tokens
|
||||||
|
.iter()
|
||||||
|
.map(|x| model.token_to_id(x).unwrap())
|
||||||
|
.collect();
|
||||||
|
self.tokenizer.decode(output_ids, true).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
/target
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,29 @@
|
||||||
|
[package]
|
||||||
|
name = "tabby"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
axum = "0.6"
|
||||||
|
hyper = { version = "0.14", features = ["full"] }
|
||||||
|
tokio = { version = "1.17", features = ["full"] }
|
||||||
|
tower = "0.4"
|
||||||
|
utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] }
|
||||||
|
utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
serde_json = "1.0"
|
||||||
|
env_logger = "0.10.0"
|
||||||
|
log = "0.4"
|
||||||
|
ctranslate2-bindings = { path = "../ctranslate2-bindings" }
|
||||||
|
tower-http = { version = "0.4.0", features = ["cors"] }
|
||||||
|
clap = { version = "4.3.0", features = ["derive"] }
|
||||||
|
regex = "1.8.3"
|
||||||
|
lazy_static = "1.4.0"
|
||||||
|
|
||||||
|
[dependencies.uuid]
|
||||||
|
version = "1.3.3"
|
||||||
|
features = [
|
||||||
|
"v4", # Lets you generate random UUIDs
|
||||||
|
"fast-rng", # Use a faster (but still sufficiently random) RNG
|
||||||
|
"macro-diagnostics", # Enable better diagnostics for compile-time UUIDs
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
#[command(propagate_version = true)]
|
||||||
|
struct Cli {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Commands,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Subcommand)]
|
||||||
|
pub enum Commands {
|
||||||
|
/// Serve the model
|
||||||
|
Serve {
|
||||||
|
/// path to model for serving
|
||||||
|
#[clap(long)]
|
||||||
|
model: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mod serve;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
let cli = Cli::parse();
|
||||||
|
|
||||||
|
// You can check for the existence of subcommands, and if found use their
|
||||||
|
// matches just as you would the top level cmd
|
||||||
|
match &cli.command {
|
||||||
|
Commands::Serve { model } => {
|
||||||
|
serve::main(model)
|
||||||
|
.await
|
||||||
|
.expect("Error happens during the serve");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
use axum::{extract::State, Json};
|
||||||
|
use ctranslate2_bindings::{TextInferenceEngine, TextInferenceOptionsBuilder};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{path::Path, sync::Arc};
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
mod languages;
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct CompletionRequest {
|
||||||
|
/// https://code.visualstudio.com/docs/languages/identifiers
|
||||||
|
#[schema(example = "python")]
|
||||||
|
language: String,
|
||||||
|
|
||||||
|
#[schema(example = "def fib(n):")]
|
||||||
|
prompt: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct Choice {
|
||||||
|
index: u32,
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct CompletionResponse {
|
||||||
|
id: String,
|
||||||
|
created: u64,
|
||||||
|
choices: Vec<Choice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
path = "/v1/completions",
|
||||||
|
request_body = CompletionRequest ,
|
||||||
|
)]
|
||||||
|
pub async fn completion(
|
||||||
|
State(state): State<Arc<CompletionState>>,
|
||||||
|
Json(request): Json<CompletionRequest>,
|
||||||
|
) -> Json<CompletionResponse> {
|
||||||
|
let options = TextInferenceOptionsBuilder::default()
|
||||||
|
.max_decoding_length(64)
|
||||||
|
.sampling_temperature(0.2)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
let text = state.engine.inference(&request.prompt, options);
|
||||||
|
let filtered_text = languages::remove_stop_words(&request.language, &text);
|
||||||
|
|
||||||
|
Json(CompletionResponse {
|
||||||
|
id: format!("cmpl-{}", uuid::Uuid::new_v4()),
|
||||||
|
created: timestamp(),
|
||||||
|
choices: [Choice {
|
||||||
|
index: 0,
|
||||||
|
text: filtered_text.to_string(),
|
||||||
|
}]
|
||||||
|
.to_vec(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CompletionState {
|
||||||
|
engine: TextInferenceEngine,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionState {
|
||||||
|
pub fn new(model: &str) -> Self {
|
||||||
|
let engine = TextInferenceEngine::create(
|
||||||
|
Path::new(model).join("cpu").to_str().unwrap(),
|
||||||
|
Path::new(model).join("tokenizer.json").to_str().unwrap(),
|
||||||
|
);
|
||||||
|
return Self { engine: engine };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn timestamp() -> u64 {
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
let start = SystemTime::now();
|
||||||
|
start
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.expect("Time went backwards")
|
||||||
|
.as_secs()
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
use lazy_static::lazy_static;
|
||||||
|
use regex::Regex;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
lazy_static! {
|
||||||
|
static ref DEFAULT: Regex = Regex::new(r"(?m)^\n\n").unwrap();
|
||||||
|
static ref LANGUAGES: HashMap<&'static str, Regex> = {
|
||||||
|
let mut map = HashMap::new();
|
||||||
|
map.insert(
|
||||||
|
"python",
|
||||||
|
Regex::new(r"(?m)^(\n\n|def|#|from|class)").unwrap(),
|
||||||
|
);
|
||||||
|
map
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remove_stop_words<'a>(language: &'a str, text: &'a str) -> &'a str {
|
||||||
|
let re = LANGUAGES.get(language).unwrap_or(&DEFAULT);
|
||||||
|
let position = re.find_iter(&text).next();
|
||||||
|
if let Some(m) = position {
|
||||||
|
&text[..m.start()]
|
||||||
|
} else {
|
||||||
|
&text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
use axum::Json;
|
||||||
|
use hyper::StatusCode;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct LogEventRequest {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
event_type: String,
|
||||||
|
completion_id: String,
|
||||||
|
choice_index: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
path = "/v1/events",
|
||||||
|
request_body = LogEventRequest,
|
||||||
|
)]
|
||||||
|
pub async fn log_event(Json(request): Json<LogEventRequest>) -> StatusCode {
|
||||||
|
println!("log_event: {:?}", request);
|
||||||
|
StatusCode::OK
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
use std::{
|
||||||
|
net::{Ipv4Addr, SocketAddr},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
|
use axum::{response::Redirect, routing, Router, Server};
|
||||||
|
use hyper::Error;
|
||||||
|
use tower_http::cors::CorsLayer;
|
||||||
|
use utoipa::OpenApi;
|
||||||
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
|
mod completions;
|
||||||
|
mod events;
|
||||||
|
|
||||||
|
#[derive(OpenApi)]
|
||||||
|
#[openapi(
|
||||||
|
paths(events::log_event, completions::completion,),
|
||||||
|
components(schemas(
|
||||||
|
events::LogEventRequest,
|
||||||
|
completions::CompletionRequest,
|
||||||
|
completions::CompletionResponse,
|
||||||
|
completions::Choice
|
||||||
|
))
|
||||||
|
)]
|
||||||
|
struct ApiDoc;
|
||||||
|
|
||||||
|
pub async fn main(model: &str) -> Result<(), Error> {
|
||||||
|
let completions_state = Arc::new(completions::CompletionState::new(model));
|
||||||
|
|
||||||
|
let app = Router::new()
|
||||||
|
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
||||||
|
.route("/v1/events", routing::post(events::log_event))
|
||||||
|
.route("/v1/completions", routing::post(completions::completion))
|
||||||
|
.with_state(completions_state)
|
||||||
|
.route(
|
||||||
|
"/",
|
||||||
|
routing::get(|| async { Redirect::temporary("/swagger-ui") }),
|
||||||
|
)
|
||||||
|
.layer(CorsLayer::permissive());
|
||||||
|
|
||||||
|
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, 8080));
|
||||||
|
println!("Listening at {}", address);
|
||||||
|
Server::bind(&address).serve(app.into_make_service()).await
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue