refactor: extract ChatState -> ChatService (#730)
parent
72d1d9f0bb
commit
b51520062a
|
|
@ -4062,6 +4062,7 @@ dependencies = [
|
||||||
"axum-streams",
|
"axum-streams",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"clap 4.4.7",
|
"clap 4.4.7",
|
||||||
|
"futures",
|
||||||
"http-api-bindings",
|
"http-api-bindings",
|
||||||
"hyper",
|
"hyper",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ textdistance = "1.0.2"
|
||||||
regex.workspace = true
|
regex.workspace = true
|
||||||
thiserror.workspace = true
|
thiserror.workspace = true
|
||||||
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
|
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
|
||||||
|
futures.workspace = true
|
||||||
|
|
||||||
[dependencies.uuid]
|
[dependencies.uuid]
|
||||||
version = "1.3.3"
|
version = "1.3.3"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,79 @@
|
||||||
|
mod prompt;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_stream::stream;
|
||||||
|
use futures::stream::BoxStream;
|
||||||
|
use prompt::ChatPromptBuilder;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tabby_common::languages::EMPTY_LANGUAGE;
|
||||||
|
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||||
|
use tracing::debug;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
#[schema(example=json!({
|
||||||
|
"messages": [
|
||||||
|
Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()},
|
||||||
|
Message { role: "assistant".to_owned(), content: "It's a kind of optimization in compiler?".to_owned()},
|
||||||
|
Message { role: "user".to_owned(), content: "Could you share more details?".to_owned()},
|
||||||
|
]
|
||||||
|
}))]
|
||||||
|
pub struct ChatCompletionRequest {
|
||||||
|
messages: Vec<Message>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct Message {
|
||||||
|
role: String,
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct ChatCompletionChunk {
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ChatService {
|
||||||
|
engine: Arc<Box<dyn TextGeneration>>,
|
||||||
|
prompt_builder: ChatPromptBuilder,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatService {
|
||||||
|
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
|
||||||
|
Self {
|
||||||
|
engine,
|
||||||
|
prompt_builder: ChatPromptBuilder::new(chat_template),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_request(&self, request: &ChatCompletionRequest) -> (String, TextGenerationOptions) {
|
||||||
|
let mut builder = TextGenerationOptionsBuilder::default();
|
||||||
|
|
||||||
|
builder
|
||||||
|
.max_input_length(2048)
|
||||||
|
.max_decoding_length(1920)
|
||||||
|
.language(&EMPTY_LANGUAGE)
|
||||||
|
.sampling_temperature(0.1);
|
||||||
|
|
||||||
|
(
|
||||||
|
self.prompt_builder.build(&request.messages),
|
||||||
|
builder.build().unwrap(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn generate(
|
||||||
|
&self,
|
||||||
|
request: &ChatCompletionRequest,
|
||||||
|
) -> BoxStream<ChatCompletionChunk> {
|
||||||
|
let (prompt, options) = self.parse_request(request);
|
||||||
|
debug!("PROMPT: {}", prompt);
|
||||||
|
let s = stream! {
|
||||||
|
for await content in self.engine.generate_stream(&prompt, options).await {
|
||||||
|
yield ChatCompletionChunk { content }
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Box::pin(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
mod chat;
|
||||||
mod download;
|
mod download;
|
||||||
mod search;
|
mod search;
|
||||||
mod serve;
|
mod serve;
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
mod prompt;
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
|
|
@ -9,49 +7,9 @@ use axum::{
|
||||||
Json,
|
Json,
|
||||||
};
|
};
|
||||||
use axum_streams::StreamBodyAs;
|
use axum_streams::StreamBodyAs;
|
||||||
use prompt::ChatPromptBuilder;
|
use tracing::instrument;
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use tabby_common::languages::EMPTY_LANGUAGE;
|
|
||||||
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
|
||||||
use tracing::{debug, instrument};
|
|
||||||
use utoipa::ToSchema;
|
|
||||||
|
|
||||||
pub struct ChatState {
|
use crate::chat::{ChatCompletionRequest, ChatService};
|
||||||
engine: Arc<Box<dyn TextGeneration>>,
|
|
||||||
prompt_builder: ChatPromptBuilder,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChatState {
|
|
||||||
pub fn new(engine: Arc<Box<dyn TextGeneration>>, chat_template: String) -> Self {
|
|
||||||
Self {
|
|
||||||
engine,
|
|
||||||
prompt_builder: ChatPromptBuilder::new(chat_template),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
|
||||||
#[schema(example=json!({
|
|
||||||
"messages": [
|
|
||||||
Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()},
|
|
||||||
Message { role: "assistant".to_owned(), content: "It's a kind of optimization in compiler?".to_owned()},
|
|
||||||
Message { role: "user".to_owned(), content: "Could you share more details?".to_owned()},
|
|
||||||
]
|
|
||||||
}))]
|
|
||||||
pub struct ChatCompletionRequest {
|
|
||||||
messages: Vec<Message>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
|
||||||
pub struct Message {
|
|
||||||
role: String,
|
|
||||||
content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
|
||||||
pub struct ChatCompletionChunk {
|
|
||||||
content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
|
|
@ -66,34 +24,14 @@ pub struct ChatCompletionChunk {
|
||||||
)]
|
)]
|
||||||
#[instrument(skip(state, request))]
|
#[instrument(skip(state, request))]
|
||||||
pub async fn completions(
|
pub async fn completions(
|
||||||
State(state): State<Arc<ChatState>>,
|
State(state): State<Arc<ChatService>>,
|
||||||
Json(request): Json<ChatCompletionRequest>,
|
Json(request): Json<ChatCompletionRequest>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let (prompt, options) = parse_request(&state, request);
|
|
||||||
debug!("PROMPT: {}", prompt);
|
|
||||||
let s = stream! {
|
let s = stream! {
|
||||||
for await content in state.engine.generate_stream(&prompt, options).await {
|
for await content in state.generate(&request).await {
|
||||||
yield ChatCompletionChunk { content }
|
yield content;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
StreamBodyAs::json_nl(s).into_response()
|
StreamBodyAs::json_nl(s).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_request(
|
|
||||||
state: &Arc<ChatState>,
|
|
||||||
request: ChatCompletionRequest,
|
|
||||||
) -> (String, TextGenerationOptions) {
|
|
||||||
let mut builder = TextGenerationOptionsBuilder::default();
|
|
||||||
|
|
||||||
builder
|
|
||||||
.max_input_length(2048)
|
|
||||||
.max_decoding_length(1920)
|
|
||||||
.language(&EMPTY_LANGUAGE)
|
|
||||||
.sampling_temperature(0.1);
|
|
||||||
|
|
||||||
(
|
|
||||||
state.prompt_builder.build(&request.messages),
|
|
||||||
builder.build().unwrap(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ use self::{
|
||||||
engine::{create_engine, EngineInfo},
|
engine::{create_engine, EngineInfo},
|
||||||
health::HealthState,
|
health::HealthState,
|
||||||
};
|
};
|
||||||
use crate::{fatal, search::CodeSearchService};
|
use crate::{chat::ChatService, fatal, search::CodeSearchService};
|
||||||
|
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
#[openapi(
|
#[openapi(
|
||||||
|
|
@ -57,9 +57,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
|
||||||
completions::Snippet,
|
completions::Snippet,
|
||||||
completions::DebugOptions,
|
completions::DebugOptions,
|
||||||
completions::DebugData,
|
completions::DebugData,
|
||||||
chat::ChatCompletionRequest,
|
crate::chat::ChatCompletionRequest,
|
||||||
chat::Message,
|
crate::chat::Message,
|
||||||
chat::ChatCompletionChunk,
|
crate::chat::ChatCompletionChunk,
|
||||||
health::HealthState,
|
health::HealthState,
|
||||||
health::Version,
|
health::Version,
|
||||||
crate::search::SearchResponse,
|
crate::search::SearchResponse,
|
||||||
|
|
@ -189,7 +189,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
panic!("Chat model requires specifying prompt template");
|
panic!("Chat model requires specifying prompt template");
|
||||||
};
|
};
|
||||||
let engine = Arc::new(engine);
|
let engine = Arc::new(engine);
|
||||||
let state = chat::ChatState::new(engine, chat_template);
|
let state = ChatService::new(engine, chat_template);
|
||||||
Some(Arc::new(state))
|
Some(Arc::new(state))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue