From 8dfe49ec6c84fc3b8bc83e1b5621a3860eda34e1 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Thu, 25 May 2023 23:23:07 -0700 Subject: [PATCH] feat: support cuda devices in rust tabby (#149) --- crates/ctranslate2-bindings/build.rs | 5 +- .../include/ctranslate2.h | 7 ++- .../ctranslate2-bindings/src/ctranslate2.cc | 34 +++++++++-- crates/ctranslate2-bindings/src/lib.rs | 33 +++++++++-- crates/tabby/src/main.rs | 10 +--- crates/tabby/src/serve/completions.rs | 13 ++-- crates/tabby/src/serve/mod.rs | 59 ++++++++++++++++++- 7 files changed, 133 insertions(+), 28 deletions(-) diff --git a/crates/ctranslate2-bindings/build.rs b/crates/ctranslate2-bindings/build.rs index b725305..916e1fd 100644 --- a/crates/ctranslate2-bindings/build.rs +++ b/crates/ctranslate2-bindings/build.rs @@ -29,7 +29,10 @@ fn main() { let dst = config.build(); - println!("cargo:rustc-link-search=native={}", dst.join("lib").display()); + println!( + "cargo:rustc-link-search=native={}", + dst.join("lib").display() + ); println!("cargo:rustc-link-lib=ctranslate2"); // Tell cargo to invalidate the built crate whenever the wrapper changes diff --git a/crates/ctranslate2-bindings/include/ctranslate2.h b/crates/ctranslate2-bindings/include/ctranslate2.h index 2cb0df6..3ed8f1b 100644 --- a/crates/ctranslate2-bindings/include/ctranslate2.h +++ b/crates/ctranslate2-bindings/include/ctranslate2.h @@ -16,5 +16,10 @@ class TextInferenceEngine { ) const = 0; }; -std::unique_ptr create_engine(rust::Str model_path); +std::unique_ptr create_engine( + rust::Str model_path, + rust::Str device, + rust::Slice device_indices, + size_t num_replicas_per_device +); } // namespace diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc index 99206bc..61dffb6 100644 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ b/crates/ctranslate2-bindings/src/ctranslate2.cc @@ -7,10 +7,7 @@ TextInferenceEngine::~TextInferenceEngine() {} class TextInferenceEngineImpl : public TextInferenceEngine { public: - TextInferenceEngineImpl(const std::string& model_path) { - ctranslate2::models::ModelLoader loader(model_path); - translator_ = std::make_unique(loader); - } + TextInferenceEngineImpl(std::unique_ptr translator) : translator_(std::move(translator)) {} ~TextInferenceEngineImpl() {} @@ -41,7 +38,32 @@ class TextInferenceEngineImpl : public TextInferenceEngine { std::unique_ptr translator_; }; -std::unique_ptr create_engine(rust::Str model_path) { - return std::make_unique(std::string(model_path)); +std::unique_ptr create_engine( + rust::Str model_path, + rust::Str device, + rust::Slice device_indices, + size_t num_replicas_per_device +) { + // model_path. + std::string model_path_string(model_path); + ctranslate2::models::ModelLoader loader(model_path_string); + + // device. + std::string device_string(device); + if (device_string == "cuda") { + loader.device = ctranslate2::Device::CUDA; + } else if (device_string == "cpu") { + loader.device = ctranslate2::Device::CPU; + } + + // device_indices + loader.device_indices.clear(); + std::copy(device_indices.begin(), device_indices.end(), std::back_inserter(loader.device_indices)); + + // num_replicas_per_device + loader.num_replicas_per_device = num_replicas_per_device; + + auto translator = std::make_unique(loader); + return std::make_unique(std::move(translator)); } } // namespace tabby diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 5b59a7f..6a6ce74 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -11,7 +11,13 @@ mod ffi { type TextInferenceEngine; - fn create_engine(model_path: &str) -> UniquePtr; + fn create_engine( + model_path: &str, + device: &str, + device_indices: &[i32], + num_replicas_per_device: usize, + ) -> UniquePtr; + fn inference( &self, tokens: &[String], @@ -22,6 +28,19 @@ mod ffi { } } +#[derive(Builder)] +pub struct TextInferenceEngineCreateOptions { + model_path: String, + + tokenizer_path: String, + + device: String, + + device_indices: Vec, + + num_replicas_per_device: usize, +} + #[derive(Builder, Debug)] pub struct TextInferenceOptions { #[builder(default = "256")] @@ -43,10 +62,16 @@ unsafe impl Send for TextInferenceEngine {} unsafe impl Sync for TextInferenceEngine {} impl TextInferenceEngine { - pub fn create(model_path: &str, tokenizer_path: &str) -> Self where { + pub fn create(options: TextInferenceEngineCreateOptions) -> Self where { + let engine = ffi::create_engine( + &options.model_path, + &options.device, + &options.device_indices, + options.num_replicas_per_device, + ); return TextInferenceEngine { - engine: Mutex::new(ffi::create_engine(model_path)), - tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(), + engine: Mutex::new(engine), + tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(), }; } diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index 4fcf593..7e3605e 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -11,11 +11,7 @@ struct Cli { #[derive(Subcommand)] pub enum Commands { /// Serve the model - Serve { - /// path to model for serving - #[clap(long)] - model: String, - }, + Serve(serve::ServeArgs), } mod serve; @@ -27,8 +23,8 @@ async fn main() { // You can check for the existence of subcommands, and if found use their // matches just as you would the top level cmd match &cli.command { - Commands::Serve { model } => { - serve::main(model) + Commands::Serve(args) => { + serve::main(args) .await .expect("Error happens during the serve"); } diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index fb0326c..125ac26 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -1,7 +1,9 @@ use axum::{extract::State, Json}; -use ctranslate2_bindings::{TextInferenceEngine, TextInferenceOptionsBuilder}; +use ctranslate2_bindings::{ + TextInferenceEngine, TextInferenceEngineCreateOptions, TextInferenceOptionsBuilder, +}; use serde::{Deserialize, Serialize}; -use std::{path::Path, sync::Arc}; +use std::sync::Arc; use utoipa::ToSchema; mod languages; @@ -62,11 +64,8 @@ pub struct CompletionState { } impl CompletionState { - pub fn new(model: &str) -> Self { - let engine = TextInferenceEngine::create( - Path::new(model).join("cpu").to_str().unwrap(), - Path::new(model).join("tokenizer.json").to_str().unwrap(), - ); + pub fn new(options: TextInferenceEngineCreateOptions) -> Self { + let engine = TextInferenceEngine::create(options); return Self { engine: engine }; } } diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 36dda4f..7a8a9c1 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -4,7 +4,10 @@ use std::{ }; use axum::{response::Redirect, routing, Router, Server}; +use clap::Args; +use ctranslate2_bindings::TextInferenceEngineCreateOptionsBuilder; use hyper::Error; +use std::path::Path; use tower_http::cors::CorsLayer; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; @@ -24,8 +27,60 @@ mod events; )] struct ApiDoc; -pub async fn main(model: &str) -> Result<(), Error> { - let completions_state = Arc::new(completions::CompletionState::new(model)); +#[derive(clap::ValueEnum, Clone)] +pub enum Device { + CPU, + CUDA, +} + +impl std::fmt::Display for Device { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let printable = match *self { + Device::CPU => "cpu", + Device::CUDA => "cuda", + }; + write!(f, "{}", printable) + } +} + +#[derive(Args)] +pub struct ServeArgs { + /// path to model for serving + #[clap(long)] + model: String, + + #[clap(long, default_value_t=Device::CPU)] + device: Device, + + #[clap(long, default_values_t=[0])] + device_indices: Vec, + + /// num_replicas_per_device + #[clap(long, default_value_t = 1)] + num_replicas_per_device: usize, +} + +pub async fn main(args: &ServeArgs) -> Result<(), Error> { + let device = format!("{}", args.device); + let options = TextInferenceEngineCreateOptionsBuilder::default() + .model_path( + Path::new(&args.model) + .join(device.clone()) + .display() + .to_string(), + ) + .tokenizer_path( + Path::new(&args.model) + .join("tokenizer.json") + .display() + .to_string(), + ) + .device(device) + .device_indices(args.device_indices.clone()) + .num_replicas_per_device(args.num_replicas_per_device) + .build() + .unwrap(); + let completions_state = Arc::new(completions::CompletionState::new(options)); let app = Router::new() .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))