refactor: handle max output length in StopCondition (#910)

* refactor: handle max output length in StopCondition

* trim stop words

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
add-prompt-lookup
Meng Zhang 2023-11-28 16:57:16 +08:00 committed by GitHub
parent c049f23a0c
commit 2b131ad1d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 51 additions and 16 deletions

View File

@ -72,8 +72,19 @@ impl LlamaTextGeneration {
#[async_trait]
impl TextGeneration for LlamaTextGeneration {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let language = options.language;
let s = self.generate_stream(prompt, options).await;
helpers::stream_to_string(s).await
let text = helpers::stream_to_string(s).await;
let Some(language) = language else {
return text;
};
let Some(trimmed) = self.stop_condition_factory.trim_stop_words(language, &text) else {
return text;
};
trimmed
}
async fn generate_stream(
@ -81,7 +92,11 @@ impl TextGeneration for LlamaTextGeneration {
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String> {
let stop_condition = self.stop_condition_factory.create(prompt, options.language);
let stop_condition = self.stop_condition_factory.create(
prompt,
options.max_decoding_length,
options.language,
);
let mut rx = self
.service
@ -89,13 +104,8 @@ impl TextGeneration for LlamaTextGeneration {
.await;
let s = stream! {
let mut length = 0;
while let Some(new_text) = rx.recv().await {
yield new_text;
length += 1;
if length >= options.max_decoding_length {
break;
}
}
rx.close();

View File

@ -87,14 +87,13 @@ impl LlamaServiceImpl {
if tx.is_closed() || text.is_empty() {
// Cancelled by client side or hit eos.
stopped = true;
} else if !stop_condition.should_stop(&text) {
} else {
stopped = stop_condition.should_stop(&text);
match tx.send(text).await {
Ok(_) => (),
Err(_) => stopped = true,
}
} else {
// Stoop words stopped
stopped = true;
}
if stopped {

View File

@ -22,11 +22,16 @@ impl Default for StopConditionFactory {
}
impl StopConditionFactory {
pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {
pub fn create(
&self,
text: &str,
max_decoding_length: usize,
language: Option<&'static Language>,
) -> StopCondition {
if let Some(language) = language {
StopCondition::new(self.get_re(language), text)
StopCondition::new(self.get_re(language), max_decoding_length, text)
} else {
StopCondition::new(None, text)
StopCondition::new(None, max_decoding_length, text)
}
}
@ -45,6 +50,22 @@ impl StopConditionFactory {
re.map(|x| x.value().clone())
}
}
pub fn trim_stop_words(&self, language: &'static Language, text: &str) -> Option<String> {
let Some(re) = self.get_re(language) else {
return None;
};
let text = reverse(text);
let text = if let Some(m) = re.find_at(&text, 0) {
&text[m.end()..]
} else {
&text
};
Some(reverse(text))
}
}
fn create_stop_regex(stop_words: Vec<String>) -> Regex {
@ -60,14 +81,18 @@ fn create_stop_regex(stop_words: Vec<String>) -> Regex {
pub struct StopCondition {
stop_re: Option<Regex>,
max_decoding_length: usize,
reversed_text: String,
num_decoded: usize,
}
impl StopCondition {
pub fn new(stop_re: Option<Regex>, text: &str) -> Self {
pub fn new(stop_re: Option<Regex>, max_decoding_length: usize, text: &str) -> Self {
Self {
stop_re,
max_decoding_length,
reversed_text: reverse(text),
num_decoded: 0,
}
}
@ -82,7 +107,8 @@ impl StopCondition {
}
}
false
self.num_decoded += 1;
self.num_decoded >= self.max_decoding_length
}
}