diff --git a/crates/tabby-inference/src/decoding.rs b/crates/tabby-inference/src/decoding.rs index ffefe94..77bf1e2 100644 --- a/crates/tabby-inference/src/decoding.rs +++ b/crates/tabby-inference/src/decoding.rs @@ -5,7 +5,7 @@ use regex::Regex; use tokenizers::tokenizer::Tokenizer; pub struct DecodingFactory { - stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, + stop_regex_cache: DashMap<&'static [&'static str], Regex>, } fn reverse(s: T) -> String @@ -28,12 +28,12 @@ impl DecodingFactory { &self, tokenizer: Arc, input_token_ids: &[u32], - stop_words: &'static Vec<&'static str>, + stop_words: &'static [&'static str], ) -> IncrementalDecoding { IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids) } - fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option { + fn get_re(&self, stop_words: &'static [&'static str]) -> Option { if stop_words.is_empty() { None } else { diff --git a/crates/tabby-inference/src/lib.rs b/crates/tabby-inference/src/lib.rs index 495785e..28ab134 100644 --- a/crates/tabby-inference/src/lib.rs +++ b/crates/tabby-inference/src/lib.rs @@ -16,7 +16,7 @@ pub struct TextGenerationOptions { pub sampling_temperature: f32, #[builder(default = "&EMPTY_STOP_WORDS")] - pub stop_words: &'static Vec<&'static str>, + pub stop_words: &'static [&'static str], } static EMPTY_STOP_WORDS: Vec<&'static str> = vec![]; diff --git a/crates/tabby-scheduler/src/index.rs b/crates/tabby-scheduler/src/index.rs index feed157..da48760 100644 --- a/crates/tabby-scheduler/src/index.rs +++ b/crates/tabby-scheduler/src/index.rs @@ -17,6 +17,7 @@ use tantivy::{ // Magic numbers static MAX_LINE_LENGTH_THRESHOLD: usize = 300; static AVG_LINE_LENGTH_THRESHOLD: f32 = 150f32; +static MAX_BODY_LINES_THRESHOLD: usize = 15; pub fn index_repositories(_config: &Config) -> Result<()> { let mut builder = Schema::builder(); @@ -82,19 +83,23 @@ struct IndexedDocument { } fn from_source_file(file: SourceFile) -> impl Iterator { - file.tags.into_iter().map(move |tag| { + file.tags.into_iter().filter_map(move |tag| { let name = file.content.get(tag.name_range).unwrap().to_owned(); let body = file.content.get(tag.range).unwrap().to_owned(); + if body.lines().collect::>().len() > MAX_BODY_LINES_THRESHOLD { + return None; + } + let language = reduce_language_if_needed(&file.language).to_owned(); - IndexedDocument { + Some(IndexedDocument { git_url: file.git_url.clone(), filepath: file.filepath.clone(), language, name, body, kind: tag.syntax_type_name, - } + }) }) } @@ -126,7 +131,7 @@ mod tests { { "range": { "start": 290, - "end": 3094 + "end": 320 }, "name_range": { "start": 296, @@ -142,7 +147,7 @@ mod tests { { "range": { "start": 953, - "end": 1507 + "end": 970 }, "name_range": { "start": 957, diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 8bfc16e..e0ab8be 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -11,7 +11,7 @@ use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; use tracing::{debug, instrument}; use utoipa::ToSchema; -use self::languages::get_stop_words; +use self::languages::get_language; use super::search::IndexServer; #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] @@ -81,7 +81,7 @@ pub async fn completions( .max_input_length(1024 + 512) .max_decoding_length(128) .sampling_temperature(0.1) - .stop_words(get_stop_words(&language)) + .stop_words(get_language(&language).stop_words) .build() .unwrap(); diff --git a/crates/tabby/src/serve/completions/languages.rs b/crates/tabby/src/serve/completions/languages.rs index 8dbe04f..2d11caa 100644 --- a/crates/tabby/src/serve/completions/languages.rs +++ b/crates/tabby/src/serve/completions/languages.rs @@ -1,7 +1,10 @@ -use std::collections::HashMap; - use lazy_static::lazy_static; +pub struct Language { + pub stop_words: &'static [&'static str], + pub line_comment: &'static str, +} + lazy_static! { static ref DEFAULT: Vec<&'static str> = vec![ "\n\n", @@ -20,29 +23,48 @@ lazy_static! { "\n\n\t\t\t\t\t\t", "\n\n\t\t\t\t\t\t\t", ]; - static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = { - let mut map = HashMap::new(); - map.insert( - "python", - vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default(), - ); - map.insert( - "javascript", - vec!["\nfunction", "\n//", "\nimport", "\nclass"], - ); - map.insert( - "typescript", - vec![ - "\nfunction", - "\n//", - "\nimport", - "\nclass", - "\ninterface", - "\ntype", - ], - ); - map + static ref UNKONWN: Language = Language { + stop_words: &DEFAULT, + line_comment: "#" }; + static ref PYTHON_STOP_WORDS: Vec<&'static str> = + vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default(); + static ref PYTHON: Language = Language { + stop_words: &PYTHON_STOP_WORDS, + line_comment: "#", + }; + static ref RUST_STOP_WORDS: Vec<&'static str> = + vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default(); + static ref RUST: Language = Language { + stop_words: &RUST_STOP_WORDS, + line_comment: "//", + }; + static ref JAVASCRIPT_STOP_WORDS: Vec<&'static str> = + vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default(); + static ref JAVASCRIPT: Language = Language { + stop_words: &JAVASCRIPT_STOP_WORDS, + line_comment: "", + }; + static ref TYPESCRIPT_STOP_WORDS: Vec<&'static str> = + vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default(); + static ref TYPESCRIPT: Language = Language { + stop_words: &TYPESCRIPT_STOP_WORDS, + line_comment: "", + }; +} + +pub fn get_language(language: &str) -> &'static Language { + if language == "python" { + &PYTHON + } else if language == "rust" { + &RUST + } else if language == "javascript" { + &JAVASCRIPT + } else if language == "typescript" { + &TYPESCRIPT + } else { + &UNKONWN + } } trait WithDefault { @@ -56,7 +78,3 @@ impl WithDefault for Vec<&'static str> { self } } - -pub fn get_stop_words(language: &str) -> &'static Vec<&'static str> { - LANGUAGES.get(language).unwrap_or(&DEFAULT) -} diff --git a/crates/tabby/src/serve/completions/prompt.rs b/crates/tabby/src/serve/completions/prompt.rs index ff99a86..9bb9f21 100644 --- a/crates/tabby/src/serve/completions/prompt.rs +++ b/crates/tabby/src/serve/completions/prompt.rs @@ -1,14 +1,14 @@ -use std::{collections::HashMap, env, sync::Arc}; +use std::{env, sync::Arc}; -use lazy_static::lazy_static; use strfmt::strfmt; use tracing::{info, warn}; use super::Segments; -use crate::serve::search::IndexServer; +use crate::serve::{completions::languages::get_language, search::IndexServer}; static MAX_SNIPPETS_TO_FETCH: usize = 20; static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512; +static SNIPPET_SCORE_THRESHOLD: f32 = 5.0; pub struct PromptBuilder { prompt_template: Option, @@ -84,7 +84,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: Vec) -> String { return prefix.to_owned(); } - let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap(); + let comment_char = get_language(language).line_comment; let mut lines: Vec = vec![ format!( "Below are some relevant {} snippets found in the repository:", @@ -142,6 +142,10 @@ fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> V }; for hit in serp.hits { + if hit.score < SNIPPET_SCORE_THRESHOLD { + break; + } + let body = hit.doc.body; if text.contains(&body) { @@ -161,15 +165,13 @@ fn sanitize_text(text: &str) -> String { |c: char| !c.is_ascii_digit() && !c.is_alphabetic() && c != '_' && c != '-', " ", ); - let tokens: Vec<&str> = x.split(' ').filter(|x| x.len() > 5).collect(); + let tokens: Vec<&str> = x + .split(' ') + .filter(|x| *x != "AND" && *x != "NOT" && *x != "OR" && x.len() > 5) + .collect(); tokens.join(" ") } -lazy_static! { - static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> = - HashMap::from([("python", "#"), ("rust", "//"),]); -} - #[cfg(test)] mod tests { use super::*; diff --git a/experimental/scheduler/completion.py b/experimental/scheduler/completion.py index 1c00e0c..c3743c7 100644 --- a/experimental/scheduler/completion.py +++ b/experimental/scheduler/completion.py @@ -3,36 +3,18 @@ import requests import streamlit as st from typing import NamedTuple -class Doc(NamedTuple): - name: str - body: str - score: float - filepath: str - - @staticmethod - def from_json(json: dict): - doc = json["doc"] - return Doc( - name=doc["name"][0], - body=doc["body"][0], - score=json["score"], - filepath=doc["filepath"][0], - ) - # force wide mode st.set_page_config(layout="wide") language = st.text_input("Language", "rust") + query = st.text_area("Query", "get") tokens = re.findall(r"\w+", query) tokens = [x for x in tokens if x != "AND" and x != "OR" and x != "NOT"] - query = "(" + " ".join(tokens) + ")" + " " + "AND language:" + language if query: r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query)) hits = r.json()["hits"] for x in hits: - doc = Doc.from_json(x) - st.write(doc.name + "@" + doc.filepath + " : " + str(doc.score)) - st.code(doc.body) + st.write(x) \ No newline at end of file