feat: support cuda devices in rust tabby (#149)
parent
484e754bc4
commit
8dfe49ec6c
|
|
@ -29,7 +29,10 @@ fn main() {
|
|||
|
||||
let dst = config.build();
|
||||
|
||||
println!("cargo:rustc-link-search=native={}", dst.join("lib").display());
|
||||
println!(
|
||||
"cargo:rustc-link-search=native={}",
|
||||
dst.join("lib").display()
|
||||
);
|
||||
println!("cargo:rustc-link-lib=ctranslate2");
|
||||
|
||||
// Tell cargo to invalidate the built crate whenever the wrapper changes
|
||||
|
|
|
|||
|
|
@ -16,5 +16,10 @@ class TextInferenceEngine {
|
|||
) const = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(
|
||||
rust::Str model_path,
|
||||
rust::Str device,
|
||||
rust::Slice<const int32_t> device_indices,
|
||||
size_t num_replicas_per_device
|
||||
);
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -7,10 +7,7 @@ TextInferenceEngine::~TextInferenceEngine() {}
|
|||
|
||||
class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||
public:
|
||||
TextInferenceEngineImpl(const std::string& model_path) {
|
||||
ctranslate2::models::ModelLoader loader(model_path);
|
||||
translator_ = std::make_unique<ctranslate2::Translator>(loader);
|
||||
}
|
||||
TextInferenceEngineImpl(std::unique_ptr<ctranslate2::Translator> translator) : translator_(std::move(translator)) {}
|
||||
|
||||
~TextInferenceEngineImpl() {}
|
||||
|
||||
|
|
@ -41,7 +38,32 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
std::unique_ptr<ctranslate2::Translator> translator_;
|
||||
};
|
||||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
||||
return std::make_unique<TextInferenceEngineImpl>(std::string(model_path));
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(
|
||||
rust::Str model_path,
|
||||
rust::Str device,
|
||||
rust::Slice<const int32_t> device_indices,
|
||||
size_t num_replicas_per_device
|
||||
) {
|
||||
// model_path.
|
||||
std::string model_path_string(model_path);
|
||||
ctranslate2::models::ModelLoader loader(model_path_string);
|
||||
|
||||
// device.
|
||||
std::string device_string(device);
|
||||
if (device_string == "cuda") {
|
||||
loader.device = ctranslate2::Device::CUDA;
|
||||
} else if (device_string == "cpu") {
|
||||
loader.device = ctranslate2::Device::CPU;
|
||||
}
|
||||
|
||||
// device_indices
|
||||
loader.device_indices.clear();
|
||||
std::copy(device_indices.begin(), device_indices.end(), std::back_inserter(loader.device_indices));
|
||||
|
||||
// num_replicas_per_device
|
||||
loader.num_replicas_per_device = num_replicas_per_device;
|
||||
|
||||
auto translator = std::make_unique<ctranslate2::Translator>(loader);
|
||||
return std::make_unique<TextInferenceEngineImpl>(std::move(translator));
|
||||
}
|
||||
} // namespace tabby
|
||||
|
|
|
|||
|
|
@ -11,7 +11,13 @@ mod ffi {
|
|||
|
||||
type TextInferenceEngine;
|
||||
|
||||
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
|
||||
fn create_engine(
|
||||
model_path: &str,
|
||||
device: &str,
|
||||
device_indices: &[i32],
|
||||
num_replicas_per_device: usize,
|
||||
) -> UniquePtr<TextInferenceEngine>;
|
||||
|
||||
fn inference(
|
||||
&self,
|
||||
tokens: &[String],
|
||||
|
|
@ -22,6 +28,19 @@ mod ffi {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Builder)]
|
||||
pub struct TextInferenceEngineCreateOptions {
|
||||
model_path: String,
|
||||
|
||||
tokenizer_path: String,
|
||||
|
||||
device: String,
|
||||
|
||||
device_indices: Vec<i32>,
|
||||
|
||||
num_replicas_per_device: usize,
|
||||
}
|
||||
|
||||
#[derive(Builder, Debug)]
|
||||
pub struct TextInferenceOptions {
|
||||
#[builder(default = "256")]
|
||||
|
|
@ -43,10 +62,16 @@ unsafe impl Send for TextInferenceEngine {}
|
|||
unsafe impl Sync for TextInferenceEngine {}
|
||||
|
||||
impl TextInferenceEngine {
|
||||
pub fn create(model_path: &str, tokenizer_path: &str) -> Self where {
|
||||
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
|
||||
let engine = ffi::create_engine(
|
||||
&options.model_path,
|
||||
&options.device,
|
||||
&options.device_indices,
|
||||
options.num_replicas_per_device,
|
||||
);
|
||||
return TextInferenceEngine {
|
||||
engine: Mutex::new(ffi::create_engine(model_path)),
|
||||
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(),
|
||||
engine: Mutex::new(engine),
|
||||
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(),
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,11 +11,7 @@ struct Cli {
|
|||
#[derive(Subcommand)]
|
||||
pub enum Commands {
|
||||
/// Serve the model
|
||||
Serve {
|
||||
/// path to model for serving
|
||||
#[clap(long)]
|
||||
model: String,
|
||||
},
|
||||
Serve(serve::ServeArgs),
|
||||
}
|
||||
|
||||
mod serve;
|
||||
|
|
@ -27,8 +23,8 @@ async fn main() {
|
|||
// You can check for the existence of subcommands, and if found use their
|
||||
// matches just as you would the top level cmd
|
||||
match &cli.command {
|
||||
Commands::Serve { model } => {
|
||||
serve::main(model)
|
||||
Commands::Serve(args) => {
|
||||
serve::main(args)
|
||||
.await
|
||||
.expect("Error happens during the serve");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
use axum::{extract::State, Json};
|
||||
use ctranslate2_bindings::{TextInferenceEngine, TextInferenceOptionsBuilder};
|
||||
use ctranslate2_bindings::{
|
||||
TextInferenceEngine, TextInferenceEngineCreateOptions, TextInferenceOptionsBuilder,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{path::Path, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
mod languages;
|
||||
|
|
@ -62,11 +64,8 @@ pub struct CompletionState {
|
|||
}
|
||||
|
||||
impl CompletionState {
|
||||
pub fn new(model: &str) -> Self {
|
||||
let engine = TextInferenceEngine::create(
|
||||
Path::new(model).join("cpu").to_str().unwrap(),
|
||||
Path::new(model).join("tokenizer.json").to_str().unwrap(),
|
||||
);
|
||||
pub fn new(options: TextInferenceEngineCreateOptions) -> Self {
|
||||
let engine = TextInferenceEngine::create(options);
|
||||
return Self { engine: engine };
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,10 @@ use std::{
|
|||
};
|
||||
|
||||
use axum::{response::Redirect, routing, Router, Server};
|
||||
use clap::Args;
|
||||
use ctranslate2_bindings::TextInferenceEngineCreateOptionsBuilder;
|
||||
use hyper::Error;
|
||||
use std::path::Path;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
|
@ -24,8 +27,60 @@ mod events;
|
|||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
pub async fn main(model: &str) -> Result<(), Error> {
|
||||
let completions_state = Arc::new(completions::CompletionState::new(model));
|
||||
#[derive(clap::ValueEnum, Clone)]
|
||||
pub enum Device {
|
||||
CPU,
|
||||
CUDA,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Device {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
let printable = match *self {
|
||||
Device::CPU => "cpu",
|
||||
Device::CUDA => "cuda",
|
||||
};
|
||||
write!(f, "{}", printable)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct ServeArgs {
|
||||
/// path to model for serving
|
||||
#[clap(long)]
|
||||
model: String,
|
||||
|
||||
#[clap(long, default_value_t=Device::CPU)]
|
||||
device: Device,
|
||||
|
||||
#[clap(long, default_values_t=[0])]
|
||||
device_indices: Vec<i32>,
|
||||
|
||||
/// num_replicas_per_device
|
||||
#[clap(long, default_value_t = 1)]
|
||||
num_replicas_per_device: usize,
|
||||
}
|
||||
|
||||
pub async fn main(args: &ServeArgs) -> Result<(), Error> {
|
||||
let device = format!("{}", args.device);
|
||||
let options = TextInferenceEngineCreateOptionsBuilder::default()
|
||||
.model_path(
|
||||
Path::new(&args.model)
|
||||
.join(device.clone())
|
||||
.display()
|
||||
.to_string(),
|
||||
)
|
||||
.tokenizer_path(
|
||||
Path::new(&args.model)
|
||||
.join("tokenizer.json")
|
||||
.display()
|
||||
.to_string(),
|
||||
)
|
||||
.device(device)
|
||||
.device_indices(args.device_indices.clone())
|
||||
.num_replicas_per_device(args.num_replicas_per_device)
|
||||
.build()
|
||||
.unwrap();
|
||||
let completions_state = Arc::new(completions::CompletionState::new(options));
|
||||
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
||||
|
|
|
|||
Loading…
Reference in New Issue