feat: add support vertex-ai http bindings (#419)
* feat: add support vertex-ai http bindings * support prefix / suffixrelease-0.2
parent
17397c8c8c
commit
f0ed366420
|
|
@ -2959,6 +2959,7 @@ dependencies = [
|
|||
"axum-tracing-opentelemetry",
|
||||
"clap",
|
||||
"ctranslate2-bindings",
|
||||
"http-api-bindings",
|
||||
"hyper",
|
||||
"lazy_static",
|
||||
"llama-cpp-bindings",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
## Usage
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
export MODEL_ID="code-gecko"
|
||||
|
|
@ -8,3 +8,14 @@ export AUTHORIZATION="Bearer $(gcloud auth print-access-token)"
|
|||
|
||||
cargo run --example simple
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
export MODEL_ID="code-gecko"
|
||||
export PROJECT_ID="$(gcloud config get project)"
|
||||
export API_ENDPOINT="https://us-central1-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:predict"
|
||||
export AUTHORIZATION="Bearer $(gcloud auth print-access-token)"
|
||||
|
||||
cargo run serve --device experimental-http --model "{\"kind\": \"vertex-ai\", \"api_endpoint\": \"$API_ENDPOINT\", \"authorization\": \"$AUTHORIZATION\"}"
|
||||
```
|
||||
|
|
|
|||
|
|
@ -56,23 +56,34 @@ impl VertexAIEngine {
|
|||
client,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prompt_template() -> String {
|
||||
"{prefix}<MID>{suffix}".to_owned()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TextGeneration for VertexAIEngine {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let stop_sequences: Vec<String> =
|
||||
options.stop_words.iter().map(|x| x.to_string()).collect();
|
||||
let stop_sequences: Vec<String> = options
|
||||
.stop_words
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
// vertex supports at most 5 stop sequence.
|
||||
.take(5)
|
||||
.collect();
|
||||
|
||||
let tokens: Vec<&str> = prompt.split("<MID>").collect();
|
||||
let request = Request {
|
||||
instances: vec![Instance {
|
||||
prefix: prompt.to_owned(),
|
||||
suffix: None,
|
||||
prefix: tokens[0].to_owned(),
|
||||
suffix: Some(tokens[1].to_owned()),
|
||||
}],
|
||||
// options.max_input_length is ignored.
|
||||
parameters: Parameters {
|
||||
temperature: options.sampling_temperature,
|
||||
max_output_tokens: options.max_decoding_length,
|
||||
// vertex supports at most 64 output tokens.
|
||||
max_output_tokens: std::cmp::min(options.max_decoding_length, 64),
|
||||
stop_sequences,
|
||||
},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ tantivy = { workspace = true }
|
|||
anyhow = { workspace = true }
|
||||
sysinfo = "0.29.8"
|
||||
nvml-wrapper = "0.9.0"
|
||||
http-api-bindings = { path = "../http-api-bindings" }
|
||||
|
||||
[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
|
||||
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
|
||||
|
|
|
|||
|
|
@ -5,8 +5,10 @@ use std::{path::Path, sync::Arc};
|
|||
|
||||
use axum::{extract::State, Json};
|
||||
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
|
||||
use http_api_bindings::vertex_ai::VertexAIEngine;
|
||||
use hyper::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tabby_common::{config::Config, events, path::ModelDir};
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
||||
use tracing::{debug, instrument};
|
||||
|
|
@ -128,22 +130,55 @@ pub struct CompletionState {
|
|||
|
||||
impl CompletionState {
|
||||
pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self {
|
||||
let model_dir = get_model_dir(&args.model);
|
||||
let metadata = read_metadata(&model_dir);
|
||||
let engine = create_engine(args, &model_dir, &metadata);
|
||||
let (engine, prompt_template) = create_engine(args);
|
||||
|
||||
Self {
|
||||
engine,
|
||||
prompt_builder: prompt::PromptBuilder::new(
|
||||
metadata.prompt_template,
|
||||
prompt_template,
|
||||
config.experimental.enable_prompt_rewrite,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Option<String>) {
|
||||
if args.device != super::Device::ExperimentalHttp {
|
||||
let model_dir = get_model_dir(&args.model);
|
||||
let metadata = read_metadata(&model_dir);
|
||||
let engine = create_local_engine(args, &model_dir, &metadata);
|
||||
(engine, metadata.prompt_template)
|
||||
} else {
|
||||
let params: Value =
|
||||
serdeconv::from_json_str(&args.model).expect("Failed to parse model string");
|
||||
|
||||
let kind = params
|
||||
.get("kind")
|
||||
.expect("Missing kind field")
|
||||
.as_str()
|
||||
.expect("Type unmatched");
|
||||
|
||||
if kind != "vertex-ai" {
|
||||
fatal!("Only vertex_ai is supported for http backend");
|
||||
}
|
||||
|
||||
let api_endpoint = params
|
||||
.get("api_endpoint")
|
||||
.expect("Missing api_endpoint field")
|
||||
.as_str()
|
||||
.expect("Type unmatched");
|
||||
let authorization = params
|
||||
.get("authorization")
|
||||
.expect("Missing authorization field")
|
||||
.as_str()
|
||||
.expect("Type unmatched");
|
||||
let engine = Box::new(VertexAIEngine::create(api_endpoint, authorization));
|
||||
(engine, Some(VertexAIEngine::prompt_template()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||
fn create_engine(
|
||||
fn create_local_engine(
|
||||
args: &crate::serve::ServeArgs,
|
||||
model_dir: &ModelDir,
|
||||
metadata: &Metadata,
|
||||
|
|
@ -152,7 +187,7 @@ fn create_engine(
|
|||
}
|
||||
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
fn create_engine(
|
||||
fn create_local_engine(
|
||||
args: &crate::serve::ServeArgs,
|
||||
model_dir: &ModelDir,
|
||||
metadata: &Metadata,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ use clap::Args;
|
|||
use tabby_common::{config::Config, usage};
|
||||
use tokio::time::sleep;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tracing::info;
|
||||
use tracing::{info, warn};
|
||||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
|
|
@ -58,6 +58,9 @@ pub enum Device {
|
|||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
#[strum(serialize = "metal")]
|
||||
Metal,
|
||||
|
||||
#[strum(serialize = "experimental_http")]
|
||||
ExperimentalHttp,
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
|
||||
|
|
@ -129,22 +132,28 @@ fn should_download_ggml_files(device: &Device) -> bool {
|
|||
pub async fn main(config: &Config, args: &ServeArgs) {
|
||||
valid_args(args);
|
||||
|
||||
// Ensure model exists.
|
||||
tabby_download::download_model(
|
||||
&args.model,
|
||||
/* download_ctranslate2_files= */
|
||||
!should_download_ggml_files(&args.device),
|
||||
/* download_ggml_files= */ should_download_ggml_files(&args.device),
|
||||
/* prefer_local_file= */ true,
|
||||
)
|
||||
.await
|
||||
.unwrap_or_else(|err| {
|
||||
fatal!(
|
||||
"Failed to fetch model due to '{}', is '{}' a valid model id?",
|
||||
err,
|
||||
args.model
|
||||
if args.device != Device::ExperimentalHttp {
|
||||
let download_ctranslate2_files = !should_download_ggml_files(&args.device);
|
||||
let download_ggml_files = should_download_ggml_files(&args.device);
|
||||
|
||||
// Ensure model exists.
|
||||
tabby_download::download_model(
|
||||
&args.model,
|
||||
download_ctranslate2_files,
|
||||
download_ggml_files,
|
||||
/* prefer_local_file= */ true,
|
||||
)
|
||||
});
|
||||
.await
|
||||
.unwrap_or_else(|err| {
|
||||
fatal!(
|
||||
"Failed to fetch model due to '{}', is '{}' a valid model id?",
|
||||
err,
|
||||
args.model
|
||||
)
|
||||
});
|
||||
} else {
|
||||
warn!("HTTP device is unstable and does not comply with semver expectations.")
|
||||
}
|
||||
|
||||
info!("Starting server, this might takes a few minutes...");
|
||||
let app = Router::new()
|
||||
|
|
|
|||
Loading…
Reference in New Issue