feat: add param --instruct-model, allowing specify different model for q&a use cases. (#494)

release-0.2
Meng Zhang 2023-09-29 16:44:53 -07:00 committed by GitHub
parent 892aa61a53
commit 10bf2d6c0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 19 deletions

View File

@ -18,15 +18,17 @@ fn get_param(params: &Value, key: &str) -> String {
.to_string()
}
pub fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Option<String>) {
pub fn create_engine(
model: &str,
args: &crate::serve::ServeArgs,
) -> (Box<dyn TextGeneration>, Option<String>) {
if args.device != super::Device::ExperimentalHttp {
let model_dir = get_model_dir(&args.model);
let model_dir = get_model_dir(model);
let metadata = read_metadata(&model_dir);
let engine = create_local_engine(args, &model_dir, &metadata);
(engine, metadata.prompt_template)
} else {
let params: Value =
serdeconv::from_json_str(&args.model).expect("Failed to parse model string");
let params: Value = serdeconv::from_json_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");

View File

@ -10,6 +10,8 @@ use utoipa::ToSchema;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct HealthState {
model: String,
#[serde(skip_serializing_if = "Option::is_none")]
instruct_model: Option<String>,
device: String,
compute_type: String,
arch: String,
@ -30,6 +32,7 @@ impl HealthState {
Self {
model: args.model.clone(),
instruct_model: args.instruct_model.clone(),
device: args.device.to_string(),
compute_type: args.compute_type.to_string(),
arch: ARCH.to_string(),

View File

@ -105,10 +105,15 @@ pub enum ComputeType {
#[derive(Args)]
pub struct ServeArgs {
/// Model id for serving.
/// Model id for `/completion` API endpoint.
#[clap(long)]
model: String,
/// Model id for `/generate` and `/generate_stream` API endpoints.
/// If not set, `model` will be loaded for the purpose.
#[clap(long)]
instruct_model: Option<String>,
#[clap(long, default_value_t = 8080)]
port: u16,
@ -142,16 +147,11 @@ fn should_download_ggml_files(device: &Device) -> bool {
pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args);
let downloader = Downloader::new(&args.model, /* prefer_local_file= */ true);
if args.device != Device::ExperimentalHttp {
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err,);
let download_result = if should_download_ggml_files(&args.device) {
downloader.download_ggml_files().await
} else {
downloader.download_ctranslate2_files().await
};
download_result.unwrap_or_else(handler);
download_model(&args.model, &args.device).await;
if let Some(instruct_model) = &args.instruct_model {
download_model(instruct_model, &args.device).await;
}
} else {
warn!("HTTP device is unstable and does not comply with semver expectations.")
}
@ -177,8 +177,14 @@ pub async fn main(config: &Config, args: &ServeArgs) {
}
fn api_router(args: &ServeArgs, config: &Config) -> Router {
let (engine, prompt_template) = create_engine(args);
let (engine, prompt_template) = create_engine(&args.model, args);
let engine = Arc::new(engine);
let instruct_engine = if let Some(instruct_model) = &args.instruct_model {
Arc::new(create_engine(instruct_model, args).0)
} else {
engine.clone()
};
Router::new()
.route("/v1/events", routing::post(events::log_event))
.route(
@ -193,13 +199,15 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
)
.route(
"/v1beta/generate",
routing::post(generate::generate)
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
routing::post(generate::generate).with_state(Arc::new(generate::GenerateState::new(
instruct_engine.clone(),
))),
)
.route(
"/v1beta/generate_stream",
routing::post(generate::generate_stream)
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
routing::post(generate::generate_stream).with_state(Arc::new(
generate::GenerateState::new(instruct_engine.clone()),
)),
)
.layer(CorsLayer::permissive())
.layer(opentelemetry_tracing_layer())
@ -272,3 +280,15 @@ fn add_proxy_server(
doc
}
async fn download_model(model: &str, device: &Device) {
let downloader = Downloader::new(model, /* prefer_local_file= */ true);
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", model, err,);
let download_result = if should_download_ggml_files(device) {
downloader.download_ggml_files().await
} else {
downloader.download_ctranslate2_files().await
};
download_result.unwrap_or_else(handler);
}