refactor: cleanup chat api make it message oriented (#497)

* refactor: refactor into /chat/completions api

* Revert "feat: support request level stop words (#492)"

This reverts commit 0d6840e372.

* feat: adjust interface

* switch interface in tabby-playground

* move to chat/prompt, add unit test

* update interface
release-0.2
Meng Zhang 2023-10-02 08:39:15 -07:00 committed by GitHub
parent dfdd0373a6
commit f05dd3a2f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 347 additions and 203 deletions

24
Cargo.lock generated
View File

@ -1779,6 +1779,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "memo-map"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "374c335b2df19e62d4cb323103473cbc6510980253119180de862d89184f6a83"
[[package]] [[package]]
name = "memoffset" name = "memoffset"
version = "0.8.0" version = "0.8.0"
@ -1804,6 +1810,17 @@ dependencies = [
"unicase", "unicase",
] ]
[[package]]
name = "minijinja"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80084fa3099f58b7afab51e5f92e24c2c2c68dcad26e96ad104bd6011570461d"
dependencies = [
"memo-map",
"self_cell",
"serde",
]
[[package]] [[package]]
name = "minimal-lexical" name = "minimal-lexical"
version = "0.2.1" version = "0.2.1"
@ -2790,6 +2807,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "self_cell"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c309e515543e67811222dbc9e3dd7e1056279b782e1dacffe4242b718734fb6"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.171" version = "1.0.171"
@ -3078,6 +3101,7 @@ dependencies = [
"lazy_static", "lazy_static",
"llama-cpp-bindings", "llama-cpp-bindings",
"mime_guess", "mime_guess",
"minijinja",
"nvml-wrapper", "nvml-wrapper",
"opentelemetry", "opentelemetry",
"opentelemetry-otlp", "opentelemetry-otlp",

View File

@ -39,9 +39,6 @@ export function Chat({ id, initialMessages, className }: ChatProps) {
} }
} }
}) })
if (messages.length > 2) {
setMessages(messages.slice(messages.length - 2, messages.length))
}
return ( return (
<> <>
<div className={cn('pb-[200px] pt-4 md:pt-10', className)}> <div className={cn('pb-[200px] pt-4 md:pt-10', className)}>

View File

@ -1,5 +1,6 @@
import { type Message } from 'ai/react' import { type Message } from 'ai/react'
import { CohereStream, StreamingTextResponse } from 'ai' import { StreamingTextResponse } from 'ai'
import { TabbyStream } from '@/lib/tabby-stream'
import { useEffect } from 'react' import { useEffect } from 'react'
const serverUrl = const serverUrl =
@ -15,25 +16,17 @@ export function usePatchFetch() {
} }
const { messages } = JSON.parse(options!.body as string) const { messages } = JSON.parse(options!.body as string)
const res = await fetch(`${serverUrl}/v1beta/generate_stream`, { const res = await fetch(`${serverUrl}/v1beta/chat/completions`, {
...options, ...options,
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
}, },
body: JSON.stringify({
prompt: messagesToPrompt(messages)
})
}) })
const stream = CohereStream(res, undefined) const stream = TabbyStream(res, undefined)
return new StreamingTextResponse(stream) return new StreamingTextResponse(stream)
} }
}, []) }, [])
} }
function messagesToPrompt(messages: Message[]) {
const instruction = messages[messages.length - 1].content
const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n${instruction}\n\n### Response:`
return prompt
}

View File

@ -0,0 +1,71 @@
import {
type AIStreamCallbacksAndOptions,
createCallbacksTransformer,
createStreamDataTransformer
} from 'ai';
const utf8Decoder = new TextDecoder('utf-8');
async function processLines(
lines: string[],
controller: ReadableStreamDefaultController<string>,
) {
for (const line of lines) {
const { content } = JSON.parse(line);
controller.enqueue(content);
}
}
async function readAndProcessLines(
reader: ReadableStreamDefaultReader<Uint8Array>,
controller: ReadableStreamDefaultController<string>,
) {
let segment = '';
while (true) {
const { value: chunk, done } = await reader.read();
if (done) {
break;
}
segment += utf8Decoder.decode(chunk, { stream: true });
const linesArray = segment.split(/\r\n|\n|\r/g);
segment = linesArray.pop() || '';
await processLines(linesArray, controller);
}
if (segment) {
const linesArray = [segment];
await processLines(linesArray, controller);
}
controller.close();
}
function createParser(res: Response) {
const reader = res.body?.getReader();
return new ReadableStream<string>({
async start(controller): Promise<void> {
if (!reader) {
controller.close();
return;
}
await readAndProcessLines(reader, controller);
},
});
}
export function TabbyStream(
reader: Response,
callbacks?: AIStreamCallbacksAndOptions,
): ReadableStream {
return createParser(reader)
.pipeThrough(createCallbacksTransformer(callbacks))
.pipeThrough(
createStreamDataTransformer(callbacks?.experimental_streamData),
);
}

View File

@ -137,7 +137,7 @@ impl TextGeneration for CTranslate2Engine {
let decoding = self let decoding = self
.decoding_factory .decoding_factory
.create(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), &options.stop_words, options.static_stop_words); .create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);
let (sender, mut receiver) = channel::<String>(8); let (sender, mut receiver) = channel::<String>(8);
let context = InferenceContext::new(sender, decoding, cancel_for_inference); let context = InferenceContext::new(sender, decoding, cancel_for_inference);

View File

@ -58,11 +58,8 @@ impl FastChatEngine {
#[async_trait] #[async_trait]
impl TextGeneration for FastChatEngine { impl TextGeneration for FastChatEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let _stop_sequences: Vec<String> = options let _stop_sequences: Vec<String> =
.static_stop_words options.stop_words.iter().map(|x| x.to_string()).collect();
.iter()
.map(|x| x.to_string())
.collect();
let tokens: Vec<&str> = prompt.split("<MID>").collect(); let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request { let request = Request {

View File

@ -67,7 +67,7 @@ impl VertexAIEngine {
impl TextGeneration for VertexAIEngine { impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> = options let stop_sequences: Vec<String> = options
.static_stop_words .stop_words
.iter() .iter()
.map(|x| x.to_string()) .map(|x| x.to_string())
// vertex supports at most 5 stop sequence. // vertex supports at most 5 stop sequence.

View File

@ -10,7 +10,8 @@ namespace llama {
TextInferenceEngine::~TextInferenceEngine() {} TextInferenceEngine::~TextInferenceEngine() {}
namespace { namespace {
static size_t N_BATCH = 512; static size_t N_BATCH = 512; // # per batch inference.
static size_t N_CTX = 4096; // # max kv history.
template<class T> template<class T>
using owned = std::unique_ptr<T, std::function<void(T*)>>; using owned = std::unique_ptr<T, std::function<void(T*)>>;
@ -59,7 +60,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
return std::distance(logits, std::max_element(logits, logits + n_vocab)); return std::distance(logits, std::max_element(logits, logits + n_vocab));
} }
bool eval(llama_token* data, size_t size, bool reset) { void eval(llama_token* data, size_t size, bool reset) {
if (reset) { if (reset) {
n_past_ = 0; n_past_ = 0;
} }
@ -76,12 +77,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
auto* ctx = ctx_.get(); auto* ctx = ctx_.get();
llama_kv_cache_tokens_rm(ctx, n_past_, -1); llama_kv_cache_tokens_rm(ctx, n_past_, -1);
if (llama_decode(ctx, batch_)) { if (llama_decode(ctx, batch_)) {
fprintf(stderr, "%s : failed to eval\n", __func__); throw std::runtime_error("Failed to eval");
return false;
} }
n_past_ += size; n_past_ += size;
return true;
} }
size_t n_past_; size_t n_past_;
@ -127,7 +126,7 @@ std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
} }
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 2048; ctx_params.n_ctx = N_CTX;
ctx_params.n_batch = N_BATCH; ctx_params.n_batch = N_BATCH;
llama_context* ctx = llama_new_context_with_model(model, ctx_params); llama_context* ctx = llama_new_context_with_model(model, ctx_params);

View File

@ -18,7 +18,7 @@ mod ffi {
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>; fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]); fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]);
fn step(self: Pin<&mut TextInferenceEngine>) -> u32; fn step(self: Pin<&mut TextInferenceEngine>) -> Result<u32>;
fn end(self: Pin<&mut TextInferenceEngine>); fn end(self: Pin<&mut TextInferenceEngine>);
fn eos_token(&self) -> u32; fn eos_token(&self) -> u32;
@ -75,10 +75,12 @@ impl TextGeneration for LlamaEngine {
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length); let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
engine.as_mut().start(input_token_ids); engine.as_mut().start(input_token_ids);
let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words); let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
let mut n_remains = options.max_decoding_length ; let mut n_remains = options.max_decoding_length ;
while n_remains > 0 { while n_remains > 0 {
let next_token_id = engine.as_mut().step(); let Ok(next_token_id) = engine.as_mut().step() else {
panic!("Failed to eval");
};
if next_token_id == eos_token { if next_token_id == eos_token {
break; break;
} }

View File

@ -24,35 +24,16 @@ impl Default for DecodingFactory {
} }
impl DecodingFactory { impl DecodingFactory {
pub fn create( pub fn create_incremental_decoding(
&self, &self,
tokenizer: Arc<Tokenizer>, tokenizer: Arc<Tokenizer>,
input_token_ids: &[u32], input_token_ids: &[u32],
stop_words: &Vec<String>, stop_words: &'static Vec<&'static str>,
static_stop_words: &'static Vec<&'static str>,
) -> IncrementalDecoding { ) -> IncrementalDecoding {
IncrementalDecoding::new( IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
tokenizer,
vec![
self.get_static_re(static_stop_words),
self.get_re(stop_words),
]
.into_iter()
.flatten()
.collect(),
input_token_ids,
)
} }
fn get_re(&self, stop_words: &Vec<String>) -> Option<Regex> { fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
if !stop_words.is_empty() {
Some(create_stop_regex(stop_words))
} else {
None
}
}
fn get_static_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
if stop_words.is_empty() { if stop_words.is_empty() {
None None
} else { } else {
@ -67,8 +48,8 @@ impl DecodingFactory {
} }
} }
fn create_stop_regex<T: AsRef<str>>(stop_words: &[T]) -> Regex { fn create_stop_regex(stop_words: &[&str]) -> Regex {
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(x.as_ref())).collect(); let tokens: Vec<String> = stop_words.iter().map(|x| reverse(*x)).collect();
// (?m) enables multi-line matching mode. // (?m) enables multi-line matching mode.
// \A means absolute begins of string. // \A means absolute begins of string.
@ -78,7 +59,7 @@ fn create_stop_regex<T: AsRef<str>>(stop_words: &[T]) -> Regex {
pub struct IncrementalDecoding { pub struct IncrementalDecoding {
tokenizer: Arc<Tokenizer>, tokenizer: Arc<Tokenizer>,
stop_re: Vec<Regex>, stop_re: Option<Regex>,
token_ids: Vec<u32>, token_ids: Vec<u32>,
prefix_offset: usize, prefix_offset: usize,
@ -88,7 +69,7 @@ pub struct IncrementalDecoding {
} }
impl IncrementalDecoding { impl IncrementalDecoding {
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Vec<Regex>, input_token_ids: &[u32]) -> Self { pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
let text = tokenizer let text = tokenizer
.decode(input_token_ids, /* skip_special_token = */ true) .decode(input_token_ids, /* skip_special_token = */ true)
.expect("Cannot decode token from tokenizer."); .expect("Cannot decode token from tokenizer.");
@ -129,7 +110,8 @@ impl IncrementalDecoding {
if !new_text.is_empty() { if !new_text.is_empty() {
self.reversed_text = reverse(new_text) + &self.reversed_text; self.reversed_text = reverse(new_text) + &self.reversed_text;
for re in &self.stop_re {
if let Some(re) = &self.stop_re {
if re.find(&self.reversed_text).is_some() { if re.find(&self.reversed_text).is_some() {
return None; return None;
} }

View File

@ -16,10 +16,7 @@ pub struct TextGenerationOptions {
pub sampling_temperature: f32, pub sampling_temperature: f32,
#[builder(default = "&EMPTY_STOP_WORDS")] #[builder(default = "&EMPTY_STOP_WORDS")]
pub static_stop_words: &'static Vec<&'static str>, pub stop_words: &'static Vec<&'static str>,
#[builder(default = "vec![]")]
pub stop_words: Vec<String>,
} }
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![]; static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];

View File

@ -39,6 +39,7 @@ http-api-bindings = { path = "../http-api-bindings" }
futures = { workspace = true } futures = { workspace = true }
async-stream = { workspace = true } async-stream = { workspace = true }
axum-streams = { version = "0.9.1", features = ["json"] } axum-streams = { version = "0.9.1", features = ["json"] }
minijinja = { version = "1.0.8", features = ["loader"] }
[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies] [target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
llama-cpp-bindings = { path = "../llama-cpp-bindings" } llama-cpp-bindings = { path = "../llama-cpp-bindings" }

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,13 +1,13 @@
1:HL["/playground/_next/static/media/86fdec36ddd9097e-s.p.woff2","font",{"crossOrigin":"","type":"font/woff2"}] 1:HL["/playground/_next/static/media/86fdec36ddd9097e-s.p.woff2","font",{"crossOrigin":"","type":"font/woff2"}]
2:HL["/playground/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2","font",{"crossOrigin":"","type":"font/woff2"}] 2:HL["/playground/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2","font",{"crossOrigin":"","type":"font/woff2"}]
3:HL["/playground/_next/static/css/d091dc2da2a795e4.css","style"] 3:HL["/playground/_next/static/css/d091dc2da2a795e4.css","style"]
0:["f6rsO7djEUh4Fn3OO-Bie",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],"$L4",[[["$","link","0",{"rel":"stylesheet","href":"/playground/_next/static/css/d091dc2da2a795e4.css","precedence":"next"}]],"$L5"]]]] 0:["9a4m76mRTGOnagTYXSPKd",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],"$L4",[[["$","link","0",{"rel":"stylesheet","href":"/playground/_next/static/css/d091dc2da2a795e4.css","precedence":"next"}]],"$L5"]]]]
6:I{"id":5925,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","524:static/chunks/524-6309ecb76a77fdcf.js","185:static/chunks/app/layout-38d79c8bb16c51be.js"],"name":"Toaster","async":false} 6:I{"id":5925,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","524:static/chunks/524-6309ecb76a77fdcf.js","185:static/chunks/app/layout-38d79c8bb16c51be.js"],"name":"Toaster","async":false}
7:I{"id":78495,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","524:static/chunks/524-6309ecb76a77fdcf.js","185:static/chunks/app/layout-38d79c8bb16c51be.js"],"name":"Providers","async":false} 7:I{"id":78495,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","524:static/chunks/524-6309ecb76a77fdcf.js","185:static/chunks/app/layout-38d79c8bb16c51be.js"],"name":"Providers","async":false}
8:I{"id":78963,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","524:static/chunks/524-6309ecb76a77fdcf.js","185:static/chunks/app/layout-38d79c8bb16c51be.js"],"name":"Header","async":false} 8:I{"id":78963,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","524:static/chunks/524-6309ecb76a77fdcf.js","185:static/chunks/app/layout-38d79c8bb16c51be.js"],"name":"Header","async":false}
9:I{"id":81443,"chunks":["272:static/chunks/webpack-e23fff8c5b5084ca.js","971:static/chunks/fd9d1056-5dfc77aa37d8c76f.js","864:static/chunks/864-1669531662d5540a.js"],"name":"","async":false} 9:I{"id":81443,"chunks":["272:static/chunks/webpack-e23fff8c5b5084ca.js","971:static/chunks/fd9d1056-5dfc77aa37d8c76f.js","864:static/chunks/864-1669531662d5540a.js"],"name":"","async":false}
a:I{"id":18639,"chunks":["272:static/chunks/webpack-e23fff8c5b5084ca.js","971:static/chunks/fd9d1056-5dfc77aa37d8c76f.js","864:static/chunks/864-1669531662d5540a.js"],"name":"","async":false} a:I{"id":18639,"chunks":["272:static/chunks/webpack-e23fff8c5b5084ca.js","971:static/chunks/fd9d1056-5dfc77aa37d8c76f.js","864:static/chunks/864-1669531662d5540a.js"],"name":"","async":false}
c:I{"id":64074,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","978:static/chunks/978-ab68c4a2390585a1.js","524:static/chunks/524-6309ecb76a77fdcf.js","931:static/chunks/app/page-757d8cb1ec33d4cb.js"],"name":"Chat","async":false} c:I{"id":10413,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","978:static/chunks/978-342eae78521d80e5.js","524:static/chunks/524-6309ecb76a77fdcf.js","931:static/chunks/app/page-2ebc2d344df80bd2.js"],"name":"Chat","async":false}
5:[["$","meta","0",{"charSet":"utf-8"}],["$","title","1",{"children":"Tabby Playground"}],["$","meta","2",{"name":"description","content":"Tabby, an opensource, self-hosted AI coding assistant."}],["$","meta","3",{"name":"theme-color","media":"(prefers-color-scheme: light)","content":"white"}],["$","meta","4",{"name":"theme-color","media":"(prefers-color-scheme: dark)","content":"black"}],["$","meta","5",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","6",{"name":"next-size-adjust"}]] 5:[["$","meta","0",{"charSet":"utf-8"}],["$","title","1",{"children":"Tabby Playground"}],["$","meta","2",{"name":"description","content":"Tabby, an opensource, self-hosted AI coding assistant."}],["$","meta","3",{"name":"theme-color","media":"(prefers-color-scheme: light)","content":"white"}],["$","meta","4",{"name":"theme-color","media":"(prefers-color-scheme: dark)","content":"black"}],["$","meta","5",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","6",{"name":"next-size-adjust"}]]
4:[null,["$","html",null,{"lang":"en","suppressHydrationWarning":true,"children":[["$","head",null,{}],["$","body",null,{"className":"font-sans antialiased __variable_4e6684 __variable_3d950d","children":[["$","$L6",null,{}],["$","$L7",null,{"attribute":"class","defaultTheme":"system","enableSystem":true,"children":[["$","div",null,{"className":"flex flex-col min-h-screen","children":[["$","$L8",null,{}],["$","main",null,{"className":"flex flex-col flex-1 bg-muted/50","children":["$","$L9",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","template":["$","$La",null,{}],"templateStyles":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"childProp":{"current":["$Lb",["$","$Lc",null,{"id":"JBeXiJC"}],null],"segment":"__PAGE__"},"styles":[]}]}]]}],null]}]]}]]}],null] 4:[null,["$","html",null,{"lang":"en","suppressHydrationWarning":true,"children":[["$","head",null,{}],["$","body",null,{"className":"font-sans antialiased __variable_4e6684 __variable_3d950d","children":[["$","$L6",null,{}],["$","$L7",null,{"attribute":"class","defaultTheme":"system","enableSystem":true,"children":[["$","div",null,{"className":"flex flex-col min-h-screen","children":[["$","$L8",null,{}],["$","main",null,{"className":"flex flex-col flex-1 bg-muted/50","children":["$","$L9",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","template":["$","$La",null,{}],"templateStyles":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"childProp":{"current":["$Lb",["$","$Lc",null,{"id":"Z43ogQe"}],null],"segment":"__PAGE__"},"styles":[]}]}]]}],null]}]]}]]}],null]
b:null b:null

View File

@ -0,0 +1,96 @@
mod prompt;
use std::sync::Arc;
use async_stream::stream;
use axum::{
extract::State,
response::{IntoResponse, Response},
Json,
};
use axum_streams::StreamBodyAs;
use prompt::ChatPromptBuilder;
use serde::{Deserialize, Serialize};
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::instrument;
use utoipa::ToSchema;
pub struct ChatState {
engine: Arc<Box<dyn TextGeneration>>,
prompt_builder: ChatPromptBuilder,
}
impl ChatState {
pub fn new(engine: Arc<Box<dyn TextGeneration>>, prompt_template: String) -> Self {
Self {
engine,
prompt_builder: ChatPromptBuilder::new(prompt_template),
}
}
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"messages": [
Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()},
Message { role: "assistant".to_owned(), content: "It's a kind of optimization in compiler?".to_owned()},
Message { role: "user".to_owned(), content: "Could you share more details?".to_owned()},
]
}))]
pub struct ChatCompletionRequest {
messages: Vec<Message>,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Message {
role: String,
content: String,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct ChatCompletionChunk {
content: String,
}
#[utoipa::path(
post,
path = "/v1beta/chat/completions",
request_body = ChatCompletionRequest,
operation_id = "chat_completions",
tag = "v1beta",
responses(
(status = 200, description = "Success", body = ChatCompletionChunk, content_type = "application/jsonstream"),
(status = 405, description = "When chat model is not specified, the endpoint will returns 405 Method Not Allowed"),
)
)]
#[instrument(skip(state, request))]
pub async fn completions(
State(state): State<Arc<ChatState>>,
Json(request): Json<ChatCompletionRequest>,
) -> Response {
let (prompt, options) = parse_request(&state, request);
let s = stream! {
for await content in state.engine.generate_stream(&prompt, options).await {
yield ChatCompletionChunk { content }
}
};
StreamBodyAs::json_nl(s).into_response()
}
fn parse_request(
state: &Arc<ChatState>,
request: ChatCompletionRequest,
) -> (String, TextGenerationOptions) {
let mut builder = TextGenerationOptionsBuilder::default();
builder
.max_input_length(2048)
.max_decoding_length(1920)
.sampling_temperature(0.1);
(
state.prompt_builder.build(&request.messages),
builder.build().unwrap(),
)
}

