refactor: make language optional in TextGenerationOptions (#897)
* refactor: make language optional in TextGenerationOptions * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>add-prompt-lookup
parent
39962c79ca
commit
a7202318b1
|
|
@ -66,14 +66,17 @@ impl VertexAIEngine {
|
|||
#[async_trait]
|
||||
impl TextGeneration for VertexAIEngine {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let stop_sequences: Vec<String> = options
|
||||
.language
|
||||
.get_stop_words()
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
// vertex supports at most 5 stop sequence.
|
||||
.take(5)
|
||||
.collect();
|
||||
let stop_sequences = if let Some(language) = options.language {
|
||||
language
|
||||
.get_stop_words()
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
// vertex supports at most 5 stop sequence.
|
||||
.take(5)
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let tokens: Vec<&str> = prompt.split("<MID>").collect();
|
||||
let request = Request {
|
||||
|
|
|
|||
|
|
@ -36,11 +36,6 @@ pub struct Language {
|
|||
|
||||
impl Language {
|
||||
pub fn get_stop_words(&self) -> Vec<String> {
|
||||
// Special handling for empty languages - returns empty stop words.
|
||||
if self.get_hashkey() == "empty" {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let mut out = vec![];
|
||||
out.push(format!("\n{}", self.line_comment));
|
||||
for word in &self.top_level_keywords {
|
||||
|
|
@ -67,11 +62,6 @@ lazy_static! {
|
|||
line_comment: "".to_owned(),
|
||||
top_level_keywords: vec![],
|
||||
};
|
||||
pub static ref EMPTY_LANGUAGE: Language = Language {
|
||||
languages: vec!["empty".to_owned()],
|
||||
line_comment: "".to_owned(),
|
||||
top_level_keywords: vec![],
|
||||
};
|
||||
}
|
||||
|
||||
pub fn get_language(language: &str) -> &'static Language {
|
||||
|
|
@ -81,10 +71,3 @@ pub fn get_language(language: &str) -> &'static Language {
|
|||
.find(|c| c.languages.iter().any(|x| x == language))
|
||||
.unwrap_or(&UNKNOWN_LANGUAGE)
|
||||
}
|
||||
|
||||
mod tests {
|
||||
#[test]
|
||||
fn test_empty_language() {
|
||||
assert!(super::EMPTY_LANGUAGE.get_stop_words().is_empty())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,8 +22,12 @@ impl Default for StopConditionFactory {
|
|||
}
|
||||
|
||||
impl StopConditionFactory {
|
||||
pub fn create(&self, text: &str, language: &'static Language) -> StopCondition {
|
||||
StopCondition::new(self.get_re(language), text)
|
||||
pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {
|
||||
if let Some(language) = language {
|
||||
StopCondition::new(self.get_re(language), text)
|
||||
} else {
|
||||
StopCondition::new(None, text)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_re(&self, language: &'static Language) -> Option<Regex> {
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ pub struct TextGenerationOptions {
|
|||
#[builder(default = "1.0")]
|
||||
pub sampling_temperature: f32,
|
||||
|
||||
#[builder(default = "&tabby_common::languages::UNKNOWN_LANGUAGE")]
|
||||
pub language: &'static Language,
|
||||
#[builder(default = "None")]
|
||||
pub language: Option<&'static Language>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ use async_stream::stream;
|
|||
use chat_prompt::ChatPromptBuilder;
|
||||
use futures::stream::BoxStream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::languages::EMPTY_LANGUAGE;
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||
use tracing::debug;
|
||||
use utoipa::ToSchema;
|
||||
|
|
@ -54,7 +53,6 @@ impl ChatService {
|
|||
TextGenerationOptionsBuilder::default()
|
||||
.max_input_length(2048)
|
||||
.max_decoding_length(1920)
|
||||
.language(&EMPTY_LANGUAGE)
|
||||
.sampling_temperature(0.1)
|
||||
.build()
|
||||
.unwrap()
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ impl CompletionService {
|
|||
.max_input_length(1024 + 512)
|
||||
.max_decoding_length(128)
|
||||
.sampling_temperature(0.1)
|
||||
.language(get_language(language))
|
||||
.language(Some(get_language(language)))
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue