diff --git a/crates/http-api-bindings/src/vertex_ai.rs b/crates/http-api-bindings/src/vertex_ai.rs index 89b79f6..04b6402 100644 --- a/crates/http-api-bindings/src/vertex_ai.rs +++ b/crates/http-api-bindings/src/vertex_ai.rs @@ -66,14 +66,17 @@ impl VertexAIEngine { #[async_trait] impl TextGeneration for VertexAIEngine { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { - let stop_sequences: Vec = options - .language - .get_stop_words() - .iter() - .map(|x| x.to_string()) - // vertex supports at most 5 stop sequence. - .take(5) - .collect(); + let stop_sequences = if let Some(language) = options.language { + language + .get_stop_words() + .iter() + .map(|x| x.to_string()) + // vertex supports at most 5 stop sequence. + .take(5) + .collect() + } else { + vec![] + }; let tokens: Vec<&str> = prompt.split("").collect(); let request = Request { diff --git a/crates/tabby-common/src/languages.rs b/crates/tabby-common/src/languages.rs index c6f177c..b5f9c8b 100644 --- a/crates/tabby-common/src/languages.rs +++ b/crates/tabby-common/src/languages.rs @@ -36,11 +36,6 @@ pub struct Language { impl Language { pub fn get_stop_words(&self) -> Vec { - // Special handling for empty languages - returns empty stop words. - if self.get_hashkey() == "empty" { - return vec![]; - } - let mut out = vec![]; out.push(format!("\n{}", self.line_comment)); for word in &self.top_level_keywords { @@ -67,11 +62,6 @@ lazy_static! { line_comment: "".to_owned(), top_level_keywords: vec![], }; - pub static ref EMPTY_LANGUAGE: Language = Language { - languages: vec!["empty".to_owned()], - line_comment: "".to_owned(), - top_level_keywords: vec![], - }; } pub fn get_language(language: &str) -> &'static Language { @@ -81,10 +71,3 @@ pub fn get_language(language: &str) -> &'static Language { .find(|c| c.languages.iter().any(|x| x == language)) .unwrap_or(&UNKNOWN_LANGUAGE) } - -mod tests { - #[test] - fn test_empty_language() { - assert!(super::EMPTY_LANGUAGE.get_stop_words().is_empty()) - } -} diff --git a/crates/tabby-inference/src/decoding.rs b/crates/tabby-inference/src/decoding.rs index 5bf202a..cbbac95 100644 --- a/crates/tabby-inference/src/decoding.rs +++ b/crates/tabby-inference/src/decoding.rs @@ -22,8 +22,12 @@ impl Default for StopConditionFactory { } impl StopConditionFactory { - pub fn create(&self, text: &str, language: &'static Language) -> StopCondition { - StopCondition::new(self.get_re(language), text) + pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition { + if let Some(language) = language { + StopCondition::new(self.get_re(language), text) + } else { + StopCondition::new(None, text) + } } fn get_re(&self, language: &'static Language) -> Option { diff --git a/crates/tabby-inference/src/lib.rs b/crates/tabby-inference/src/lib.rs index 3c3990b..e4052d0 100644 --- a/crates/tabby-inference/src/lib.rs +++ b/crates/tabby-inference/src/lib.rs @@ -16,8 +16,8 @@ pub struct TextGenerationOptions { #[builder(default = "1.0")] pub sampling_temperature: f32, - #[builder(default = "&tabby_common::languages::UNKNOWN_LANGUAGE")] - pub language: &'static Language, + #[builder(default = "None")] + pub language: Option<&'static Language>, } #[async_trait] diff --git a/crates/tabby/src/services/chat.rs b/crates/tabby/src/services/chat.rs index d791496..e81096d 100644 --- a/crates/tabby/src/services/chat.rs +++ b/crates/tabby/src/services/chat.rs @@ -6,7 +6,6 @@ use async_stream::stream; use chat_prompt::ChatPromptBuilder; use futures::stream::BoxStream; use serde::{Deserialize, Serialize}; -use tabby_common::languages::EMPTY_LANGUAGE; use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; use tracing::debug; use utoipa::ToSchema; @@ -54,7 +53,6 @@ impl ChatService { TextGenerationOptionsBuilder::default() .max_input_length(2048) .max_decoding_length(1920) - .language(&EMPTY_LANGUAGE) .sampling_temperature(0.1) .build() .unwrap() diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index 366b421..153d744 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -201,7 +201,7 @@ impl CompletionService { .max_input_length(1024 + 512) .max_decoding_length(128) .sampling_temperature(0.1) - .language(get_language(language)) + .language(Some(get_language(language))) .build() .unwrap() }