View File

@ -0,0 +1,65 @@
use minijinja::{context, Environment};
use super::Message;
pub struct ChatPromptBuilder {
env: Environment<'static>,
}
impl ChatPromptBuilder {
pub fn new(prompt_template: String) -> Self {
let mut env = Environment::new();
env.add_function("raise_exception", |e: String| panic!("{}", e));
env.add_template_owned("prompt", prompt_template)
.expect("Failed to compile template");
Self { env }
}
pub fn build(&self, messages: &[Message]) -> String {
self.env
.get_template("prompt")
.unwrap()
.render(context!(
messages => messages
))
.expect("Failed to evaluate")
}
}
#[cfg(test)]
mod tests {
use super::*;
static PROMPT_TEMPLATE : &str = "<s>{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + '</s> ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
#[test]
fn test_it_works() {
let builder = ChatPromptBuilder::new(PROMPT_TEMPLATE.to_owned());
let messages = vec![
Message {
role: "user".to_owned(),
content: "What is tail recursion?".to_owned(),
},
Message {
role: "assistant".to_owned(),
content: "It's a kind of optimization in compiler?".to_owned(),
},
Message {
role: "user".to_owned(),
content: "Could you share more details?".to_owned(),
},
];
assert_eq!(builder.build(&messages), "<s>[INST] What is tail recursion? [/INST]It's a kind of optimization in compiler?</s> [INST] Could you share more details? [/INST]")
}
#[test]
#[should_panic]
fn test_it_panic() {
let builder = ChatPromptBuilder::new(PROMPT_TEMPLATE.to_owned());
let messages = vec![Message {
role: "system".to_owned(),
content: "system".to_owned(),
}];
builder.build(&messages);
}
}

