feat: support cuda devices in rust tabby (#149)

add-tracing
Meng Zhang 2023-05-25 23:23:07 -07:00 committed by GitHub
parent 484e754bc4
commit 8dfe49ec6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 133 additions and 28 deletions

View File

@ -29,7 +29,10 @@ fn main() {
let dst = config.build();
println!("cargo:rustc-link-search=native={}", dst.join("lib").display());
println!(
"cargo:rustc-link-search=native={}",
dst.join("lib").display()
);
println!("cargo:rustc-link-lib=ctranslate2");
// Tell cargo to invalidate the built crate whenever the wrapper changes

View File

@ -16,5 +16,10 @@ class TextInferenceEngine {
) const = 0;
};
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
std::unique_ptr<TextInferenceEngine> create_engine(
rust::Str model_path,
rust::Str device,
rust::Slice<const int32_t> device_indices,
size_t num_replicas_per_device
);
} // namespace

View File

@ -7,10 +7,7 @@ TextInferenceEngine::~TextInferenceEngine() {}
class TextInferenceEngineImpl : public TextInferenceEngine {
public:
TextInferenceEngineImpl(const std::string& model_path) {
ctranslate2::models::ModelLoader loader(model_path);
translator_ = std::make_unique<ctranslate2::Translator>(loader);
}
TextInferenceEngineImpl(std::unique_ptr<ctranslate2::Translator> translator) : translator_(std::move(translator)) {}
~TextInferenceEngineImpl() {}
@ -41,7 +38,32 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
std::unique_ptr<ctranslate2::Translator> translator_;
};
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
return std::make_unique<TextInferenceEngineImpl>(std::string(model_path));
std::unique_ptr<TextInferenceEngine> create_engine(
rust::Str model_path,
rust::Str device,
rust::Slice<const int32_t> device_indices,
size_t num_replicas_per_device
) {
// model_path.
std::string model_path_string(model_path);
ctranslate2::models::ModelLoader loader(model_path_string);
// device.
std::string device_string(device);
if (device_string == "cuda") {
loader.device = ctranslate2::Device::CUDA;
} else if (device_string == "cpu") {
loader.device = ctranslate2::Device::CPU;
}
// device_indices
loader.device_indices.clear();
std::copy(device_indices.begin(), device_indices.end(), std::back_inserter(loader.device_indices));
// num_replicas_per_device
loader.num_replicas_per_device = num_replicas_per_device;
auto translator = std::make_unique<ctranslate2::Translator>(loader);
return std::make_unique<TextInferenceEngineImpl>(std::move(translator));
}
} // namespace tabby

View File

@ -11,7 +11,13 @@ mod ffi {
type TextInferenceEngine;
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
fn create_engine(
model_path: &str,
device: &str,
device_indices: &[i32],
num_replicas_per_device: usize,
) -> UniquePtr<TextInferenceEngine>;
fn inference(
&self,
tokens: &[String],
@ -22,6 +28,19 @@ mod ffi {
}
}
#[derive(Builder)]
pub struct TextInferenceEngineCreateOptions {
model_path: String,
tokenizer_path: String,
device: String,
device_indices: Vec<i32>,
num_replicas_per_device: usize,
}
#[derive(Builder, Debug)]
pub struct TextInferenceOptions {
#[builder(default = "256")]
@ -43,10 +62,16 @@ unsafe impl Send for TextInferenceEngine {}
unsafe impl Sync for TextInferenceEngine {}
impl TextInferenceEngine {
pub fn create(model_path: &str, tokenizer_path: &str) -> Self where {
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
let engine = ffi::create_engine(
&options.model_path,
&options.device,
&options.device_indices,
options.num_replicas_per_device,
);
return TextInferenceEngine {
engine: Mutex::new(ffi::create_engine(model_path)),
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(),
engine: Mutex::new(engine),
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(),
};
}

View File

@ -11,11 +11,7 @@ struct Cli {
#[derive(Subcommand)]
pub enum Commands {
/// Serve the model
Serve {
/// path to model for serving
#[clap(long)]
model: String,
},
Serve(serve::ServeArgs),
}
mod serve;
@ -27,8 +23,8 @@ async fn main() {
// You can check for the existence of subcommands, and if found use their
// matches just as you would the top level cmd
match &cli.command {
Commands::Serve { model } => {
serve::main(model)
Commands::Serve(args) => {
serve::main(args)
.await
.expect("Error happens during the serve");
}

View File

@ -1,7 +1,9 @@
use axum::{extract::State, Json};
use ctranslate2_bindings::{TextInferenceEngine, TextInferenceOptionsBuilder};
use ctranslate2_bindings::{
TextInferenceEngine, TextInferenceEngineCreateOptions, TextInferenceOptionsBuilder,
};
use serde::{Deserialize, Serialize};
use std::{path::Path, sync::Arc};
use std::sync::Arc;
use utoipa::ToSchema;
mod languages;
@ -62,11 +64,8 @@ pub struct CompletionState {
}
impl CompletionState {
pub fn new(model: &str) -> Self {
let engine = TextInferenceEngine::create(
Path::new(model).join("cpu").to_str().unwrap(),
Path::new(model).join("tokenizer.json").to_str().unwrap(),
);
pub fn new(options: TextInferenceEngineCreateOptions) -> Self {
let engine = TextInferenceEngine::create(options);
return Self { engine: engine };
}
}

View File

@ -4,7 +4,10 @@ use std::{
};
use axum::{response::Redirect, routing, Router, Server};
use clap::Args;
use ctranslate2_bindings::TextInferenceEngineCreateOptionsBuilder;
use hyper::Error;
use std::path::Path;
use tower_http::cors::CorsLayer;
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
@ -24,8 +27,60 @@ mod events;
)]
struct ApiDoc;
pub async fn main(model: &str) -> Result<(), Error> {
let completions_state = Arc::new(completions::CompletionState::new(model));
#[derive(clap::ValueEnum, Clone)]
pub enum Device {
CPU,
CUDA,
}
impl std::fmt::Display for Device {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let printable = match *self {
Device::CPU => "cpu",
Device::CUDA => "cuda",
};
write!(f, "{}", printable)
}
}
#[derive(Args)]
pub struct ServeArgs {
/// path to model for serving
#[clap(long)]
model: String,
#[clap(long, default_value_t=Device::CPU)]
device: Device,
#[clap(long, default_values_t=[0])]
device_indices: Vec<i32>,
/// num_replicas_per_device
#[clap(long, default_value_t = 1)]
num_replicas_per_device: usize,
}
pub async fn main(args: &ServeArgs) -> Result<(), Error> {
let device = format!("{}", args.device);
let options = TextInferenceEngineCreateOptionsBuilder::default()
.model_path(
Path::new(&args.model)
.join(device.clone())
.display()
.to_string(),
)
.tokenizer_path(
Path::new(&args.model)
.join("tokenizer.json")
.display()
.to_string(),
)
.device(device)
.device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device)
.build()
.unwrap();
let completions_state = Arc::new(completions::CompletionState::new(options));
let app = Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))