feat: add support vertex-ai http bindings (#419)
* feat: add support vertex-ai http bindings * support prefix / suffixrelease-0.2
parent
17397c8c8c
commit
f0ed366420
|
|
@ -2959,6 +2959,7 @@ dependencies = [
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"clap",
|
"clap",
|
||||||
"ctranslate2-bindings",
|
"ctranslate2-bindings",
|
||||||
|
"http-api-bindings",
|
||||||
"hyper",
|
"hyper",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"llama-cpp-bindings",
|
"llama-cpp-bindings",
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
## Usage
|
## Examples
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export MODEL_ID="code-gecko"
|
export MODEL_ID="code-gecko"
|
||||||
|
|
@ -8,3 +8,14 @@ export AUTHORIZATION="Bearer $(gcloud auth print-access-token)"
|
||||||
|
|
||||||
cargo run --example simple
|
cargo run --example simple
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export MODEL_ID="code-gecko"
|
||||||
|
export PROJECT_ID="$(gcloud config get project)"
|
||||||
|
export API_ENDPOINT="https://us-central1-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:predict"
|
||||||
|
export AUTHORIZATION="Bearer $(gcloud auth print-access-token)"
|
||||||
|
|
||||||
|
cargo run serve --device experimental-http --model "{\"kind\": \"vertex-ai\", \"api_endpoint\": \"$API_ENDPOINT\", \"authorization\": \"$AUTHORIZATION\"}"
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -56,23 +56,34 @@ impl VertexAIEngine {
|
||||||
client,
|
client,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn prompt_template() -> String {
|
||||||
|
"{prefix}<MID>{suffix}".to_owned()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl TextGeneration for VertexAIEngine {
|
impl TextGeneration for VertexAIEngine {
|
||||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||||
let stop_sequences: Vec<String> =
|
let stop_sequences: Vec<String> = options
|
||||||
options.stop_words.iter().map(|x| x.to_string()).collect();
|
.stop_words
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.to_string())
|
||||||
|
// vertex supports at most 5 stop sequence.
|
||||||
|
.take(5)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let tokens: Vec<&str> = prompt.split("<MID>").collect();
|
||||||
let request = Request {
|
let request = Request {
|
||||||
instances: vec![Instance {
|
instances: vec![Instance {
|
||||||
prefix: prompt.to_owned(),
|
prefix: tokens[0].to_owned(),
|
||||||
suffix: None,
|
suffix: Some(tokens[1].to_owned()),
|
||||||
}],
|
}],
|
||||||
// options.max_input_length is ignored.
|
// options.max_input_length is ignored.
|
||||||
parameters: Parameters {
|
parameters: Parameters {
|
||||||
temperature: options.sampling_temperature,
|
temperature: options.sampling_temperature,
|
||||||
max_output_tokens: options.max_decoding_length,
|
// vertex supports at most 64 output tokens.
|
||||||
|
max_output_tokens: std::cmp::min(options.max_decoding_length, 64),
|
||||||
stop_sequences,
|
stop_sequences,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ tantivy = { workspace = true }
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
sysinfo = "0.29.8"
|
sysinfo = "0.29.8"
|
||||||
nvml-wrapper = "0.9.0"
|
nvml-wrapper = "0.9.0"
|
||||||
|
http-api-bindings = { path = "../http-api-bindings" }
|
||||||
|
|
||||||
[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
|
[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
|
||||||
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
|
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,10 @@ use std::{path::Path, sync::Arc};
|
||||||
|
|
||||||
use axum::{extract::State, Json};
|
use axum::{extract::State, Json};
|
||||||
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
|
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
|
||||||
|
use http_api_bindings::vertex_ai::VertexAIEngine;
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
use tabby_common::{config::Config, events, path::ModelDir};
|
use tabby_common::{config::Config, events, path::ModelDir};
|
||||||
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
||||||
use tracing::{debug, instrument};
|
use tracing::{debug, instrument};
|
||||||
|
|
@ -128,22 +130,55 @@ pub struct CompletionState {
|
||||||
|
|
||||||
impl CompletionState {
|
impl CompletionState {
|
||||||
pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self {
|
pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self {
|
||||||
let model_dir = get_model_dir(&args.model);
|
let (engine, prompt_template) = create_engine(args);
|
||||||
let metadata = read_metadata(&model_dir);
|
|
||||||
let engine = create_engine(args, &model_dir, &metadata);
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
engine,
|
engine,
|
||||||
prompt_builder: prompt::PromptBuilder::new(
|
prompt_builder: prompt::PromptBuilder::new(
|
||||||
metadata.prompt_template,
|
prompt_template,
|
||||||
config.experimental.enable_prompt_rewrite,
|
config.experimental.enable_prompt_rewrite,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Option<String>) {
|
||||||
|
if args.device != super::Device::ExperimentalHttp {
|
||||||
|
let model_dir = get_model_dir(&args.model);
|
||||||
|
let metadata = read_metadata(&model_dir);
|
||||||
|
let engine = create_local_engine(args, &model_dir, &metadata);
|
||||||
|
(engine, metadata.prompt_template)
|
||||||
|
} else {
|
||||||
|
let params: Value =
|
||||||
|
serdeconv::from_json_str(&args.model).expect("Failed to parse model string");
|
||||||
|
|
||||||
|
let kind = params
|
||||||
|
.get("kind")
|
||||||
|
.expect("Missing kind field")
|
||||||
|
.as_str()
|
||||||
|
.expect("Type unmatched");
|
||||||
|
|
||||||
|
if kind != "vertex-ai" {
|
||||||
|
fatal!("Only vertex_ai is supported for http backend");
|
||||||
|
}
|
||||||
|
|
||||||
|
let api_endpoint = params
|
||||||
|
.get("api_endpoint")
|
||||||
|
.expect("Missing api_endpoint field")
|
||||||
|
.as_str()
|
||||||
|
.expect("Type unmatched");
|
||||||
|
let authorization = params
|
||||||
|
.get("authorization")
|
||||||
|
.expect("Missing authorization field")
|
||||||
|
.as_str()
|
||||||
|
.expect("Type unmatched");
|
||||||
|
let engine = Box::new(VertexAIEngine::create(api_endpoint, authorization));
|
||||||
|
(engine, Some(VertexAIEngine::prompt_template()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||||
fn create_engine(
|
fn create_local_engine(
|
||||||
args: &crate::serve::ServeArgs,
|
args: &crate::serve::ServeArgs,
|
||||||
model_dir: &ModelDir,
|
model_dir: &ModelDir,
|
||||||
metadata: &Metadata,
|
metadata: &Metadata,
|
||||||
|
|
@ -152,7 +187,7 @@ fn create_engine(
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||||
fn create_engine(
|
fn create_local_engine(
|
||||||
args: &crate::serve::ServeArgs,
|
args: &crate::serve::ServeArgs,
|
||||||
model_dir: &ModelDir,
|
model_dir: &ModelDir,
|
||||||
metadata: &Metadata,
|
metadata: &Metadata,
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ use clap::Args;
|
||||||
use tabby_common::{config::Config, usage};
|
use tabby_common::{config::Config, usage};
|
||||||
use tokio::time::sleep;
|
use tokio::time::sleep;
|
||||||
use tower_http::cors::CorsLayer;
|
use tower_http::cors::CorsLayer;
|
||||||
use tracing::info;
|
use tracing::{info, warn};
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
|
|
@ -58,6 +58,9 @@ pub enum Device {
|
||||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||||
#[strum(serialize = "metal")]
|
#[strum(serialize = "metal")]
|
||||||
Metal,
|
Metal,
|
||||||
|
|
||||||
|
#[strum(serialize = "experimental_http")]
|
||||||
|
ExperimentalHttp,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
|
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
|
||||||
|
|
@ -129,22 +132,28 @@ fn should_download_ggml_files(device: &Device) -> bool {
|
||||||
pub async fn main(config: &Config, args: &ServeArgs) {
|
pub async fn main(config: &Config, args: &ServeArgs) {
|
||||||
valid_args(args);
|
valid_args(args);
|
||||||
|
|
||||||
// Ensure model exists.
|
if args.device != Device::ExperimentalHttp {
|
||||||
tabby_download::download_model(
|
let download_ctranslate2_files = !should_download_ggml_files(&args.device);
|
||||||
&args.model,
|
let download_ggml_files = should_download_ggml_files(&args.device);
|
||||||
/* download_ctranslate2_files= */
|
|
||||||
!should_download_ggml_files(&args.device),
|
// Ensure model exists.
|
||||||
/* download_ggml_files= */ should_download_ggml_files(&args.device),
|
tabby_download::download_model(
|
||||||
/* prefer_local_file= */ true,
|
&args.model,
|
||||||
)
|
download_ctranslate2_files,
|
||||||
.await
|
download_ggml_files,
|
||||||
.unwrap_or_else(|err| {
|
/* prefer_local_file= */ true,
|
||||||
fatal!(
|
|
||||||
"Failed to fetch model due to '{}', is '{}' a valid model id?",
|
|
||||||
err,
|
|
||||||
args.model
|
|
||||||
)
|
)
|
||||||
});
|
.await
|
||||||
|
.unwrap_or_else(|err| {
|
||||||
|
fatal!(
|
||||||
|
"Failed to fetch model due to '{}', is '{}' a valid model id?",
|
||||||
|
err,
|
||||||
|
args.model
|
||||||
|
)
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
warn!("HTTP device is unstable and does not comply with semver expectations.")
|
||||||
|
}
|
||||||
|
|
||||||
info!("Starting server, this might takes a few minutes...");
|
info!("Starting server, this might takes a few minutes...");
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue