feat: support cuda devices in rust tabby (#149)
parent
484e754bc4
commit
8dfe49ec6c
|
|
@ -29,7 +29,10 @@ fn main() {
|
||||||
|
|
||||||
let dst = config.build();
|
let dst = config.build();
|
||||||
|
|
||||||
println!("cargo:rustc-link-search=native={}", dst.join("lib").display());
|
println!(
|
||||||
|
"cargo:rustc-link-search=native={}",
|
||||||
|
dst.join("lib").display()
|
||||||
|
);
|
||||||
println!("cargo:rustc-link-lib=ctranslate2");
|
println!("cargo:rustc-link-lib=ctranslate2");
|
||||||
|
|
||||||
// Tell cargo to invalidate the built crate whenever the wrapper changes
|
// Tell cargo to invalidate the built crate whenever the wrapper changes
|
||||||
|
|
|
||||||
|
|
@ -16,5 +16,10 @@ class TextInferenceEngine {
|
||||||
) const = 0;
|
) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
|
std::unique_ptr<TextInferenceEngine> create_engine(
|
||||||
|
rust::Str model_path,
|
||||||
|
rust::Str device,
|
||||||
|
rust::Slice<const int32_t> device_indices,
|
||||||
|
size_t num_replicas_per_device
|
||||||
|
);
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,7 @@ TextInferenceEngine::~TextInferenceEngine() {}
|
||||||
|
|
||||||
class TextInferenceEngineImpl : public TextInferenceEngine {
|
class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
public:
|
public:
|
||||||
TextInferenceEngineImpl(const std::string& model_path) {
|
TextInferenceEngineImpl(std::unique_ptr<ctranslate2::Translator> translator) : translator_(std::move(translator)) {}
|
||||||
ctranslate2::models::ModelLoader loader(model_path);
|
|
||||||
translator_ = std::make_unique<ctranslate2::Translator>(loader);
|
|
||||||
}
|
|
||||||
|
|
||||||
~TextInferenceEngineImpl() {}
|
~TextInferenceEngineImpl() {}
|
||||||
|
|
||||||
|
|
@ -41,7 +38,32 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
std::unique_ptr<ctranslate2::Translator> translator_;
|
std::unique_ptr<ctranslate2::Translator> translator_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
std::unique_ptr<TextInferenceEngine> create_engine(
|
||||||
return std::make_unique<TextInferenceEngineImpl>(std::string(model_path));
|
rust::Str model_path,
|
||||||
|
rust::Str device,
|
||||||
|
rust::Slice<const int32_t> device_indices,
|
||||||
|
size_t num_replicas_per_device
|
||||||
|
) {
|
||||||
|
// model_path.
|
||||||
|
std::string model_path_string(model_path);
|
||||||
|
ctranslate2::models::ModelLoader loader(model_path_string);
|
||||||
|
|
||||||
|
// device.
|
||||||
|
std::string device_string(device);
|
||||||
|
if (device_string == "cuda") {
|
||||||
|
loader.device = ctranslate2::Device::CUDA;
|
||||||
|
} else if (device_string == "cpu") {
|
||||||
|
loader.device = ctranslate2::Device::CPU;
|
||||||
|
}
|
||||||
|
|
||||||
|
// device_indices
|
||||||
|
loader.device_indices.clear();
|
||||||
|
std::copy(device_indices.begin(), device_indices.end(), std::back_inserter(loader.device_indices));
|
||||||
|
|
||||||
|
// num_replicas_per_device
|
||||||
|
loader.num_replicas_per_device = num_replicas_per_device;
|
||||||
|
|
||||||
|
auto translator = std::make_unique<ctranslate2::Translator>(loader);
|
||||||
|
return std::make_unique<TextInferenceEngineImpl>(std::move(translator));
|
||||||
}
|
}
|
||||||
} // namespace tabby
|
} // namespace tabby
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,13 @@ mod ffi {
|
||||||
|
|
||||||
type TextInferenceEngine;
|
type TextInferenceEngine;
|
||||||
|
|
||||||
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
|
fn create_engine(
|
||||||
|
model_path: &str,
|
||||||
|
device: &str,
|
||||||
|
device_indices: &[i32],
|
||||||
|
num_replicas_per_device: usize,
|
||||||
|
) -> UniquePtr<TextInferenceEngine>;
|
||||||
|
|
||||||
fn inference(
|
fn inference(
|
||||||
&self,
|
&self,
|
||||||
tokens: &[String],
|
tokens: &[String],
|
||||||
|
|
@ -22,6 +28,19 @@ mod ffi {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Builder)]
|
||||||
|
pub struct TextInferenceEngineCreateOptions {
|
||||||
|
model_path: String,
|
||||||
|
|
||||||
|
tokenizer_path: String,
|
||||||
|
|
||||||
|
device: String,
|
||||||
|
|
||||||
|
device_indices: Vec<i32>,
|
||||||
|
|
||||||
|
num_replicas_per_device: usize,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Builder, Debug)]
|
#[derive(Builder, Debug)]
|
||||||
pub struct TextInferenceOptions {
|
pub struct TextInferenceOptions {
|
||||||
#[builder(default = "256")]
|
#[builder(default = "256")]
|
||||||
|
|
@ -43,10 +62,16 @@ unsafe impl Send for TextInferenceEngine {}
|
||||||
unsafe impl Sync for TextInferenceEngine {}
|
unsafe impl Sync for TextInferenceEngine {}
|
||||||
|
|
||||||
impl TextInferenceEngine {
|
impl TextInferenceEngine {
|
||||||
pub fn create(model_path: &str, tokenizer_path: &str) -> Self where {
|
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
|
||||||
|
let engine = ffi::create_engine(
|
||||||
|
&options.model_path,
|
||||||
|
&options.device,
|
||||||
|
&options.device_indices,
|
||||||
|
options.num_replicas_per_device,
|
||||||
|
);
|
||||||
return TextInferenceEngine {
|
return TextInferenceEngine {
|
||||||
engine: Mutex::new(ffi::create_engine(model_path)),
|
engine: Mutex::new(engine),
|
||||||
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(),
|
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,11 +11,7 @@ struct Cli {
|
||||||
#[derive(Subcommand)]
|
#[derive(Subcommand)]
|
||||||
pub enum Commands {
|
pub enum Commands {
|
||||||
/// Serve the model
|
/// Serve the model
|
||||||
Serve {
|
Serve(serve::ServeArgs),
|
||||||
/// path to model for serving
|
|
||||||
#[clap(long)]
|
|
||||||
model: String,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mod serve;
|
mod serve;
|
||||||
|
|
@ -27,8 +23,8 @@ async fn main() {
|
||||||
// You can check for the existence of subcommands, and if found use their
|
// You can check for the existence of subcommands, and if found use their
|
||||||
// matches just as you would the top level cmd
|
// matches just as you would the top level cmd
|
||||||
match &cli.command {
|
match &cli.command {
|
||||||
Commands::Serve { model } => {
|
Commands::Serve(args) => {
|
||||||
serve::main(model)
|
serve::main(args)
|
||||||
.await
|
.await
|
||||||
.expect("Error happens during the serve");
|
.expect("Error happens during the serve");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
use axum::{extract::State, Json};
|
use axum::{extract::State, Json};
|
||||||
use ctranslate2_bindings::{TextInferenceEngine, TextInferenceOptionsBuilder};
|
use ctranslate2_bindings::{
|
||||||
|
TextInferenceEngine, TextInferenceEngineCreateOptions, TextInferenceOptionsBuilder,
|
||||||
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{path::Path, sync::Arc};
|
use std::sync::Arc;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
mod languages;
|
mod languages;
|
||||||
|
|
@ -62,11 +64,8 @@ pub struct CompletionState {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CompletionState {
|
impl CompletionState {
|
||||||
pub fn new(model: &str) -> Self {
|
pub fn new(options: TextInferenceEngineCreateOptions) -> Self {
|
||||||
let engine = TextInferenceEngine::create(
|
let engine = TextInferenceEngine::create(options);
|
||||||
Path::new(model).join("cpu").to_str().unwrap(),
|
|
||||||
Path::new(model).join("tokenizer.json").to_str().unwrap(),
|
|
||||||
);
|
|
||||||
return Self { engine: engine };
|
return Self { engine: engine };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,10 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use axum::{response::Redirect, routing, Router, Server};
|
use axum::{response::Redirect, routing, Router, Server};
|
||||||
|
use clap::Args;
|
||||||
|
use ctranslate2_bindings::TextInferenceEngineCreateOptionsBuilder;
|
||||||
use hyper::Error;
|
use hyper::Error;
|
||||||
|
use std::path::Path;
|
||||||
use tower_http::cors::CorsLayer;
|
use tower_http::cors::CorsLayer;
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
@ -24,8 +27,60 @@ mod events;
|
||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
||||||
pub async fn main(model: &str) -> Result<(), Error> {
|
#[derive(clap::ValueEnum, Clone)]
|
||||||
let completions_state = Arc::new(completions::CompletionState::new(model));
|
pub enum Device {
|
||||||
|
CPU,
|
||||||
|
CUDA,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Device {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
let printable = match *self {
|
||||||
|
Device::CPU => "cpu",
|
||||||
|
Device::CUDA => "cuda",
|
||||||
|
};
|
||||||
|
write!(f, "{}", printable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Args)]
|
||||||
|
pub struct ServeArgs {
|
||||||
|
/// path to model for serving
|
||||||
|
#[clap(long)]
|
||||||
|
model: String,
|
||||||
|
|
||||||
|
#[clap(long, default_value_t=Device::CPU)]
|
||||||
|
device: Device,
|
||||||
|
|
||||||
|
#[clap(long, default_values_t=[0])]
|
||||||
|
device_indices: Vec<i32>,
|
||||||
|
|
||||||
|
/// num_replicas_per_device
|
||||||
|
#[clap(long, default_value_t = 1)]
|
||||||
|
num_replicas_per_device: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn main(args: &ServeArgs) -> Result<(), Error> {
|
||||||
|
let device = format!("{}", args.device);
|
||||||
|
let options = TextInferenceEngineCreateOptionsBuilder::default()
|
||||||
|
.model_path(
|
||||||
|
Path::new(&args.model)
|
||||||
|
.join(device.clone())
|
||||||
|
.display()
|
||||||
|
.to_string(),
|
||||||
|
)
|
||||||
|
.tokenizer_path(
|
||||||
|
Path::new(&args.model)
|
||||||
|
.join("tokenizer.json")
|
||||||
|
.display()
|
||||||
|
.to_string(),
|
||||||
|
)
|
||||||
|
.device(device)
|
||||||
|
.device_indices(args.device_indices.clone())
|
||||||
|
.num_replicas_per_device(args.num_replicas_per_device)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
let completions_state = Arc::new(completions::CompletionState::new(options));
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue