diff --git a/crates/http-api-bindings/examples/simple.rs b/crates/http-api-bindings/examples/simple.rs deleted file mode 100644 index 2c357e1..0000000 --- a/crates/http-api-bindings/examples/simple.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::env; - -use http_api_bindings::vertex_ai::VertexAIEngine; -use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; - -#[tokio::main] -async fn main() { - let api_endpoint = env::var("API_ENDPOINT").expect("API_ENDPOINT not set"); - let authorization = env::var("AUTHORIZATION").expect("AUTHORIZATION not set"); - let engine = VertexAIEngine::create(&api_endpoint, &authorization); - - let options = TextGenerationOptionsBuilder::default() - .sampling_temperature(0.1) - .max_decoding_length(32) - .build() - .unwrap(); - let prompt = "def fib(n)"; - let text = engine.generate(prompt, options).await; - println!("{}{}", prompt, text); -} diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs index f28bac6..fc743ca 100644 --- a/crates/http-api-bindings/src/lib.rs +++ b/crates/http-api-bindings/src/lib.rs @@ -1,2 +1,42 @@ -pub mod fastchat; -pub mod vertex_ai; +mod fastchat; +mod vertex_ai; + +use fastchat::FastChatEngine; +use serde_json::Value; +use tabby_inference::TextGeneration; +use vertex_ai::VertexAIEngine; + +pub fn create(model: &str) -> (Box, String) { + let params = serde_json::from_str(model).expect("Failed to parse model string"); + let kind = get_param(¶ms, "kind"); + if kind == "vertex-ai" { + let api_endpoint = get_param(¶ms, "api_endpoint"); + let authorization = get_param(¶ms, "authorization"); + let engine = Box::new(VertexAIEngine::create( + api_endpoint.as_str(), + authorization.as_str(), + )); + (engine, VertexAIEngine::prompt_template()) + } else if kind == "fastchat" { + let model_name = get_param(¶ms, "model_name"); + let api_endpoint = get_param(¶ms, "api_endpoint"); + let authorization = get_param(¶ms, "authorization"); + let engine = Box::new(FastChatEngine::create( + api_endpoint.as_str(), + model_name.as_str(), + authorization.as_str(), + )); + (engine, FastChatEngine::prompt_template()) + } else { + panic!("Only vertex_ai and fastchat are supported for http backend"); + } +} + +fn get_param(params: &Value, key: &str) -> String { + params + .get(key) + .unwrap_or_else(|| panic!("Missing {} field", key)) + .as_str() + .expect("Type unmatched") + .to_string() +} diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index 3998ec1..9eb86f9 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -1,23 +1,12 @@ use std::path::Path; use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder}; -use http_api_bindings::{fastchat::FastChatEngine, vertex_ai::VertexAIEngine}; use serde::Deserialize; -use serde_json::Value; use tabby_common::path::ModelDir; use tabby_inference::TextGeneration; use crate::fatal; -fn get_param(params: &Value, key: &str) -> String { - params - .get(key) - .unwrap_or_else(|| panic!("Missing {} field", key)) - .as_str() - .expect("Type unmatched") - .to_string() -} - pub fn create_engine( model: &str, args: &crate::serve::ServeArgs, @@ -34,43 +23,14 @@ pub fn create_engine( }, ) } else { - let params: Value = serdeconv::from_json_str(model).expect("Failed to parse model string"); - - let kind = get_param(¶ms, "kind"); - - if kind == "vertex-ai" { - let api_endpoint = get_param(¶ms, "api_endpoint"); - let authorization = get_param(¶ms, "authorization"); - let engine = Box::new(VertexAIEngine::create( - api_endpoint.as_str(), - authorization.as_str(), - )); - ( - engine, - EngineInfo { - prompt_template: Some(VertexAIEngine::prompt_template()), - chat_template: None, - }, - ) - } else if kind == "fastchat" { - let model_name = get_param(¶ms, "model_name"); - let api_endpoint = get_param(¶ms, "api_endpoint"); - let authorization = get_param(¶ms, "authorization"); - let engine = Box::new(FastChatEngine::create( - api_endpoint.as_str(), - model_name.as_str(), - authorization.as_str(), - )); - ( - engine, - EngineInfo { - prompt_template: Some(FastChatEngine::prompt_template()), - chat_template: None, - }, - ) - } else { - fatal!("Only vertex_ai and fastchat are supported for http backend"); - } + let (engine, prompt_template) = http_api_bindings::create(model); + ( + engine, + EngineInfo { + prompt_template: Some(prompt_template), + chat_template: None, + }, + ) } }