refactor: extract language configuration into individual toml file (#564)

* refactor: extract language configuration into individual toml file

* feat: add golang language configuration (#565)
dedup-snippet-at-index
Meng Zhang 2023-10-15 17:24:44 -07:00 committed by GitHub
parent 9d6a9a6fa5
commit 99a7053b6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 160 additions and 146 deletions

1
Cargo.lock generated
View File

@ -3185,6 +3185,7 @@ dependencies = [
"derive_builder",
"futures",
"regex",
"tabby-common",
"tokenizers",
]

View File

@ -124,7 +124,7 @@ impl TextGeneration for CTranslate2Engine {
let decoding = self.decoding_factory.create_incremental_decoding(
self.tokenizer.clone(),
truncate_tokens(encoding.get_ids(), options.max_input_length),
options.stop_words,
options.language,
);
let cancel = CancellationToken::new();

View File

@ -58,9 +58,6 @@ impl FastChatEngine {
#[async_trait]
impl TextGeneration for FastChatEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let _stop_sequences: Vec<String> =
options.stop_words.iter().map(|x| x.to_string()).collect();
let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request {
model: self.model_name.to_owned(),

View File

@ -67,7 +67,8 @@ impl VertexAIEngine {
impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> = options
.stop_words
.language
.get_stop_words()
.iter()
.map(|x| x.to_string())
// vertex supports at most 5 stop sequence.

View File

@ -75,7 +75,7 @@ impl TextGeneration for LlamaEngine {
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
engine.as_mut().start(input_token_ids);
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.language);
let mut n_remains = options.max_decoding_length ;
while n_remains > 0 {
let Ok(next_token_id) = engine.as_mut().step() else {

View File

@ -0,0 +1,53 @@
[[config]]
languages = ["python"]
line_comment = "#"
top_level_keywords = ["def", "from", "class", "import"]
[[config]]
languages = ["rust"]
line_comment = "//"
top_level_keywords = [
"fn",
"trait",
"impl",
"enum",
"pub",
"extern",
"static",
"trait",
"unsafe",
"use",
]
[[config]]
languages = ["javascript", "typescript", "javascriptreact", "typescriptreact"]
line_comment = "//"
top_level_keywords = [
"abstract",
"async",
"class",
"const",
"export",
"function",
"interface",
"module",
"package",
"type",
"var",
"enum",
"let",
]
[[config]]
languages = ["go"]
line_comment = "//"
top_level_keywords = [
"func",
"interface",
"struct",
"package",
"type",
"import",
"var",
"const",
]

View File

@ -0,0 +1,73 @@
use lazy_static::lazy_static;
use serde::Deserialize;
lazy_static! {
static ref DEFAULT: Vec<&'static str> = vec![
"\n\n",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n\t",
"\n\n\t\t",
"\n\n\t\t\t",
"\n\n\t\t\t\t",
"\n\n\t\t\t\t\t",
"\n\n\t\t\t\t\t\t",
"\n\n\t\t\t\t\t\t\t",
];
}
#[derive(Deserialize)]
struct ConfigList {
config: Vec<Language>,
}
#[derive(Deserialize, Debug)]
pub struct Language {
languages: Vec<String>,
top_level_keywords: Vec<String>,
pub line_comment: String,
}
impl Language {
pub fn get_stop_words(&self) -> Vec<String> {
let mut out = vec![];
out.push(format!("\n{}", self.line_comment));
for word in &self.top_level_keywords {
out.push(format!("\n{}", word));
}
for x in DEFAULT.iter() {
out.push((*x).to_owned());
}
out
}
pub fn get_hashkey(&self) -> String {
self.languages[0].clone()
}
}
lazy_static! {
static ref CONFIG: ConfigList =
serdeconv::from_toml_str(include_str!("../assets/languages.toml")).unwrap();
pub static ref UNKNOWN_LANGUAGE: Language = Language {
languages: vec!["unknown".to_owned()],
line_comment: "".to_owned(),
top_level_keywords: vec![],
};
}
pub fn get_language(language: &str) -> &'static Language {
CONFIG
.config
.iter()
.find(|c| c.languages.iter().any(|x| x == language))
.unwrap_or(&UNKNOWN_LANGUAGE)
}

View File

@ -1,6 +1,7 @@
pub mod config;
pub mod events;
pub mod index;
pub mod languages;
pub mod path;
pub mod usage;

View File

@ -13,3 +13,4 @@ derive_builder = "0.12.0"
futures = { workspace = true }
regex.workspace = true
tokenizers.workspace = true
tabby-common = { path = "../tabby-common" }

View File

@ -2,10 +2,11 @@ use std::sync::Arc;
use dashmap::DashMap;
use regex::Regex;
use tabby_common::languages::Language;
use tokenizers::tokenizer::Tokenizer;
pub struct DecodingFactory {
stop_regex_cache: DashMap<&'static [&'static str], Regex>,
stop_regex_cache: DashMap<String, Regex>,
}
fn reverse<T>(s: T) -> String
@ -28,32 +29,34 @@ impl DecodingFactory {
&self,
tokenizer: Arc<Tokenizer>,
input_token_ids: &[u32],
stop_words: &'static [&'static str],
language: &'static Language,
) -> IncrementalDecoding {
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
IncrementalDecoding::new(tokenizer, self.get_re(language), input_token_ids)
}
fn get_re(&self, stop_words: &'static [&'static str]) -> Option<Regex> {
fn get_re(&self, language: &'static Language) -> Option<Regex> {
let stop_words = language.get_stop_words();
if stop_words.is_empty() {
None
} else {
let mut re = self.stop_regex_cache.get(stop_words);
let hashkey = language.get_hashkey();
let mut re = self.stop_regex_cache.get(&hashkey);
if re.is_none() {
self.stop_regex_cache
.insert(stop_words, create_stop_regex(stop_words));
re = self.stop_regex_cache.get(stop_words);
.insert(hashkey.clone(), create_stop_regex(stop_words));
re = self.stop_regex_cache.get(&hashkey);
}
re.map(|x| x.value().clone())
}
}
}
fn create_stop_regex(stop_words: &[&str]) -> Regex {
fn create_stop_regex(stop_words: Vec<String>) -> Regex {
// (?m) enables multi-line matching mode.
// \A means absolute begins of string.
let reversed_stop_words: Vec<_> = stop_words
.iter()
.map(|x| regex::escape(&reverse(*x)))
.map(|x| regex::escape(&reverse(x)))
.collect();
let regex_string = r"(?m)\A".to_owned() + "((" + &reversed_stop_words.join(")|(") + "))";
Regex::new(&regex_string).expect("Failed to create regex")
@ -131,7 +134,12 @@ mod tests {
#[test]
fn test_it_works() {
let text = reverse("void write_u32(std::uint32_t val) const {\n write_raw(&val, sizeof(val));\n }\n\n ~llama_file() {\n if (fp) {\n std::fclose(fp);\n }\n }\n};\n\nvoid");
assert!(!create_stop_regex(&["\n\n", "\n\n "]).is_match(&text));
assert!(create_stop_regex(&["\n\n", "\n\n ", "\nvoid"]).is_match(&text));
assert!(!create_stop_regex(vec!["\n\n".to_owned(), "\n\n ".to_owned()]).is_match(&text));
assert!(create_stop_regex(vec![
"\n\n".to_owned(),
"\n\n ".to_owned(),
"\nvoid".to_owned()
])
.is_match(&text));
}
}

View File

@ -3,6 +3,7 @@ pub mod decoding;
use async_trait::async_trait;
use derive_builder::Builder;
use futures::stream::BoxStream;
use tabby_common::languages::Language;
#[derive(Builder, Debug)]
pub struct TextGenerationOptions {
@ -15,12 +16,10 @@ pub struct TextGenerationOptions {
#[builder(default = "1.0")]
pub sampling_temperature: f32,
#[builder(default = "&EMPTY_STOP_WORDS")]
pub stop_words: &'static [&'static str],
#[builder(default = "&tabby_common::languages::UNKNOWN_LANGUAGE")]
pub language: &'static Language,
}
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
#[async_trait]
pub trait TextGeneration: Sync + Send {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;

View File

@ -1,4 +1,3 @@
mod languages;
mod prompt;
use std::sync::Arc;
@ -6,12 +5,11 @@ use std::sync::Arc;
use axum::{extract::State, Json};
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use tabby_common::events;
use tabby_common::{events, languages::get_language};
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument};
use utoipa::ToSchema;
use self::languages::get_language;
use super::search::IndexServer;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
@ -112,7 +110,7 @@ pub async fn completions(
.max_input_length(1024 + 512)
.max_decoding_length(128)
.sampling_temperature(0.1)
.stop_words(get_language(&language).stop_words)
.language(get_language(&language))
.build()
.unwrap();

View File

@ -1,116 +0,0 @@
use lazy_static::lazy_static;
pub struct Language {
pub stop_words: &'static [&'static str],
pub line_comment: &'static str,
}
lazy_static! {
static ref DEFAULT: Vec<&'static str> = vec![
"\n\n",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n\t",
"\n\n\t\t",
"\n\n\t\t\t",
"\n\n\t\t\t\t",
"\n\n\t\t\t\t\t",
"\n\n\t\t\t\t\t\t",
"\n\n\t\t\t\t\t\t\t",
];
static ref UNKONWN: Language = Language {
stop_words: &DEFAULT,
line_comment: "#"
};
/* Python */
static ref PYTHON_STOP_WORDS: Vec<&'static str> =
vec!["\ndef", "\n#", "\nfrom", "\nclass", "\nimport"].with_default();
static ref PYTHON: Language = Language {
stop_words: &PYTHON_STOP_WORDS,
line_comment: "#",
};
/* Rust */
static ref RUST_STOP_WORDS: Vec<&'static str> = vec![
"\n//", "\nfn", "\ntrait", "\nimpl", "\nenum", "\npub", "\nextern", "\nstatic",
"\ntrait", "\nunsafe", "\nuse"
]
.with_default();
static ref RUST: Language = Language {
stop_words: &RUST_STOP_WORDS,
line_comment: "//",
};
/* Javascript / Typescript */
static ref JAVASCRIPT_TYPESCRIPT_STOP_WORDS: Vec<&'static str> = vec![
"\n//",
"\nabstract",
"\nasync",
"\nclass",
"\nconst",
"\nexport",
"\nfunction",
"\ninterface",
"\nmodule",
"\npackage",
"\ntype",
"\nvar",
"\nenum",
"\nlet",
]
.with_default();
static ref JAVASCRIPT_TYPESCRIPT: Language = Language {
stop_words: &JAVASCRIPT_TYPESCRIPT_STOP_WORDS,
line_comment: "//",
};
/* Golang */
static ref GO_STOP_WORDS: Vec<&'static str> = vec![
"\n//",
"\nfunc",
"\ninterface",
"\nstruct",
"\npackage",
"\ntype",
"\nimport",
"\nvar",
"\nconst",
]
.with_default();
static ref GO: Language = Language {
stop_words: &GO_STOP_WORDS,
line_comment: "//",
};
}
pub fn get_language(language: &str) -> &'static Language {
if language == "python" {
&PYTHON
} else if language == "rust" {
&RUST
} else if language == "javascript" || language == "typescript" {
&JAVASCRIPT_TYPESCRIPT
} else if language == "go" {
&GO
} else {
&UNKONWN
}
}
trait WithDefault {
fn with_default(self) -> Self;
}
impl WithDefault for Vec<&'static str> {
fn with_default(mut self) -> Self {
let mut x = DEFAULT.clone();
self.append(&mut x);
self
}
}

View File

@ -3,14 +3,12 @@ use std::sync::Arc;
use lazy_static::lazy_static;
use regex::Regex;
use strfmt::strfmt;
use tabby_common::languages::get_language;
use textdistance::Algorithm;
use tracing::warn;
use super::{Segments, Snippet};
use crate::serve::{
completions::languages::get_language,
search::{IndexServer, IndexServerError},
};
use crate::serve::search::{IndexServer, IndexServerError};
static MAX_SNIPPETS_TO_FETCH: usize = 20;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
@ -78,7 +76,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
return prefix.to_owned();
}
let comment_char = get_language(language).line_comment;
let comment_char = &get_language(language).line_comment;
let mut lines: Vec<String> = vec![];
for (i, snippet) in snippets.iter().enumerate() {