2023-10-09 17:37:04 +00:00
|
|
|
mod fastchat;
|
|
|
|
|
mod vertex_ai;
|
|
|
|
|
|
2023-11-11 21:56:01 +00:00
|
|
|
use std::sync::Arc;
|
|
|
|
|
|
2023-10-09 17:37:04 +00:00
|
|
|
use fastchat::FastChatEngine;
|
|
|
|
|
use serde_json::Value;
|
|
|
|
|
use tabby_inference::TextGeneration;
|
|
|
|
|
use vertex_ai::VertexAIEngine;
|
|
|
|
|
|
2023-11-11 21:56:01 +00:00
|
|
|
pub fn create(model: &str) -> (Arc<dyn TextGeneration>, String) {
|
2023-10-09 17:37:04 +00:00
|
|
|
let params = serde_json::from_str(model).expect("Failed to parse model string");
|
|
|
|
|
let kind = get_param(¶ms, "kind");
|
|
|
|
|
if kind == "vertex-ai" {
|
|
|
|
|
let api_endpoint = get_param(¶ms, "api_endpoint");
|
|
|
|
|
let authorization = get_param(¶ms, "authorization");
|
2023-11-11 21:56:01 +00:00
|
|
|
let engine = VertexAIEngine::create(api_endpoint.as_str(), authorization.as_str());
|
|
|
|
|
(Arc::new(engine), VertexAIEngine::prompt_template())
|
2023-10-09 17:37:04 +00:00
|
|
|
} else if kind == "fastchat" {
|
|
|
|
|
let model_name = get_param(¶ms, "model_name");
|
|
|
|
|
let api_endpoint = get_param(¶ms, "api_endpoint");
|
|
|
|
|
let authorization = get_param(¶ms, "authorization");
|
2023-11-11 21:56:01 +00:00
|
|
|
let engine = FastChatEngine::create(
|
2023-10-09 17:37:04 +00:00
|
|
|
api_endpoint.as_str(),
|
|
|
|
|
model_name.as_str(),
|
|
|
|
|
authorization.as_str(),
|
2023-11-11 21:56:01 +00:00
|
|
|
);
|
|
|
|
|
(Arc::new(engine), FastChatEngine::prompt_template())
|
2023-10-09 17:37:04 +00:00
|
|
|
} else {
|
|
|
|
|
panic!("Only vertex_ai and fastchat are supported for http backend");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn get_param(params: &Value, key: &str) -> String {
|
|
|
|
|
params
|
|
|
|
|
.get(key)
|
|
|
|
|
.unwrap_or_else(|| panic!("Missing {} field", key))
|
|
|
|
|
.as_str()
|
|
|
|
|
.expect("Type unmatched")
|
|
|
|
|
.to_string()
|
|
|
|
|
}
|