feat: add param --instruct-model, allowing specify different model for q&a use cases. (#494)
parent
892aa61a53
commit
10bf2d6c0c
|
|
@ -18,15 +18,17 @@ fn get_param(params: &Value, key: &str) -> String {
|
|||
.to_string()
|
||||
}
|
||||
|
||||
pub fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Option<String>) {
|
||||
pub fn create_engine(
|
||||
model: &str,
|
||||
args: &crate::serve::ServeArgs,
|
||||
) -> (Box<dyn TextGeneration>, Option<String>) {
|
||||
if args.device != super::Device::ExperimentalHttp {
|
||||
let model_dir = get_model_dir(&args.model);
|
||||
let model_dir = get_model_dir(model);
|
||||
let metadata = read_metadata(&model_dir);
|
||||
let engine = create_local_engine(args, &model_dir, &metadata);
|
||||
(engine, metadata.prompt_template)
|
||||
} else {
|
||||
let params: Value =
|
||||
serdeconv::from_json_str(&args.model).expect("Failed to parse model string");
|
||||
let params: Value = serdeconv::from_json_str(model).expect("Failed to parse model string");
|
||||
|
||||
let kind = get_param(¶ms, "kind");
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ use utoipa::ToSchema;
|
|||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct HealthState {
|
||||
model: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
instruct_model: Option<String>,
|
||||
device: String,
|
||||
compute_type: String,
|
||||
arch: String,
|
||||
|
|
@ -30,6 +32,7 @@ impl HealthState {
|
|||
|
||||
Self {
|
||||
model: args.model.clone(),
|
||||
instruct_model: args.instruct_model.clone(),
|
||||
device: args.device.to_string(),
|
||||
compute_type: args.compute_type.to_string(),
|
||||
arch: ARCH.to_string(),
|
||||
|
|
|
|||
|
|
@ -105,10 +105,15 @@ pub enum ComputeType {
|
|||
|
||||
#[derive(Args)]
|
||||
pub struct ServeArgs {
|
||||
/// Model id for serving.
|
||||
/// Model id for `/completion` API endpoint.
|
||||
#[clap(long)]
|
||||
model: String,
|
||||
|
||||
/// Model id for `/generate` and `/generate_stream` API endpoints.
|
||||
/// If not set, `model` will be loaded for the purpose.
|
||||
#[clap(long)]
|
||||
instruct_model: Option<String>,
|
||||
|
||||
#[clap(long, default_value_t = 8080)]
|
||||
port: u16,
|
||||
|
||||
|
|
@ -142,16 +147,11 @@ fn should_download_ggml_files(device: &Device) -> bool {
|
|||
pub async fn main(config: &Config, args: &ServeArgs) {
|
||||
valid_args(args);
|
||||
|
||||
let downloader = Downloader::new(&args.model, /* prefer_local_file= */ true);
|
||||
if args.device != Device::ExperimentalHttp {
|
||||
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err,);
|
||||
let download_result = if should_download_ggml_files(&args.device) {
|
||||
downloader.download_ggml_files().await
|
||||
} else {
|
||||
downloader.download_ctranslate2_files().await
|
||||
};
|
||||
|
||||
download_result.unwrap_or_else(handler);
|
||||
download_model(&args.model, &args.device).await;
|
||||
if let Some(instruct_model) = &args.instruct_model {
|
||||
download_model(instruct_model, &args.device).await;
|
||||
}
|
||||
} else {
|
||||
warn!("HTTP device is unstable and does not comply with semver expectations.")
|
||||
}
|
||||
|
|
@ -177,8 +177,14 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
|||
}
|
||||
|
||||
fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||
let (engine, prompt_template) = create_engine(args);
|
||||
let (engine, prompt_template) = create_engine(&args.model, args);
|
||||
let engine = Arc::new(engine);
|
||||
let instruct_engine = if let Some(instruct_model) = &args.instruct_model {
|
||||
Arc::new(create_engine(instruct_model, args).0)
|
||||
} else {
|
||||
engine.clone()
|
||||
};
|
||||
|
||||
Router::new()
|
||||
.route("/v1/events", routing::post(events::log_event))
|
||||
.route(
|
||||
|
|
@ -193,13 +199,15 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
)
|
||||
.route(
|
||||
"/v1beta/generate",
|
||||
routing::post(generate::generate)
|
||||
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
|
||||
routing::post(generate::generate).with_state(Arc::new(generate::GenerateState::new(
|
||||
instruct_engine.clone(),
|
||||
))),
|
||||
)
|
||||
.route(
|
||||
"/v1beta/generate_stream",
|
||||
routing::post(generate::generate_stream)
|
||||
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
|
||||
routing::post(generate::generate_stream).with_state(Arc::new(
|
||||
generate::GenerateState::new(instruct_engine.clone()),
|
||||
)),
|
||||
)
|
||||
.layer(CorsLayer::permissive())
|
||||
.layer(opentelemetry_tracing_layer())
|
||||
|
|
@ -272,3 +280,15 @@ fn add_proxy_server(
|
|||
|
||||
doc
|
||||
}
|
||||
|
||||
async fn download_model(model: &str, device: &Device) {
|
||||
let downloader = Downloader::new(model, /* prefer_local_file= */ true);
|
||||
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", model, err,);
|
||||
let download_result = if should_download_ggml_files(device) {
|
||||
downloader.download_ggml_files().await
|
||||
} else {
|
||||
downloader.download_ctranslate2_files().await
|
||||
};
|
||||
|
||||
download_result.unwrap_or_else(handler);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue