feat: add /generate and /generate_streaming (#482)

* feat: add generate_stream interface

* extract engine::create_engine

* feat add generate::generate

* support streaming in llama.cpp

* support streaming in ctranslate2

* update

* fix formatting

* refactor: extract helpers functions
release-0.2
Meng Zhang 2023-09-28 10:20:50 -07:00 committed by GitHub
parent 1d6ac7836b
commit 44f013f26e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 422 additions and 190 deletions

36
Cargo.lock generated
View File

@ -247,6 +247,26 @@ dependencies = [
"tower-service",
]
[[package]]
name = "axum-streams"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a3e367d27d8c1ce16fbd0d96ddf05105fd1147f5d35ffc55e254dab914e72e8"
dependencies = [
"axum",
"bytes",
"cargo-husky",
"futures",
"futures-util",
"http",
"mime",
"serde",
"serde_json",
"tokio",
"tokio-stream",
"tokio-util",
]
[[package]]
name = "axum-tracing-opentelemetry"
version = "0.10.0"
@ -417,6 +437,12 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a4f925191b4367301851c6d99b09890311d74b0d43f274c0b34c86d308a3663"
[[package]]
name = "cargo-husky"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b02b629252fe8ef6460461409564e2c21d0c8e77e0944f3d189ff06c4e932ad"
[[package]]
name = "cc"
version = "1.0.79"
@ -666,11 +692,13 @@ dependencies = [
name = "ctranslate2-bindings"
version = "0.1.0"
dependencies = [
"async-stream",
"async-trait",
"cmake",
"cxx",
"cxx-build",
"derive_builder",
"futures",
"rust-cxx-cmake-bridge",
"stop-words",
"tabby-inference",
@ -1295,6 +1323,7 @@ name = "http-api-bindings"
version = "0.1.0"
dependencies = [
"async-trait",
"futures",
"reqwest",
"serde",
"serde_json",
@ -1625,11 +1654,13 @@ checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519"
name = "llama-cpp-bindings"
version = "0.1.0"
dependencies = [
"async-stream",
"async-trait",
"cmake",
"cxx",
"cxx-build",
"derive_builder",
"futures",
"stop-words",
"tabby-inference",
"tokenizers",
@ -3012,10 +3043,13 @@ name = "tabby"
version = "0.1.1"
dependencies = [
"anyhow",
"async-stream",
"axum",
"axum-streams",
"axum-tracing-opentelemetry",
"clap",
"ctranslate2-bindings",
"futures",
"http-api-bindings",
"hyper",
"lazy_static",
@ -3086,8 +3120,10 @@ dependencies = [
name = "tabby-inference"
version = "0.1.0"
dependencies = [
"async-stream",
"async-trait",
"derive_builder",
"futures",
]
[[package]]

View File

@ -35,3 +35,5 @@ async-trait = "0.1.72"
reqwest = { version = "0.11.18" }
derive_builder = "0.12.0"
tokenizers = "0.13.4-rc3"
futures = "0.3.28"
async-stream = "0.3.5"

View File

@ -12,6 +12,8 @@ tokio-util = { workspace = true }
tabby-inference = { path = "../tabby-inference" }
async-trait = { workspace = true }
stop-words = { path = "../stop-words" }
futures.workspace = true
async-stream.workspace = true
[build-dependencies]
cxx-build = "1.0"

View File

@ -1,10 +1,13 @@
use std::sync::Arc;
use async_stream::stream;
use async_trait::async_trait;
use derive_builder::Builder;
use futures::stream::BoxStream;
use stop_words::{StopWords, StopWordsCondition};
use tabby_inference::{TextGeneration, TextGenerationOptions};
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc::{channel, Sender};
use tokio_util::sync::CancellationToken;
#[cxx::bridge(namespace = "tabby")]
@ -67,13 +70,19 @@ pub struct CTranslate2EngineOptions {
}
pub struct InferenceContext {
sender: Sender<u32>,
stop_condition: StopWordsCondition,
cancel: CancellationToken,
}
impl InferenceContext {
fn new(stop_condition: StopWordsCondition, cancel: CancellationToken) -> Self {
fn new(
sender: Sender<u32>,
stop_condition: StopWordsCondition,
cancel: CancellationToken,
) -> Self {
InferenceContext {
sender,
stop_condition,
cancel,
}
@ -108,30 +117,45 @@ impl CTranslate2Engine {
#[async_trait]
impl TextGeneration for CTranslate2Engine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let s = self.generate_stream(prompt, options).await;
helpers::stream_to_string(s).await
}
async fn generate_stream(
&self,
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String> {
let encoding = self.tokenizer.encode(prompt, true).unwrap();
let engine = self.engine.clone();
let s = stream! {
let cancel = CancellationToken::new();
let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard();
let cancel = CancellationToken::new();
let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard();
let stop_condition = self
.stop_words
.create_condition(self.tokenizer.clone(), options.stop_words);
let stop_condition = self
.stop_words
.create_condition(self.tokenizer.clone(), options.stop_words);
let context = InferenceContext::new(stop_condition, cancel_for_inference);
let output_ids = tokio::task::spawn_blocking(move || {
let context = Box::new(context);
engine.inference(
context,
inference_callback,
truncate_tokens(encoding.get_tokens(), options.max_input_length),
options.max_decoding_length,
options.sampling_temperature,
)
})
.await
.expect("Inference failed");
self.tokenizer.decode(&output_ids, true).unwrap()
let (sender, mut receiver) = channel::<u32>(8);
let context = InferenceContext::new(sender, stop_condition, cancel_for_inference);
tokio::task::spawn(async move {
let context = Box::new(context);
engine.inference(
context,
inference_callback,
truncate_tokens(encoding.get_tokens(), options.max_input_length),
options.max_decoding_length,
options.sampling_temperature,
);
});
while let Some(next_token_id) = receiver.recv().await {
let text = self.tokenizer.decode(&[next_token_id], true).unwrap();
yield text;
}
};
Box::pin(s)
}
}
@ -150,6 +174,7 @@ fn inference_callback(
token_id: u32,
_token: String,
) -> bool {
let _ = context.sender.blocking_send(token_id);
if context.cancel.is_cancelled() {
true
} else {

View File

@ -5,6 +5,7 @@ edition = "2021"
[dependencies]
async-trait.workspace = true
futures.workspace = true
reqwest = { workspace = true, features = ["json"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }

View File

@ -1,8 +1,9 @@
use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::header;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tabby_inference::{TextGeneration, TextGenerationOptions};
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
#[derive(Serialize)]
struct Request {
@ -87,4 +88,12 @@ impl TextGeneration for FastChatEngine {
resp.choices[0].text[0].clone()
}
async fn generate_stream(
&self,
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String> {
helpers::string_to_stream(self.generate(prompt, options).await).await
}
}

View File

@ -1,8 +1,9 @@
use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::header;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tabby_inference::{TextGeneration, TextGenerationOptions};
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
#[derive(Serialize)]
struct Request {
@ -107,4 +108,12 @@ impl TextGeneration for VertexAIEngine {
resp.predictions[0].content.clone()
}
async fn generate_stream(
&self,
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String> {
helpers::string_to_stream(self.generate(prompt, options).await).await
}
}

View File

@ -16,3 +16,5 @@ derive_builder = { workspace = true }
tokenizers = { workspace = true }
stop-words = { version = "0.1.0", path = "../stop-words" }
tokio-util = { workspace = true }
futures.workspace = true
async-stream.workspace = true

View File

@ -1,12 +1,13 @@
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use async_stream::stream;
use async_trait::async_trait;
use derive_builder::Builder;
use ffi::create_engine;
use futures::{lock::Mutex, stream::BoxStream};
use stop_words::StopWords;
use tabby_inference::{TextGeneration, TextGenerationOptions};
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
use tokenizers::tokenizer::Tokenizer;
use tokio_util::sync::CancellationToken;
#[cxx::bridge(namespace = "llama")]
mod ffi {
@ -35,7 +36,7 @@ pub struct LlamaEngineOptions {
}
pub struct LlamaEngine {
engine: Arc<Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>>,
engine: Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>,
tokenizer: Arc<Tokenizer>,
stop_words: StopWords,
}
@ -43,7 +44,7 @@ pub struct LlamaEngine {
impl LlamaEngine {
pub fn create(options: LlamaEngineOptions) -> Self {
LlamaEngine {
engine: Arc::new(Mutex::new(create_engine(&options.model_path))),
engine: Mutex::new(create_engine(&options.model_path)),
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
stop_words: StopWords::default(),
}
@ -53,51 +54,49 @@ impl LlamaEngine {
#[async_trait]
impl TextGeneration for LlamaEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let cancel = CancellationToken::new();
let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard();
let s = self.generate_stream(prompt, options).await;
helpers::stream_to_string(s).await
}
async fn generate_stream(
&self,
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String> {
let prompt = prompt.to_owned();
let engine = self.engine.clone();
let mut stop_condition = self
.stop_words
.create_condition(self.tokenizer.clone(), options.stop_words);
let output_ids = tokio::task::spawn_blocking(move || {
let engine = engine.lock().unwrap();
let s = stream! {
let engine = self.engine.lock().await;
let eos_token = engine.eos_token();
let mut next_token_id = engine.start(&prompt, options.max_input_length);
if next_token_id == eos_token {
return Vec::new();
}
yield "".to_owned();
} else {
let mut n_remains = options.max_decoding_length - 1;
let mut n_remains = options.max_decoding_length - 1;
let mut output_ids = vec![next_token_id];
while n_remains > 0 {
next_token_id = engine.step(next_token_id);
if next_token_id == eos_token {
break;
}
while n_remains > 0 {
if cancel_for_inference.is_cancelled() {
// The token was cancelled
break;
if stop_condition.next_token(next_token_id) {
break;
}
let text = self.tokenizer.decode(&[next_token_id], true).unwrap();
yield text;
n_remains -= 1;
}
next_token_id = engine.step(next_token_id);
if next_token_id == eos_token {
break;
}
if stop_condition.next_token(next_token_id) {
break;
}
output_ids.push(next_token_id);
n_remains -= 1;
}
engine.end();
output_ids
})
.await
.expect("Inference failed");
self.tokenizer.decode(&output_ids, true).unwrap()
};
Box::pin(s)
}
}

View File

@ -6,5 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
async-stream = { workspace = true }
async-trait = { workspace = true }
derive_builder = "0.12.0"
futures = { workspace = true }

View File

@ -1,5 +1,6 @@
use async_trait::async_trait;
use derive_builder::Builder;
use futures::stream::BoxStream;
#[derive(Builder, Debug)]
pub struct TextGenerationOptions {
@ -21,4 +22,33 @@ static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
#[async_trait]
pub trait TextGeneration: Sync + Send {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
async fn generate_stream(
&self,
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String>;
}
pub mod helpers {
use async_stream::stream;
use futures::{pin_mut, stream::BoxStream, Stream, StreamExt};
pub async fn stream_to_string(s: impl Stream<Item = String>) -> String {
pin_mut!(s);
let mut text = "".to_owned();
while let Some(value) = s.next().await {
text += &value;
}
text
}
pub async fn string_to_stream(s: String) -> BoxStream<'static, String> {
let stream = stream! {
yield s
};
Box::pin(stream)
}
}

View File

@ -36,6 +36,9 @@ anyhow = { workspace = true }
sysinfo = "0.29.8"
nvml-wrapper = "0.9.0"
http-api-bindings = { path = "../http-api-bindings" }
futures = { workspace = true }
async-stream = { workspace = true }
axum-streams = { version = "0.9.1", features = ["json"] }
[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
llama-cpp-bindings = { path = "../llama-cpp-bindings" }

View File

@ -1,21 +1,17 @@
mod languages;
mod prompt;
use std::{path::Path, sync::Arc};
use std::sync::Arc;
use axum::{extract::State, Json};
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
use http_api_bindings::{fastchat::FastChatEngine, vertex_ai::VertexAIEngine};
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tabby_common::{config::Config, events, path::ModelDir};
use tabby_common::{config::Config, events};
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument};
use utoipa::ToSchema;
use self::languages::get_stop_words;
use crate::fatal;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
@ -124,14 +120,16 @@ pub async fn completion(
}
pub struct CompletionState {
engine: Box<dyn TextGeneration>,
engine: Arc<Box<dyn TextGeneration>>,
prompt_builder: prompt::PromptBuilder,
}
impl CompletionState {
pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self {
let (engine, prompt_template) = create_engine(args);
pub fn new(
engine: Arc<Box<dyn TextGeneration>>,
prompt_template: Option<String>,
config: &Config,
) -> Self {
Self {
engine,
prompt_builder: prompt::PromptBuilder::new(
@ -141,120 +139,3 @@ impl CompletionState {
}
}
}
fn get_param(params: &Value, key: &str) -> String {
params
.get(key)
.unwrap_or_else(|| panic!("Missing {} field", key))
.as_str()
.expect("Type unmatched")
.to_string()
}
fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Option<String>) {
if args.device != super::Device::ExperimentalHttp {
let model_dir = get_model_dir(&args.model);
let metadata = read_metadata(&model_dir);
let engine = create_local_engine(args, &model_dir, &metadata);
(engine, metadata.prompt_template)
} else {
let params: Value =
serdeconv::from_json_str(&args.model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "vertex-ai" {
let api_endpoint = get_param(&params, "api_endpoint");
let authorization = get_param(&params, "authorization");
let engine = Box::new(VertexAIEngine::create(
api_endpoint.as_str(),
authorization.as_str(),
));
(engine, Some(VertexAIEngine::prompt_template()))
} else if kind == "fastchat" {
let model_name = get_param(&params, "model_name");
let api_endpoint = get_param(&params, "api_endpoint");
let authorization = get_param(&params, "authorization");
let engine = Box::new(FastChatEngine::create(
api_endpoint.as_str(),
model_name.as_str(),
authorization.as_str(),
));
(engine, Some(FastChatEngine::prompt_template()))
} else {
fatal!("Only vertex_ai and fastchat are supported for http backend");
}
}
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
create_ctranslate2_engine(args, model_dir, metadata)
}
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
if args.device != super::Device::Metal {
create_ctranslate2_engine(args, model_dir, metadata)
} else {
create_llama_engine(model_dir)
}
}
fn create_ctranslate2_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
let device = format!("{}", args.device);
let compute_type = format!("{}", args.compute_type);
let options = CTranslate2EngineOptionsBuilder::default()
.model_path(model_dir.ctranslate2_dir())
.tokenizer_path(model_dir.tokenizer_file())
.device(device)
.model_type(metadata.auto_model.clone())
.device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device)
.compute_type(compute_type)
.build()
.unwrap();
Box::new(CTranslate2Engine::create(options))
}
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn create_llama_engine(model_dir: &ModelDir) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
.model_path(model_dir.ggml_q8_0_file())
.tokenizer_path(model_dir.tokenizer_file())
.build()
.unwrap();
Box::new(llama_cpp_bindings::LlamaEngine::create(options))
}
fn get_model_dir(model: &str) -> ModelDir {
if Path::new(model).exists() {
ModelDir::from(model)
} else {
ModelDir::new(model)
}
}
#[derive(Deserialize)]
struct Metadata {
auto_model: String,
prompt_template: Option<String>,
}
fn read_metadata(model_dir: &ModelDir) -> Metadata {
serdeconv::from_json_file(model_dir.metadata_file())
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", model_dir.metadata_file()))
}

View File

@ -0,0 +1,127 @@
use std::path::Path;
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
use http_api_bindings::{fastchat::FastChatEngine, vertex_ai::VertexAIEngine};
use serde::Deserialize;
use serde_json::Value;
use tabby_common::path::ModelDir;
use tabby_inference::TextGeneration;
use crate::fatal;
fn get_param(params: &Value, key: &str) -> String {
params
.get(key)
.unwrap_or_else(|| panic!("Missing {} field", key))
.as_str()
.expect("Type unmatched")
.to_string()
}
pub fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Option<String>) {
if args.device != super::Device::ExperimentalHttp {
let model_dir = get_model_dir(&args.model);
let metadata = read_metadata(&model_dir);
let engine = create_local_engine(args, &model_dir, &metadata);
(engine, metadata.prompt_template)
} else {
let params: Value =
serdeconv::from_json_str(&args.model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "vertex-ai" {
let api_endpoint = get_param(&params, "api_endpoint");
let authorization = get_param(&params, "authorization");
let engine = Box::new(VertexAIEngine::create(
api_endpoint.as_str(),
authorization.as_str(),
));
(engine, Some(VertexAIEngine::prompt_template()))
} else if kind == "fastchat" {
let model_name = get_param(&params, "model_name");
let api_endpoint = get_param(&params, "api_endpoint");
let authorization = get_param(&params, "authorization");
let engine = Box::new(FastChatEngine::create(
api_endpoint.as_str(),
model_name.as_str(),
authorization.as_str(),
));
(engine, Some(FastChatEngine::prompt_template()))
} else {
fatal!("Only vertex_ai and fastchat are supported for http backend");
}
}
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
create_ctranslate2_engine(args, model_dir, metadata)
}
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
if args.device != super::Device::Metal {
create_ctranslate2_engine(args, model_dir, metadata)
} else {
create_llama_engine(model_dir)
}
}
fn create_ctranslate2_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
let device = format!("{}", args.device);
let compute_type = format!("{}", args.compute_type);
let options = CTranslate2EngineOptionsBuilder::default()
.model_path(model_dir.ctranslate2_dir())
.tokenizer_path(model_dir.tokenizer_file())
.device(device)
.model_type(metadata.auto_model.clone())
.device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device)
.compute_type(compute_type)
.build()
.unwrap();
Box::new(CTranslate2Engine::create(options))
}
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn create_llama_engine(model_dir: &ModelDir) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
.model_path(model_dir.ggml_q8_0_file())
.tokenizer_path(model_dir.tokenizer_file())
.build()
.unwrap();
Box::new(llama_cpp_bindings::LlamaEngine::create(options))
}
fn get_model_dir(model: &str) -> ModelDir {
if Path::new(model).exists() {
ModelDir::from(model)
} else {
ModelDir::new(model)
}
}
#[derive(Deserialize)]
struct Metadata {
auto_model: String,
prompt_template: Option<String>,
}
fn read_metadata(model_dir: &ModelDir) -> Metadata {
serdeconv::from_json_file(model_dir.metadata_file())
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", model_dir.metadata_file()))
}

View File

@ -0,0 +1,87 @@
use std::sync::Arc;
use async_stream::stream;
use axum::{extract::State, response::IntoResponse, Json};
use axum_streams::StreamBodyAs;
use serde::{Deserialize, Serialize};
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::instrument;
use utoipa::ToSchema;
pub struct GenerateState {
engine: Arc<Box<dyn TextGeneration>>,
}
impl GenerateState {
pub fn new(engine: Arc<Box<dyn TextGeneration>>) -> Self {
Self { engine }
}
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct GenerateRequest {
#[schema(
example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\ndef"
)]
prompt: String,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct GenerateResponse {
text: String,
}
#[utoipa::path(
post,
path = "/v1/generate",
request_body = GenerateRequest,
operation_id = "generate",
tag = "v1",
responses(
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/json"),
)
)]
#[instrument(skip(state, request))]
pub async fn generate(
State(state): State<Arc<GenerateState>>,
Json(request): Json<GenerateRequest>,
) -> impl IntoResponse {
let options = build_options(&request);
Json(GenerateResponse {
text: state.engine.generate(&request.prompt, options).await,
})
}
#[utoipa::path(
post,
path = "/v1/generate_stream",
request_body = GenerateRequest,
operation_id = "generate_stream",
tag = "v1",
responses(
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/jsonstream"),
)
)]
#[instrument(skip(state, request))]
pub async fn generate_stream(
State(state): State<Arc<GenerateState>>,
Json(request): Json<GenerateRequest>,
) -> impl IntoResponse {
let options = build_options(&request);
let s = stream! {
for await text in state.engine.generate_stream(&request.prompt, options).await {
yield GenerateResponse { text }
}
};
StreamBodyAs::json_nl(s)
}
fn build_options(_request: &GenerateRequest) -> TextGenerationOptions {
TextGenerationOptionsBuilder::default()
.max_input_length(2048)
.max_decoding_length(usize::MAX)
.sampling_temperature(0.1)
.build()
.unwrap()
}

View File

@ -1,5 +1,7 @@
mod completions;
mod engine;
mod events;
mod generate;
mod health;
use std::{
@ -19,7 +21,7 @@ use tracing::{info, warn};
use utoipa::{openapi::ServerBuilder, OpenApi};
use utoipa_swagger_ui::SwaggerUi;
use self::health::HealthState;
use self::{engine::create_engine, health::HealthState};
use crate::fatal;
#[derive(OpenApi)]
@ -39,13 +41,15 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
servers(
(url = "https://playground.app.tabbyml.com", description = "Playground server"),
),
paths(events::log_event, completions::completion, health::health),
paths(events::log_event, completions::completion, generate::generate, generate::generate_stream, health::health),
components(schemas(
events::LogEventRequest,
completions::CompletionRequest,
completions::CompletionResponse,
completions::Segments,
completions::Choice,
generate::GenerateRequest,
generate::GenerateResponse,
health::HealthState,
health::Version,
))
@ -171,6 +175,8 @@ pub async fn main(config: &Config, args: &ServeArgs) {
}
fn api_router(args: &ServeArgs, config: &Config) -> Router {
let (engine, prompt_template) = create_engine(args);
let engine = Arc::new(engine);
Router::new()
.route("/events", routing::post(events::log_event))
.route(
@ -179,8 +185,19 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
)
.route(
"/completions",
routing::post(completions::completion)
.with_state(Arc::new(completions::CompletionState::new(args, config))),
routing::post(completions::completion).with_state(Arc::new(
completions::CompletionState::new(engine.clone(), prompt_template, config),
)),
)
.route(
"/generate",
routing::post(generate::generate)
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
)
.route(
"/generate_stream",
routing::post(generate::generate_stream)
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
)
.layer(CorsLayer::permissive())
.layer(opentelemetry_tracing_layer())