feat: support prompt rewriting (#295)
* refactor: extract PromptBuilder * feat: load tantivy index in prompt builder * integrate with searcher * add enable_prompt_rewrite to control rewrite behavior * nit docs * limit 1 snippet per identifier * extract magic numberssweep/improve-logging-information
parent
207559b0a2
commit
4388fd0050
|
|
@ -2776,6 +2776,7 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
|
||||||
name = "tabby"
|
name = "tabby"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
"axum",
|
"axum",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"clap",
|
"clap",
|
||||||
|
|
@ -2794,6 +2795,7 @@ dependencies = [
|
||||||
"tabby-common",
|
"tabby-common",
|
||||||
"tabby-download",
|
"tabby-download",
|
||||||
"tabby-scheduler",
|
"tabby-scheduler",
|
||||||
|
"tantivy",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-http 0.4.0",
|
"tower-http 0.4.0",
|
||||||
|
|
|
||||||
|
|
@ -24,3 +24,4 @@ tracing = "0.1"
|
||||||
tracing-subscriber = "0.3"
|
tracing-subscriber = "0.3"
|
||||||
anyhow = "1.0.71"
|
anyhow = "1.0.71"
|
||||||
serde-jsonlines = "0.4.0"
|
serde-jsonlines = "0.4.0"
|
||||||
|
tantivy = "0.19.2"
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,15 @@ use serde::Deserialize;
|
||||||
|
|
||||||
use crate::path::{config_file, repositories_dir};
|
use crate::path::{config_file, repositories_dir};
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize, Default)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub repositories: Vec<Repository>,
|
pub repositories: Vec<Repository>,
|
||||||
|
pub experimental: Experimental,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Default)]
|
||||||
|
pub struct Experimental {
|
||||||
|
pub enable_prompt_rewrite: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ anyhow = { workspace = true }
|
||||||
filenamify = "0.1.0"
|
filenamify = "0.1.0"
|
||||||
job_scheduler = "1.2.1"
|
job_scheduler = "1.2.1"
|
||||||
tabby-common = { path = "../tabby-common" }
|
tabby-common = { path = "../tabby-common" }
|
||||||
tantivy = "0.19.2"
|
tantivy = { workspace = true }
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tree-sitter-javascript = "0.20.0"
|
tree-sitter-javascript = "0.20.0"
|
||||||
tree-sitter-tags = "0.20.2"
|
tree-sitter-tags = "0.20.2"
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ pub async fn scheduler(now: bool) -> Result<()> {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use tabby_common::{
|
use tabby_common::{
|
||||||
config::{Config, Repository},
|
config::{Config, Experimental, Repository},
|
||||||
path::set_tabby_root,
|
path::set_tabby_root,
|
||||||
};
|
};
|
||||||
use temp_testdir::*;
|
use temp_testdir::*;
|
||||||
|
|
@ -76,6 +76,7 @@ mod tests {
|
||||||
repositories: vec![Repository {
|
repositories: vec![Repository {
|
||||||
git_url: "https://github.com/TabbyML/interview-questions".to_owned(),
|
git_url: "https://github.com/TabbyML/interview-questions".to_owned(),
|
||||||
}],
|
}],
|
||||||
|
experimental: Experimental::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
repository::sync_repositories(&config).unwrap();
|
repository::sync_repositories(&config).unwrap();
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,8 @@ opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
||||||
opentelemetry-otlp = "0.11.0"
|
opentelemetry-otlp = "0.11.0"
|
||||||
axum-tracing-opentelemetry = "0.10.0"
|
axum-tracing-opentelemetry = "0.10.0"
|
||||||
tracing-opentelemetry = "0.18.0"
|
tracing-opentelemetry = "0.18.0"
|
||||||
|
tantivy = { workspace = true }
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
|
||||||
|
|
||||||
[dependencies.uuid]
|
[dependencies.uuid]
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ use opentelemetry::{
|
||||||
KeyValue,
|
KeyValue,
|
||||||
};
|
};
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
use opentelemetry_otlp::WithExportConfig;
|
||||||
|
use tabby_common::config::Config;
|
||||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
|
|
@ -47,8 +48,10 @@ async fn main() {
|
||||||
let cli = Cli::parse();
|
let cli = Cli::parse();
|
||||||
init_logging(cli.otlp_endpoint);
|
init_logging(cli.otlp_endpoint);
|
||||||
|
|
||||||
|
let config = Config::load().unwrap_or(Config::default());
|
||||||
|
|
||||||
match &cli.command {
|
match &cli.command {
|
||||||
Commands::Serve(args) => serve::main(args).await,
|
Commands::Serve(args) => serve::main(&config, args).await,
|
||||||
Commands::Download(args) => download::main(args).await,
|
Commands::Download(args) => download::main(args).await,
|
||||||
#[cfg(feature = "scheduler")]
|
#[cfg(feature = "scheduler")]
|
||||||
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)
|
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,6 @@
|
||||||
|
mod languages;
|
||||||
|
mod prompt;
|
||||||
|
|
||||||
use std::{path::Path, sync::Arc};
|
use std::{path::Path, sync::Arc};
|
||||||
|
|
||||||
use axum::{extract::State, Json};
|
use axum::{extract::State, Json};
|
||||||
|
|
@ -6,16 +9,13 @@ use ctranslate2_bindings::{
|
||||||
};
|
};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use strfmt::{strfmt, strfmt_builder};
|
use tabby_common::{config::Config, events, path::ModelDir};
|
||||||
use tabby_common::{events, path::ModelDir};
|
use tracing::{debug, instrument};
|
||||||
use tracing::instrument;
|
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
use self::languages::get_stop_words;
|
use self::languages::get_stop_words;
|
||||||
use crate::fatal;
|
use crate::fatal;
|
||||||
|
|
||||||
mod languages;
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
#[schema(example=json!({
|
#[schema(example=json!({
|
||||||
"language": "python",
|
"language": "python",
|
||||||
|
|
@ -86,23 +86,20 @@ pub async fn completion(
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let prompt = if let Some(Segments { prefix, suffix }) = request.segments {
|
let segments = if let Some(segments) = request.segments {
|
||||||
if let (Some(prompt_template), Some(suffix)) = (&state.prompt_template, suffix) {
|
segments
|
||||||
if !suffix.is_empty() {
|
|
||||||
strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap()
|
|
||||||
} else {
|
|
||||||
prefix
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// If there's no prompt template, just use prefix.
|
|
||||||
prefix
|
|
||||||
}
|
|
||||||
} else if let Some(prompt) = request.prompt {
|
} else if let Some(prompt) = request.prompt {
|
||||||
prompt
|
Segments {
|
||||||
|
prefix: prompt,
|
||||||
|
suffix: None,
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return Err(StatusCode::BAD_REQUEST);
|
return Err(StatusCode::BAD_REQUEST);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
|
||||||
|
let prompt = state.prompt_builder.build(&language, segments);
|
||||||
|
debug!("PROMPT: {}", prompt);
|
||||||
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
|
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
|
||||||
let text = state.engine.inference(&prompt, options).await;
|
let text = state.engine.inference(&prompt, options).await;
|
||||||
|
|
||||||
|
|
@ -126,11 +123,11 @@ pub async fn completion(
|
||||||
|
|
||||||
pub struct CompletionState {
|
pub struct CompletionState {
|
||||||
engine: TextInferenceEngine,
|
engine: TextInferenceEngine,
|
||||||
prompt_template: Option<String>,
|
prompt_builder: prompt::PromptBuilder,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CompletionState {
|
impl CompletionState {
|
||||||
pub fn new(args: &crate::serve::ServeArgs) -> Self {
|
pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self {
|
||||||
let model_dir = get_model_dir(&args.model);
|
let model_dir = get_model_dir(&args.model);
|
||||||
let metadata = read_metadata(&model_dir);
|
let metadata = read_metadata(&model_dir);
|
||||||
|
|
||||||
|
|
@ -149,7 +146,10 @@ impl CompletionState {
|
||||||
let engine = TextInferenceEngine::create(options);
|
let engine = TextInferenceEngine::create(options);
|
||||||
Self {
|
Self {
|
||||||
engine,
|
engine,
|
||||||
prompt_template: metadata.prompt_template,
|
prompt_builder: prompt::PromptBuilder::new(
|
||||||
|
metadata.prompt_template,
|
||||||
|
config.experimental.enable_prompt_rewrite,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,198 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use lazy_static::lazy_static;
|
||||||
|
use strfmt::strfmt;
|
||||||
|
use tabby_common::path::index_dir;
|
||||||
|
use tantivy::{
|
||||||
|
collector::TopDocs, query::QueryParser, schema::Field, Index, ReloadPolicy, Searcher,
|
||||||
|
};
|
||||||
|
use tracing::{info, warn};
|
||||||
|
|
||||||
|
use super::Segments;
|
||||||
|
|
||||||
|
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
||||||
|
static MAX_SNIPPET_PER_NAME: u32 = 1;
|
||||||
|
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 1024;
|
||||||
|
|
||||||
|
pub struct PromptBuilder {
|
||||||
|
prompt_template: Option<String>,
|
||||||
|
index: Option<IndexState>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptBuilder {
|
||||||
|
pub fn new(prompt_template: Option<String>, enable_prompt_rewrite: bool) -> Self {
|
||||||
|
let index = if enable_prompt_rewrite {
|
||||||
|
info!("Experimental feature `enable_prompt_rewrite` is enabled, loading index ...");
|
||||||
|
let index = IndexState::new();
|
||||||
|
if let Err(err) = &index {
|
||||||
|
warn!("Failed to open index in {:?}: {:?}", index_dir(), err);
|
||||||
|
}
|
||||||
|
index.ok()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
PromptBuilder {
|
||||||
|
prompt_template,
|
||||||
|
index,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_prompt(&self, prefix: String, suffix: String) -> String {
|
||||||
|
if let Some(prompt_template) = &self.prompt_template {
|
||||||
|
strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap()
|
||||||
|
} else {
|
||||||
|
prefix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(&self, language: &str, segments: Segments) -> String {
|
||||||
|
let segments = self.rewrite(language, segments);
|
||||||
|
if let Some(suffix) = segments.suffix {
|
||||||
|
self.build_prompt(segments.prefix, suffix)
|
||||||
|
} else {
|
||||||
|
self.build_prompt(segments.prefix, "".to_owned())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rewrite(&self, language: &str, segments: Segments) -> Segments {
|
||||||
|
if let Some(index) = &self.index {
|
||||||
|
rewrite_with_index(index, language, segments)
|
||||||
|
} else {
|
||||||
|
segments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rewrite_with_index(index: &IndexState, language: &str, segments: Segments) -> Segments {
|
||||||
|
let snippets = collect_snippets(index, language, &segments.prefix);
|
||||||
|
if snippets.is_empty() {
|
||||||
|
segments
|
||||||
|
} else {
|
||||||
|
let prefix = build_prefix(language, &segments.prefix, snippets);
|
||||||
|
Segments { prefix, ..segments }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
|
||||||
|
let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap();
|
||||||
|
let mut lines: Vec<String> = vec![format!(
|
||||||
|
"Below are some relevant {} snippets found in the repository:",
|
||||||
|
language
|
||||||
|
)];
|
||||||
|
|
||||||
|
let mut count_characters = 0;
|
||||||
|
for (i, snippet) in snippets.iter().enumerate() {
|
||||||
|
if count_characters + snippet.len() > MAX_SNIPPET_CHARS_IN_PROMPT {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.push(format!("== Snippet {} ==", i + 1));
|
||||||
|
for line in snippet.lines() {
|
||||||
|
lines.push(line.to_owned());
|
||||||
|
count_characters += line.len();
|
||||||
|
}
|
||||||
|
|
||||||
|
count_characters += snippet.len();
|
||||||
|
}
|
||||||
|
|
||||||
|
let commented_lines: Vec<String> = lines
|
||||||
|
.iter()
|
||||||
|
.map(|x| format!("{} {}", comment_char, x))
|
||||||
|
.collect();
|
||||||
|
let comments = commented_lines.join("\n");
|
||||||
|
format!("{}\n{}", comments, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn collect_snippets(index: &IndexState, language: &str, text: &str) -> Vec<String> {
|
||||||
|
let mut ret = Vec::new();
|
||||||
|
let sanitized_text = sanitize_text(text);
|
||||||
|
if sanitized_text.is_empty() {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
let query_text = format!(
|
||||||
|
"language:{} AND kind:call AND ({})",
|
||||||
|
language, sanitized_text
|
||||||
|
);
|
||||||
|
let query = match index.query_parser.parse_query(&query_text) {
|
||||||
|
Ok(query) => query,
|
||||||
|
Err(err) => {
|
||||||
|
warn!("Failed to parse query: {}", err);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let top_docs = index
|
||||||
|
.searcher
|
||||||
|
.search(&query, &TopDocs::with_limit(MAX_SNIPPETS_TO_FETCH))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut names: HashMap<String, u32> = HashMap::new();
|
||||||
|
for (_score, doc_address) in top_docs {
|
||||||
|
let doc = index.searcher.doc(doc_address).unwrap();
|
||||||
|
let name = doc
|
||||||
|
.get_first(index.field_name)
|
||||||
|
.and_then(|x| x.as_text())
|
||||||
|
.unwrap();
|
||||||
|
let count = *names.get(name).unwrap_or(&0);
|
||||||
|
|
||||||
|
// Max 1 snippet per identifier.
|
||||||
|
if count >= MAX_SNIPPET_PER_NAME {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let body = doc
|
||||||
|
.get_first(index.field_body)
|
||||||
|
.and_then(|x| x.as_text())
|
||||||
|
.unwrap();
|
||||||
|
names.insert(name.to_owned(), count + 1);
|
||||||
|
ret.push(body.to_owned());
|
||||||
|
}
|
||||||
|
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sanitize_text(text: &str) -> String {
|
||||||
|
let x = text.replace(|c: char| !c.is_ascii_digit() && !c.is_alphabetic(), " ");
|
||||||
|
let tokens: Vec<&str> = x.split(' ').collect();
|
||||||
|
tokens.join(" ")
|
||||||
|
}
|
||||||
|
|
||||||
|
struct IndexState {
|
||||||
|
searcher: Searcher,
|
||||||
|
query_parser: QueryParser,
|
||||||
|
field_name: Field,
|
||||||
|
field_body: Field,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IndexState {
|
||||||
|
fn new() -> Result<IndexState> {
|
||||||
|
let index = Index::open_in_dir(index_dir())?;
|
||||||
|
let reader = index
|
||||||
|
.reader_builder()
|
||||||
|
.reload_policy(ReloadPolicy::OnCommit)
|
||||||
|
.try_into()?;
|
||||||
|
let field_name = index
|
||||||
|
.schema()
|
||||||
|
.get_field("name")
|
||||||
|
.ok_or(anyhow!("Index doesn't have required field"))?;
|
||||||
|
let field_body = index
|
||||||
|
.schema()
|
||||||
|
.get_field("body")
|
||||||
|
.ok_or(anyhow!("Index doesn't have required field"))?;
|
||||||
|
let query_parser = QueryParser::for_index(&index, vec![field_name]);
|
||||||
|
Ok(Self {
|
||||||
|
searcher: reader.searcher(),
|
||||||
|
query_parser,
|
||||||
|
field_name,
|
||||||
|
field_body,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lazy_static! {
|
||||||
|
static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> =
|
||||||
|
HashMap::from([("python", "#")]);
|
||||||
|
}
|
||||||
|
|
@ -10,6 +10,7 @@ use std::{
|
||||||
use axum::{routing, Router, Server};
|
use axum::{routing, Router, Server};
|
||||||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||||
use clap::Args;
|
use clap::Args;
|
||||||
|
use tabby_common::config::Config;
|
||||||
use tower_http::cors::CorsLayer;
|
use tower_http::cors::CorsLayer;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
|
|
@ -107,7 +108,7 @@ pub struct ServeArgs {
|
||||||
compute_type: ComputeType,
|
compute_type: ComputeType,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn main(args: &ServeArgs) {
|
pub async fn main(config: &Config, args: &ServeArgs) {
|
||||||
valid_args(args);
|
valid_args(args);
|
||||||
|
|
||||||
// Ensure model exists.
|
// Ensure model exists.
|
||||||
|
|
@ -123,7 +124,7 @@ pub async fn main(args: &ServeArgs) {
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
||||||
.nest("/v1", api_router(args))
|
.nest("/v1", api_router(args, config))
|
||||||
.fallback(fallback());
|
.fallback(fallback());
|
||||||
|
|
||||||
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
|
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
|
||||||
|
|
@ -134,7 +135,7 @@ pub async fn main(args: &ServeArgs) {
|
||||||
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
|
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn api_router(args: &ServeArgs) -> Router {
|
fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/events", routing::post(events::log_event))
|
.route("/events", routing::post(events::log_event))
|
||||||
.route(
|
.route(
|
||||||
|
|
@ -144,7 +145,7 @@ fn api_router(args: &ServeArgs) -> Router {
|
||||||
.route(
|
.route(
|
||||||
"/completions",
|
"/completions",
|
||||||
routing::post(completions::completion)
|
routing::post(completions::completion)
|
||||||
.with_state(Arc::new(completions::CompletionState::new(args))),
|
.with_state(Arc::new(completions::CompletionState::new(args, config))),
|
||||||
)
|
)
|
||||||
.layer(CorsLayer::permissive())
|
.layer(CorsLayer::permissive())
|
||||||
.layer(opentelemetry_tracing_layer())
|
.layer(opentelemetry_tracing_layer())
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue