refactor: extract ChatState -> ChatService (#730)

refactor-extract-code
Meng Zhang 2023-11-08 14:12:29 -08:00 committed by GitHub
parent 72d1d9f0bb
commit b51520062a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 92 additions and 72 deletions

1
Cargo.lock generated
View File

@ -4062,6 +4062,7 @@ dependencies = [
"axum-streams",
"axum-tracing-opentelemetry",
"clap 4.4.7",
"futures",
"http-api-bindings",
"hyper",
"lazy_static",

View File

@ -44,6 +44,7 @@ textdistance = "1.0.2"
regex.workspace = true
thiserror.workspace = true
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
futures.workspace = true
[dependencies.uuid]
version = "1.3.3"

79
crates/tabby/src/chat.rs Normal file
View File

@ -0,0 +1,79 @@
mod prompt;
use std::sync::Arc;
use async_stream::stream;
use futures::stream::BoxStream;
use prompt::ChatPromptBuilder;
use serde::{Deserialize, Serialize};
use tabby_common::languages::EMPTY_LANGUAGE;
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::debug;
use utoipa::ToSchema;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"messages": [
Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()},
Message { role: "assistant".to_owned(), content: "It's a kind of optimization in compiler?".to_owned()},
Message { role: "user".to_owned(), content: "Could you share more details?".to_owned()},
]
}))]
pub struct ChatCompletionRequest {
messages: Vec<Message>,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Message {
role: String,
content: String,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct ChatCompletionChunk {
content: String,
}
pub struct ChatService {
engine: Arc<Box<dyn TextGeneration>>,
prompt_builder: ChatPromptBuilder,
}
impl ChatService {
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
Self {
engine,
prompt_builder: ChatPromptBuilder::new(chat_template),
}
}
fn parse_request(&self, request: &ChatCompletionRequest) -> (String, TextGenerationOptions) {
let mut builder = TextGenerationOptionsBuilder::default();
builder
.max_input_length(2048)
.max_decoding_length(1920)
.language(&EMPTY_LANGUAGE)
.sampling_temperature(0.1);
(
self.prompt_builder.build(&request.messages),
builder.build().unwrap(),
)
}
pub async fn generate(
&self,
request: &ChatCompletionRequest,
) -> BoxStream<ChatCompletionChunk> {
let (prompt, options) = self.parse_request(request);
debug!("PROMPT: {}", prompt);
let s = stream! {
for await content in self.engine.generate_stream(&prompt, options).await {
yield ChatCompletionChunk { content }
}
};
Box::pin(s)
}
}

View File

@ -1,3 +1,4 @@
mod chat;
mod download;
mod search;
mod serve;

View File

@ -1,5 +1,3 @@
mod prompt;
use std::sync::Arc;
use async_stream::stream;
@ -9,49 +7,9 @@ use axum::{
Json,
};
use axum_streams::StreamBodyAs;
use prompt::ChatPromptBuilder;
use serde::{Deserialize, Serialize};
use tabby_common::languages::EMPTY_LANGUAGE;
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::{debug, instrument};
use utoipa::ToSchema;
use tracing::instrument;
pub struct ChatState {
engine: Arc<Box<dyn TextGeneration>>,
prompt_builder: ChatPromptBuilder,
}
impl ChatState {
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
Self {
engine,
prompt_builder: ChatPromptBuilder::new(chat_template),
}
}
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"messages": [
Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()},
Message { role: "assistant".to_owned(), content: "It's a kind of optimization in compiler?".to_owned()},
Message { role: "user".to_owned(), content: "Could you share more details?".to_owned()},
]
}))]
pub struct ChatCompletionRequest {
messages: Vec<Message>,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Message {
role: String,
content: String,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct ChatCompletionChunk {
content: String,
}
use crate::chat::{ChatCompletionRequest, ChatService};
#[utoipa::path(
post,
@ -66,34 +24,14 @@ pub struct ChatCompletionChunk {
)]
#[instrument(skip(state, request))]
pub async fn completions(
State(state): State<Arc<ChatState>>,
State(state): State<Arc<ChatService>>,
Json(request): Json<ChatCompletionRequest>,
) -> Response {
let (prompt, options) = parse_request(&state, request);
debug!("PROMPT: {}", prompt);
let s = stream! {
for await content in state.engine.generate_stream(&prompt, options).await {
yield ChatCompletionChunk { content }
for await content in state.generate(&request).await {
yield content;
}
};
StreamBodyAs::json_nl(s).into_response()
}
fn parse_request(
state: &Arc<ChatState>,
request: ChatCompletionRequest,
) -> (String, TextGenerationOptions) {
let mut builder = TextGenerationOptionsBuilder::default();
builder
.max_input_length(2048)
.max_decoding_length(1920)
.language(&EMPTY_LANGUAGE)
.sampling_temperature(0.1);
(
state.prompt_builder.build(&request.messages),
builder.build().unwrap(),
)
}

View File

@ -28,7 +28,7 @@ use self::{
engine::{create_engine, EngineInfo},
health::HealthState,
};
use crate::{fatal, search::CodeSearchService};
use crate::{chat::ChatService, fatal, search::CodeSearchService};
#[derive(OpenApi)]
#[openapi(
@ -57,9 +57,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
completions::Snippet,
completions::DebugOptions,
completions::DebugData,
chat::ChatCompletionRequest,
chat::Message,
chat::ChatCompletionChunk,
crate::chat::ChatCompletionRequest,
crate::chat::Message,
crate::chat::ChatCompletionChunk,
health::HealthState,
health::Version,
crate::search::SearchResponse,
@ -189,7 +189,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
panic!("Chat model requires specifying prompt template");
};
let engine = Arc::new(engine);
let state = chat::ChatState::new(engine, chat_template);
let state = ChatService::new(engine, chat_template);
Some(Arc::new(state))
} else {
None