refactor: make language optional in TextGenerationOptions (#897)

* refactor: make language optional in TextGenerationOptions

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
add-prompt-lookup
Meng Zhang 2023-11-26 11:17:31 +08:00 committed by GitHub
parent 39962c79ca
commit a7202318b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 20 additions and 32 deletions

View File

@ -66,14 +66,17 @@ impl VertexAIEngine {
#[async_trait]
impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> = options
.language
.get_stop_words()
.iter()
.map(|x| x.to_string())
// vertex supports at most 5 stop sequence.
.take(5)
.collect();
let stop_sequences = if let Some(language) = options.language {
language
.get_stop_words()
.iter()
.map(|x| x.to_string())
// vertex supports at most 5 stop sequence.
.take(5)
.collect()
} else {
vec![]
};
let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request {

View File

@ -36,11 +36,6 @@ pub struct Language {
impl Language {
pub fn get_stop_words(&self) -> Vec<String> {
// Special handling for empty languages - returns empty stop words.
if self.get_hashkey() == "empty" {
return vec![];
}
let mut out = vec![];
out.push(format!("\n{}", self.line_comment));
for word in &self.top_level_keywords {
@ -67,11 +62,6 @@ lazy_static! {
line_comment: "".to_owned(),
top_level_keywords: vec![],
};
pub static ref EMPTY_LANGUAGE: Language = Language {
languages: vec!["empty".to_owned()],
line_comment: "".to_owned(),
top_level_keywords: vec![],
};
}
pub fn get_language(language: &str) -> &'static Language {
@ -81,10 +71,3 @@ pub fn get_language(language: &str) -> &'static Language {
.find(|c| c.languages.iter().any(|x| x == language))
.unwrap_or(&UNKNOWN_LANGUAGE)
}
mod tests {
#[test]
fn test_empty_language() {
assert!(super::EMPTY_LANGUAGE.get_stop_words().is_empty())
}
}

View File

@ -22,8 +22,12 @@ impl Default for StopConditionFactory {
}
impl StopConditionFactory {
pub fn create(&self, text: &str, language: &'static Language) -> StopCondition {
StopCondition::new(self.get_re(language), text)
pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {
if let Some(language) = language {
StopCondition::new(self.get_re(language), text)
} else {
StopCondition::new(None, text)
}
}
fn get_re(&self, language: &'static Language) -> Option<Regex> {

View File

@ -16,8 +16,8 @@ pub struct TextGenerationOptions {
#[builder(default = "1.0")]
pub sampling_temperature: f32,
#[builder(default = "&tabby_common::languages::UNKNOWN_LANGUAGE")]
pub language: &'static Language,
#[builder(default = "None")]
pub language: Option<&'static Language>,
}
#[async_trait]

View File

@ -6,7 +6,6 @@ use async_stream::stream;
use chat_prompt::ChatPromptBuilder;
use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
use tabby_common::languages::EMPTY_LANGUAGE;
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::debug;
use utoipa::ToSchema;
@ -54,7 +53,6 @@ impl ChatService {
TextGenerationOptionsBuilder::default()
.max_input_length(2048)
.max_decoding_length(1920)
.language(&EMPTY_LANGUAGE)
.sampling_temperature(0.1)
.build()
.unwrap()

View File

@ -201,7 +201,7 @@ impl CompletionService {
.max_input_length(1024 + 512)
.max_decoding_length(128)
.sampling_temperature(0.1)
.language(get_language(language))
.language(Some(get_language(language)))
.build()
.unwrap()
}