refactor: cleanup chat api make it message oriented (#497)

* refactor: refactor into /chat/completions api

* Revert "feat: support request level stop words (#492)"

This reverts commit 0d6840e372.

* feat: adjust interface

* switch interface in tabby-playground

* move to chat/prompt, add unit test

* update interface
release-0.2
Meng Zhang 2023-10-02 08:39:15 -07:00 committed by GitHub
parent dfdd0373a6
commit f05dd3a2f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 347 additions and 203 deletions

24
Cargo.lock generated
View File

@ -1779,6 +1779,12 @@ dependencies = [
"libc",
]
[[package]]
name = "memo-map"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "374c335b2df19e62d4cb323103473cbc6510980253119180de862d89184f6a83"
[[package]]
name = "memoffset"
version = "0.8.0"
@ -1804,6 +1810,17 @@ dependencies = [
"unicase",
]
[[package]]
name = "minijinja"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80084fa3099f58b7afab51e5f92e24c2c2c68dcad26e96ad104bd6011570461d"
dependencies = [
"memo-map",
"self_cell",
"serde",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@ -2790,6 +2807,12 @@ dependencies = [
"libc",
]
[[package]]
name = "self_cell"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c309e515543e67811222dbc9e3dd7e1056279b782e1dacffe4242b718734fb6"
[[package]]
name = "serde"
version = "1.0.171"
@ -3078,6 +3101,7 @@ dependencies = [
"lazy_static",
"llama-cpp-bindings",
"mime_guess",
"minijinja",
"nvml-wrapper",
"opentelemetry",
"opentelemetry-otlp",

View File

@ -39,9 +39,6 @@ export function Chat({ id, initialMessages, className }: ChatProps) {
}
}
})
if (messages.length > 2) {
setMessages(messages.slice(messages.length - 2, messages.length))
}
return (
<>
<div className={cn('pb-[200px] pt-4 md:pt-10', className)}>

View File

@ -1,5 +1,6 @@
import { type Message } from 'ai/react'
import { CohereStream, StreamingTextResponse } from 'ai'
import { StreamingTextResponse } from 'ai'
import { TabbyStream } from '@/lib/tabby-stream'
import { useEffect } from 'react'
const serverUrl =
@ -15,25 +16,17 @@ export function usePatchFetch() {
}
const { messages } = JSON.parse(options!.body as string)
const res = await fetch(`${serverUrl}/v1beta/generate_stream`, {
const res = await fetch(`${serverUrl}/v1beta/chat/completions`, {
...options,
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
prompt: messagesToPrompt(messages)
})
})
const stream = CohereStream(res, undefined)
const stream = TabbyStream(res, undefined)
return new StreamingTextResponse(stream)
}
}, [])
}
function messagesToPrompt(messages: Message[]) {
const instruction = messages[messages.length - 1].content
const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n${instruction}\n\n### Response:`
return prompt
}

View File

@ -0,0 +1,71 @@
import {
type AIStreamCallbacksAndOptions,
createCallbacksTransformer,
createStreamDataTransformer
} from 'ai';
const utf8Decoder = new TextDecoder('utf-8');
async function processLines(
lines: string[],
controller: ReadableStreamDefaultController<string>,
) {
for (const line of lines) {
const { content } = JSON.parse(line);
controller.enqueue(content);
}
}
async function readAndProcessLines(
reader: ReadableStreamDefaultReader<Uint8Array>,
controller: ReadableStreamDefaultController<string>,
) {
let segment = '';
while (true) {
const { value: chunk, done } = await reader.read();
if (done) {
break;
}
segment += utf8Decoder.decode(chunk, { stream: true });
const linesArray = segment.split(/\r\n|\n|\r/g);
segment = linesArray.pop() || '';
await processLines(linesArray, controller);
}
if (segment) {
const linesArray = [segment];
await processLines(linesArray, controller);
}
controller.close();
}
function createParser(res: Response) {
const reader = res.body?.getReader();
return new ReadableStream<string>({
async start(controller): Promise<void> {
if (!reader) {
controller.close();
return;
}
await readAndProcessLines(reader, controller);
},
});
}
export function TabbyStream(
reader: Response,
callbacks?: AIStreamCallbacksAndOptions,
): ReadableStream {
return createParser(reader)
.pipeThrough(createCallbacksTransformer(callbacks))
.pipeThrough(
createStreamDataTransformer(callbacks?.experimental_streamData),
);
}

View File

