refactor: extract language configuration into individual toml file (#564)
* refactor: extract language configuration into individual toml file * feat: add golang language configuration (#565)dedup-snippet-at-index
parent
9d6a9a6fa5
commit
99a7053b6f
|
|
@ -3185,6 +3185,7 @@ dependencies = [
|
|||
"derive_builder",
|
||||
"futures",
|
||||
"regex",
|
||||
"tabby-common",
|
||||
"tokenizers",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ impl TextGeneration for CTranslate2Engine {
|
|||
let decoding = self.decoding_factory.create_incremental_decoding(
|
||||
self.tokenizer.clone(),
|
||||
truncate_tokens(encoding.get_ids(), options.max_input_length),
|
||||
options.stop_words,
|
||||
options.language,
|
||||
);
|
||||
|
||||
let cancel = CancellationToken::new();
|
||||
|
|
|
|||
|
|
@ -58,9 +58,6 @@ impl FastChatEngine {
|
|||
#[async_trait]
|
||||
impl TextGeneration for FastChatEngine {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let _stop_sequences: Vec<String> =
|
||||
options.stop_words.iter().map(|x| x.to_string()).collect();
|
||||
|
||||
let tokens: Vec<&str> = prompt.split("<MID>").collect();
|
||||
let request = Request {
|
||||
model: self.model_name.to_owned(),
|
||||
|
|
|
|||
|
|
@ -67,7 +67,8 @@ impl VertexAIEngine {
|
|||
impl TextGeneration for VertexAIEngine {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let stop_sequences: Vec<String> = options
|
||||
.stop_words
|
||||
.language
|
||||
.get_stop_words()
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
// vertex supports at most 5 stop sequence.
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ impl TextGeneration for LlamaEngine {
|
|||
|
||||
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
||||
engine.as_mut().start(input_token_ids);
|
||||
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
|
||||
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.language);
|
||||
let mut n_remains = options.max_decoding_length ;
|
||||
while n_remains > 0 {
|
||||
let Ok(next_token_id) = engine.as_mut().step() else {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,53 @@
|
|||
[[config]]
|
||||
languages = ["python"]
|
||||
line_comment = "#"
|
||||
top_level_keywords = ["def", "from", "class", "import"]
|
||||
|
||||
[[config]]
|
||||
languages = ["rust"]
|
||||
line_comment = "//"
|
||||
top_level_keywords = [
|
||||
"fn",
|
||||
"trait",
|
||||
"impl",
|
||||
"enum",
|
||||
"pub",
|
||||
"extern",
|
||||
"static",
|
||||
"trait",
|
||||
"unsafe",
|
||||
"use",
|
||||
]
|
||||
|
||||
[[config]]
|
||||
languages = ["javascript", "typescript", "javascriptreact", "typescriptreact"]
|
||||
line_comment = "//"
|
||||
top_level_keywords = [
|
||||
"abstract",
|
||||
"async",
|
||||
"class",
|
||||
"const",
|
||||
"export",
|
||||
"function",
|
||||
"interface",
|
||||
"module",
|
||||
"package",
|
||||
"type",
|
||||
"var",
|
||||
"enum",
|
||||
"let",
|
||||
]
|
||||
|
||||
[[config]]
|
||||
languages = ["go"]
|
||||
line_comment = "//"
|
||||
top_level_keywords = [
|
||||
"func",
|
||||
"interface",
|
||||
"struct",
|
||||
"package",
|
||||
"type",
|
||||
"import",
|
||||
"var",
|
||||
"const",
|
||||
]
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
use lazy_static::lazy_static;
|
||||
use serde::Deserialize;
|
||||
|
||||
lazy_static! {
|
||||
static ref DEFAULT: Vec<&'static str> = vec![
|
||||
"\n\n",
|
||||
"\n\n ",
|
||||
"\n\n ",
|
||||
"\n\n ",
|
||||
"\n\n ",
|
||||
"\n\n ",
|
||||
"\n\n ",
|
||||
"\n\n ",
|
||||
"\n\n\t",
|
||||
"\n\n\t\t",
|
||||
"\n\n\t\t\t",
|
||||
"\n\n\t\t\t\t",
|
||||
"\n\n\t\t\t\t\t",
|
||||
"\n\n\t\t\t\t\t\t",
|
||||
"\n\n\t\t\t\t\t\t\t",
|
||||
];
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ConfigList {
|
||||
config: Vec<Language>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct Language {
|
||||
languages: Vec<String>,
|
||||
top_level_keywords: Vec<String>,
|
||||
|
||||
pub line_comment: String,
|
||||
}
|
||||
|
||||
impl Language {
|
||||
pub fn get_stop_words(&self) -> Vec<String> {
|
||||
let mut out = vec![];
|
||||
out.push(format!("\n{}", self.line_comment));
|
||||
for word in &self.top_level_keywords {
|
||||
out.push(format!("\n{}", word));
|
||||
}
|
||||
|
||||
for x in DEFAULT.iter() {
|
||||
out.push((*x).to_owned());
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
pub fn get_hashkey(&self) -> String {
|
||||
self.languages[0].clone()
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref CONFIG: ConfigList =
|
||||
serdeconv::from_toml_str(include_str!("../assets/languages.toml")).unwrap();
|
||||
pub static ref UNKNOWN_LANGUAGE: Language = Language {
|
||||
languages: vec!["unknown".to_owned()],
|
||||
line_comment: "".to_owned(),
|
||||
top_level_keywords: vec![],
|
||||
};
|
||||
}
|
||||
|
||||
pub fn get_language(language: &str) -> &'static Language {
|
||||
CONFIG
|
||||
.config
|
||||
.iter()
|
||||
.find(|c| c.languages.iter().any(|x| x == language))
|
||||
.unwrap_or(&UNKNOWN_LANGUAGE)
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
pub mod config;
|
||||
pub mod events;
|
||||
pub mod index;
|
||||
pub mod languages;
|
||||
pub mod path;
|
||||
pub mod usage;
|
||||
|
||||
|
|
|
|||
|
|
@ -13,3 +13,4 @@ derive_builder = "0.12.0"
|
|||
futures = { workspace = true }
|
||||
regex.workspace = true
|
||||
tokenizers.workspace = true
|
||||
tabby-common = { path = "../tabby-common" }
|
||||
|
|
@ -2,10 +2,11 @@ use std::sync::Arc;
|
|||
|
||||
use dashmap::DashMap;
|
||||
use regex::Regex;
|
||||
use tabby_common::languages::Language;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
||||
pub struct DecodingFactory {
|
||||
stop_regex_cache: DashMap<&'static [&'static str], Regex>,
|
||||
stop_regex_cache: DashMap<String, Regex>,
|
||||
}
|
||||
|
||||
fn reverse<T>(s: T) -> String
|
||||
|
|
@ -28,32 +29,34 @@ impl DecodingFactory {
|
|||
&self,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
input_token_ids: &[u32],
|
||||
stop_words: &'static [&'static str],
|
||||
language: &'static Language,
|
||||
) -> IncrementalDecoding {
|
||||
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
|
||||
IncrementalDecoding::new(tokenizer, self.get_re(language), input_token_ids)
|
||||
}
|
||||
|
||||
fn get_re(&self, stop_words: &'static [&'static str]) -> Option<Regex> {
|
||||
fn get_re(&self, language: &'static Language) -> Option<Regex> {
|
||||
let stop_words = language.get_stop_words();
|
||||
if stop_words.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut re = self.stop_regex_cache.get(stop_words);
|
||||
let hashkey = language.get_hashkey();
|
||||
let mut re = self.stop_regex_cache.get(&hashkey);
|
||||
if re.is_none() {
|
||||
self.stop_regex_cache
|
||||
.insert(stop_words, create_stop_regex(stop_words));
|
||||
re = self.stop_regex_cache.get(stop_words);
|
||||
.insert(hashkey.clone(), create_stop_regex(stop_words));
|
||||
re = self.stop_regex_cache.get(&hashkey);
|
||||
}
|
||||
re.map(|x| x.value().clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_stop_regex(stop_words: &[&str]) -> Regex {
|
||||
fn create_stop_regex(stop_words: Vec<String>) -> Regex {
|
||||
// (?m) enables multi-line matching mode.
|
||||
// \A means absolute begins of string.
|
||||
let reversed_stop_words: Vec<_> = stop_words
|
||||
.iter()
|
||||
.map(|x| regex::escape(&reverse(*x)))
|
||||
.map(|x| regex::escape(&reverse(x)))
|
||||
.collect();
|
||||
let regex_string = r"(?m)\A".to_owned() + "((" + &reversed_stop_words.join(")|(") + "))";
|
||||
Regex::new(®ex_string).expect("Failed to create regex")
|
||||
|
|
@ -131,7 +134,12 @@ mod tests {
|
|||
#[test]
|
||||
fn test_it_works() {
|
||||
let text = reverse("void write_u32(std::uint32_t val) const {\n write_raw(&val, sizeof(val));\n }\n\n ~llama_file() {\n if (fp) {\n std::fclose(fp);\n }\n }\n};\n\nvoid");
|
||||
assert!(!create_stop_regex(&["\n\n", "\n\n "]).is_match(&text));
|
||||
assert!(create_stop_regex(&["\n\n", "\n\n ", "\nvoid"]).is_match(&text));
|
||||
assert!(!create_stop_regex(vec!["\n\n".to_owned(), "\n\n ".to_owned()]).is_match(&text));
|
||||
assert!(create_stop_regex(vec![
|
||||
"\n\n".to_owned(),
|
||||
"\n\n ".to_owned(),
|
||||
"\nvoid".to_owned()
|
||||
])
|
||||
.is_match(&text));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ pub mod decoding;
|
|||
use async_trait::async_trait;
|
||||
use derive_builder::Builder;
|
||||
use futures::stream::BoxStream;
|
||||
use tabby_common::languages::Language;
|
||||
|
||||
#[derive(Builder, Debug)]
|
||||
pub struct TextGenerationOptions {
|
||||
|
|
@ -15,12 +16,10 @@ pub struct TextGenerationOptions {
|
|||
#[builder(default = "1.0")]
|
||||
pub sampling_temperature: f32,
|
||||
|
||||
#[builder(default = "&EMPTY_STOP_WORDS")]
|
||||
pub stop_words: &'static [&'static str],
|
||||
#[builder(default = "&tabby_common::languages::UNKNOWN_LANGUAGE")]
|
||||
pub language: &'static Language,
|
||||
}
|
||||
|
||||
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
|
||||
|
||||
#[async_trait]
|
||||
pub trait TextGeneration: Sync + Send {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
mod languages;
|
||||
mod prompt;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
|
@ -6,12 +5,11 @@ use std::sync::Arc;
|
|||
use axum::{extract::State, Json};
|
||||
use hyper::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::events;
|
||||
use tabby_common::{events, languages::get_language};
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
||||
use tracing::{debug, instrument};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use self::languages::get_language;
|
||||
use super::search::IndexServer;
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
|
|
@ -112,7 +110,7 @@ pub async fn completions(
|
|||
.max_input_length(1024 + 512)
|
||||
.max_decoding_length(128)
|
||||
.sampling_temperature(0.1)
|
||||
.stop_words(get_language(&language).stop_words)
|
||||
.language(get_language(&language))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
|
|
|
|||
|
|
@ -1,116 +0,0 @@
|
|||
use lazy_static::lazy_static;
|
||||
|
||||
pub struct Language {
|
||||
pub stop_words: &'static [&'static str],
|
||||
pub line_comment: &'static str,
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref DEFAULT: Vec<&'static str> = vec![
|
||||
"\n\n",
|
||||
"\n\n ",
|
||||
"\n\n ",
|
||||
"\n\n ",
|
||||
"\n\n ",
|
||||
"\n\n ",
|
||||
"\n\n ",
|
||||
"\n\n ",
|
||||
"\n\n\t",
|
||||
"\n\n\t\t",
|
||||
"\n\n\t\t\t",
|
||||
"\n\n\t\t\t\t",
|
||||
"\n\n\t\t\t\t\t",
|
||||
"\n\n\t\t\t\t\t\t",
|
||||
"\n\n\t\t\t\t\t\t\t",
|
||||
];
|
||||
static ref UNKONWN: Language = Language {
|
||||
stop_words: &DEFAULT,
|
||||
line_comment: "#"
|
||||
};
|
||||
|
||||
/* Python */
|
||||
static ref PYTHON_STOP_WORDS: Vec<&'static str> =
|
||||
vec!["\ndef", "\n#", "\nfrom", "\nclass", "\nimport"].with_default();
|
||||
static ref PYTHON: Language = Language {
|
||||
stop_words: &PYTHON_STOP_WORDS,
|
||||
line_comment: "#",
|
||||
};
|
||||
|
||||
/* Rust */
|
||||
static ref RUST_STOP_WORDS: Vec<&'static str> = vec![
|
||||
"\n//", "\nfn", "\ntrait", "\nimpl", "\nenum", "\npub", "\nextern", "\nstatic",
|
||||
"\ntrait", "\nunsafe", "\nuse"
|
||||
]
|
||||
.with_default();
|
||||
static ref RUST: Language = Language {
|
||||
stop_words: &RUST_STOP_WORDS,
|
||||
line_comment: "//",
|
||||
};
|
||||
|
||||
/* Javascript / Typescript */
|
||||
static ref JAVASCRIPT_TYPESCRIPT_STOP_WORDS: Vec<&'static str> = vec![
|
||||
"\n//",
|
||||
"\nabstract",
|
||||
"\nasync",
|
||||
"\nclass",
|
||||
"\nconst",
|
||||
"\nexport",
|
||||
"\nfunction",
|
||||
"\ninterface",
|
||||
"\nmodule",
|
||||
"\npackage",
|
||||
"\ntype",
|
||||
"\nvar",
|
||||
"\nenum",
|
||||
"\nlet",
|
||||
]
|
||||
.with_default();
|
||||
static ref JAVASCRIPT_TYPESCRIPT: Language = Language {
|
||||
stop_words: &JAVASCRIPT_TYPESCRIPT_STOP_WORDS,
|
||||
line_comment: "//",
|
||||
};
|
||||
|
||||
/* Golang */
|
||||
static ref GO_STOP_WORDS: Vec<&'static str> = vec![
|
||||
"\n//",
|
||||
"\nfunc",
|
||||
"\ninterface",
|
||||
"\nstruct",
|
||||
"\npackage",
|
||||
"\ntype",
|
||||
"\nimport",
|
||||
"\nvar",
|
||||
"\nconst",
|
||||
]
|
||||
.with_default();
|
||||
static ref GO: Language = Language {
|
||||
stop_words: &GO_STOP_WORDS,
|
||||
line_comment: "//",
|
||||
};
|
||||
}
|
||||
|
||||
pub fn get_language(language: &str) -> &'static Language {
|
||||
if language == "python" {
|
||||
&PYTHON
|
||||
} else if language == "rust" {
|
||||
&RUST
|
||||
} else if language == "javascript" || language == "typescript" {
|
||||
&JAVASCRIPT_TYPESCRIPT
|
||||
} else if language == "go" {
|
||||
&GO
|
||||
} else {
|
||||
&UNKONWN
|
||||
}
|
||||
}
|
||||
|
||||
trait WithDefault {
|
||||
fn with_default(self) -> Self;
|
||||
}
|
||||
|
||||
impl WithDefault for Vec<&'static str> {
|
||||
fn with_default(mut self) -> Self {
|
||||
let mut x = DEFAULT.clone();
|
||||
self.append(&mut x);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
|
@ -3,14 +3,12 @@ use std::sync::Arc;
|
|||
use lazy_static::lazy_static;
|
||||
use regex::Regex;
|
||||
use strfmt::strfmt;
|
||||
use tabby_common::languages::get_language;
|
||||
use textdistance::Algorithm;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{Segments, Snippet};
|
||||
use crate::serve::{
|
||||
completions::languages::get_language,
|
||||
search::{IndexServer, IndexServerError},
|
||||
};
|
||||
use crate::serve::search::{IndexServer, IndexServerError};
|
||||
|
||||
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
||||
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
|
||||
|
|
@ -78,7 +76,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
|
|||
return prefix.to_owned();
|
||||
}
|
||||
|
||||
let comment_char = get_language(language).line_comment;
|
||||
let comment_char = &get_language(language).line_comment;
|
||||
let mut lines: Vec<String> = vec![];
|
||||
|
||||
for (i, snippet) in snippets.iter().enumerate() {
|
||||
|
|
|
|||
Loading…
Reference in New Issue