refactor: extract language related data into languages.rs (#518)

* refactor: extract language related data into languages.rs

* fix

* cleanup index

* fix

* further sanitize

* add a score threshold
wsxiaoys-patch-1
Meng Zhang 2023-10-06 18:40:21 -07:00 committed by GitHub
parent d85a7892d1
commit 8c09f75360
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 76 additions and 69 deletions

View File

@ -5,7 +5,7 @@ use regex::Regex;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
pub struct DecodingFactory { pub struct DecodingFactory {
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, stop_regex_cache: DashMap<&'static [&'static str], Regex>,
} }
fn reverse<T>(s: T) -> String fn reverse<T>(s: T) -> String
@ -28,12 +28,12 @@ impl DecodingFactory {
&self, &self,
tokenizer: Arc<Tokenizer>, tokenizer: Arc<Tokenizer>,
input_token_ids: &[u32], input_token_ids: &[u32],
stop_words: &'static Vec<&'static str>, stop_words: &'static [&'static str],
) -> IncrementalDecoding { ) -> IncrementalDecoding {
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids) IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
} }
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> { fn get_re(&self, stop_words: &'static [&'static str]) -> Option<Regex> {
if stop_words.is_empty() { if stop_words.is_empty() {
None None
} else { } else {

View File

@ -16,7 +16,7 @@ pub struct TextGenerationOptions {
pub sampling_temperature: f32, pub sampling_temperature: f32,
#[builder(default = "&EMPTY_STOP_WORDS")] #[builder(default = "&EMPTY_STOP_WORDS")]
pub stop_words: &'static Vec<&'static str>, pub stop_words: &'static [&'static str],
} }
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![]; static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];

View File

@ -17,6 +17,7 @@ use tantivy::{
// Magic numbers // Magic numbers
static MAX_LINE_LENGTH_THRESHOLD: usize = 300; static MAX_LINE_LENGTH_THRESHOLD: usize = 300;
static AVG_LINE_LENGTH_THRESHOLD: f32 = 150f32; static AVG_LINE_LENGTH_THRESHOLD: f32 = 150f32;
static MAX_BODY_LINES_THRESHOLD: usize = 15;
pub fn index_repositories(_config: &Config) -> Result<()> { pub fn index_repositories(_config: &Config) -> Result<()> {
let mut builder = Schema::builder(); let mut builder = Schema::builder();
@ -82,19 +83,23 @@ struct IndexedDocument {
} }
fn from_source_file(file: SourceFile) -> impl Iterator<Item = IndexedDocument> { fn from_source_file(file: SourceFile) -> impl Iterator<Item = IndexedDocument> {
file.tags.into_iter().map(move |tag| { file.tags.into_iter().filter_map(move |tag| {
let name = file.content.get(tag.name_range).unwrap().to_owned(); let name = file.content.get(tag.name_range).unwrap().to_owned();
let body = file.content.get(tag.range).unwrap().to_owned(); let body = file.content.get(tag.range).unwrap().to_owned();
if body.lines().collect::<Vec<_>>().len() > MAX_BODY_LINES_THRESHOLD {
return None;
}
let language = reduce_language_if_needed(&file.language).to_owned(); let language = reduce_language_if_needed(&file.language).to_owned();
IndexedDocument { Some(IndexedDocument {
git_url: file.git_url.clone(), git_url: file.git_url.clone(),
filepath: file.filepath.clone(), filepath: file.filepath.clone(),
language, language,
name, name,
body, body,
kind: tag.syntax_type_name, kind: tag.syntax_type_name,
} })
}) })
} }
@ -126,7 +131,7 @@ mod tests {
{ {
"range": { "range": {
"start": 290, "start": 290,
"end": 3094 "end": 320
}, },
"name_range": { "name_range": {
"start": 296, "start": 296,
@ -142,7 +147,7 @@ mod tests {
{ {
"range": { "range": {
"start": 953, "start": 953,
"end": 1507 "end": 970
}, },
"name_range": { "name_range": {
"start": 957, "start": 957,

View File

@ -11,7 +11,7 @@ use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument}; use tracing::{debug, instrument};
use utoipa::ToSchema; use utoipa::ToSchema;
use self::languages::get_stop_words; use self::languages::get_language;
use super::search::IndexServer; use super::search::IndexServer;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
@ -81,7 +81,7 @@ pub async fn completions(
.max_input_length(1024 + 512) .max_input_length(1024 + 512)
.max_decoding_length(128) .max_decoding_length(128)
.sampling_temperature(0.1) .sampling_temperature(0.1)
.stop_words(get_stop_words(&language)) .stop_words(get_language(&language).stop_words)
.build() .build()
.unwrap(); .unwrap();

View File

@ -1,7 +1,10 @@
use std::collections::HashMap;
use lazy_static::lazy_static; use lazy_static::lazy_static;
pub struct Language {
pub stop_words: &'static [&'static str],
pub line_comment: &'static str,
}
lazy_static! { lazy_static! {
static ref DEFAULT: Vec<&'static str> = vec![ static ref DEFAULT: Vec<&'static str> = vec![
"\n\n", "\n\n",
@ -20,29 +23,48 @@ lazy_static! {
"\n\n\t\t\t\t\t\t", "\n\n\t\t\t\t\t\t",
"\n\n\t\t\t\t\t\t\t", "\n\n\t\t\t\t\t\t\t",
]; ];
static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = { static ref UNKONWN: Language = Language {
let mut map = HashMap::new(); stop_words: &DEFAULT,
map.insert( line_comment: "#"
"python",
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default(),
);
map.insert(
"javascript",
vec!["\nfunction", "\n//", "\nimport", "\nclass"],
);
map.insert(
"typescript",
vec![
"\nfunction",
"\n//",
"\nimport",
"\nclass",
"\ninterface",
"\ntype",
],
);
map
}; };
static ref PYTHON_STOP_WORDS: Vec<&'static str> =
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
static ref PYTHON: Language = Language {
stop_words: &PYTHON_STOP_WORDS,
line_comment: "#",
};
static ref RUST_STOP_WORDS: Vec<&'static str> =
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
static ref RUST: Language = Language {
stop_words: &RUST_STOP_WORDS,
line_comment: "//",
};
static ref JAVASCRIPT_STOP_WORDS: Vec<&'static str> =
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
static ref JAVASCRIPT: Language = Language {
stop_words: &JAVASCRIPT_STOP_WORDS,
line_comment: "",
};
static ref TYPESCRIPT_STOP_WORDS: Vec<&'static str> =
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
static ref TYPESCRIPT: Language = Language {
stop_words: &TYPESCRIPT_STOP_WORDS,
line_comment: "",
};
}
pub fn get_language(language: &str) -> &'static Language {
if language == "python" {
&PYTHON
} else if language == "rust" {
&RUST
} else if language == "javascript" {
&JAVASCRIPT
} else if language == "typescript" {
&TYPESCRIPT
} else {
&UNKONWN
}
} }
trait WithDefault { trait WithDefault {
@ -56,7 +78,3 @@ impl WithDefault for Vec<&'static str> {
self self
} }
} }
pub fn get_stop_words(language: &str) -> &'static Vec<&'static str> {
LANGUAGES.get(language).unwrap_or(&DEFAULT)
}

View File

@ -1,14 +1,14 @@
use std::{collections::HashMap, env, sync::Arc}; use std::{env, sync::Arc};
use lazy_static::lazy_static;
use strfmt::strfmt; use strfmt::strfmt;
use tracing::{info, warn}; use tracing::{info, warn};
use super::Segments; use super::Segments;
use crate::serve::search::IndexServer; use crate::serve::{completions::languages::get_language, search::IndexServer};
static MAX_SNIPPETS_TO_FETCH: usize = 20; static MAX_SNIPPETS_TO_FETCH: usize = 20;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512; static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512;
static SNIPPET_SCORE_THRESHOLD: f32 = 5.0;
pub struct PromptBuilder { pub struct PromptBuilder {
prompt_template: Option<String>, prompt_template: Option<String>,
@ -84,7 +84,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
return prefix.to_owned(); return prefix.to_owned();
} }
let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap(); let comment_char = get_language(language).line_comment;
let mut lines: Vec<String> = vec![ let mut lines: Vec<String> = vec![
format!( format!(
"Below are some relevant {} snippets found in the repository:", "Below are some relevant {} snippets found in the repository:",
@ -142,6 +142,10 @@ fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> V
}; };
for hit in serp.hits { for hit in serp.hits {
if hit.score < SNIPPET_SCORE_THRESHOLD {
break;
}
let body = hit.doc.body; let body = hit.doc.body;
if text.contains(&body) { if text.contains(&body) {
@ -161,15 +165,13 @@ fn sanitize_text(text: &str) -> String {
|c: char| !c.is_ascii_digit() && !c.is_alphabetic() && c != '_' && c != '-', |c: char| !c.is_ascii_digit() && !c.is_alphabetic() && c != '_' && c != '-',
" ", " ",
); );
let tokens: Vec<&str> = x.split(' ').filter(|x| x.len() > 5).collect(); let tokens: Vec<&str> = x
.split(' ')
.filter(|x| *x != "AND" && *x != "NOT" && *x != "OR" && x.len() > 5)
.collect();
tokens.join(" ") tokens.join(" ")
} }
lazy_static! {
static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> =
HashMap::from([("python", "#"), ("rust", "//"),]);
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -3,36 +3,18 @@ import requests
import streamlit as st import streamlit as st
from typing import NamedTuple from typing import NamedTuple
class Doc(NamedTuple):
name: str
body: str
score: float
filepath: str
@staticmethod
def from_json(json: dict):
doc = json["doc"]
return Doc(
name=doc["name"][0],
body=doc["body"][0],
score=json["score"],
filepath=doc["filepath"][0],
)
# force wide mode # force wide mode
st.set_page_config(layout="wide") st.set_page_config(layout="wide")
language = st.text_input("Language", "rust") language = st.text_input("Language", "rust")
query = st.text_area("Query", "get") query = st.text_area("Query", "get")
tokens = re.findall(r"\w+", query) tokens = re.findall(r"\w+", query)
tokens = [x for x in tokens if x != "AND" and x != "OR" and x != "NOT"] tokens = [x for x in tokens if x != "AND" and x != "OR" and x != "NOT"]
query = "(" + " ".join(tokens) + ")" + " " + "AND language:" + language query = "(" + " ".join(tokens) + ")" + " " + "AND language:" + language
if query: if query:
r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query)) r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query))
hits = r.json()["hits"] hits = r.json()["hits"]
for x in hits: for x in hits:
doc = Doc.from_json(x) st.write(x)
st.write(doc.name + "@" + doc.filepath + " : " + str(doc.score))
st.code(doc.body)