refactor: cleanup chat api make it message oriented (#497)
* refactor: refactor into /chat/completions api
* Revert "feat: support request level stop words (#492)"
This reverts commit 0d6840e372.
* feat: adjust interface
* switch interface in tabby-playground
* move to chat/prompt, add unit test
* update interface
release-0.2
parent
dfdd0373a6
commit
f05dd3a2f6
|
|
@ -1779,6 +1779,12 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memo-map"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "374c335b2df19e62d4cb323103473cbc6510980253119180de862d89184f6a83"
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.8.0"
|
||||
|
|
@ -1804,6 +1810,17 @@ dependencies = [
|
|||
"unicase",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "minijinja"
|
||||
version = "1.0.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "80084fa3099f58b7afab51e5f92e24c2c2c68dcad26e96ad104bd6011570461d"
|
||||
dependencies = [
|
||||
"memo-map",
|
||||
"self_cell",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "minimal-lexical"
|
||||
version = "0.2.1"
|
||||
|
|
@ -2790,6 +2807,12 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "self_cell"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c309e515543e67811222dbc9e3dd7e1056279b782e1dacffe4242b718734fb6"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.171"
|
||||
|
|
@ -3078,6 +3101,7 @@ dependencies = [
|
|||
"lazy_static",
|
||||
"llama-cpp-bindings",
|
||||
"mime_guess",
|
||||
"minijinja",
|
||||
"nvml-wrapper",
|
||||
"opentelemetry",
|
||||
"opentelemetry-otlp",
|
||||
|
|
|
|||
|
|
@ -39,9 +39,6 @@ export function Chat({ id, initialMessages, className }: ChatProps) {
|
|||
}
|
||||
}
|
||||
})
|
||||
if (messages.length > 2) {
|
||||
setMessages(messages.slice(messages.length - 2, messages.length))
|
||||
}
|
||||
return (
|
||||
<>
|
||||
<div className={cn('pb-[200px] pt-4 md:pt-10', className)}>
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import { type Message } from 'ai/react'
|
||||
import { CohereStream, StreamingTextResponse } from 'ai'
|
||||
import { StreamingTextResponse } from 'ai'
|
||||
import { TabbyStream } from '@/lib/tabby-stream'
|
||||
import { useEffect } from 'react'
|
||||
|
||||
const serverUrl =
|
||||
|
|
@ -15,25 +16,17 @@ export function usePatchFetch() {
|
|||
}
|
||||
|
||||
const { messages } = JSON.parse(options!.body as string)
|
||||
const res = await fetch(`${serverUrl}/v1beta/generate_stream`, {
|
||||
const res = await fetch(`${serverUrl}/v1beta/chat/completions`, {
|
||||
...options,
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
prompt: messagesToPrompt(messages)
|
||||
})
|
||||
})
|
||||
|
||||
const stream = CohereStream(res, undefined)
|
||||
const stream = TabbyStream(res, undefined)
|
||||
return new StreamingTextResponse(stream)
|
||||
}
|
||||
}, [])
|
||||
}
|
||||
|
||||
function messagesToPrompt(messages: Message[]) {
|
||||
const instruction = messages[messages.length - 1].content
|
||||
const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n${instruction}\n\n### Response:`
|
||||
return prompt
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,71 @@
|
|||
import {
|
||||
type AIStreamCallbacksAndOptions,
|
||||
createCallbacksTransformer,
|
||||
createStreamDataTransformer
|
||||
} from 'ai';
|
||||
|
||||
const utf8Decoder = new TextDecoder('utf-8');
|
||||
|
||||
async function processLines(
|
||||
lines: string[],
|
||||
controller: ReadableStreamDefaultController<string>,
|
||||
) {
|
||||
for (const line of lines) {
|
||||
const { content } = JSON.parse(line);
|
||||
controller.enqueue(content);
|
||||
}
|
||||
}
|
||||
|
||||
async function readAndProcessLines(
|
||||
reader: ReadableStreamDefaultReader<Uint8Array>,
|
||||
controller: ReadableStreamDefaultController<string>,
|
||||
) {
|
||||
let segment = '';
|
||||
|
||||
while (true) {
|
||||
const { value: chunk, done } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
segment += utf8Decoder.decode(chunk, { stream: true });
|
||||
|
||||
const linesArray = segment.split(/\r\n|\n|\r/g);
|
||||
segment = linesArray.pop() || '';
|
||||
|
||||
await processLines(linesArray, controller);
|
||||
}
|
||||
|
||||
if (segment) {
|
||||
const linesArray = [segment];
|
||||
await processLines(linesArray, controller);
|
||||
}
|
||||
|
||||
controller.close();
|
||||
}
|
||||
|
||||
function createParser(res: Response) {
|
||||
const reader = res.body?.getReader();
|
||||
|
||||
return new ReadableStream<string>({
|
||||
async start(controller): Promise<void> {
|
||||
if (!reader) {
|
||||
controller.close();
|
||||
return;
|
||||
}
|
||||
|
||||
await readAndProcessLines(reader, controller);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function TabbyStream(
|
||||
reader: Response,
|
||||
callbacks?: AIStreamCallbacksAndOptions,
|
||||
): ReadableStream {
|
||||
return createParser(reader)
|
||||
.pipeThrough(createCallbacksTransformer(callbacks))
|
||||
.pipeThrough(
|
||||
createStreamDataTransformer(callbacks?.experimental_streamData),
|
||||
);
|
||||
}
|
||||
|
|
@ -137,7 +137,7 @@ impl TextGeneration for CTranslate2Engine {
|
|||
|
||||
let decoding = self
|
||||
.decoding_factory
|
||||
.create(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), &options.stop_words, options.static_stop_words);
|
||||
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);
|
||||
|
||||
let (sender, mut receiver) = channel::<String>(8);
|
||||
let context = InferenceContext::new(sender, decoding, cancel_for_inference);
|
||||
|
|
|
|||
|
|
@ -58,11 +58,8 @@ impl FastChatEngine {
|
|||
#[async_trait]
|
||||
impl TextGeneration for FastChatEngine {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let _stop_sequences: Vec<String> = options
|
||||
.static_stop_words
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect();
|
||||
let _stop_sequences: Vec<String> =
|
||||
options.stop_words.iter().map(|x| x.to_string()).collect();
|
||||
|
||||
let tokens: Vec<&str> = prompt.split("<MID>").collect();
|
||||
let request = Request {
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ impl VertexAIEngine {
|
|||
impl TextGeneration for VertexAIEngine {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let stop_sequences: Vec<String> = options
|
||||
.static_stop_words
|
||||
.stop_words
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
// vertex supports at most 5 stop sequence.
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ namespace llama {
|
|||
TextInferenceEngine::~TextInferenceEngine() {}
|
||||
|
||||
namespace {
|
||||
static size_t N_BATCH = 512;
|
||||
static size_t N_BATCH = 512; // # per batch inference.
|
||||
static size_t N_CTX = 4096; // # max kv history.
|
||||
|
||||
template<class T>
|
||||
using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
||||
|
|
@ -59,7 +60,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
return std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||
}
|
||||
|
||||
bool eval(llama_token* data, size_t size, bool reset) {
|
||||
void eval(llama_token* data, size_t size, bool reset) {
|
||||
if (reset) {
|
||||
n_past_ = 0;
|
||||
}
|
||||
|
|
@ -76,12 +77,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
auto* ctx = ctx_.get();
|
||||
llama_kv_cache_tokens_rm(ctx, n_past_, -1);
|
||||
if (llama_decode(ctx, batch_)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
throw std::runtime_error("Failed to eval");
|
||||
}
|
||||
|
||||
n_past_ += size;
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t n_past_;
|
||||
|
|
@ -127,7 +126,7 @@ std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
|||
}
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_ctx = 2048;
|
||||
ctx_params.n_ctx = N_CTX;
|
||||
ctx_params.n_batch = N_BATCH;
|
||||
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ mod ffi {
|
|||
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
|
||||
|
||||
fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]);
|
||||
fn step(self: Pin<&mut TextInferenceEngine>) -> u32;
|
||||
fn step(self: Pin<&mut TextInferenceEngine>) -> Result<u32>;
|
||||
fn end(self: Pin<&mut TextInferenceEngine>);
|
||||
|
||||
fn eos_token(&self) -> u32;
|
||||
|
|
@ -75,10 +75,12 @@ impl TextGeneration for LlamaEngine {
|
|||
|
||||
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
||||
engine.as_mut().start(input_token_ids);
|
||||
let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words);
|
||||
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
|
||||
let mut n_remains = options.max_decoding_length ;
|
||||
while n_remains > 0 {
|
||||
let next_token_id = engine.as_mut().step();
|
||||
let Ok(next_token_id) = engine.as_mut().step() else {
|
||||
panic!("Failed to eval");
|
||||
};
|
||||
if next_token_id == eos_token {
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,35 +24,16 @@ impl Default for DecodingFactory {
|
|||
}
|
||||
|
||||
impl DecodingFactory {
|
||||
pub fn create(
|
||||
pub fn create_incremental_decoding(
|
||||
&self,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
input_token_ids: &[u32],
|
||||
stop_words: &Vec<String>,
|
||||
static_stop_words: &'static Vec<&'static str>,
|
||||
stop_words: &'static Vec<&'static str>,
|
||||
) -> IncrementalDecoding {
|
||||
IncrementalDecoding::new(
|
||||
tokenizer,
|
||||
vec![
|
||||
self.get_static_re(static_stop_words),
|
||||
self.get_re(stop_words),
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
input_token_ids,
|
||||
)
|
||||
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
|
||||
}
|
||||
|
||||
fn get_re(&self, stop_words: &Vec<String>) -> Option<Regex> {
|
||||
if !stop_words.is_empty() {
|
||||
Some(create_stop_regex(stop_words))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_static_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
|
||||
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
|
||||
if stop_words.is_empty() {
|
||||
None
|
||||
} else {
|
||||
|
|
@ -67,8 +48,8 @@ impl DecodingFactory {
|
|||
}
|
||||
}
|
||||
|
||||
fn create_stop_regex<T: AsRef<str>>(stop_words: &[T]) -> Regex {
|
||||
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(x.as_ref())).collect();
|
||||
fn create_stop_regex(stop_words: &[&str]) -> Regex {
|
||||
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(*x)).collect();
|
||||
|
||||
// (?m) enables multi-line matching mode.
|
||||
// \A means absolute begins of string.
|
||||
|
|
@ -78,7 +59,7 @@ fn create_stop_regex<T: AsRef<str>>(stop_words: &[T]) -> Regex {
|
|||
|
||||
pub struct IncrementalDecoding {
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
stop_re: Vec<Regex>,
|
||||
stop_re: Option<Regex>,
|
||||
|
||||
token_ids: Vec<u32>,
|
||||
prefix_offset: usize,
|
||||
|
|
@ -88,7 +69,7 @@ pub struct IncrementalDecoding {
|
|||
}
|
||||
|
||||
impl IncrementalDecoding {
|
||||
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Vec<Regex>, input_token_ids: &[u32]) -> Self {
|
||||
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
|
||||
let text = tokenizer
|
||||
.decode(input_token_ids, /* skip_special_token = */ true)
|
||||
.expect("Cannot decode token from tokenizer.");
|
||||
|
|
@ -129,7 +110,8 @@ impl IncrementalDecoding {
|
|||
|
||||
if !new_text.is_empty() {
|
||||
self.reversed_text = reverse(new_text) + &self.reversed_text;
|
||||
for re in &self.stop_re {
|
||||
|
||||
if let Some(re) = &self.stop_re {
|
||||
if re.find(&self.reversed_text).is_some() {
|
||||
return None;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,10 +16,7 @@ pub struct TextGenerationOptions {
|
|||
pub sampling_temperature: f32,
|
||||
|
||||
#[builder(default = "&EMPTY_STOP_WORDS")]
|
||||
pub static_stop_words: &'static Vec<&'static str>,
|
||||
|
||||
#[builder(default = "vec![]")]
|
||||
pub stop_words: Vec<String>,
|
||||
pub stop_words: &'static Vec<&'static str>,
|
||||
}
|
||||
|
||||
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ http-api-bindings = { path = "../http-api-bindings" }
|
|||
futures = { workspace = true }
|
||||
async-stream = { workspace = true }
|
||||
axum-streams = { version = "0.9.1", features = ["json"] }
|
||||
minijinja = { version = "1.0.8", features = ["loader"] }
|
||||
|
||||
[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
|
||||
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -1,13 +1,13 @@
|
|||
1:HL["/playground/_next/static/media/86fdec36ddd9097e-s.p.woff2","font",{"crossOrigin":"","type":"font/woff2"}]
|
||||
2:HL["/playground/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2","font",{"crossOrigin":"","type":"font/woff2"}]
|
||||
3:HL["/playground/_next/static/css/d091dc2da2a795e4.css","style"]
|
||||
0:["f6rsO7djEUh4Fn3OO-Bie",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],"$L4",[[["$","link","0",{"rel":"stylesheet","href":"/playground/_next/static/css/d091dc2da2a795e4.css","precedence":"next"}]],"$L5"]]]]
|
||||
0:["9a4m76mRTGOnagTYXSPKd",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],"$L4",[[["$","link","0",{"rel":"stylesheet","href":"/playground/_next/static/css/d091dc2da2a795e4.css","precedence":"next"}]],"$L5"]]]]
|
||||
6:I{"id":5925,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","524:static/chunks/524-6309ecb76a77fdcf.js","185:static/chunks/app/layout-38d79c8bb16c51be.js"],"name":"Toaster","async":false}
|
||||
7:I{"id":78495,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","524:static/chunks/524-6309ecb76a77fdcf.js","185:static/chunks/app/layout-38d79c8bb16c51be.js"],"name":"Providers","async":false}
|
||||
8:I{"id":78963,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","524:static/chunks/524-6309ecb76a77fdcf.js","185:static/chunks/app/layout-38d79c8bb16c51be.js"],"name":"Header","async":false}
|
||||
9:I{"id":81443,"chunks":["272:static/chunks/webpack-e23fff8c5b5084ca.js","971:static/chunks/fd9d1056-5dfc77aa37d8c76f.js","864:static/chunks/864-1669531662d5540a.js"],"name":"","async":false}
|
||||
a:I{"id":18639,"chunks":["272:static/chunks/webpack-e23fff8c5b5084ca.js","971:static/chunks/fd9d1056-5dfc77aa37d8c76f.js","864:static/chunks/864-1669531662d5540a.js"],"name":"","async":false}
|
||||
c:I{"id":64074,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","978:static/chunks/978-ab68c4a2390585a1.js","524:static/chunks/524-6309ecb76a77fdcf.js","931:static/chunks/app/page-757d8cb1ec33d4cb.js"],"name":"Chat","async":false}
|
||||
c:I{"id":10413,"chunks":["346:static/chunks/346-c4227fa5fd95e485.js","978:static/chunks/978-342eae78521d80e5.js","524:static/chunks/524-6309ecb76a77fdcf.js","931:static/chunks/app/page-2ebc2d344df80bd2.js"],"name":"Chat","async":false}
|
||||
5:[["$","meta","0",{"charSet":"utf-8"}],["$","title","1",{"children":"Tabby Playground"}],["$","meta","2",{"name":"description","content":"Tabby, an opensource, self-hosted AI coding assistant."}],["$","meta","3",{"name":"theme-color","media":"(prefers-color-scheme: light)","content":"white"}],["$","meta","4",{"name":"theme-color","media":"(prefers-color-scheme: dark)","content":"black"}],["$","meta","5",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","6",{"name":"next-size-adjust"}]]
|
||||
4:[null,["$","html",null,{"lang":"en","suppressHydrationWarning":true,"children":[["$","head",null,{}],["$","body",null,{"className":"font-sans antialiased __variable_4e6684 __variable_3d950d","children":[["$","$L6",null,{}],["$","$L7",null,{"attribute":"class","defaultTheme":"system","enableSystem":true,"children":[["$","div",null,{"className":"flex flex-col min-h-screen","children":[["$","$L8",null,{}],["$","main",null,{"className":"flex flex-col flex-1 bg-muted/50","children":["$","$L9",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","template":["$","$La",null,{}],"templateStyles":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"childProp":{"current":["$Lb",["$","$Lc",null,{"id":"JBeXiJC"}],null],"segment":"__PAGE__"},"styles":[]}]}]]}],null]}]]}]]}],null]
|
||||
4:[null,["$","html",null,{"lang":"en","suppressHydrationWarning":true,"children":[["$","head",null,{}],["$","body",null,{"className":"font-sans antialiased __variable_4e6684 __variable_3d950d","children":[["$","$L6",null,{}],["$","$L7",null,{"attribute":"class","defaultTheme":"system","enableSystem":true,"children":[["$","div",null,{"className":"flex flex-col min-h-screen","children":[["$","$L8",null,{}],["$","main",null,{"className":"flex flex-col flex-1 bg-muted/50","children":["$","$L9",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","template":["$","$La",null,{}],"templateStyles":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"childProp":{"current":["$Lb",["$","$Lc",null,{"id":"Z43ogQe"}],null],"segment":"__PAGE__"},"styles":[]}]}]]}],null]}]]}]]}],null]
|
||||
b:null
|
||||
|
|
|
|||
|
|
@ -0,0 +1,96 @@
|
|||
mod prompt;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_stream::stream;
|
||||
use axum::{
|
||||
extract::State,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use axum_streams::StreamBodyAs;
|
||||
use prompt::ChatPromptBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||
use tracing::instrument;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
pub struct ChatState {
|
||||
engine: Arc<Box<dyn TextGeneration>>,
|
||||
prompt_builder: ChatPromptBuilder,
|
||||
}
|
||||
|
||||
impl ChatState {
|
||||
pub fn new(engine: Arc<Box<dyn TextGeneration>>, prompt_template: String) -> Self {
|
||||
Self {
|
||||
engine,
|
||||
prompt_builder: ChatPromptBuilder::new(prompt_template),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
#[schema(example=json!({
|
||||
"messages": [
|
||||
Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()},
|
||||
Message { role: "assistant".to_owned(), content: "It's a kind of optimization in compiler?".to_owned()},
|
||||
Message { role: "user".to_owned(), content: "Could you share more details?".to_owned()},
|
||||
]
|
||||
}))]
|
||||
pub struct ChatCompletionRequest {
|
||||
messages: Vec<Message>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct ChatCompletionChunk {
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/v1beta/chat/completions",
|
||||
request_body = ChatCompletionRequest,
|
||||
operation_id = "chat_completions",
|
||||
tag = "v1beta",
|
||||
responses(
|
||||
(status = 200, description = "Success", body = ChatCompletionChunk, content_type = "application/jsonstream"),
|
||||
(status = 405, description = "When chat model is not specified, the endpoint will returns 405 Method Not Allowed"),
|
||||
)
|
||||
)]
|
||||
#[instrument(skip(state, request))]
|
||||
pub async fn completions(
|
||||
State(state): State<Arc<ChatState>>,
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
) -> Response {
|
||||
let (prompt, options) = parse_request(&state, request);
|
||||
let s = stream! {
|
||||
for await content in state.engine.generate_stream(&prompt, options).await {
|
||||
yield ChatCompletionChunk { content }
|
||||
}
|
||||
};
|
||||
|
||||
StreamBodyAs::json_nl(s).into_response()
|
||||
}
|
||||
|
||||
fn parse_request(
|
||||
state: &Arc<ChatState>,
|
||||
request: ChatCompletionRequest,
|
||||
) -> (String, TextGenerationOptions) {
|
||||
let mut builder = TextGenerationOptionsBuilder::default();
|
||||
|
||||
builder
|
||||
.max_input_length(2048)
|
||||
.max_decoding_length(1920)
|
||||
.sampling_temperature(0.1);
|
||||
|
||||
(
|
||||
state.prompt_builder.build(&request.messages),
|
||||
builder.build().unwrap(),
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
use minijinja::{context, Environment};
|
||||
|
||||
use super::Message;
|
||||
|
||||
pub struct ChatPromptBuilder {
|
||||
env: Environment<'static>,
|
||||
}
|
||||
|
||||
impl ChatPromptBuilder {
|
||||
pub fn new(prompt_template: String) -> Self {
|
||||
let mut env = Environment::new();
|
||||
env.add_function("raise_exception", |e: String| panic!("{}", e));
|
||||
env.add_template_owned("prompt", prompt_template)
|
||||
.expect("Failed to compile template");
|
||||
|
||||
Self { env }
|
||||
}
|
||||
|
||||
pub fn build(&self, messages: &[Message]) -> String {
|
||||
self.env
|
||||
.get_template("prompt")
|
||||
.unwrap()
|
||||
.render(context!(
|
||||
messages => messages
|
||||
))
|
||||
.expect("Failed to evaluate")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
static PROMPT_TEMPLATE : &str = "<s>{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + '</s> ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}";
|
||||
|
||||
#[test]
|
||||
fn test_it_works() {
|
||||
let builder = ChatPromptBuilder::new(PROMPT_TEMPLATE.to_owned());
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "user".to_owned(),
|
||||
content: "What is tail recursion?".to_owned(),
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_owned(),
|
||||
content: "It's a kind of optimization in compiler?".to_owned(),
|
||||
},
|
||||
Message {
|
||||
role: "user".to_owned(),
|
||||
content: "Could you share more details?".to_owned(),
|
||||
},
|
||||
];
|
||||
assert_eq!(builder.build(&messages), "<s>[INST] What is tail recursion? [/INST]It's a kind of optimization in compiler?</s> [INST] Could you share more details? [/INST]")
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_it_panic() {
|
||||
let builder = ChatPromptBuilder::new(PROMPT_TEMPLATE.to_owned());
|
||||
let messages = vec![Message {
|
||||
role: "system".to_owned(),
|
||||
content: "system".to_owned(),
|
||||
}];
|
||||
builder.build(&messages);
|
||||
}
|
||||
}
|
||||
|
|
@ -71,7 +71,7 @@ pub struct CompletionResponse {
|
|||
)
|
||||
)]
|
||||
#[instrument(skip(state, request))]
|
||||
pub async fn completion(
|
||||
pub async fn completions(
|
||||
State(state): State<Arc<CompletionState>>,
|
||||
Json(request): Json<CompletionRequest>,
|
||||
) -> Result<Json<CompletionResponse>, StatusCode> {
|
||||
|
|
@ -80,7 +80,7 @@ pub async fn completion(
|
|||
.max_input_length(1024 + 512)
|
||||
.max_decoding_length(128)
|
||||
.sampling_temperature(0.1)
|
||||
.static_stop_words(get_stop_words(&language))
|
||||
.stop_words(get_stop_words(&language))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
|
|
|
|||
|
|
@ -1,94 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use async_stream::stream;
|
||||
use axum::{extract::State, response::IntoResponse, Json};
|
||||
use axum_streams::StreamBodyAs;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||
use tracing::instrument;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
pub struct GenerateState {
|
||||
engine: Arc<Box<dyn TextGeneration>>,
|
||||
}
|
||||
|
||||
impl GenerateState {
|
||||
pub fn new(engine: Arc<Box<dyn TextGeneration>>) -> Self {
|
||||
Self { engine }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
#[schema(example=json!({
|
||||
"prompt": "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n",
|
||||
}))]
|
||||
pub struct GenerateRequest {
|
||||
prompt: String,
|
||||
stop_words: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct GenerateResponse {
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/v1beta/generate",
|
||||
request_body = GenerateRequest,
|
||||
operation_id = "generate",
|
||||
tag = "v1beta",
|
||||
responses(
|
||||
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/json"),
|
||||
)
|
||||
)]
|
||||
#[instrument(skip(state, request))]
|
||||
pub async fn generate(
|
||||
State(state): State<Arc<GenerateState>>,
|
||||
Json(request): Json<GenerateRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let (prompt, options) = parse_request(request);
|
||||
Json(GenerateResponse {
|
||||
text: state.engine.generate(&prompt, options).await,
|
||||
})
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/v1beta/generate_stream",
|
||||
request_body = GenerateRequest,
|
||||
operation_id = "generate_stream",
|
||||
tag = "v1beta",
|
||||
responses(
|
||||
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/jsonstream"),
|
||||
)
|
||||
)]
|
||||
#[instrument(skip(state, request))]
|
||||
pub async fn generate_stream(
|
||||
State(state): State<Arc<GenerateState>>,
|
||||
Json(request): Json<GenerateRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let (prompt, options) = parse_request(request);
|
||||
let s = stream! {
|
||||
for await text in state.engine.generate_stream(&prompt, options).await {
|
||||
yield GenerateResponse { text }
|
||||
}
|
||||
};
|
||||
|
||||
StreamBodyAs::json_nl(s)
|
||||
}
|
||||
|
||||
fn parse_request(request: GenerateRequest) -> (String, TextGenerationOptions) {
|
||||
let mut builder = TextGenerationOptionsBuilder::default();
|
||||
|
||||
builder
|
||||
.max_input_length(1024)
|
||||
.max_decoding_length(968)
|
||||
.sampling_temperature(0.1);
|
||||
|
||||
if let Some(stop_words) = request.stop_words {
|
||||
builder.stop_words(stop_words);
|
||||
};
|
||||
|
||||
(request.prompt, builder.build().unwrap())
|
||||
}
|
||||
|
|
@ -11,7 +11,7 @@ use utoipa::ToSchema;
|
|||
pub struct HealthState {
|
||||
model: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
instruct_model: Option<String>,
|
||||
chat_model: Option<String>,
|
||||
device: String,
|
||||
compute_type: String,
|
||||
arch: String,
|
||||
|
|
@ -32,7 +32,7 @@ impl HealthState {
|
|||
|
||||
Self {
|
||||
model: args.model.clone(),
|
||||
instruct_model: args.instruct_model.clone(),
|
||||
chat_model: args.chat_model.clone(),
|
||||
device: args.device.to_string(),
|
||||
compute_type: args.compute_type.to_string(),
|
||||
arch: ARCH.to_string(),
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
mod chat;
|
||||
mod completions;
|
||||
mod engine;
|
||||
mod events;
|
||||
mod generate;
|
||||
mod health;
|
||||
mod playground;
|
||||
|
||||
|
|
@ -42,15 +42,16 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
|
|||
servers(
|
||||
(url = "https://playground.app.tabbyml.com", description = "Playground server"),
|
||||
),
|
||||
paths(events::log_event, completions::completion, generate::generate, generate::generate_stream, health::health),
|
||||
paths(events::log_event, completions::completions, chat::completions, health::health),
|
||||
components(schemas(
|
||||
events::LogEventRequest,
|
||||
completions::CompletionRequest,
|
||||
completions::CompletionResponse,
|
||||
completions::Segments,
|
||||
completions::Choice,
|
||||
generate::GenerateRequest,
|
||||
generate::GenerateResponse,
|
||||
chat::ChatCompletionRequest,
|
||||
chat::Message,
|
||||
chat::ChatCompletionChunk,
|
||||
health::HealthState,
|
||||
health::Version,
|
||||
))
|
||||
|
|
@ -105,14 +106,13 @@ pub enum ComputeType {
|
|||
|
||||
#[derive(Args)]
|
||||
pub struct ServeArgs {
|
||||
/// Model id for `/completion` API endpoint.
|
||||
/// Model id for `/completions` API endpoint.
|
||||
#[clap(long)]
|
||||
model: String,
|
||||
|
||||
/// Model id for `/generate` and `/generate_stream` API endpoints.
|
||||
/// If not set, `model` will be loaded for the purpose.
|
||||
/// Model id for `/chat/completions` API endpoints.
|
||||
#[clap(long)]
|
||||
instruct_model: Option<String>,
|
||||
chat_model: Option<String>,
|
||||
|
||||
#[clap(long, default_value_t = 8080)]
|
||||
port: u16,
|
||||
|
|
@ -149,8 +149,8 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
|||
|
||||
if args.device != Device::ExperimentalHttp {
|
||||
download_model(&args.model, &args.device).await;
|
||||
if let Some(instruct_model) = &args.instruct_model {
|
||||
download_model(instruct_model, &args.device).await;
|
||||
if let Some(chat_model) = &args.chat_model {
|
||||
download_model(chat_model, &args.device).await;
|
||||
}
|
||||
} else {
|
||||
warn!("HTTP device is unstable and does not comply with semver expectations.")
|
||||
|
|
@ -160,12 +160,18 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
|||
|
||||
let doc = add_localhost_server(ApiDoc::openapi(), args.port);
|
||||
let doc = add_proxy_server(doc, config.swagger.server_url.clone());
|
||||
let app = api_router(args, config)
|
||||
let app = Router::new()
|
||||
.merge(api_router(args, config))
|
||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc))
|
||||
.route("/playground", routing::get(playground::handler))
|
||||
.route("/playground/*path", routing::get(playground::handler))
|
||||
.fallback(fallback());
|
||||
|
||||
let app = if args.chat_model.is_some() {
|
||||
app.route("/playground", routing::get(playground::handler))
|
||||
.route("/playground/*path", routing::get(playground::handler))
|
||||
} else {
|
||||
app
|
||||
};
|
||||
|
||||
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
|
||||
info!("Listening at {}", address);
|
||||
|
||||
|
|
@ -177,15 +183,26 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
|||
}
|
||||
|
||||
fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||
let (engine, prompt_template) = create_engine(&args.model, args);
|
||||
let engine = Arc::new(engine);
|
||||
let instruct_engine = if let Some(instruct_model) = &args.instruct_model {
|
||||
Arc::new(create_engine(instruct_model, args).0)
|
||||
} else {
|
||||
engine.clone()
|
||||
let completion_state = {
|
||||
let (engine, prompt_template) = create_engine(&args.model, args);
|
||||
let engine = Arc::new(engine);
|
||||
let state = completions::CompletionState::new(engine.clone(), prompt_template, config);
|
||||
Arc::new(state)
|
||||
};
|
||||
|
||||
Router::new()
|
||||
let chat_state = if let Some(chat_model) = &args.chat_model {
|
||||
let (engine, prompt_template) = create_engine(chat_model, args);
|
||||
let Some(prompt_template) = prompt_template else {
|
||||
panic!("Chat model requires specifying prompt template");
|
||||
};
|
||||
let engine = Arc::new(engine);
|
||||
let state = chat::ChatState::new(engine, prompt_template);
|
||||
Some(Arc::new(state))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let router = Router::new()
|
||||
.route("/v1/events", routing::post(events::log_event))
|
||||
.route(
|
||||
"/v1/health",
|
||||
|
|
@ -193,22 +210,19 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
)
|
||||
.route(
|
||||
"/v1/completions",
|
||||
routing::post(completions::completion).with_state(Arc::new(
|
||||
completions::CompletionState::new(engine.clone(), prompt_template, config),
|
||||
)),
|
||||
)
|
||||
.route(
|
||||
"/v1beta/generate",
|
||||
routing::post(generate::generate).with_state(Arc::new(generate::GenerateState::new(
|
||||
instruct_engine.clone(),
|
||||
))),
|
||||
)
|
||||
.route(
|
||||
"/v1beta/generate_stream",
|
||||
routing::post(generate::generate_stream).with_state(Arc::new(
|
||||
generate::GenerateState::new(instruct_engine.clone()),
|
||||
)),
|
||||
routing::post(completions::completions).with_state(completion_state),
|
||||
);
|
||||
|
||||
let router = if let Some(chat_state) = chat_state {
|
||||
router.route(
|
||||
"/v1beta/chat/completions",
|
||||
routing::post(chat::completions).with_state(chat_state),
|
||||
)
|
||||
} else {
|
||||
router
|
||||
};
|
||||
|
||||
router
|
||||
.layer(CorsLayer::permissive())
|
||||
.layer(opentelemetry_tracing_layer())
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue