feat: add support vertex-ai http bindings (#419)

* feat: add support vertex-ai http bindings

* support prefix / suffix
release-0.2
Meng Zhang 2023-09-09 19:22:58 +08:00 committed by GitHub
parent 17397c8c8c
commit f0ed366420
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 96 additions and 28 deletions

1
Cargo.lock generated
View File

@ -2959,6 +2959,7 @@ dependencies = [
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"clap", "clap",
"ctranslate2-bindings", "ctranslate2-bindings",
"http-api-bindings",
"hyper", "hyper",
"lazy_static", "lazy_static",
"llama-cpp-bindings", "llama-cpp-bindings",

View File

@ -1,4 +1,4 @@
## Usage ## Examples
```bash ```bash
export MODEL_ID="code-gecko" export MODEL_ID="code-gecko"
@ -8,3 +8,14 @@ export AUTHORIZATION="Bearer $(gcloud auth print-access-token)"
cargo run --example simple cargo run --example simple
``` ```
## Usage
```bash
export MODEL_ID="code-gecko"
export PROJECT_ID="$(gcloud config get project)"
export API_ENDPOINT="https://us-central1-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:predict"
export AUTHORIZATION="Bearer $(gcloud auth print-access-token)"
cargo run serve --device experimental-http --model "{\"kind\": \"vertex-ai\", \"api_endpoint\": \"$API_ENDPOINT\", \"authorization\": \"$AUTHORIZATION\"}"
```

View File

@ -56,23 +56,34 @@ impl VertexAIEngine {
client, client,
} }
} }
pub fn prompt_template() -> String {
"{prefix}<MID>{suffix}".to_owned()
}
} }
#[async_trait] #[async_trait]
impl TextGeneration for VertexAIEngine { impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> = let stop_sequences: Vec<String> = options
options.stop_words.iter().map(|x| x.to_string()).collect(); .stop_words
.iter()
.map(|x| x.to_string())
// vertex supports at most 5 stop sequence.
.take(5)
.collect();
let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request { let request = Request {
instances: vec![Instance { instances: vec![Instance {
prefix: prompt.to_owned(), prefix: tokens[0].to_owned(),
suffix: None, suffix: Some(tokens[1].to_owned()),
}], }],
// options.max_input_length is ignored. // options.max_input_length is ignored.
parameters: Parameters { parameters: Parameters {
temperature: options.sampling_temperature, temperature: options.sampling_temperature,
max_output_tokens: options.max_decoding_length, // vertex supports at most 64 output tokens.
max_output_tokens: std::cmp::min(options.max_decoding_length, 64),
stop_sequences, stop_sequences,
}, },
}; };

View File

@ -35,6 +35,7 @@ tantivy = { workspace = true }
anyhow = { workspace = true } anyhow = { workspace = true }
sysinfo = "0.29.8" sysinfo = "0.29.8"
nvml-wrapper = "0.9.0" nvml-wrapper = "0.9.0"
http-api-bindings = { path = "../http-api-bindings" }
[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies] [target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
llama-cpp-bindings = { path = "../llama-cpp-bindings" } llama-cpp-bindings = { path = "../llama-cpp-bindings" }

View File

@ -5,8 +5,10 @@ use std::{path::Path, sync::Arc};
use axum::{extract::State, Json}; use axum::{extract::State, Json};
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder}; use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
use http_api_bindings::vertex_ai::VertexAIEngine;
use hyper::StatusCode; use hyper::StatusCode;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value;
use tabby_common::{config::Config, events, path::ModelDir}; use tabby_common::{config::Config, events, path::ModelDir};
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument}; use tracing::{debug, instrument};
@ -128,22 +130,55 @@ pub struct CompletionState {
impl CompletionState { impl CompletionState {
pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self { pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self {
let model_dir = get_model_dir(&args.model); let (engine, prompt_template) = create_engine(args);
let metadata = read_metadata(&model_dir);
let engine = create_engine(args, &model_dir, &metadata);
Self { Self {
engine, engine,
prompt_builder: prompt::PromptBuilder::new( prompt_builder: prompt::PromptBuilder::new(
metadata.prompt_template, prompt_template,
config.experimental.enable_prompt_rewrite, config.experimental.enable_prompt_rewrite,
), ),
} }
} }
} }
fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Option<String>) {
if args.device != super::Device::ExperimentalHttp {
let model_dir = get_model_dir(&args.model);
let metadata = read_metadata(&model_dir);
let engine = create_local_engine(args, &model_dir, &metadata);
(engine, metadata.prompt_template)
} else {
let params: Value =
serdeconv::from_json_str(&args.model).expect("Failed to parse model string");
let kind = params
.get("kind")
.expect("Missing kind field")
.as_str()
.expect("Type unmatched");
if kind != "vertex-ai" {
fatal!("Only vertex_ai is supported for http backend");
}
let api_endpoint = params
.get("api_endpoint")
.expect("Missing api_endpoint field")
.as_str()
.expect("Type unmatched");
let authorization = params
.get("authorization")
.expect("Missing authorization field")
.as_str()
.expect("Type unmatched");
let engine = Box::new(VertexAIEngine::create(api_endpoint, authorization));
(engine, Some(VertexAIEngine::prompt_template()))
}
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn create_engine( fn create_local_engine(
args: &crate::serve::ServeArgs, args: &crate::serve::ServeArgs,
model_dir: &ModelDir, model_dir: &ModelDir,
metadata: &Metadata, metadata: &Metadata,
@ -152,7 +187,7 @@ fn create_engine(
} }
#[cfg(all(target_os = "macos", target_arch = "aarch64"))] #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn create_engine( fn create_local_engine(
args: &crate::serve::ServeArgs, args: &crate::serve::ServeArgs,
model_dir: &ModelDir, model_dir: &ModelDir,
metadata: &Metadata, metadata: &Metadata,

View File

@ -14,7 +14,7 @@ use clap::Args;
use tabby_common::{config::Config, usage}; use tabby_common::{config::Config, usage};
use tokio::time::sleep; use tokio::time::sleep;
use tower_http::cors::CorsLayer; use tower_http::cors::CorsLayer;
use tracing::info; use tracing::{info, warn};
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
@ -58,6 +58,9 @@ pub enum Device {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))] #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
#[strum(serialize = "metal")] #[strum(serialize = "metal")]
Metal, Metal,
#[strum(serialize = "experimental_http")]
ExperimentalHttp,
} }
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)] #[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
@ -129,22 +132,28 @@ fn should_download_ggml_files(device: &Device) -> bool {
pub async fn main(config: &Config, args: &ServeArgs) { pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args); valid_args(args);
// Ensure model exists. if args.device != Device::ExperimentalHttp {
tabby_download::download_model( let download_ctranslate2_files = !should_download_ggml_files(&args.device);
&args.model, let download_ggml_files = should_download_ggml_files(&args.device);
/* download_ctranslate2_files= */
!should_download_ggml_files(&args.device), // Ensure model exists.
/* download_ggml_files= */ should_download_ggml_files(&args.device), tabby_download::download_model(
/* prefer_local_file= */ true, &args.model,
) download_ctranslate2_files,
.await download_ggml_files,
.unwrap_or_else(|err| { /* prefer_local_file= */ true,
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
) )
}); .await
.unwrap_or_else(|err| {
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
)
});
} else {
warn!("HTTP device is unstable and does not comply with semver expectations.")
}
info!("Starting server, this might takes a few minutes..."); info!("Starting server, this might takes a few minutes...");
let app = Router::new() let app = Router::new()