@ -137,7 +137,7 @@ impl TextGeneration for CTranslate2Engine {
let decoding = self
.decoding_factory
.create(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), &options.stop_words, options.static_stop_words);
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);
let (sender, mut receiver) = channel::<String>(8);
let context = InferenceContext::new(sender, decoding, cancel_for_inference);

View File

@ -58,11 +58,8 @@ impl FastChatEngine {
#[async_trait]
impl TextGeneration for FastChatEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let _stop_sequences: Vec<String> = options
.static_stop_words
.iter()
.map(|x| x.to_string())
.collect();
let _stop_sequences: Vec<String> =
options.stop_words.iter().map(|x| x.to_string()).collect();
let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request {

View File

@ -67,7 +67,7 @@ impl VertexAIEngine {
impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> = options
.static_stop_words
.stop_words
.iter()
.map(|x| x.to_string())
// vertex supports at most 5 stop sequence.

View File

@ -10,7 +10,8 @@ namespace llama {
TextInferenceEngine::~TextInferenceEngine() {}
namespace {
static size_t N_BATCH = 512;
static size_t N_BATCH = 512; // # per batch inference.
static size_t N_CTX = 4096; // # max kv history.
template<class T>
using owned = std::unique_ptr<T, std::function<void(T*)>>;
@ -59,7 +60,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
return std::distance(logits, std::max_element(logits, logits + n_vocab));
}
bool eval(llama_token* data, size_t size, bool reset) {
void eval(llama_token* data, size_t size, bool reset) {
if (reset) {
n_past_ = 0;
}
@ -76,12 +77,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
auto* ctx = ctx_.get();
llama_kv_cache_tokens_rm(ctx, n_past_, -1);
if (llama_decode(ctx, batch_)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
throw std::runtime_error("Failed to eval");
}
n_past_ += size;
return true;
}
size_t n_past_;
@ -127,7 +126,7 @@ std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
}
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 2048;
ctx_params.n_ctx = N_CTX;
ctx_params.n_batch = N_BATCH;
llama_context* ctx = llama_new_context_with_model(model, ctx_params);

View File

@ -18,7 +18,7 @@ mod ffi {
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]);
fn step(self: Pin<&mut TextInferenceEngine>) -> u32;
fn step(self: Pin<&mut TextInferenceEngine>) -> Result<u32>;
fn end(self: Pin<&mut TextInferenceEngine>);
fn eos_token(&self) -> u32;
@ -75,10 +75,12 @@ impl TextGeneration for LlamaEngine {
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
engine.as_mut().start(input_token_ids);
let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words);
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
let mut n_remains = options.max_decoding_length ;
while n_remains > 0 {
let next_token_id = engine.as_mut().step();
let Ok(next_token_id) = engine.as_mut().step() else {
panic!("Failed to eval");
};
if next_token_id == eos_token {
break;
}

View File

@ -24,35 +24,16 @@ impl Default for DecodingFactory {
}
impl DecodingFactory {
pub fn create(
pub fn create_incremental_decoding(
&self,
tokenizer: Arc<Tokenizer>,
input_token_ids: &[u32],
stop_words: &Vec<String>,
static_stop_words: &'static Vec<&'static str>,
stop_words: &'static Vec<&'static str>,
) -> IncrementalDecoding {
IncrementalDecoding::new(
tokenizer,
vec![
self.get_static_re(static_stop_words),
self.get_re(stop_words),
]
.into_iter()
.flatten()
.collect(),
input_token_ids,
)
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
}
fn get_re(&self, stop_words: &Vec<String>) -> Option<Regex> {
if !stop_words.is_empty() {
Some(create_stop_regex(stop_words))
} else {
None
}
}
fn get_static_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
if stop_words.is_empty() {
None
} else {
@ -67,8 +48,8 @@ impl DecodingFactory {
}
}
fn create_stop_regex<T: AsRef<str>>(stop_words: &[T]) -> Regex {
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(x.as_ref())).collect();
fn create_stop_regex(stop_words: &[&str]) -> Regex {
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(*x)).collect();
// (?m) enables multi-line matching mode.
// \A means absolute begins of string.
@ -78,7 +59,7 @@ fn create_stop_regex<T: AsRef<str>>(stop_words: &[T]) -> Regex {
pub struct IncrementalDecoding {
tokenizer: Arc<Tokenizer>,
stop_re: Vec<Regex>,
stop_re: Option<Regex>,
token_ids: Vec<u32>,
prefix_offset: usize,
@ -88,7 +69,7 @@ pub struct IncrementalDecoding {
}
impl IncrementalDecoding {
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Vec<Regex>, input_token_ids: &[u32]) -> Self {
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
let text = tokenizer
.decode(input_token_ids, /* skip_special_token = */ true)
.expect("Cannot decode token from tokenizer.");
@ -129,7 +110,8 @@ impl IncrementalDecoding {
if !new_text.is_empty() {
self.reversed_text = reverse(new_text) + &self.reversed_text;
for re in &self.stop_re {
if let Some(re) = &self.stop_re {
if re.find(&self.reversed_text).is_some() {
return None;
}

View File

@ -16,10 +16,7 @@ pub struct TextGenerationOptions {
pub sampling_temperature: f32,
#[builder(default = "&EMPTY_STOP_WORDS")]
pub static_stop_words: &'static Vec<&'static str>,
#[builder(default = "vec![]")]
pub stop_words: Vec<String>,
pub stop_words: &'static Vec<&'static str>,
}
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];

View File

@ -39,6 +39,7 @@ http-api-bindings = { path = "../http-api-bindings" }
futures = { workspace = true }
async-stream = { workspace = true }
axum-streams = { version = "0.9.1", features = ["json"] }
minijinja = { version = "1.0.8", features = ["loader"] }
[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
llama-cpp-bindings = { path = "../llama-cpp-bindings" }

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,13 +1,13 @@
1:HL["/playground/_next/static/media/86fdec36ddd9097e-s.p.woff2","font",{"crossOrigin":"","type":"font/woff2"}]
2:HL["/playground/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2","font",{"crossOrigin":"","type":"font/woff2"}]
3:HL["/playground/_next/static/css/d091dc2da2a795e4.css","style"]
0:["f6rsO7djEUh4Fn3OO-Bie",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],"$L4",[[["$","link","0",{"rel":"stylesheet","href":"/playground/_next/static/css/d091dc2da2a795e4.css","precedence":"next"}]],"$L5"]]]]
0:["9a4m76mRTGOnagTYXSPKd",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],"$L4",[[["$","link","0",{"rel":"stylesheet","href":"/playground/_next/static/css/d091dc2da2a795e4.css","precedence":"next"}]],"$L5"]]]]
6:I{"id":5925,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","524:static/chunks/524-6309ecb76a77fdcf.js","185:static/chunks/app/layout-38d79c8bb16c51be.js"],"name":"Toaster","async":false}
7:I{"id":78495,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","524:static/chunks/524-6309ecb76a77fdcf.js","185:static/chunks/app/layout-38d79c8bb16c51be.js"],"name":"Providers","async":false}
8:I{"id":78963,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","524:static/chunks/524-6309ecb76a77fdcf.js","185:static/chunks/app/layout-38d79c8bb16c51be.js"],"name":"Header","async":false}
9:I{"id":81443,"chunks":["272:static/chunks/webpack-e23fff8c5b5084ca.js","971:static/chunks/fd9d1056-5dfc77aa37d8c76f.js","864:static/chunks/864-1669531662d5540a.js"],"name":"","async":false}
a:I{"id":18639,"chunks":["272:static/chunks/webpack-e23fff8c5b5084ca.js","971:static/chunks/fd9d1056-5dfc77aa37d8c76f.js","864:static/chunks/864-1669531662d5540a.js"],"name":"","async":false}
c:I{"id":64074,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","978:static/chunks/978-ab68c4a2390585a1.js","524:static/chunks/524-6309ecb76a77fdcf.js","931:static/chunks/app/page-757d8cb1ec33d4cb.js"],"name":"Chat","async":false}
c:I{"id":10413,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","978:static/chunks/978-342eae78521d80e5.js","524:static/chunks/524-6309ecb76a77fdcf.js","931:static/chunks/app/page-2ebc2d344df80bd2.js"],"name":"Chat","async":false}
5:[["$","meta","0",{"charSet":"utf-8"}],["$","title","1",{"children":"Tabby Playground"}],["$","meta","2",{"name":"description","content":"Tabby, an opensource, self-hosted AI coding assistant."}],["$","meta","3",{"name":"theme-color","media":"(prefers-color-scheme: light)","content":"white"}],["$","meta","4",{"name":"theme-color","media":"(prefers-color-scheme: dark)","content":"black"}],["$","meta","5",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","6",{"name":"next-size-adjust"}]]
4:[null,["$","html",null,{"lang":"en","suppressHydrationWarning":true,"children":[["$","head",null,{}],["$","body",null,{"className":"font-sans antialiased __variable_4e6684 __variable_3d950d","children":[["$","$L6",null,{}],["$","$L7",null,{"attribute":"class","defaultTheme":"system","enableSystem":true,"children":[["$","div",null,{"className":"flex flex-col min-h-screen","children":[["$","$L8",null,{}],["$","main",null,{"className":"flex flex-col flex-1 bg-muted/50","children":["$","$L9",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","template":["$","$La",null,{}],"templateStyles":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"childProp":{"current":["$Lb",["$","$Lc",null,{"id":"JBeXiJC"}],null],"segment":"__PAGE__"},"styles":[]}]}]]}],null]}]]}]]}],null]
4:[null,["$","html",null,{"lang":"en","suppressHydrationWarning":true,"children":[["$","head",null,{}],["$","body",null,{"className":"font-sans antialiased __variable_4e6684 __variable_3d950d","children":[["$","$L6",null,{}],["$","$L7",null,{"attribute":"class","defaultTheme":"system","enableSystem":true,"children":[["$","div",null,{"className":"flex flex-col min-h-screen","children":[["$","$L8",null,{}],["$","main",null,{"className":"flex flex-col flex-1 bg-muted/50","children":["$","$L9",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","template":["$","$La",null,{}],"templateStyles":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"childProp":{"current":["$Lb",["$","$Lc",null,{"id":"Z43ogQe"}],null],"segment":"__PAGE__"},"styles":[]}]}]]}],null]}]]}]]}],null]
b:null

View File

@ -0,0 +1,96 @@
mod prompt;
use std::sync::Arc;
use async_stream::stream;
use axum::{
extract::State,
response::{IntoResponse, Response},
Json,
};
use axum_streams::StreamBodyAs;
use prompt::ChatPromptBuilder;
use serde::{Deserialize, Serialize};
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::instrument;
use utoipa::ToSchema;
pub struct ChatState {
engine: Arc<Box<dyn TextGeneration>>,
prompt_builder: ChatPromptBuilder,
}
impl ChatState {
pub fn new(engine: Arc<Box<dyn TextGeneration>>, prompt_template: String) -> Self {
Self {
engine,
prompt_builder: ChatPromptBuilder::new(prompt_template),
}
}
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"messages": [
Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()},
Message { role: "assistant".to_owned(), content: "It's a kind of optimization in compiler?".to_owned()},
Message { role: "user".to_owned(), content: "Could you share more details?".to_owned()},
]
}))]
pub struct ChatCompletionRequest {
messages: Vec<Message>,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Message {
role: String,
content: String,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct ChatCompletionChunk {
content: String,
}
#[utoipa::path(
post,
path = "/v1beta/chat/completions",
request_body = ChatCompletionRequest,
operation_id = "chat_completions",
tag = "v1beta",
responses(
(status = 200, description = "Success", body = ChatCompletionChunk, content_type = "application/jsonstream"),
(status = 405, description = "When chat model is not specified, the endpoint will returns 405 Method Not Allowed"),
)
)]
#[instrument(skip(state, request))]
pub async fn completions(
State(state): State<Arc<ChatState>>,
Json(request): Json<ChatCompletionRequest>,
) -> Response {
let (prompt, options) = parse_request(&state, request);
let s = stream! {
for await content in state.engine.generate_stream(&prompt, options).await {
yield ChatCompletionChunk { content }
}
};
StreamBodyAs::json_nl(s).into_response()
}
fn parse_request(
state: &Arc<ChatState>,
request: ChatCompletionRequest,
) -> (String, TextGenerationOptions) {
let mut builder = TextGenerationOptionsBuilder::default();
builder
.max_input_length(2048)
.max_decoding_length(1920)
.sampling_temperature(0.1);
(
state.prompt_builder.build(&request.messages),
builder.build().unwrap(),
)
}

View File

@ -0,0 +1,65 @@
use minijinja::{context, Environment};
use super::Message;
pub struct ChatPromptBuilder {
env: Environment<'static>,
}
impl ChatPromptBuilder {
pub fn new(prompt_template: String) -> Self {
let mut env = Environment::new();
env.add_function("raise_exception", |e: String| panic!("{}", e));
env.add_template_owned("prompt", prompt_template)
.expect("Failed to compile template");
Self { env }
}
pub fn build(&self, messages: &[Message]) -> String {
self.env
.get_template("prompt")
.unwrap()
.render(context!(
messages => messages
))
.expect("Failed to evaluate")
}
}
#[cfg(test)]
mod tests {
use super::*;
static PROMPT_TEMPLATE : &str = "<s>{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + '</s> ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
#[test]
fn test_it_works() {
let builder = ChatPromptBuilder::new(PROMPT_TEMPLATE.to_owned());
let messages = vec![
Message {
role: "user".to_owned(),
content: "What is tail recursion?".to_owned(),
},
Message {
role: "assistant".to_owned(),
content: "It's a kind of optimization in compiler?".to_owned(),
},
Message {
role: "user".to_owned(),
content: "Could you share more details?".to_owned(),
},
];
assert_eq!(builder.build(&messages), "<s>[INST] What is tail recursion? [/INST]It's a kind of optimization in compiler?</s> [INST] Could you share more details? [/INST]")
}
#[test]
#[should_panic]
fn test_it_panic() {
let builder = ChatPromptBuilder::new(PROMPT_TEMPLATE.to_owned());
let messages = vec![Message {
role: "system".to_owned(),
content: "system".to_owned(),
}];
builder.build(&messages);
}
}

View File

@ -71,7 +71,7 @@ pub struct CompletionResponse {
)
)]
#[instrument(skip(state, request))]
pub async fn completion(
pub async fn completions(
State(state): State<Arc<CompletionState>>,
Json(request): Json<CompletionRequest>,
) -> Result<Json<CompletionResponse>, StatusCode> {
@ -80,7 +80,7 @@ pub async fn completion(
.max_input_length(1024 + 512)
.max_decoding_length(128)
.sampling_temperature(0.1)
.static_stop_words(get_stop_words(&language))
.stop_words(get_stop_words(&language))
.build()
.unwrap();

View File

@ -1,94 +0,0 @@
use std::sync::Arc;
use async_stream::stream;
use axum::{extract::State, response::IntoResponse, Json};
use axum_streams::StreamBodyAs;
use serde::{Deserialize, Serialize};
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::instrument;
use utoipa::ToSchema;
pub struct GenerateState {
engine: Arc<Box<dyn TextGeneration>>,
}
impl GenerateState {
pub fn new(engine: Arc<Box<dyn TextGeneration>>) -> Self {
Self { engine }
}
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"prompt": "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n",
}))]
pub struct GenerateRequest {
prompt: String,
stop_words: Option<Vec<String>>,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct GenerateResponse {
text: String,
}
#[utoipa::path(
post,
path = "/v1beta/generate",
request_body = GenerateRequest,
operation_id = "generate",
tag = "v1beta",
responses(
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/json"),
)
)]
#[instrument(skip(state, request))]
pub async fn generate(
State(state): State<Arc<GenerateState>>,
Json(request): Json<GenerateRequest>,
) -> impl IntoResponse {
let (prompt, options) = parse_request(request);
Json(GenerateResponse {
text: state.engine.generate(&prompt, options).await,
})
}
#[utoipa::path(
post,
path = "/v1beta/generate_stream",
request_body = GenerateRequest,
operation_id = "generate_stream",
tag = "v1beta",
responses(
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/jsonstream"),
)
)]
#[instrument(skip(state, request))]
pub async fn generate_stream(
State(state): State<Arc<GenerateState>>,
Json(request): Json<GenerateRequest>,
) -> impl IntoResponse {
let (prompt, options) = parse_request(request);
let s = stream! {
for await text in state.engine.generate_stream(&prompt, options).await {
yield GenerateResponse { text }
}
};
StreamBodyAs::json_nl(s)
}
fn parse_request(request: GenerateRequest) -> (String, TextGenerationOptions) {
let mut builder = TextGenerationOptionsBuilder::default();
builder
.max_input_length(1024)
.max_decoding_length(968)
.sampling_temperature(0.1);
if let Some(stop_words) = request.stop_words {
builder.stop_words(stop_words);
};
(request.prompt, builder.build().unwrap())
}

View File

@ -11,7 +11,7 @@ use utoipa::ToSchema;
pub struct HealthState {
model: String,
#[serde(skip_serializing_if = "Option::is_none")]
instruct_model: Option<String>,
chat_model: Option<String>,
device: String,
compute_type: String,
arch: String,
@ -32,7 +32,7 @@ impl HealthState {
Self {
model: args.model.clone(),
instruct_model: args.instruct_model.clone(),
chat_model: args.chat_model.clone(),
device: args.device.to_string(),
compute_type: args.compute_type.to_string(),
arch: ARCH.to_string(),

View File

@ -1,7 +1,7 @@
mod chat;
mod completions;
mod engine;
mod events;
mod generate;
mod health;
mod playground;
@ -42,15 +42,16 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
servers(
(url = "https://playground.app.tabbyml.com", description = "Playground server"),
),
paths(events::log_event, completions::completion, generate::generate, generate::generate_stream, health::health),
paths(events::log_event, completions::completions, chat::completions, health::health),
components(schemas(
events::LogEventRequest,
completions::CompletionRequest,
completions::CompletionResponse,
completions::Segments,
completions::Choice,
generate::GenerateRequest,
generate::GenerateResponse,
chat::ChatCompletionRequest,
chat::Message,
chat::ChatCompletionChunk,
health::HealthState,
health::Version,
))
@ -105,14 +106,13 @@ pub enum ComputeType {
#[derive(Args)]
pub struct ServeArgs {
/// Model id for `/completion` API endpoint.
/// Model id for `/completions` API endpoint.
#[clap(long)]
model: String,
/// Model id for `/generate` and `/generate_stream` API endpoints.
/// If not set, `model` will be loaded for the purpose.
/// Model id for `/chat/completions` API endpoints.
#[clap(long)]
instruct_model: Option<String>,
chat_model: Option<String>,
#[clap(long, default_value_t = 8080)]
port: u16,
@ -149,8 +149,8 @@ pub async fn main(config: &Config, args: &ServeArgs) {
if args.device != Device::ExperimentalHttp {
download_model(&args.model, &args.device).await;
if let Some(instruct_model) = &args.instruct_model {
download_model(instruct_model, &args.device).await;
if let Some(chat_model) = &args.chat_model {
download_model(chat_model, &args.device).await;
}
} else {
warn!("HTTP device is unstable and does not comply with semver expectations.")
@ -160,12 +160,18 @@ pub async fn main(config: &Config, args: &ServeArgs) {
let doc = add_localhost_server(ApiDoc::openapi(), args.port);
let doc = add_proxy_server(doc, config.swagger.server_url.clone());
let app = api_router(args, config)
let app = Router::new()
.merge(api_router(args, config))
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc))
.route("/playground", routing::get(playground::handler))
.route("/playground/*path", routing::get(playground::handler))
.fallback(fallback());
let app = if args.chat_model.is_some() {
app.route("/playground", routing::get(playground::handler))
.route("/playground/*path", routing::get(playground::handler))
} else {
app
};
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
info!("Listening at {}", address);
@ -177,15 +183,26 @@ pub async fn main(config: &Config, args: &ServeArgs) {
}
fn api_router(args: &ServeArgs, config: &Config) -> Router {
let (engine, prompt_template) = create_engine(&args.model, args);
let engine = Arc::new(engine);
let instruct_engine = if let Some(instruct_model) = &args.instruct_model {
Arc::new(create_engine(instruct_model, args).0)
} else {
engine.clone()
let completion_state = {
let (engine, prompt_template) = create_engine(&args.model, args);
let engine = Arc::new(engine);
let state = completions::CompletionState::new(engine.clone(), prompt_template, config);
Arc::new(state)
};
Router::new()
let chat_state = if let Some(chat_model) = &args.chat_model {
let (engine, prompt_template) = create_engine(chat_model, args);
let Some(prompt_template) = prompt_template else {
panic!("Chat model requires specifying prompt template");
};
let engine = Arc::new(engine);
let state = chat::ChatState::new(engine, prompt_template);
Some(Arc::new(state))
} else {
None
};
let router = Router::new()
.route("/v1/events", routing::post(events::log_event))
.route(
"/v1/health",
@ -193,22 +210,19 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
)
.route(
"/v1/completions",
routing::post(completions::completion).with_state(Arc::new(
completions::CompletionState::new(engine.clone(), prompt_template, config),
)),
)
.route(
"/v1beta/generate",
routing::post(generate::generate).with_state(Arc::new(generate::GenerateState::new(
instruct_engine.clone(),
))),
)
.route(
"/v1beta/generate_stream",
routing::post(generate::generate_stream).with_state(Arc::new(
generate::GenerateState::new(instruct_engine.clone()),
)),
routing::post(completions::completions).with_state(completion_state),
);
let router = if let Some(chat_state) = chat_state {
router.route(
"/v1beta/chat/completions",
routing::post(chat::completions).with_state(chat_state),
)
} else {
router
};
router
.layer(CorsLayer::permissive())
.layer(opentelemetry_tracing_layer())
}