refactor: extract language related data into languages.rs (#518)
* refactor: extract language related data into languages.rs * fix * cleanup index * fix * further sanitize * add a score thresholdwsxiaoys-patch-1
parent
d85a7892d1
commit
8c09f75360
|
|
@ -5,7 +5,7 @@ use regex::Regex;
|
|||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
||||
pub struct DecodingFactory {
|
||||
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
||||
stop_regex_cache: DashMap<&'static [&'static str], Regex>,
|
||||
}
|
||||
|
||||
fn reverse<T>(s: T) -> String
|
||||
|
|
@ -28,12 +28,12 @@ impl DecodingFactory {
|
|||
&self,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
input_token_ids: &[u32],
|
||||
stop_words: &'static Vec<&'static str>,
|
||||
stop_words: &'static [&'static str],
|
||||
) -> IncrementalDecoding {
|
||||
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
|
||||
}
|
||||
|
||||
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
|
||||
fn get_re(&self, stop_words: &'static [&'static str]) -> Option<Regex> {
|
||||
if stop_words.is_empty() {
|
||||
None
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ pub struct TextGenerationOptions {
|
|||
pub sampling_temperature: f32,
|
||||
|
||||
#[builder(default = "&EMPTY_STOP_WORDS")]
|
||||
pub stop_words: &'static Vec<&'static str>,
|
||||
pub stop_words: &'static [&'static str],
|
||||
}
|
||||
|
||||
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ use tantivy::{
|
|||
// Magic numbers
|
||||
static MAX_LINE_LENGTH_THRESHOLD: usize = 300;
|
||||
static AVG_LINE_LENGTH_THRESHOLD: f32 = 150f32;
|
||||
static MAX_BODY_LINES_THRESHOLD: usize = 15;
|
||||
|
||||
pub fn index_repositories(_config: &Config) -> Result<()> {
|
||||
let mut builder = Schema::builder();
|
||||
|
|
@ -82,19 +83,23 @@ struct IndexedDocument {
|
|||
}
|
||||
|
||||
fn from_source_file(file: SourceFile) -> impl Iterator<Item = IndexedDocument> {
|
||||
file.tags.into_iter().map(move |tag| {
|
||||
file.tags.into_iter().filter_map(move |tag| {
|
||||
let name = file.content.get(tag.name_range).unwrap().to_owned();
|
||||
let body = file.content.get(tag.range).unwrap().to_owned();
|
||||
|
||||
if body.lines().collect::<Vec<_>>().len() > MAX_BODY_LINES_THRESHOLD {
|
||||
return None;
|
||||
}
|
||||
|
||||
let language = reduce_language_if_needed(&file.language).to_owned();
|
||||
IndexedDocument {
|
||||
Some(IndexedDocument {
|
||||
git_url: file.git_url.clone(),
|
||||
filepath: file.filepath.clone(),
|
||||
language,
|
||||
name,
|
||||
body,
|
||||
kind: tag.syntax_type_name,
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -126,7 +131,7 @@ mod tests {
|
|||
{
|
||||
"range": {
|
||||
"start": 290,
|
||||
"end": 3094
|
||||
"end": 320
|
||||
},
|
||||
"name_range": {
|
||||
"start": 296,
|
||||
|
|
@ -142,7 +147,7 @@ mod tests {
|
|||
{
|
||||
"range": {
|
||||
"start": 953,
|
||||
"end": 1507
|
||||
"end": 970
|
||||
},
|
||||
"name_range": {
|
||||
"start": 957,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
|||
use tracing::{debug, instrument};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use self::languages::get_stop_words;
|
||||
use self::languages::get_language;
|
||||
use super::search::IndexServer;
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
|
|
@ -81,7 +81,7 @@ pub async fn completions(
|
|||
.max_input_length(1024 + 512)
|
||||
.max_decoding_length(128)
|
||||
.sampling_temperature(0.1)
|
||||
.stop_words(get_stop_words(&language))
|
||||
.stop_words(get_language(&language).stop_words)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
pub struct Language {
|
||||
pub stop_words: &'static [&'static str],
|
||||
pub line_comment: &'static str,
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref DEFAULT: Vec<&'static str> = vec![
|
||||
"\n\n",
|
||||
|
|
@ -20,29 +23,48 @@ lazy_static! {
|
|||
"\n\n\t\t\t\t\t\t",
|
||||
"\n\n\t\t\t\t\t\t\t",
|
||||
];
|
||||
static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = {
|
||||
let mut map = HashMap::new();
|
||||
map.insert(
|
||||
"python",
|
||||
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default(),
|
||||
);
|
||||
map.insert(
|
||||
"javascript",
|
||||
vec!["\nfunction", "\n//", "\nimport", "\nclass"],
|
||||
);
|
||||
map.insert(
|
||||
"typescript",
|
||||
vec![
|
||||
"\nfunction",
|
||||
"\n//",
|
||||
"\nimport",
|
||||
"\nclass",
|
||||
"\ninterface",
|
||||
"\ntype",
|
||||
],
|
||||
);
|
||||
map
|
||||
static ref UNKONWN: Language = Language {
|
||||
stop_words: &DEFAULT,
|
||||
line_comment: "#"
|
||||
};
|
||||
static ref PYTHON_STOP_WORDS: Vec<&'static str> =
|
||||
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
|
||||
static ref PYTHON: Language = Language {
|
||||
stop_words: &PYTHON_STOP_WORDS,
|
||||
line_comment: "#",
|
||||
};
|
||||
static ref RUST_STOP_WORDS: Vec<&'static str> =
|
||||
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
|
||||
static ref RUST: Language = Language {
|
||||
stop_words: &RUST_STOP_WORDS,
|
||||
line_comment: "//",
|
||||
};
|
||||
static ref JAVASCRIPT_STOP_WORDS: Vec<&'static str> =
|
||||
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
|
||||
static ref JAVASCRIPT: Language = Language {
|
||||
stop_words: &JAVASCRIPT_STOP_WORDS,
|
||||
line_comment: "",
|
||||
};
|
||||
static ref TYPESCRIPT_STOP_WORDS: Vec<&'static str> =
|
||||
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
|
||||
static ref TYPESCRIPT: Language = Language {
|
||||
stop_words: &TYPESCRIPT_STOP_WORDS,
|
||||
line_comment: "",
|
||||
};
|
||||
}
|
||||
|
||||
pub fn get_language(language: &str) -> &'static Language {
|
||||
if language == "python" {
|
||||
&PYTHON
|
||||
} else if language == "rust" {
|
||||
&RUST
|
||||
} else if language == "javascript" {
|
||||
&JAVASCRIPT
|
||||
} else if language == "typescript" {
|
||||
&TYPESCRIPT
|
||||
} else {
|
||||
&UNKONWN
|
||||
}
|
||||
}
|
||||
|
||||
trait WithDefault {
|
||||
|
|
@ -56,7 +78,3 @@ impl WithDefault for Vec<&'static str> {
|
|||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_stop_words(language: &str) -> &'static Vec<&'static str> {
|
||||
LANGUAGES.get(language).unwrap_or(&DEFAULT)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
use std::{collections::HashMap, env, sync::Arc};
|
||||
use std::{env, sync::Arc};
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use strfmt::strfmt;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use super::Segments;
|
||||
use crate::serve::search::IndexServer;
|
||||
use crate::serve::{completions::languages::get_language, search::IndexServer};
|
||||
|
||||
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
||||
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512;
|
||||
static SNIPPET_SCORE_THRESHOLD: f32 = 5.0;
|
||||
|
||||
pub struct PromptBuilder {
|
||||
prompt_template: Option<String>,
|
||||
|
|
@ -84,7 +84,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
|
|||
return prefix.to_owned();
|
||||
}
|
||||
|
||||
let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap();
|
||||
let comment_char = get_language(language).line_comment;
|
||||
let mut lines: Vec<String> = vec![
|
||||
format!(
|
||||
"Below are some relevant {} snippets found in the repository:",
|
||||
|
|
@ -142,6 +142,10 @@ fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> V
|
|||
};
|
||||
|
||||
for hit in serp.hits {
|
||||
if hit.score < SNIPPET_SCORE_THRESHOLD {
|
||||
break;
|
||||
}
|
||||
|
||||
let body = hit.doc.body;
|
||||
|
||||
if text.contains(&body) {
|
||||
|
|
@ -161,15 +165,13 @@ fn sanitize_text(text: &str) -> String {
|
|||
|c: char| !c.is_ascii_digit() && !c.is_alphabetic() && c != '_' && c != '-',
|
||||
" ",
|
||||
);
|
||||
let tokens: Vec<&str> = x.split(' ').filter(|x| x.len() > 5).collect();
|
||||
let tokens: Vec<&str> = x
|
||||
.split(' ')
|
||||
.filter(|x| *x != "AND" && *x != "NOT" && *x != "OR" && x.len() > 5)
|
||||
.collect();
|
||||
tokens.join(" ")
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> =
|
||||
HashMap::from([("python", "#"), ("rust", "//"),]);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
|
|
@ -3,36 +3,18 @@ import requests
|
|||
import streamlit as st
|
||||
from typing import NamedTuple
|
||||
|
||||
class Doc(NamedTuple):
|
||||
name: str
|
||||
body: str
|
||||
score: float
|
||||
filepath: str
|
||||
|
||||
@staticmethod
|
||||
def from_json(json: dict):
|
||||
doc = json["doc"]
|
||||
return Doc(
|
||||
name=doc["name"][0],
|
||||
body=doc["body"][0],
|
||||
score=json["score"],
|
||||
filepath=doc["filepath"][0],
|
||||
)
|
||||
|
||||
# force wide mode
|
||||
st.set_page_config(layout="wide")
|
||||
|
||||
language = st.text_input("Language", "rust")
|
||||
|
||||
query = st.text_area("Query", "get")
|
||||
tokens = re.findall(r"\w+", query)
|
||||
tokens = [x for x in tokens if x != "AND" and x != "OR" and x != "NOT"]
|
||||
|
||||
query = "(" + " ".join(tokens) + ")" + " " + "AND language:" + language
|
||||
|
||||
if query:
|
||||
r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query))
|
||||
hits = r.json()["hits"]
|
||||
for x in hits:
|
||||
doc = Doc.from_json(x)
|
||||
st.write(doc.name + "@" + doc.filepath + " : " + str(doc.score))
|
||||
st.code(doc.body)
|
||||
st.write(x)
|
||||
Loading…
Reference in New Issue