feat: add support vertex-ai http bindings (#419)

* feat: add support vertex-ai http bindings

* support prefix / suffix
release-0.2
Meng Zhang 2023-09-09 19:22:58 +08:00 committed by GitHub
parent 17397c8c8c
commit f0ed366420
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 96 additions and 28 deletions

1
Cargo.lock generated
View File

@ -2959,6 +2959,7 @@ dependencies = [
"axum-tracing-opentelemetry",
"clap",
"ctranslate2-bindings",
"http-api-bindings",
"hyper",
"lazy_static",
"llama-cpp-bindings",

View File

@ -1,4 +1,4 @@
## Usage
## Examples
```bash
export MODEL_ID="code-gecko"
@ -8,3 +8,14 @@ export AUTHORIZATION="Bearer $(gcloud auth print-access-token)"
cargo run --example simple
```
## Usage
```bash
export MODEL_ID="code-gecko"
export PROJECT_ID="$(gcloud config get project)"
export API_ENDPOINT="https://us-central1-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:predict"
export AUTHORIZATION="Bearer $(gcloud auth print-access-token)"
cargo run serve --device experimental-http --model "{\"kind\": \"vertex-ai\", \"api_endpoint\": \"$API_ENDPOINT\", \"authorization\": \"$AUTHORIZATION\"}"
```

View File

@ -56,23 +56,34 @@ impl VertexAIEngine {
client,
}
}
pub fn prompt_template() -> String {
"{prefix}<MID>{suffix}".to_owned()
}
}
#[async_trait]
impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> =
options.stop_words.iter().map(|x| x.to_string()).collect();
let stop_sequences: Vec<String> = options
.stop_words
.iter()
.map(|x| x.to_string())
// vertex supports at most 5 stop sequence.
.take(5)
.collect();
let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request {
instances: vec![Instance {
prefix: prompt.to_owned(),
suffix: None,
prefix: tokens[0].to_owned(),
suffix: Some(tokens[1].to_owned()),
}],
// options.max_input_length is ignored.
parameters: Parameters {
temperature: options.sampling_temperature,
max_output_tokens: options.max_decoding_length,
// vertex supports at most 64 output tokens.
max_output_tokens: std::cmp::min(options.max_decoding_length, 64),
stop_sequences,
},
};

View File

@ -35,6 +35,7 @@ tantivy = { workspace = true }
anyhow = { workspace = true }
sysinfo = "0.29.8"
nvml-wrapper = "0.9.0"
http-api-bindings = { path = "../http-api-bindings" }
[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
llama-cpp-bindings = { path = "../llama-cpp-bindings" }

View File

@ -5,8 +5,10 @@ use std::{path::Path, sync::Arc};
use axum::{extract::State, Json};
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
use http_api_bindings::vertex_ai::VertexAIEngine;
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tabby_common::{config::Config, events, path::ModelDir};
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument};
@ -128,22 +130,55 @@ pub struct CompletionState {
impl CompletionState {
pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self {
let model_dir = get_model_dir(&args.model);
let metadata = read_metadata(&model_dir);
let engine = create_engine(args, &model_dir, &metadata);
let (engine, prompt_template) = create_engine(args);
Self {
engine,
prompt_builder: prompt::PromptBuilder::new(
metadata.prompt_template,
prompt_template,
config.experimental.enable_prompt_rewrite,
),
}
}
}
fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Option<String>) {
if args.device != super::Device::ExperimentalHttp {
let model_dir = get_model_dir(&args.model);
let metadata = read_metadata(&model_dir);
let engine = create_local_engine(args, &model_dir, &metadata);
(engine, metadata.prompt_template)
} else {
let params: Value =
serdeconv::from_json_str(&args.model).expect("Failed to parse model string");
let kind = params
.get("kind")
.expect("Missing kind field")
.as_str()
.expect("Type unmatched");
if kind != "vertex-ai" {
fatal!("Only vertex_ai is supported for http backend");
}
let api_endpoint = params
.get("api_endpoint")
.expect("Missing api_endpoint field")
.as_str()
.expect("Type unmatched");
let authorization = params
.get("authorization")
.expect("Missing authorization field")
.as_str()
.expect("Type unmatched");
let engine = Box::new(VertexAIEngine::create(api_endpoint, authorization));
(engine, Some(VertexAIEngine::prompt_template()))
}
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn create_engine(
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
@ -152,7 +187,7 @@ fn create_engine(
}
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn create_engine(
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,

View File

@ -14,7 +14,7 @@ use clap::Args;
use tabby_common::{config::Config, usage};
use tokio::time::sleep;
use tower_http::cors::CorsLayer;
use tracing::info;
use tracing::{info, warn};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
@ -58,6 +58,9 @@ pub enum Device {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
#[strum(serialize = "metal")]
Metal,
#[strum(serialize = "experimental_http")]
ExperimentalHttp,
}
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
@ -129,22 +132,28 @@ fn should_download_ggml_files(device: &Device) -> bool {
pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args);
// Ensure model exists.
tabby_download::download_model(
&args.model,
/* download_ctranslate2_files= */
!should_download_ggml_files(&args.device),
/* download_ggml_files= */ should_download_ggml_files(&args.device),
/* prefer_local_file= */ true,
)
.await
.unwrap_or_else(|err| {
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
if args.device != Device::ExperimentalHttp {
let download_ctranslate2_files = !should_download_ggml_files(&args.device);
let download_ggml_files = should_download_ggml_files(&args.device);
// Ensure model exists.
tabby_download::download_model(
&args.model,
download_ctranslate2_files,
download_ggml_files,
/* prefer_local_file= */ true,
)
});
.await
.unwrap_or_else(|err| {
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
)
});
} else {
warn!("HTTP device is unstable and does not comply with semver expectations.")
}
info!("Starting server, this might takes a few minutes...");
let app = Router::new()