View File

@ -71,7 +71,7 @@ pub struct CompletionResponse {
) )
)] )]
#[instrument(skip(state, request))] #[instrument(skip(state, request))]
pub async fn completion( pub async fn completions(
State(state): State<Arc<CompletionState>>, State(state): State<Arc<CompletionState>>,
Json(request): Json<CompletionRequest>, Json(request): Json<CompletionRequest>,
) -> Result<Json<CompletionResponse>, StatusCode> { ) -> Result<Json<CompletionResponse>, StatusCode> {
@ -80,7 +80,7 @@ pub async fn completion(
.max_input_length(1024 + 512) .max_input_length(1024 + 512)
.max_decoding_length(128) .max_decoding_length(128)
.sampling_temperature(0.1) .sampling_temperature(0.1)
.static_stop_words(get_stop_words(&language)) .stop_words(get_stop_words(&language))
.build() .build()
.unwrap(); .unwrap();

View File

@ -1,94 +0,0 @@
use std::sync::Arc;
use async_stream::stream;
use axum::{extract::State, response::IntoResponse, Json};
use axum_streams::StreamBodyAs;
use serde::{Deserialize, Serialize};
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::instrument;
use utoipa::ToSchema;
pub struct GenerateState {
engine: Arc<Box<dyn TextGeneration>>,
}
impl GenerateState {
pub fn new(engine: Arc<Box<dyn TextGeneration>>) -> Self {
Self { engine }
}
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"prompt": "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n",
}))]
pub struct GenerateRequest {
prompt: String,
stop_words: Option<Vec<String>>,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct GenerateResponse {
text: String,
}
#[utoipa::path(
post,
path = "/v1beta/generate",
request_body = GenerateRequest,
operation_id = "generate",
tag = "v1beta",
responses(
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/json"),
)
)]
#[instrument(skip(state, request))]
pub async fn generate(
State(state): State<Arc<GenerateState>>,
Json(request): Json<GenerateRequest>,
) -> impl IntoResponse {
let (prompt, options) = parse_request(request);
Json(GenerateResponse {
text: state.engine.generate(&prompt, options).await,
})
}
#[utoipa::path(
post,
path = "/v1beta/generate_stream",
request_body = GenerateRequest,
operation_id = "generate_stream",
tag = "v1beta",
responses(
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/jsonstream"),
)
)]
#[instrument(skip(state, request))]
pub async fn generate_stream(
State(state): State<Arc<GenerateState>>,
Json(request): Json<GenerateRequest>,
) -> impl IntoResponse {
let (prompt, options) = parse_request(request);
let s = stream! {
for await text in state.engine.generate_stream(&prompt, options).await {
yield GenerateResponse { text }
}
};
StreamBodyAs::json_nl(s)
}
fn parse_request(request: GenerateRequest) -> (String, TextGenerationOptions) {
let mut builder = TextGenerationOptionsBuilder::default();
builder
.max_input_length(1024)
.max_decoding_length(968)
.sampling_temperature(0.1);
if let Some(stop_words) = request.stop_words {
builder.stop_words(stop_words);
};
(request.prompt, builder.build().unwrap())
}

