diff --git a/Cargo.lock b/Cargo.lock index 1f8bb67..c2f31e1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4062,6 +4062,7 @@ dependencies = [ "axum-streams", "axum-tracing-opentelemetry", "clap 4.4.7", + "futures", "http-api-bindings", "hyper", "lazy_static", diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index c1d08ed..df54be6 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -44,6 +44,7 @@ textdistance = "1.0.2" regex.workspace = true thiserror.workspace = true llama-cpp-bindings = { path = "../llama-cpp-bindings" } +futures.workspace = true [dependencies.uuid] version = "1.3.3" diff --git a/crates/tabby/src/chat.rs b/crates/tabby/src/chat.rs new file mode 100644 index 0000000..0d1d778 --- /dev/null +++ b/crates/tabby/src/chat.rs @@ -0,0 +1,79 @@ +mod prompt; + +use std::sync::Arc; + +use async_stream::stream; +use futures::stream::BoxStream; +use prompt::ChatPromptBuilder; +use serde::{Deserialize, Serialize}; +use tabby_common::languages::EMPTY_LANGUAGE; +use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; +use tracing::debug; +use utoipa::ToSchema; + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +#[schema(example=json!({ + "messages": [ + Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()}, + Message { role: "assistant".to_owned(), content: "It's a kind of optimization in compiler?".to_owned()}, + Message { role: "user".to_owned(), content: "Could you share more details?".to_owned()}, + ] +}))] +pub struct ChatCompletionRequest { + messages: Vec, +} + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +pub struct Message { + role: String, + content: String, +} + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +pub struct ChatCompletionChunk { + content: String, +} + +pub struct ChatService { + engine: Arc>, + prompt_builder: ChatPromptBuilder, +} + +impl ChatService { + pub fn new(engine: Arc>, chat_template: String) -> Self { + Self { + engine, + prompt_builder: ChatPromptBuilder::new(chat_template), + } + } + + fn parse_request(&self, request: &ChatCompletionRequest) -> (String, TextGenerationOptions) { + let mut builder = TextGenerationOptionsBuilder::default(); + + builder + .max_input_length(2048) + .max_decoding_length(1920) + .language(&EMPTY_LANGUAGE) + .sampling_temperature(0.1); + + ( + self.prompt_builder.build(&request.messages), + builder.build().unwrap(), + ) + } + + pub async fn generate( + &self, + request: &ChatCompletionRequest, + ) -> BoxStream { + let (prompt, options) = self.parse_request(request); + debug!("PROMPT: {}", prompt); + let s = stream! { + for await content in self.engine.generate_stream(&prompt, options).await { + yield ChatCompletionChunk { content } + } + }; + + Box::pin(s) + } +} diff --git a/crates/tabby/src/serve/chat/prompt.rs b/crates/tabby/src/chat/prompt.rs similarity index 100% rename from crates/tabby/src/serve/chat/prompt.rs rename to crates/tabby/src/chat/prompt.rs diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index 194574a..e12c979 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -1,3 +1,4 @@ +mod chat; mod download; mod search; mod serve; diff --git a/crates/tabby/src/serve/chat.rs b/crates/tabby/src/serve/chat.rs index 2fb62a2..f24aa67 100644 --- a/crates/tabby/src/serve/chat.rs +++ b/crates/tabby/src/serve/chat.rs @@ -1,5 +1,3 @@ -mod prompt; - use std::sync::Arc; use async_stream::stream; @@ -9,49 +7,9 @@ use axum::{ Json, }; use axum_streams::StreamBodyAs; -use prompt::ChatPromptBuilder; -use serde::{Deserialize, Serialize}; -use tabby_common::languages::EMPTY_LANGUAGE; -use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; -use tracing::{debug, instrument}; -use utoipa::ToSchema; +use tracing::instrument; -pub struct ChatState { - engine: Arc>, - prompt_builder: ChatPromptBuilder, -} - -impl ChatState { - pub fn new(engine: Arc>, chat_template: String) -> Self { - Self { - engine, - prompt_builder: ChatPromptBuilder::new(chat_template), - } - } -} - -#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] -#[schema(example=json!({ - "messages": [ - Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()}, - Message { role: "assistant".to_owned(), content: "It's a kind of optimization in compiler?".to_owned()}, - Message { role: "user".to_owned(), content: "Could you share more details?".to_owned()}, - ] -}))] -pub struct ChatCompletionRequest { - messages: Vec, -} - -#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] -pub struct Message { - role: String, - content: String, -} - -#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] -pub struct ChatCompletionChunk { - content: String, -} +use crate::chat::{ChatCompletionRequest, ChatService}; #[utoipa::path( post, @@ -66,34 +24,14 @@ pub struct ChatCompletionChunk { )] #[instrument(skip(state, request))] pub async fn completions( - State(state): State>, + State(state): State>, Json(request): Json, ) -> Response { - let (prompt, options) = parse_request(&state, request); - debug!("PROMPT: {}", prompt); let s = stream! { - for await content in state.engine.generate_stream(&prompt, options).await { - yield ChatCompletionChunk { content } + for await content in state.generate(&request).await { + yield content; } }; StreamBodyAs::json_nl(s).into_response() } - -fn parse_request( - state: &Arc, - request: ChatCompletionRequest, -) -> (String, TextGenerationOptions) { - let mut builder = TextGenerationOptionsBuilder::default(); - - builder - .max_input_length(2048) - .max_decoding_length(1920) - .language(&EMPTY_LANGUAGE) - .sampling_temperature(0.1); - - ( - state.prompt_builder.build(&request.messages), - builder.build().unwrap(), - ) -} diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index cde9a94..bfafd0c 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -28,7 +28,7 @@ use self::{ engine::{create_engine, EngineInfo}, health::HealthState, }; -use crate::{fatal, search::CodeSearchService}; +use crate::{chat::ChatService, fatal, search::CodeSearchService}; #[derive(OpenApi)] #[openapi( @@ -57,9 +57,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi completions::Snippet, completions::DebugOptions, completions::DebugData, - chat::ChatCompletionRequest, - chat::Message, - chat::ChatCompletionChunk, + crate::chat::ChatCompletionRequest, + crate::chat::Message, + crate::chat::ChatCompletionChunk, health::HealthState, health::Version, crate::search::SearchResponse, @@ -189,7 +189,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router { panic!("Chat model requires specifying prompt template"); }; let engine = Arc::new(engine); - let state = chat::ChatState::new(engine, chat_template); + let state = ChatService::new(engine, chat_template); Some(Arc::new(state)) } else { None