refactor: extract language configuration into individual toml file (#564)

* refactor: extract language configuration into individual toml file

* feat: add golang language configuration (#565)
dedup-snippet-at-index
Meng Zhang 2023-10-15 17:24:44 -07:00 committed by GitHub
parent 9d6a9a6fa5
commit 99a7053b6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 160 additions and 146 deletions

1
Cargo.lock generated
View File

@ -3185,6 +3185,7 @@ dependencies = [
"derive_builder", "derive_builder",
"futures", "futures",
"regex", "regex",
"tabby-common",
"tokenizers", "tokenizers",
] ]

View File

@ -124,7 +124,7 @@ impl TextGeneration for CTranslate2Engine {
let decoding = self.decoding_factory.create_incremental_decoding( let decoding = self.decoding_factory.create_incremental_decoding(
self.tokenizer.clone(), self.tokenizer.clone(),
truncate_tokens(encoding.get_ids(), options.max_input_length), truncate_tokens(encoding.get_ids(), options.max_input_length),
options.stop_words, options.language,
); );
let cancel = CancellationToken::new(); let cancel = CancellationToken::new();

View File

@ -58,9 +58,6 @@ impl FastChatEngine {
#[async_trait] #[async_trait]
impl TextGeneration for FastChatEngine { impl TextGeneration for FastChatEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let _stop_sequences: Vec<String> =
options.stop_words.iter().map(|x| x.to_string()).collect();
let tokens: Vec<&str> = prompt.split("<MID>").collect(); let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request { let request = Request {
model: self.model_name.to_owned(), model: self.model_name.to_owned(),

View File

@ -67,7 +67,8 @@ impl VertexAIEngine {
impl TextGeneration for VertexAIEngine { impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> = options let stop_sequences: Vec<String> = options
.stop_words .language
.get_stop_words()
.iter() .iter()
.map(|x| x.to_string()) .map(|x| x.to_string())
// vertex supports at most 5 stop sequence. // vertex supports at most 5 stop sequence.

View File

@ -75,7 +75,7 @@ impl TextGeneration for LlamaEngine {
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length); let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
engine.as_mut().start(input_token_ids); engine.as_mut().start(input_token_ids);
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words); let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.language);
let mut n_remains = options.max_decoding_length ; let mut n_remains = options.max_decoding_length ;
while n_remains > 0 { while n_remains > 0 {
let Ok(next_token_id) = engine.as_mut().step() else { let Ok(next_token_id) = engine.as_mut().step() else {

View File

@ -0,0 +1,53 @@
[[config]]
languages = ["python"]
line_comment = "#"
top_level_keywords = ["def", "from", "class", "import"]
[[config]]
languages = ["rust"]
line_comment = "//"
top_level_keywords = [
"fn",
"trait",
"impl",
"enum",
"pub",
"extern",
"static",
"trait",
"unsafe",
"use",
]
[[config]]
languages = ["javascript", "typescript", "javascriptreact", "typescriptreact"]
line_comment = "//"
top_level_keywords = [
"abstract",
"async",
"class",
"const",
"export",
"function",
"interface",
"module",
"package",
"type",
"var",
"enum",
"let",
]
[[config]]
languages = ["go"]
line_comment = "//"
top_level_keywords = [
"func",
"interface",
"struct",
"package",
"type",
"import",
"var",
"const",
]

View File

@ -0,0 +1,73 @@
use lazy_static::lazy_static;
use serde::Deserialize;
lazy_static! {
static ref DEFAULT: Vec<&'static str> = vec![
"\n\n",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n\t",
"\n\n\t\t",
"\n\n\t\t\t",
"\n\n\t\t\t\t",
"\n\n\t\t\t\t\t",
"\n\n\t\t\t\t\t\t",
"\n\n\t\t\t\t\t\t\t",
];
}
#[derive(Deserialize)]
struct ConfigList {
config: Vec<Language>,
}
#[derive(Deserialize, Debug)]
pub struct Language {
languages: Vec<String>,
top_level_keywords: Vec<String>,
pub line_comment: String,
}
impl Language {
pub fn get_stop_words(&self) -> Vec<String> {
let mut out = vec![];
out.push(format!("\n{}", self.line_comment));
for word in &self.top_level_keywords {
out.push(format!("\n{}", word));
}
for x in DEFAULT.iter() {
out.push((*x).to_owned());
}
out
}
pub fn get_hashkey(&self) -> String {
self.languages[0].clone()
}
}
lazy_static! {
static ref CONFIG: ConfigList =
serdeconv::from_toml_str(include_str!("../assets/languages.toml")).unwrap();
pub static ref UNKNOWN_LANGUAGE: Language = Language {
languages: vec!["unknown".to_owned()],
line_comment: "".to_owned(),
top_level_keywords: vec![],
};
}
pub fn get_language(language: &str) -> &'static Language {
CONFIG
.config
.iter()
.find(|c| c.languages.iter().any(|x| x == language))
.unwrap_or(&UNKNOWN_LANGUAGE)
}

View File

@ -1,6 +1,7 @@
pub mod config; pub mod config;
pub mod events; pub mod events;
pub mod index; pub mod index;
pub mod languages;
pub mod path; pub mod path;
pub mod usage; pub mod usage;

View File

@ -13,3 +13,4 @@ derive_builder = "0.12.0"
futures = { workspace = true } futures = { workspace = true }
regex.workspace = true regex.workspace = true
tokenizers.workspace = true tokenizers.workspace = true
tabby-common = { path = "../tabby-common" }

View File

@ -2,10 +2,11 @@ use std::sync::Arc;
use dashmap::DashMap; use dashmap::DashMap;
use regex::Regex; use regex::Regex;
use tabby_common::languages::Language;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
pub struct DecodingFactory { pub struct DecodingFactory {
stop_regex_cache: DashMap<&'static [&'static str], Regex>, stop_regex_cache: DashMap<String, Regex>,
} }
fn reverse<T>(s: T) -> String fn reverse<T>(s: T) -> String
@ -28,32 +29,34 @@ impl DecodingFactory {
&self, &self,
tokenizer: Arc<Tokenizer>, tokenizer: Arc<Tokenizer>,
input_token_ids: &[u32], input_token_ids: &[u32],
stop_words: &'static [&'static str], language: &'static Language,
) -> IncrementalDecoding { ) -> IncrementalDecoding {
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids) IncrementalDecoding::new(tokenizer, self.get_re(language), input_token_ids)
} }
fn get_re(&self, stop_words: &'static [&'static str]) -> Option<Regex> { fn get_re(&self, language: &'static Language) -> Option<Regex> {
let stop_words = language.get_stop_words();
if stop_words.is_empty() { if stop_words.is_empty() {
None None
} else { } else {
let mut re = self.stop_regex_cache.get(stop_words); let hashkey = language.get_hashkey();
let mut re = self.stop_regex_cache.get(&hashkey);
if re.is_none() { if re.is_none() {
self.stop_regex_cache self.stop_regex_cache
.insert(stop_words, create_stop_regex(stop_words)); .insert(hashkey.clone(), create_stop_regex(stop_words));
re = self.stop_regex_cache.get(stop_words); re = self.stop_regex_cache.get(&hashkey);
} }
re.map(|x| x.value().clone()) re.map(|x| x.value().clone())
} }
} }
} }
fn create_stop_regex(stop_words: &[&str]) -> Regex { fn create_stop_regex(stop_words: Vec<String>) -> Regex {
// (?m) enables multi-line matching mode. // (?m) enables multi-line matching mode.
// \A means absolute begins of string. // \A means absolute begins of string.
let reversed_stop_words: Vec<_> = stop_words let reversed_stop_words: Vec<_> = stop_words
.iter() .iter()
.map(|x| regex::escape(&reverse(*x))) .map(|x| regex::escape(&reverse(x)))
.collect(); .collect();
let regex_string = r"(?m)\A".to_owned() + "((" + &reversed_stop_words.join(")|(") + "))"; let regex_string = r"(?m)\A".to_owned() + "((" + &reversed_stop_words.join(")|(") + "))";
Regex::new(&regex_string).expect("Failed to create regex") Regex::new(&regex_string).expect("Failed to create regex")
@ -131,7 +134,12 @@ mod tests {
#[test] #[test]
fn test_it_works() { fn test_it_works() {
let text = reverse("void write_u32(std::uint32_t val) const {\n write_raw(&val, sizeof(val));\n }\n\n ~llama_file() {\n if (fp) {\n std::fclose(fp);\n }\n }\n};\n\nvoid"); let text = reverse("void write_u32(std::uint32_t val) const {\n write_raw(&val, sizeof(val));\n }\n\n ~llama_file() {\n if (fp) {\n std::fclose(fp);\n }\n }\n};\n\nvoid");
assert!(!create_stop_regex(&["\n\n", "\n\n "]).is_match(&text)); assert!(!create_stop_regex(vec!["\n\n".to_owned(), "\n\n ".to_owned()]).is_match(&text));
assert!(create_stop_regex(&["\n\n", "\n\n ", "\nvoid"]).is_match(&text)); assert!(create_stop_regex(vec![
"\n\n".to_owned(),
"\n\n ".to_owned(),
"\nvoid".to_owned()
])
.is_match(&text));
} }
} }

View File

@ -3,6 +3,7 @@ pub mod decoding;
use async_trait::async_trait; use async_trait::async_trait;
use derive_builder::Builder; use derive_builder::Builder;
use futures::stream::BoxStream; use futures::stream::BoxStream;
use tabby_common::languages::Language;
#[derive(Builder, Debug)] #[derive(Builder, Debug)]
pub struct TextGenerationOptions { pub struct TextGenerationOptions {
@ -15,12 +16,10 @@ pub struct TextGenerationOptions {
#[builder(default = "1.0")] #[builder(default = "1.0")]
pub sampling_temperature: f32, pub sampling_temperature: f32,
#[builder(default = "&EMPTY_STOP_WORDS")] #[builder(default = "&tabby_common::languages::UNKNOWN_LANGUAGE")]
pub stop_words: &'static [&'static str], pub language: &'static Language,
} }
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
#[async_trait] #[async_trait]
pub trait TextGeneration: Sync + Send { pub trait TextGeneration: Sync + Send {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String; async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;

View File

@ -1,4 +1,3 @@
mod languages;
mod prompt; mod prompt;
use std::sync::Arc; use std::sync::Arc;
@ -6,12 +5,11 @@ use std::sync::Arc;
use axum::{extract::State, Json}; use axum::{extract::State, Json};
use hyper::StatusCode; use hyper::StatusCode;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tabby_common::events; use tabby_common::{events, languages::get_language};
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument}; use tracing::{debug, instrument};
use utoipa::ToSchema; use utoipa::ToSchema;
use self::languages::get_language;
use super::search::IndexServer; use super::search::IndexServer;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
@ -112,7 +110,7 @@ pub async fn completions(
.max_input_length(1024 + 512) .max_input_length(1024 + 512)
.max_decoding_length(128) .max_decoding_length(128)
.sampling_temperature(0.1) .sampling_temperature(0.1)
.stop_words(get_language(&language).stop_words) .language(get_language(&language))
.build() .build()
.unwrap(); .unwrap();

View File

@ -1,116 +0,0 @@
use lazy_static::lazy_static;
pub struct Language {
pub stop_words: &'static [&'static str],
pub line_comment: &'static str,
}
lazy_static! {
static ref DEFAULT: Vec<&'static str> = vec![
"\n\n",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n\t",
"\n\n\t\t",
"\n\n\t\t\t",
"\n\n\t\t\t\t",
"\n\n\t\t\t\t\t",
"\n\n\t\t\t\t\t\t",
"\n\n\t\t\t\t\t\t\t",
];
static ref UNKONWN: Language = Language {
stop_words: &DEFAULT,
line_comment: "#"
};
/* Python */
static ref PYTHON_STOP_WORDS: Vec<&'static str> =
vec!["\ndef", "\n#", "\nfrom", "\nclass", "\nimport"].with_default();
static ref PYTHON: Language = Language {
stop_words: &PYTHON_STOP_WORDS,
line_comment: "#",
};
/* Rust */
static ref RUST_STOP_WORDS: Vec<&'static str> = vec![
"\n//", "\nfn", "\ntrait", "\nimpl", "\nenum", "\npub", "\nextern", "\nstatic",
"\ntrait", "\nunsafe", "\nuse"
]
.with_default();
static ref RUST: Language = Language {
stop_words: &RUST_STOP_WORDS,
line_comment: "//",
};
/* Javascript / Typescript */
static ref JAVASCRIPT_TYPESCRIPT_STOP_WORDS: Vec<&'static str> = vec![
"\n//",
"\nabstract",
"\nasync",
"\nclass",
"\nconst",
"\nexport",
"\nfunction",
"\ninterface",
"\nmodule",
"\npackage",
"\ntype",
"\nvar",
"\nenum",
"\nlet",
]
.with_default();
static ref JAVASCRIPT_TYPESCRIPT: Language = Language {
stop_words: &JAVASCRIPT_TYPESCRIPT_STOP_WORDS,
line_comment: "//",
};
/* Golang */
static ref GO_STOP_WORDS: Vec<&'static str> = vec![
"\n//",
"\nfunc",
"\ninterface",
"\nstruct",
"\npackage",
"\ntype",
"\nimport",
"\nvar",
"\nconst",
]
.with_default();
static ref GO: Language = Language {
stop_words: &GO_STOP_WORDS,
line_comment: "//",
};
}
pub fn get_language(language: &str) -> &'static Language {
if language == "python" {
&PYTHON
} else if language == "rust" {
&RUST
} else if language == "javascript" || language == "typescript" {
&JAVASCRIPT_TYPESCRIPT
} else if language == "go" {
&GO
} else {
&UNKONWN
}
}
trait WithDefault {
fn with_default(self) -> Self;
}
impl WithDefault for Vec<&'static str> {
fn with_default(mut self) -> Self {
let mut x = DEFAULT.clone();
self.append(&mut x);
self
}
}

View File

@ -3,14 +3,12 @@ use std::sync::Arc;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
use strfmt::strfmt; use strfmt::strfmt;
use tabby_common::languages::get_language;
use textdistance::Algorithm; use textdistance::Algorithm;
use tracing::warn; use tracing::warn;
use super::{Segments, Snippet}; use super::{Segments, Snippet};
use crate::serve::{ use crate::serve::search::{IndexServer, IndexServerError};
completions::languages::get_language,
search::{IndexServer, IndexServerError},
};
static MAX_SNIPPETS_TO_FETCH: usize = 20; static MAX_SNIPPETS_TO_FETCH: usize = 20;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768; static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
@ -78,7 +76,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
return prefix.to_owned(); return prefix.to_owned();
} }
let comment_char = get_language(language).line_comment; let comment_char = &get_language(language).line_comment;
let mut lines: Vec<String> = vec![]; let mut lines: Vec<String> = vec![];
for (i, snippet) in snippets.iter().enumerate() { for (i, snippet) in snippets.iter().enumerate() {