refactor: extract ChatState -> ChatService (#730)
parent
72d1d9f0bb
commit
b51520062a
|
|
@ -4062,6 +4062,7 @@ dependencies = [
|
|||
"axum-streams",
|
||||
"axum-tracing-opentelemetry",
|
||||
"clap 4.4.7",
|
||||
"futures",
|
||||
"http-api-bindings",
|
||||
"hyper",
|
||||
"lazy_static",
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ textdistance = "1.0.2"
|
|||
regex.workspace = true
|
||||
thiserror.workspace = true
|
||||
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
|
||||
futures.workspace = true
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1.3.3"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
mod prompt;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_stream::stream;
|
||||
use futures::stream::BoxStream;
|
||||
use prompt::ChatPromptBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::languages::EMPTY_LANGUAGE;
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||
use tracing::debug;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
#[schema(example=json!({
|
||||
"messages": [
|
||||
Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()},
|
||||
Message { role: "assistant".to_owned(), content: "It's a kind of optimization in compiler?".to_owned()},
|
||||
Message { role: "user".to_owned(), content: "Could you share more details?".to_owned()},
|
||||
]
|
||||
}))]
|
||||
pub struct ChatCompletionRequest {
|
||||
messages: Vec<Message>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct ChatCompletionChunk {
|
||||
content: String,
|
||||
}
|
||||
|
||||
pub struct ChatService {
|
||||
engine: Arc<Box<dyn TextGeneration>>,
|
||||
prompt_builder: ChatPromptBuilder,
|
||||
}
|
||||
|
||||
impl ChatService {
|
||||
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
|
||||
Self {
|
||||
engine,
|
||||
prompt_builder: ChatPromptBuilder::new(chat_template),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_request(&self, request: &ChatCompletionRequest) -> (String, TextGenerationOptions) {
|
||||
let mut builder = TextGenerationOptionsBuilder::default();
|
||||
|
||||
builder
|
||||
.max_input_length(2048)
|
||||
.max_decoding_length(1920)
|
||||
.language(&EMPTY_LANGUAGE)
|
||||
.sampling_temperature(0.1);
|
||||
|
||||
(
|
||||
self.prompt_builder.build(&request.messages),
|
||||
builder.build().unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn generate(
|
||||
&self,
|
||||
request: &ChatCompletionRequest,
|
||||
) -> BoxStream<ChatCompletionChunk> {
|
||||
let (prompt, options) = self.parse_request(request);
|
||||
debug!("PROMPT: {}", prompt);
|
||||
let s = stream! {
|
||||
for await content in self.engine.generate_stream(&prompt, options).await {
|
||||
yield ChatCompletionChunk { content }
|
||||
}
|
||||
};
|
||||
|
||||
Box::pin(s)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
mod chat;
|
||||
mod download;
|
||||
mod search;
|
||||
mod serve;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
mod prompt;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_stream::stream;
|
||||
|
|
@ -9,49 +7,9 @@ use axum::{
|
|||
Json,
|
||||
};
|
||||
use axum_streams::StreamBodyAs;
|
||||
use prompt::ChatPromptBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::languages::EMPTY_LANGUAGE;
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||
use tracing::{debug, instrument};
|
||||
use utoipa::ToSchema;
|
||||
use tracing::instrument;
|
||||
|
||||
pub struct ChatState {
|
||||
engine: Arc<Box<dyn TextGeneration>>,
|
||||
prompt_builder: ChatPromptBuilder,
|
||||
}
|
||||
|
||||
impl ChatState {
|
||||
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
|
||||
Self {
|
||||
engine,
|
||||
prompt_builder: ChatPromptBuilder::new(chat_template),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
#[schema(example=json!({
|
||||
"messages": [
|
||||
Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()},
|
||||
Message { role: "assistant".to_owned(), content: "It's a kind of optimization in compiler?".to_owned()},
|
||||
Message { role: "user".to_owned(), content: "Could you share more details?".to_owned()},
|
||||
]
|
||||
}))]
|
||||
pub struct ChatCompletionRequest {
|
||||
messages: Vec<Message>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct ChatCompletionChunk {
|
||||
content: String,
|
||||
}
|
||||
use crate::chat::{ChatCompletionRequest, ChatService};
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
|
|
@ -66,34 +24,14 @@ pub struct ChatCompletionChunk {
|
|||
)]
|
||||
#[instrument(skip(state, request))]
|
||||
pub async fn completions(
|
||||
State(state): State<Arc<ChatState>>,
|
||||
State(state): State<Arc<ChatService>>,
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
) -> Response {
|
||||
let (prompt, options) = parse_request(&state, request);
|
||||
debug!("PROMPT: {}", prompt);
|
||||
let s = stream! {
|
||||
for await content in state.engine.generate_stream(&prompt, options).await {
|
||||
yield ChatCompletionChunk { content }
|
||||
for await content in state.generate(&request).await {
|
||||
yield content;
|
||||
}
|
||||
};
|
||||
|
||||
StreamBodyAs::json_nl(s).into_response()
|
||||
}
|
||||
|
||||
fn parse_request(
|
||||
state: &Arc<ChatState>,
|
||||
request: ChatCompletionRequest,
|
||||
) -> (String, TextGenerationOptions) {
|
||||
let mut builder = TextGenerationOptionsBuilder::default();
|
||||
|
||||
builder
|
||||
.max_input_length(2048)
|
||||
.max_decoding_length(1920)
|
||||
.language(&EMPTY_LANGUAGE)
|
||||
.sampling_temperature(0.1);
|
||||
|
||||
(
|
||||
state.prompt_builder.build(&request.messages),
|
||||
builder.build().unwrap(),
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ use self::{
|
|||
engine::{create_engine, EngineInfo},
|
||||
health::HealthState,
|
||||
};
|
||||
use crate::{fatal, search::CodeSearchService};
|
||||
use crate::{chat::ChatService, fatal, search::CodeSearchService};
|
||||
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
|
|
@ -57,9 +57,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
|
|||
completions::Snippet,
|
||||
completions::DebugOptions,
|
||||
completions::DebugData,
|
||||
chat::ChatCompletionRequest,
|
||||
chat::Message,
|
||||
chat::ChatCompletionChunk,
|
||||
crate::chat::ChatCompletionRequest,
|
||||
crate::chat::Message,
|
||||
crate::chat::ChatCompletionChunk,
|
||||
health::HealthState,
|
||||
health::Version,
|
||||
crate::search::SearchResponse,
|
||||
|
|
@ -189,7 +189,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
panic!("Chat model requires specifying prompt template");
|
||||
};
|
||||
let engine = Arc::new(engine);
|
||||
let state = chat::ChatState::new(engine, chat_template);
|
||||
let state = ChatService::new(engine, chat_template);
|
||||
Some(Arc::new(state))
|
||||
} else {
|
||||
None
|
||||
|
|
|
|||
Loading…
Reference in New Issue