feat: support prompt rewriting (#295)
* refactor: extract PromptBuilder * feat: load tantivy index in prompt builder * integrate with searcher * add enable_prompt_rewrite to control rewrite behavior * nit docs * limit 1 snippet per identifier * extract magic numberssweep/improve-logging-information
parent
207559b0a2
commit
4388fd0050
|
|
@ -2776,6 +2776,7 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
|
|||
name = "tabby"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"axum-tracing-opentelemetry",
|
||||
"clap",
|
||||
|
|
@ -2794,6 +2795,7 @@ dependencies = [
|
|||
"tabby-common",
|
||||
"tabby-download",
|
||||
"tabby-scheduler",
|
||||
"tantivy",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tower-http 0.4.0",
|
||||
|
|
|
|||
|
|
@ -24,3 +24,4 @@ tracing = "0.1"
|
|||
tracing-subscriber = "0.3"
|
||||
anyhow = "1.0.71"
|
||||
serde-jsonlines = "0.4.0"
|
||||
tantivy = "0.19.2"
|
||||
|
|
|
|||
|
|
@ -8,9 +8,15 @@ use serde::Deserialize;
|
|||
|
||||
use crate::path::{config_file, repositories_dir};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Default)]
|
||||
pub struct Config {
|
||||
pub repositories: Vec<Repository>,
|
||||
pub experimental: Experimental,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Default)]
|
||||
pub struct Experimental {
|
||||
pub enable_prompt_rewrite: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ anyhow = { workspace = true }
|
|||
filenamify = "0.1.0"
|
||||
job_scheduler = "1.2.1"
|
||||
tabby-common = { path = "../tabby-common" }
|
||||
tantivy = "0.19.2"
|
||||
tantivy = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tree-sitter-javascript = "0.20.0"
|
||||
tree-sitter-tags = "0.20.2"
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ pub async fn scheduler(now: bool) -> Result<()> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tabby_common::{
|
||||
config::{Config, Repository},
|
||||
config::{Config, Experimental, Repository},
|
||||
path::set_tabby_root,
|
||||
};
|
||||
use temp_testdir::*;
|
||||
|
|
@ -76,6 +76,7 @@ mod tests {
|
|||
repositories: vec![Repository {
|
||||
git_url: "https://github.com/TabbyML/interview-questions".to_owned(),
|
||||
}],
|
||||
experimental: Experimental::default(),
|
||||
};
|
||||
|
||||
repository::sync_repositories(&config).unwrap();
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
|||
opentelemetry-otlp = "0.11.0"
|
||||
axum-tracing-opentelemetry = "0.10.0"
|
||||
tracing-opentelemetry = "0.18.0"
|
||||
tantivy = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
|
||||
|
||||
[dependencies.uuid]
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ use opentelemetry::{
|
|||
KeyValue,
|
||||
};
|
||||
use opentelemetry_otlp::WithExportConfig;
|
||||
use tabby_common::config::Config;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
|
||||
|
||||
#[derive(Parser)]
|
||||
|
|
@ -47,8 +48,10 @@ async fn main() {
|
|||
let cli = Cli::parse();
|
||||
init_logging(cli.otlp_endpoint);
|
||||
|
||||
let config = Config::load().unwrap_or(Config::default());
|
||||
|
||||
match &cli.command {
|
||||
Commands::Serve(args) => serve::main(args).await,
|
||||
Commands::Serve(args) => serve::main(&config, args).await,
|
||||
Commands::Download(args) => download::main(args).await,
|
||||
#[cfg(feature = "scheduler")]
|
||||
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
mod languages;
|
||||
mod prompt;
|
||||
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
use axum::{extract::State, Json};
|
||||
|
|
@ -6,16 +9,13 @@ use ctranslate2_bindings::{
|
|||
};
|
||||
use hyper::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strfmt::{strfmt, strfmt_builder};
|
||||
use tabby_common::{events, path::ModelDir};
|
||||
use tracing::instrument;
|
||||
use tabby_common::{config::Config, events, path::ModelDir};
|
||||
use tracing::{debug, instrument};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use self::languages::get_stop_words;
|
||||
use crate::fatal;
|
||||
|
||||
mod languages;
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
#[schema(example=json!({
|
||||
"language": "python",
|
||||
|
|
@ -86,23 +86,20 @@ pub async fn completion(
|
|||
.build()
|
||||
.unwrap();
|
||||
|
||||
let prompt = if let Some(Segments { prefix, suffix }) = request.segments {
|
||||
if let (Some(prompt_template), Some(suffix)) = (&state.prompt_template, suffix) {
|
||||
if !suffix.is_empty() {
|
||||
strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap()
|
||||
} else {
|
||||
prefix
|
||||
}
|
||||
} else {
|
||||
// If there's no prompt template, just use prefix.
|
||||
prefix
|
||||
}
|
||||
let segments = if let Some(segments) = request.segments {
|
||||
segments
|
||||
} else if let Some(prompt) = request.prompt {
|
||||
prompt
|
||||
Segments {
|
||||
prefix: prompt,
|
||||
suffix: None,
|
||||
}
|
||||
} else {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
};
|
||||
|
||||
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
|
||||
let prompt = state.prompt_builder.build(&language, segments);
|
||||
debug!("PROMPT: {}", prompt);
|
||||
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
|
||||
let text = state.engine.inference(&prompt, options).await;
|
||||
|
||||
|
|
@ -126,11 +123,11 @@ pub async fn completion(
|
|||
|
||||
pub struct CompletionState {
|
||||
engine: TextInferenceEngine,
|
||||
prompt_template: Option<String>,
|
||||
prompt_builder: prompt::PromptBuilder,
|
||||
}
|
||||
|
||||
impl CompletionState {
|
||||
pub fn new(args: &crate::serve::ServeArgs) -> Self {
|
||||
pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self {
|
||||
let model_dir = get_model_dir(&args.model);
|
||||
let metadata = read_metadata(&model_dir);
|
||||
|
||||
|
|
@ -149,7 +146,10 @@ impl CompletionState {
|
|||
let engine = TextInferenceEngine::create(options);
|
||||
Self {
|
||||
engine,
|
||||
prompt_template: metadata.prompt_template,
|
||||
prompt_builder: prompt::PromptBuilder::new(
|
||||
metadata.prompt_template,
|
||||
config.experimental.enable_prompt_rewrite,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,198 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use lazy_static::lazy_static;
|
||||
use strfmt::strfmt;
|
||||
use tabby_common::path::index_dir;
|
||||
use tantivy::{
|
||||
collector::TopDocs, query::QueryParser, schema::Field, Index, ReloadPolicy, Searcher,
|
||||
};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use super::Segments;
|
||||
|
||||
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
||||
static MAX_SNIPPET_PER_NAME: u32 = 1;
|
||||
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 1024;
|
||||
|
||||
pub struct PromptBuilder {
|
||||
prompt_template: Option<String>,
|
||||
index: Option<IndexState>,
|
||||
}
|
||||
|
||||
impl PromptBuilder {
|
||||
pub fn new(prompt_template: Option<String>, enable_prompt_rewrite: bool) -> Self {
|
||||
let index = if enable_prompt_rewrite {
|
||||
info!("Experimental feature `enable_prompt_rewrite` is enabled, loading index ...");
|
||||
let index = IndexState::new();
|
||||
if let Err(err) = &index {
|
||||
warn!("Failed to open index in {:?}: {:?}", index_dir(), err);
|
||||
}
|
||||
index.ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
PromptBuilder {
|
||||
prompt_template,
|
||||
index,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_prompt(&self, prefix: String, suffix: String) -> String {
|
||||
if let Some(prompt_template) = &self.prompt_template {
|
||||
strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap()
|
||||
} else {
|
||||
prefix
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build(&self, language: &str, segments: Segments) -> String {
|
||||
let segments = self.rewrite(language, segments);
|
||||
if let Some(suffix) = segments.suffix {
|
||||
self.build_prompt(segments.prefix, suffix)
|
||||
} else {
|
||||
self.build_prompt(segments.prefix, "".to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
fn rewrite(&self, language: &str, segments: Segments) -> Segments {
|
||||
if let Some(index) = &self.index {
|
||||
rewrite_with_index(index, language, segments)
|
||||
} else {
|
||||
segments
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rewrite_with_index(index: &IndexState, language: &str, segments: Segments) -> Segments {
|
||||
let snippets = collect_snippets(index, language, &segments.prefix);
|
||||
if snippets.is_empty() {
|
||||
segments
|
||||
} else {
|
||||
let prefix = build_prefix(language, &segments.prefix, snippets);
|
||||
Segments { prefix, ..segments }
|
||||
}
|
||||
}
|
||||
|
||||
fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
|
||||
let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap();
|
||||
let mut lines: Vec<String> = vec![format!(
|
||||
"Below are some relevant {} snippets found in the repository:",
|
||||
language
|
||||
)];
|
||||
|
||||
let mut count_characters = 0;
|
||||
for (i, snippet) in snippets.iter().enumerate() {
|
||||
if count_characters + snippet.len() > MAX_SNIPPET_CHARS_IN_PROMPT {
|
||||
break;
|
||||
}
|
||||
|
||||
lines.push(format!("== Snippet {} ==", i + 1));
|
||||
for line in snippet.lines() {
|
||||
lines.push(line.to_owned());
|
||||
count_characters += line.len();
|
||||
}
|
||||
|
||||
count_characters += snippet.len();
|
||||
}
|
||||
|
||||
let commented_lines: Vec<String> = lines
|
||||
.iter()
|
||||
.map(|x| format!("{} {}", comment_char, x))
|
||||
.collect();
|
||||
let comments = commented_lines.join("\n");
|
||||
format!("{}\n{}", comments, prefix)
|
||||
}
|
||||
|
||||
fn collect_snippets(index: &IndexState, language: &str, text: &str) -> Vec<String> {
|
||||
let mut ret = Vec::new();
|
||||
let sanitized_text = sanitize_text(text);
|
||||
if sanitized_text.is_empty() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
let query_text = format!(
|
||||
"language:{} AND kind:call AND ({})",
|
||||
language, sanitized_text
|
||||
);
|
||||
let query = match index.query_parser.parse_query(&query_text) {
|
||||
Ok(query) => query,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse query: {}", err);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
let top_docs = index
|
||||
.searcher
|
||||
.search(&query, &TopDocs::with_limit(MAX_SNIPPETS_TO_FETCH))
|
||||
.unwrap();
|
||||
|
||||
let mut names: HashMap<String, u32> = HashMap::new();
|
||||
for (_score, doc_address) in top_docs {
|
||||
let doc = index.searcher.doc(doc_address).unwrap();
|
||||
let name = doc
|
||||
.get_first(index.field_name)
|
||||
.and_then(|x| x.as_text())
|
||||
.unwrap();
|
||||
let count = *names.get(name).unwrap_or(&0);
|
||||
|
||||
// Max 1 snippet per identifier.
|
||||
if count >= MAX_SNIPPET_PER_NAME {
|
||||
continue;
|
||||
}
|
||||
|
||||
let body = doc
|
||||
.get_first(index.field_body)
|
||||
.and_then(|x| x.as_text())
|
||||
.unwrap();
|
||||
names.insert(name.to_owned(), count + 1);
|
||||
ret.push(body.to_owned());
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
fn sanitize_text(text: &str) -> String {
|
||||
let x = text.replace(|c: char| !c.is_ascii_digit() && !c.is_alphabetic(), " ");
|
||||
let tokens: Vec<&str> = x.split(' ').collect();
|
||||
tokens.join(" ")
|
||||
}
|
||||
|
||||
struct IndexState {
|
||||
searcher: Searcher,
|
||||
query_parser: QueryParser,
|
||||
field_name: Field,
|
||||
field_body: Field,
|
||||
}
|
||||
|
||||
impl IndexState {
|
||||
fn new() -> Result<IndexState> {
|
||||
let index = Index::open_in_dir(index_dir())?;
|
||||
let reader = index
|
||||
.reader_builder()
|
||||
.reload_policy(ReloadPolicy::OnCommit)
|
||||
.try_into()?;
|
||||
let field_name = index
|
||||
.schema()
|
||||
.get_field("name")
|
||||
.ok_or(anyhow!("Index doesn't have required field"))?;
|
||||
let field_body = index
|
||||
.schema()
|
||||
.get_field("body")
|
||||
.ok_or(anyhow!("Index doesn't have required field"))?;
|
||||
let query_parser = QueryParser::for_index(&index, vec![field_name]);
|
||||
Ok(Self {
|
||||
searcher: reader.searcher(),
|
||||
query_parser,
|
||||
field_name,
|
||||
field_body,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> =
|
||||
HashMap::from([("python", "#")]);
|
||||
}
|
||||
|
|
@ -10,6 +10,7 @@ use std::{
|
|||
use axum::{routing, Router, Server};
|
||||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||
use clap::Args;
|
||||
use tabby_common::config::Config;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tracing::info;
|
||||
use utoipa::OpenApi;
|
||||
|
|
@ -107,7 +108,7 @@ pub struct ServeArgs {
|
|||
compute_type: ComputeType,
|
||||
}
|
||||
|
||||
pub async fn main(args: &ServeArgs) {
|
||||
pub async fn main(config: &Config, args: &ServeArgs) {
|
||||
valid_args(args);
|
||||
|
||||
// Ensure model exists.
|
||||
|
|
@ -123,7 +124,7 @@ pub async fn main(args: &ServeArgs) {
|
|||
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
||||
.nest("/v1", api_router(args))
|
||||
.nest("/v1", api_router(args, config))
|
||||
.fallback(fallback());
|
||||
|
||||
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
|
||||
|
|
@ -134,7 +135,7 @@ pub async fn main(args: &ServeArgs) {
|
|||
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
|
||||
}
|
||||
|
||||
fn api_router(args: &ServeArgs) -> Router {
|
||||
fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||
Router::new()
|
||||
.route("/events", routing::post(events::log_event))
|
||||
.route(
|
||||
|
|
@ -144,7 +145,7 @@ fn api_router(args: &ServeArgs) -> Router {
|
|||
.route(
|
||||
"/completions",
|
||||
routing::post(completions::completion)
|
||||
.with_state(Arc::new(completions::CompletionState::new(args))),
|
||||
.with_state(Arc::new(completions::CompletionState::new(args, config))),
|
||||
)
|
||||
.layer(CorsLayer::permissive())
|
||||
.layer(opentelemetry_tracing_layer())
|
||||
|
|
|
|||
Loading…
Reference in New Issue