refactor: make language optional in TextGenerationOptions (#897)

* refactor: make language optional in TextGenerationOptions

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
add-prompt-lookup
Meng Zhang 2023-11-26 11:17:31 +08:00 committed by GitHub
parent 39962c79ca
commit a7202318b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 20 additions and 32 deletions

View File

@ -66,14 +66,17 @@ impl VertexAIEngine {
#[async_trait] #[async_trait]
impl TextGeneration for VertexAIEngine { impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> = options let stop_sequences = if let Some(language) = options.language {
.language language
.get_stop_words() .get_stop_words()
.iter() .iter()
.map(|x| x.to_string()) .map(|x| x.to_string())
// vertex supports at most 5 stop sequence. // vertex supports at most 5 stop sequence.
.take(5) .take(5)
.collect(); .collect()
} else {
vec![]
};
let tokens: Vec<&str> = prompt.split("<MID>").collect(); let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request { let request = Request {

View File

@ -36,11 +36,6 @@ pub struct Language {
impl Language { impl Language {
pub fn get_stop_words(&self) -> Vec<String> { pub fn get_stop_words(&self) -> Vec<String> {
// Special handling for empty languages - returns empty stop words.
if self.get_hashkey() == "empty" {
return vec![];
}
let mut out = vec![]; let mut out = vec![];
out.push(format!("\n{}", self.line_comment)); out.push(format!("\n{}", self.line_comment));
for word in &self.top_level_keywords { for word in &self.top_level_keywords {
@ -67,11 +62,6 @@ lazy_static! {
line_comment: "".to_owned(), line_comment: "".to_owned(),
top_level_keywords: vec![], top_level_keywords: vec![],
}; };
pub static ref EMPTY_LANGUAGE: Language = Language {
languages: vec!["empty".to_owned()],
line_comment: "".to_owned(),
top_level_keywords: vec![],
};
} }
pub fn get_language(language: &str) -> &'static Language { pub fn get_language(language: &str) -> &'static Language {
@ -81,10 +71,3 @@ pub fn get_language(language: &str) -> &'static Language {
.find(|c| c.languages.iter().any(|x| x == language)) .find(|c| c.languages.iter().any(|x| x == language))
.unwrap_or(&UNKNOWN_LANGUAGE) .unwrap_or(&UNKNOWN_LANGUAGE)
} }
mod tests {
#[test]
fn test_empty_language() {
assert!(super::EMPTY_LANGUAGE.get_stop_words().is_empty())
}
}

View File

@ -22,8 +22,12 @@ impl Default for StopConditionFactory {
} }
impl StopConditionFactory { impl StopConditionFactory {
pub fn create(&self, text: &str, language: &'static Language) -> StopCondition { pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {
StopCondition::new(self.get_re(language), text) if let Some(language) = language {
StopCondition::new(self.get_re(language), text)
} else {
StopCondition::new(None, text)
}
} }
fn get_re(&self, language: &'static Language) -> Option<Regex> { fn get_re(&self, language: &'static Language) -> Option<Regex> {

View File

@ -16,8 +16,8 @@ pub struct TextGenerationOptions {
#[builder(default = "1.0")] #[builder(default = "1.0")]
pub sampling_temperature: f32, pub sampling_temperature: f32,
#[builder(default = "&tabby_common::languages::UNKNOWN_LANGUAGE")] #[builder(default = "None")]
pub language: &'static Language, pub language: Option<&'static Language>,
} }
#[async_trait] #[async_trait]

View File

@ -6,7 +6,6 @@ use async_stream::stream;
use chat_prompt::ChatPromptBuilder; use chat_prompt::ChatPromptBuilder;
use futures::stream::BoxStream; use futures::stream::BoxStream;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tabby_common::languages::EMPTY_LANGUAGE;
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::debug; use tracing::debug;
use utoipa::ToSchema; use utoipa::ToSchema;
@ -54,7 +53,6 @@ impl ChatService {
TextGenerationOptionsBuilder::default() TextGenerationOptionsBuilder::default()
.max_input_length(2048) .max_input_length(2048)
.max_decoding_length(1920) .max_decoding_length(1920)
.language(&EMPTY_LANGUAGE)
.sampling_temperature(0.1) .sampling_temperature(0.1)
.build() .build()
.unwrap() .unwrap()

View File

@ -201,7 +201,7 @@ impl CompletionService {
.max_input_length(1024 + 512) .max_input_length(1024 + 512)
.max_decoding_length(128) .max_decoding_length(128)
.sampling_temperature(0.1) .sampling_temperature(0.1)
.language(get_language(language)) .language(Some(get_language(language)))
.build() .build()
.unwrap() .unwrap()
} }