View File

@ -11,7 +11,7 @@ use utoipa::ToSchema;
pub struct HealthState { pub struct HealthState {
model: String, model: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
instruct_model: Option<String>, chat_model: Option<String>,
device: String, device: String,
compute_type: String, compute_type: String,
arch: String, arch: String,
@ -32,7 +32,7 @@ impl HealthState {
Self { Self {
model: args.model.clone(), model: args.model.clone(),
instruct_model: args.instruct_model.clone(), chat_model: args.chat_model.clone(),
device: args.device.to_string(), device: args.device.to_string(),
compute_type: args.compute_type.to_string(), compute_type: args.compute_type.to_string(),
arch: ARCH.to_string(), arch: ARCH.to_string(),

View File

@ -1,7 +1,7 @@
mod chat;
mod completions; mod completions;
mod engine; mod engine;
mod events; mod events;
mod generate;
mod health; mod health;
mod playground; mod playground;
@ -42,15 +42,16 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
servers( servers(
(url = "https://playground.app.tabbyml.com", description = "Playground server"), (url = "https://playground.app.tabbyml.com", description = "Playground server"),
), ),
paths(events::log_event, completions::completion, generate::generate, generate::generate_stream, health::health), paths(events::log_event, completions::completions, chat::completions, health::health),
components(schemas( components(schemas(
events::LogEventRequest, events::LogEventRequest,
completions::CompletionRequest, completions::CompletionRequest,
completions::CompletionResponse, completions::CompletionResponse,
completions::Segments, completions::Segments,
completions::Choice, completions::Choice,
generate::GenerateRequest, chat::ChatCompletionRequest,
generate::GenerateResponse, chat::Message,
chat::ChatCompletionChunk,
health::HealthState, health::HealthState,
health::Version, health::Version,
)) ))
@ -105,14 +106,13 @@ pub enum ComputeType {
#[derive(Args)] #[derive(Args)]
pub struct ServeArgs { pub struct ServeArgs {
/// Model id for `/completion` API endpoint. /// Model id for `/completions` API endpoint.
#[clap(long)] #[clap(long)]
model: String, model: String,
/// Model id for `/generate` and `/generate_stream` API endpoints. /// Model id for `/chat/completions` API endpoints.
/// If not set, `model` will be loaded for the purpose.
#[clap(long)] #[clap(long)]
instruct_model: Option<String>, chat_model: Option<String>,
#[clap(long, default_value_t = 8080)] #[clap(long, default_value_t = 8080)]
port: u16, port: u16,
@ -149,8 +149,8 @@ pub async fn main(config: &Config, args: &ServeArgs) {
if args.device != Device::ExperimentalHttp { if args.device != Device::ExperimentalHttp {
download_model(&args.model, &args.device).await; download_model(&args.model, &args.device).await;
if let Some(instruct_model) = &args.instruct_model { if let Some(chat_model) = &args.chat_model {
download_model(instruct_model, &args.device).await; download_model(chat_model, &args.device).await;
} }
} else { } else {
warn!("HTTP device is unstable and does not comply with semver expectations.") warn!("HTTP device is unstable and does not comply with semver expectations.")
@ -160,12 +160,18 @@ pub async fn main(config: &Config, args: &ServeArgs) {
let doc = add_localhost_server(ApiDoc::openapi(), args.port); let doc = add_localhost_server(ApiDoc::openapi(), args.port);
let doc = add_proxy_server(doc, config.swagger.server_url.clone()); let doc = add_proxy_server(doc, config.swagger.server_url.clone());
let app = api_router(args, config) let app = Router::new()
.merge(api_router(args, config))
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc)) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc))
.route("/playground", routing::get(playground::handler))
.route("/playground/*path", routing::get(playground::handler))
.fallback(fallback()); .fallback(fallback());
let app = if args.chat_model.is_some() {
app.route("/playground", routing::get(playground::handler))
.route("/playground/*path", routing::get(playground::handler))
} else {
app
};
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port)); let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
info!("Listening at {}", address); info!("Listening at {}", address);
@ -177,15 +183,26 @@ pub async fn main(config: &Config, args: &ServeArgs) {
} }
fn api_router(args: &ServeArgs, config: &Config) -> Router { fn api_router(args: &ServeArgs, config: &Config) -> Router {
let completion_state = {
let (engine, prompt_template) = create_engine(&args.model, args); let (engine, prompt_template) = create_engine(&args.model, args);
let engine = Arc::new(engine); let engine = Arc::new(engine);
let instruct_engine = if let Some(instruct_model) = &args.instruct_model { let state = completions::CompletionState::new(engine.clone(), prompt_template, config);
Arc::new(create_engine(instruct_model, args).0) Arc::new(state)
} else {
engine.clone()
}; };
Router::new() let chat_state = if let Some(chat_model) = &args.chat_model {
let (engine, prompt_template) = create_engine(chat_model, args);
let Some(prompt_template) = prompt_template else {
panic!("Chat model requires specifying prompt template");
};
let engine = Arc::new(engine);
let state = chat::ChatState::new(engine, prompt_template);
Some(Arc::new(state))
} else {
None
};
let router = Router::new()
.route("/v1/events", routing::post(events::log_event)) .route("/v1/events", routing::post(events::log_event))
.route( .route(
"/v1/health", "/v1/health",
@ -193,22 +210,19 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
) )
.route( .route(
"/v1/completions", "/v1/completions",
routing::post(completions::completion).with_state(Arc::new( routing::post(completions::completions).with_state(completion_state),
completions::CompletionState::new(engine.clone(), prompt_template, config), );
)),
) let router = if let Some(chat_state) = chat_state {
.route( router.route(
"/v1beta/generate", "/v1beta/chat/completions",
routing::post(generate::generate).with_state(Arc::new(generate::GenerateState::new( routing::post(chat::completions).with_state(chat_state),
instruct_engine.clone(),
))),
)
.route(
"/v1beta/generate_stream",
routing::post(generate::generate_stream).with_state(Arc::new(
generate::GenerateState::new(instruct_engine.clone()),
)),
) )
} else {
router
};
router
.layer(CorsLayer::permissive()) .layer(CorsLayer::permissive())
.layer(opentelemetry_tracing_layer()) .layer(opentelemetry_tracing_layer())
} }