feat: add param --instruct-model, allowing specify different model for q&a use cases. (#494)
parent
892aa61a53
commit
10bf2d6c0c
|
|
@ -18,15 +18,17 @@ fn get_param(params: &Value, key: &str) -> String {
|
||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Option<String>) {
|
pub fn create_engine(
|
||||||
|
model: &str,
|
||||||
|
args: &crate::serve::ServeArgs,
|
||||||
|
) -> (Box<dyn TextGeneration>, Option<String>) {
|
||||||
if args.device != super::Device::ExperimentalHttp {
|
if args.device != super::Device::ExperimentalHttp {
|
||||||
let model_dir = get_model_dir(&args.model);
|
let model_dir = get_model_dir(model);
|
||||||
let metadata = read_metadata(&model_dir);
|
let metadata = read_metadata(&model_dir);
|
||||||
let engine = create_local_engine(args, &model_dir, &metadata);
|
let engine = create_local_engine(args, &model_dir, &metadata);
|
||||||
(engine, metadata.prompt_template)
|
(engine, metadata.prompt_template)
|
||||||
} else {
|
} else {
|
||||||
let params: Value =
|
let params: Value = serdeconv::from_json_str(model).expect("Failed to parse model string");
|
||||||
serdeconv::from_json_str(&args.model).expect("Failed to parse model string");
|
|
||||||
|
|
||||||
let kind = get_param(¶ms, "kind");
|
let kind = get_param(¶ms, "kind");
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ use utoipa::ToSchema;
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
pub struct HealthState {
|
pub struct HealthState {
|
||||||
model: String,
|
model: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
instruct_model: Option<String>,
|
||||||
device: String,
|
device: String,
|
||||||
compute_type: String,
|
compute_type: String,
|
||||||
arch: String,
|
arch: String,
|
||||||
|
|
@ -30,6 +32,7 @@ impl HealthState {
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
model: args.model.clone(),
|
model: args.model.clone(),
|
||||||
|
instruct_model: args.instruct_model.clone(),
|
||||||
device: args.device.to_string(),
|
device: args.device.to_string(),
|
||||||
compute_type: args.compute_type.to_string(),
|
compute_type: args.compute_type.to_string(),
|
||||||
arch: ARCH.to_string(),
|
arch: ARCH.to_string(),
|
||||||
|
|
|
||||||
|
|
@ -105,10 +105,15 @@ pub enum ComputeType {
|
||||||
|
|
||||||
#[derive(Args)]
|
#[derive(Args)]
|
||||||
pub struct ServeArgs {
|
pub struct ServeArgs {
|
||||||
/// Model id for serving.
|
/// Model id for `/completion` API endpoint.
|
||||||
#[clap(long)]
|
#[clap(long)]
|
||||||
model: String,
|
model: String,
|
||||||
|
|
||||||
|
/// Model id for `/generate` and `/generate_stream` API endpoints.
|
||||||
|
/// If not set, `model` will be loaded for the purpose.
|
||||||
|
#[clap(long)]
|
||||||
|
instruct_model: Option<String>,
|
||||||
|
|
||||||
#[clap(long, default_value_t = 8080)]
|
#[clap(long, default_value_t = 8080)]
|
||||||
port: u16,
|
port: u16,
|
||||||
|
|
||||||
|
|
@ -142,16 +147,11 @@ fn should_download_ggml_files(device: &Device) -> bool {
|
||||||
pub async fn main(config: &Config, args: &ServeArgs) {
|
pub async fn main(config: &Config, args: &ServeArgs) {
|
||||||
valid_args(args);
|
valid_args(args);
|
||||||
|
|
||||||
let downloader = Downloader::new(&args.model, /* prefer_local_file= */ true);
|
|
||||||
if args.device != Device::ExperimentalHttp {
|
if args.device != Device::ExperimentalHttp {
|
||||||
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err,);
|
download_model(&args.model, &args.device).await;
|
||||||
let download_result = if should_download_ggml_files(&args.device) {
|
if let Some(instruct_model) = &args.instruct_model {
|
||||||
downloader.download_ggml_files().await
|
download_model(instruct_model, &args.device).await;
|
||||||
} else {
|
}
|
||||||
downloader.download_ctranslate2_files().await
|
|
||||||
};
|
|
||||||
|
|
||||||
download_result.unwrap_or_else(handler);
|
|
||||||
} else {
|
} else {
|
||||||
warn!("HTTP device is unstable and does not comply with semver expectations.")
|
warn!("HTTP device is unstable and does not comply with semver expectations.")
|
||||||
}
|
}
|
||||||
|
|
@ -177,8 +177,14 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
let (engine, prompt_template) = create_engine(args);
|
let (engine, prompt_template) = create_engine(&args.model, args);
|
||||||
let engine = Arc::new(engine);
|
let engine = Arc::new(engine);
|
||||||
|
let instruct_engine = if let Some(instruct_model) = &args.instruct_model {
|
||||||
|
Arc::new(create_engine(instruct_model, args).0)
|
||||||
|
} else {
|
||||||
|
engine.clone()
|
||||||
|
};
|
||||||
|
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/v1/events", routing::post(events::log_event))
|
.route("/v1/events", routing::post(events::log_event))
|
||||||
.route(
|
.route(
|
||||||
|
|
@ -193,13 +199,15 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
)
|
)
|
||||||
.route(
|
.route(
|
||||||
"/v1beta/generate",
|
"/v1beta/generate",
|
||||||
routing::post(generate::generate)
|
routing::post(generate::generate).with_state(Arc::new(generate::GenerateState::new(
|
||||||
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
|
instruct_engine.clone(),
|
||||||
|
))),
|
||||||
)
|
)
|
||||||
.route(
|
.route(
|
||||||
"/v1beta/generate_stream",
|
"/v1beta/generate_stream",
|
||||||
routing::post(generate::generate_stream)
|
routing::post(generate::generate_stream).with_state(Arc::new(
|
||||||
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
|
generate::GenerateState::new(instruct_engine.clone()),
|
||||||
|
)),
|
||||||
)
|
)
|
||||||
.layer(CorsLayer::permissive())
|
.layer(CorsLayer::permissive())
|
||||||
.layer(opentelemetry_tracing_layer())
|
.layer(opentelemetry_tracing_layer())
|
||||||
|
|
@ -272,3 +280,15 @@ fn add_proxy_server(
|
||||||
|
|
||||||
doc
|
doc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn download_model(model: &str, device: &Device) {
|
||||||
|
let downloader = Downloader::new(model, /* prefer_local_file= */ true);
|
||||||
|
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", model, err,);
|
||||||
|
let download_result = if should_download_ggml_files(device) {
|
||||||
|
downloader.download_ggml_files().await
|
||||||
|
} else {
|
||||||
|
downloader.download_ctranslate2_files().await
|
||||||
|
};
|
||||||
|
|
||||||
|
download_result.unwrap_or_else(handler);
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue