refactor: extract ChatState -> ChatService (#730)

refactor-extract-code
Meng Zhang 2023-11-08 14:12:29 -08:00 committed by GitHub
parent 72d1d9f0bb
commit b51520062a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 92 additions and 72 deletions

1
Cargo.lock generated
View File

@ -4062,6 +4062,7 @@ dependencies = [
"axum-streams", "axum-streams",
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"clap 4.4.7", "clap 4.4.7",
"futures",
"http-api-bindings", "http-api-bindings",
"hyper", "hyper",
"lazy_static", "lazy_static",

View File

@ -44,6 +44,7 @@ textdistance = "1.0.2"
regex.workspace = true regex.workspace = true
thiserror.workspace = true thiserror.workspace = true
llama-cpp-bindings = { path = "../llama-cpp-bindings" } llama-cpp-bindings = { path = "../llama-cpp-bindings" }
futures.workspace = true
[dependencies.uuid] [dependencies.uuid]
version = "1.3.3" version = "1.3.3"

79
crates/tabby/src/chat.rs Normal file
View File

@ -0,0 +1,79 @@
mod prompt;
use std::sync::Arc;
use async_stream::stream;
use futures::stream::BoxStream;
use prompt::ChatPromptBuilder;
use serde::{Deserialize, Serialize};
use tabby_common::languages::EMPTY_LANGUAGE;
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::debug;
use utoipa::ToSchema;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"messages": [
Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()},
Message { role: "assistant".to_owned(), content: "It's a kind of optimization in compiler?".to_owned()},
Message { role: "user".to_owned(), content: "Could you share more details?".to_owned()},
]
}))]
pub struct ChatCompletionRequest {
messages: Vec<Message>,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Message {
role: String,
content: String,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct ChatCompletionChunk {
content: String,
}
pub struct ChatService {
engine: Arc<Box<dyn TextGeneration>>,
prompt_builder: ChatPromptBuilder,
}
impl ChatService {
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
Self {
engine,
prompt_builder: ChatPromptBuilder::new(chat_template),
}
}
fn parse_request(&self, request: &ChatCompletionRequest) -> (String, TextGenerationOptions) {
let mut builder = TextGenerationOptionsBuilder::default();
builder
.max_input_length(2048)
.max_decoding_length(1920)
.language(&EMPTY_LANGUAGE)
.sampling_temperature(0.1);
(
self.prompt_builder.build(&request.messages),
builder.build().unwrap(),
)
}
pub async fn generate(
&self,
request: &ChatCompletionRequest,
) -> BoxStream<ChatCompletionChunk> {
let (prompt, options) = self.parse_request(request);
debug!("PROMPT: {}", prompt);
let s = stream! {
for await content in self.engine.generate_stream(&prompt, options).await {
yield ChatCompletionChunk { content }
}
};
Box::pin(s)
}
}

View File

@ -1,3 +1,4 @@
mod chat;
mod download; mod download;
mod search; mod search;
mod serve; mod serve;

View File

@ -1,5 +1,3 @@
mod prompt;
use std::sync::Arc; use std::sync::Arc;
use async_stream::stream; use async_stream::stream;
@ -9,49 +7,9 @@ use axum::{
Json, Json,
}; };
use axum_streams::StreamBodyAs; use axum_streams::StreamBodyAs;
use prompt::ChatPromptBuilder; use tracing::instrument;
use serde::{Deserialize, Serialize};
use tabby_common::languages::EMPTY_LANGUAGE;
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::{debug, instrument};
use utoipa::ToSchema;
pub struct ChatState { use crate::chat::{ChatCompletionRequest, ChatService};
engine: Arc<Box<dyn TextGeneration>>,
prompt_builder: ChatPromptBuilder,
}
impl ChatState {
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
Self {
engine,
prompt_builder: ChatPromptBuilder::new(chat_template),
}
}
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"messages": [
Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()},
Message { role: "assistant".to_owned(), content: "It's a kind of optimization in compiler?".to_owned()},
Message { role: "user".to_owned(), content: "Could you share more details?".to_owned()},
]
}))]
pub struct ChatCompletionRequest {
messages: Vec<Message>,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Message {
role: String,
content: String,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct ChatCompletionChunk {
content: String,
}
#[utoipa::path( #[utoipa::path(
post, post,
@ -66,34 +24,14 @@ pub struct ChatCompletionChunk {
)] )]
#[instrument(skip(state, request))] #[instrument(skip(state, request))]
pub async fn completions( pub async fn completions(
State(state): State<Arc<ChatState>>, State(state): State<Arc<ChatService>>,
Json(request): Json<ChatCompletionRequest>, Json(request): Json<ChatCompletionRequest>,
) -> Response { ) -> Response {
let (prompt, options) = parse_request(&state, request);
debug!("PROMPT: {}", prompt);
let s = stream! { let s = stream! {
for await content in state.engine.generate_stream(&prompt, options).await { for await content in state.generate(&request).await {
yield ChatCompletionChunk { content } yield content;
} }
}; };
StreamBodyAs::json_nl(s).into_response() StreamBodyAs::json_nl(s).into_response()
} }
fn parse_request(
state: &Arc<ChatState>,
request: ChatCompletionRequest,
) -> (String, TextGenerationOptions) {
let mut builder = TextGenerationOptionsBuilder::default();
builder
.max_input_length(2048)
.max_decoding_length(1920)
.language(&EMPTY_LANGUAGE)
.sampling_temperature(0.1);
(
state.prompt_builder.build(&request.messages),
builder.build().unwrap(),
)
}

View File

@ -28,7 +28,7 @@ use self::{
engine::{create_engine, EngineInfo}, engine::{create_engine, EngineInfo},
health::HealthState, health::HealthState,
}; };
use crate::{fatal, search::CodeSearchService}; use crate::{chat::ChatService, fatal, search::CodeSearchService};
#[derive(OpenApi)] #[derive(OpenApi)]
#[openapi( #[openapi(
@ -57,9 +57,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
completions::Snippet, completions::Snippet,
completions::DebugOptions, completions::DebugOptions,
completions::DebugData, completions::DebugData,
chat::ChatCompletionRequest, crate::chat::ChatCompletionRequest,
chat::Message, crate::chat::Message,
chat::ChatCompletionChunk, crate::chat::ChatCompletionChunk,
health::HealthState, health::HealthState,
health::Version, health::Version,
crate::search::SearchResponse, crate::search::SearchResponse,
@ -189,7 +189,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
panic!("Chat model requires specifying prompt template"); panic!("Chat model requires specifying prompt template");
}; };
let engine = Arc::new(engine); let engine = Arc::new(engine);
let state = chat::ChatState::new(engine, chat_template); let state = ChatService::new(engine, chat_template);
Some(Arc::new(state)) Some(Arc::new(state))
} else { } else {
None None