refactor: extract language related data into languages.rs (#518)
* refactor: extract language related data into languages.rs * fix * cleanup index * fix * further sanitize * add a score thresholdwsxiaoys-patch-1
parent
d85a7892d1
commit
8c09f75360
|
|
@ -5,7 +5,7 @@ use regex::Regex;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
|
||||||
pub struct DecodingFactory {
|
pub struct DecodingFactory {
|
||||||
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
stop_regex_cache: DashMap<&'static [&'static str], Regex>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reverse<T>(s: T) -> String
|
fn reverse<T>(s: T) -> String
|
||||||
|
|
@ -28,12 +28,12 @@ impl DecodingFactory {
|
||||||
&self,
|
&self,
|
||||||
tokenizer: Arc<Tokenizer>,
|
tokenizer: Arc<Tokenizer>,
|
||||||
input_token_ids: &[u32],
|
input_token_ids: &[u32],
|
||||||
stop_words: &'static Vec<&'static str>,
|
stop_words: &'static [&'static str],
|
||||||
) -> IncrementalDecoding {
|
) -> IncrementalDecoding {
|
||||||
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
|
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
|
fn get_re(&self, stop_words: &'static [&'static str]) -> Option<Regex> {
|
||||||
if stop_words.is_empty() {
|
if stop_words.is_empty() {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ pub struct TextGenerationOptions {
|
||||||
pub sampling_temperature: f32,
|
pub sampling_temperature: f32,
|
||||||
|
|
||||||
#[builder(default = "&EMPTY_STOP_WORDS")]
|
#[builder(default = "&EMPTY_STOP_WORDS")]
|
||||||
pub stop_words: &'static Vec<&'static str>,
|
pub stop_words: &'static [&'static str],
|
||||||
}
|
}
|
||||||
|
|
||||||
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
|
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ use tantivy::{
|
||||||
// Magic numbers
|
// Magic numbers
|
||||||
static MAX_LINE_LENGTH_THRESHOLD: usize = 300;
|
static MAX_LINE_LENGTH_THRESHOLD: usize = 300;
|
||||||
static AVG_LINE_LENGTH_THRESHOLD: f32 = 150f32;
|
static AVG_LINE_LENGTH_THRESHOLD: f32 = 150f32;
|
||||||
|
static MAX_BODY_LINES_THRESHOLD: usize = 15;
|
||||||
|
|
||||||
pub fn index_repositories(_config: &Config) -> Result<()> {
|
pub fn index_repositories(_config: &Config) -> Result<()> {
|
||||||
let mut builder = Schema::builder();
|
let mut builder = Schema::builder();
|
||||||
|
|
@ -82,19 +83,23 @@ struct IndexedDocument {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_source_file(file: SourceFile) -> impl Iterator<Item = IndexedDocument> {
|
fn from_source_file(file: SourceFile) -> impl Iterator<Item = IndexedDocument> {
|
||||||
file.tags.into_iter().map(move |tag| {
|
file.tags.into_iter().filter_map(move |tag| {
|
||||||
let name = file.content.get(tag.name_range).unwrap().to_owned();
|
let name = file.content.get(tag.name_range).unwrap().to_owned();
|
||||||
let body = file.content.get(tag.range).unwrap().to_owned();
|
let body = file.content.get(tag.range).unwrap().to_owned();
|
||||||
|
|
||||||
|
if body.lines().collect::<Vec<_>>().len() > MAX_BODY_LINES_THRESHOLD {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
let language = reduce_language_if_needed(&file.language).to_owned();
|
let language = reduce_language_if_needed(&file.language).to_owned();
|
||||||
IndexedDocument {
|
Some(IndexedDocument {
|
||||||
git_url: file.git_url.clone(),
|
git_url: file.git_url.clone(),
|
||||||
filepath: file.filepath.clone(),
|
filepath: file.filepath.clone(),
|
||||||
language,
|
language,
|
||||||
name,
|
name,
|
||||||
body,
|
body,
|
||||||
kind: tag.syntax_type_name,
|
kind: tag.syntax_type_name,
|
||||||
}
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -126,7 +131,7 @@ mod tests {
|
||||||
{
|
{
|
||||||
"range": {
|
"range": {
|
||||||
"start": 290,
|
"start": 290,
|
||||||
"end": 3094
|
"end": 320
|
||||||
},
|
},
|
||||||
"name_range": {
|
"name_range": {
|
||||||
"start": 296,
|
"start": 296,
|
||||||
|
|
@ -142,7 +147,7 @@ mod tests {
|
||||||
{
|
{
|
||||||
"range": {
|
"range": {
|
||||||
"start": 953,
|
"start": 953,
|
||||||
"end": 1507
|
"end": 970
|
||||||
},
|
},
|
||||||
"name_range": {
|
"name_range": {
|
||||||
"start": 957,
|
"start": 957,
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
||||||
use tracing::{debug, instrument};
|
use tracing::{debug, instrument};
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
use self::languages::get_stop_words;
|
use self::languages::get_language;
|
||||||
use super::search::IndexServer;
|
use super::search::IndexServer;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
|
@ -81,7 +81,7 @@ pub async fn completions(
|
||||||
.max_input_length(1024 + 512)
|
.max_input_length(1024 + 512)
|
||||||
.max_decoding_length(128)
|
.max_decoding_length(128)
|
||||||
.sampling_temperature(0.1)
|
.sampling_temperature(0.1)
|
||||||
.stop_words(get_stop_words(&language))
|
.stop_words(get_language(&language).stop_words)
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
|
|
||||||
|
pub struct Language {
|
||||||
|
pub stop_words: &'static [&'static str],
|
||||||
|
pub line_comment: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref DEFAULT: Vec<&'static str> = vec![
|
static ref DEFAULT: Vec<&'static str> = vec![
|
||||||
"\n\n",
|
"\n\n",
|
||||||
|
|
@ -20,29 +23,48 @@ lazy_static! {
|
||||||
"\n\n\t\t\t\t\t\t",
|
"\n\n\t\t\t\t\t\t",
|
||||||
"\n\n\t\t\t\t\t\t\t",
|
"\n\n\t\t\t\t\t\t\t",
|
||||||
];
|
];
|
||||||
static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = {
|
static ref UNKONWN: Language = Language {
|
||||||
let mut map = HashMap::new();
|
stop_words: &DEFAULT,
|
||||||
map.insert(
|
line_comment: "#"
|
||||||
"python",
|
|
||||||
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default(),
|
|
||||||
);
|
|
||||||
map.insert(
|
|
||||||
"javascript",
|
|
||||||
vec!["\nfunction", "\n//", "\nimport", "\nclass"],
|
|
||||||
);
|
|
||||||
map.insert(
|
|
||||||
"typescript",
|
|
||||||
vec![
|
|
||||||
"\nfunction",
|
|
||||||
"\n//",
|
|
||||||
"\nimport",
|
|
||||||
"\nclass",
|
|
||||||
"\ninterface",
|
|
||||||
"\ntype",
|
|
||||||
],
|
|
||||||
);
|
|
||||||
map
|
|
||||||
};
|
};
|
||||||
|
static ref PYTHON_STOP_WORDS: Vec<&'static str> =
|
||||||
|
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
|
||||||
|
static ref PYTHON: Language = Language {
|
||||||
|
stop_words: &PYTHON_STOP_WORDS,
|
||||||
|
line_comment: "#",
|
||||||
|
};
|
||||||
|
static ref RUST_STOP_WORDS: Vec<&'static str> =
|
||||||
|
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
|
||||||
|
static ref RUST: Language = Language {
|
||||||
|
stop_words: &RUST_STOP_WORDS,
|
||||||
|
line_comment: "//",
|
||||||
|
};
|
||||||
|
static ref JAVASCRIPT_STOP_WORDS: Vec<&'static str> =
|
||||||
|
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
|
||||||
|
static ref JAVASCRIPT: Language = Language {
|
||||||
|
stop_words: &JAVASCRIPT_STOP_WORDS,
|
||||||
|
line_comment: "",
|
||||||
|
};
|
||||||
|
static ref TYPESCRIPT_STOP_WORDS: Vec<&'static str> =
|
||||||
|
vec!["\ndef", "\n#", "\nfrom", "\nclass"].with_default();
|
||||||
|
static ref TYPESCRIPT: Language = Language {
|
||||||
|
stop_words: &TYPESCRIPT_STOP_WORDS,
|
||||||
|
line_comment: "",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_language(language: &str) -> &'static Language {
|
||||||
|
if language == "python" {
|
||||||
|
&PYTHON
|
||||||
|
} else if language == "rust" {
|
||||||
|
&RUST
|
||||||
|
} else if language == "javascript" {
|
||||||
|
&JAVASCRIPT
|
||||||
|
} else if language == "typescript" {
|
||||||
|
&TYPESCRIPT
|
||||||
|
} else {
|
||||||
|
&UNKONWN
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trait WithDefault {
|
trait WithDefault {
|
||||||
|
|
@ -56,7 +78,3 @@ impl WithDefault for Vec<&'static str> {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_stop_words(language: &str) -> &'static Vec<&'static str> {
|
|
||||||
LANGUAGES.get(language).unwrap_or(&DEFAULT)
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,14 @@
|
||||||
use std::{collections::HashMap, env, sync::Arc};
|
use std::{env, sync::Arc};
|
||||||
|
|
||||||
use lazy_static::lazy_static;
|
|
||||||
use strfmt::strfmt;
|
use strfmt::strfmt;
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
use super::Segments;
|
use super::Segments;
|
||||||
use crate::serve::search::IndexServer;
|
use crate::serve::{completions::languages::get_language, search::IndexServer};
|
||||||
|
|
||||||
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
||||||
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512;
|
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512;
|
||||||
|
static SNIPPET_SCORE_THRESHOLD: f32 = 5.0;
|
||||||
|
|
||||||
pub struct PromptBuilder {
|
pub struct PromptBuilder {
|
||||||
prompt_template: Option<String>,
|
prompt_template: Option<String>,
|
||||||
|
|
@ -84,7 +84,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
|
||||||
return prefix.to_owned();
|
return prefix.to_owned();
|
||||||
}
|
}
|
||||||
|
|
||||||
let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap();
|
let comment_char = get_language(language).line_comment;
|
||||||
let mut lines: Vec<String> = vec![
|
let mut lines: Vec<String> = vec![
|
||||||
format!(
|
format!(
|
||||||
"Below are some relevant {} snippets found in the repository:",
|
"Below are some relevant {} snippets found in the repository:",
|
||||||
|
|
@ -142,6 +142,10 @@ fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> V
|
||||||
};
|
};
|
||||||
|
|
||||||
for hit in serp.hits {
|
for hit in serp.hits {
|
||||||
|
if hit.score < SNIPPET_SCORE_THRESHOLD {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
let body = hit.doc.body;
|
let body = hit.doc.body;
|
||||||
|
|
||||||
if text.contains(&body) {
|
if text.contains(&body) {
|
||||||
|
|
@ -161,15 +165,13 @@ fn sanitize_text(text: &str) -> String {
|
||||||
|c: char| !c.is_ascii_digit() && !c.is_alphabetic() && c != '_' && c != '-',
|
|c: char| !c.is_ascii_digit() && !c.is_alphabetic() && c != '_' && c != '-',
|
||||||
" ",
|
" ",
|
||||||
);
|
);
|
||||||
let tokens: Vec<&str> = x.split(' ').filter(|x| x.len() > 5).collect();
|
let tokens: Vec<&str> = x
|
||||||
|
.split(' ')
|
||||||
|
.filter(|x| *x != "AND" && *x != "NOT" && *x != "OR" && x.len() > 5)
|
||||||
|
.collect();
|
||||||
tokens.join(" ")
|
tokens.join(" ")
|
||||||
}
|
}
|
||||||
|
|
||||||
lazy_static! {
|
|
||||||
static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> =
|
|
||||||
HashMap::from([("python", "#"), ("rust", "//"),]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
||||||
|
|
@ -3,36 +3,18 @@ import requests
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
class Doc(NamedTuple):
|
|
||||||
name: str
|
|
||||||
body: str
|
|
||||||
score: float
|
|
||||||
filepath: str
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_json(json: dict):
|
|
||||||
doc = json["doc"]
|
|
||||||
return Doc(
|
|
||||||
name=doc["name"][0],
|
|
||||||
body=doc["body"][0],
|
|
||||||
score=json["score"],
|
|
||||||
filepath=doc["filepath"][0],
|
|
||||||
)
|
|
||||||
|
|
||||||
# force wide mode
|
# force wide mode
|
||||||
st.set_page_config(layout="wide")
|
st.set_page_config(layout="wide")
|
||||||
|
|
||||||
language = st.text_input("Language", "rust")
|
language = st.text_input("Language", "rust")
|
||||||
|
|
||||||
query = st.text_area("Query", "get")
|
query = st.text_area("Query", "get")
|
||||||
tokens = re.findall(r"\w+", query)
|
tokens = re.findall(r"\w+", query)
|
||||||
tokens = [x for x in tokens if x != "AND" and x != "OR" and x != "NOT"]
|
tokens = [x for x in tokens if x != "AND" and x != "OR" and x != "NOT"]
|
||||||
|
|
||||||
query = "(" + " ".join(tokens) + ")" + " " + "AND language:" + language
|
query = "(" + " ".join(tokens) + ")" + " " + "AND language:" + language
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query))
|
r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query))
|
||||||
hits = r.json()["hits"]
|
hits = r.json()["hits"]
|
||||||
for x in hits:
|
for x in hits:
|
||||||
doc = Doc.from_json(x)
|
st.write(x)
|
||||||
st.write(doc.name + "@" + doc.filepath + " : " + str(doc.score))
|
|
||||||
st.code(doc.body)
|
|
||||||
Loading…
Reference in New Issue