refactor: handle max output length in StopCondition (#910)
* refactor: handle max output length in StopCondition * trim stop words * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>add-prompt-lookup
parent
c049f23a0c
commit
2b131ad1d2
|
|
@ -72,8 +72,19 @@ impl LlamaTextGeneration {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl TextGeneration for LlamaTextGeneration {
|
impl TextGeneration for LlamaTextGeneration {
|
||||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||||
|
let language = options.language;
|
||||||
let s = self.generate_stream(prompt, options).await;
|
let s = self.generate_stream(prompt, options).await;
|
||||||
helpers::stream_to_string(s).await
|
let text = helpers::stream_to_string(s).await;
|
||||||
|
|
||||||
|
let Some(language) = language else {
|
||||||
|
return text;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(trimmed) = self.stop_condition_factory.trim_stop_words(language, &text) else {
|
||||||
|
return text;
|
||||||
|
};
|
||||||
|
|
||||||
|
trimmed
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
|
|
@ -81,7 +92,11 @@ impl TextGeneration for LlamaTextGeneration {
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
options: TextGenerationOptions,
|
options: TextGenerationOptions,
|
||||||
) -> BoxStream<String> {
|
) -> BoxStream<String> {
|
||||||
let stop_condition = self.stop_condition_factory.create(prompt, options.language);
|
let stop_condition = self.stop_condition_factory.create(
|
||||||
|
prompt,
|
||||||
|
options.max_decoding_length,
|
||||||
|
options.language,
|
||||||
|
);
|
||||||
|
|
||||||
let mut rx = self
|
let mut rx = self
|
||||||
.service
|
.service
|
||||||
|
|
@ -89,13 +104,8 @@ impl TextGeneration for LlamaTextGeneration {
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let s = stream! {
|
let s = stream! {
|
||||||
let mut length = 0;
|
|
||||||
while let Some(new_text) = rx.recv().await {
|
while let Some(new_text) = rx.recv().await {
|
||||||
yield new_text;
|
yield new_text;
|
||||||
length += 1;
|
|
||||||
if length >= options.max_decoding_length {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rx.close();
|
rx.close();
|
||||||
|
|
|
||||||
|
|
@ -87,14 +87,13 @@ impl LlamaServiceImpl {
|
||||||
if tx.is_closed() || text.is_empty() {
|
if tx.is_closed() || text.is_empty() {
|
||||||
// Cancelled by client side or hit eos.
|
// Cancelled by client side or hit eos.
|
||||||
stopped = true;
|
stopped = true;
|
||||||
} else if !stop_condition.should_stop(&text) {
|
} else {
|
||||||
|
stopped = stop_condition.should_stop(&text);
|
||||||
|
|
||||||
match tx.send(text).await {
|
match tx.send(text).await {
|
||||||
Ok(_) => (),
|
Ok(_) => (),
|
||||||
Err(_) => stopped = true,
|
Err(_) => stopped = true,
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Stoop words stopped
|
|
||||||
stopped = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if stopped {
|
if stopped {
|
||||||
|
|
|
||||||
|
|
@ -22,11 +22,16 @@ impl Default for StopConditionFactory {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StopConditionFactory {
|
impl StopConditionFactory {
|
||||||
pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {
|
pub fn create(
|
||||||
|
&self,
|
||||||
|
text: &str,
|
||||||
|
max_decoding_length: usize,
|
||||||
|
language: Option<&'static Language>,
|
||||||
|
) -> StopCondition {
|
||||||
if let Some(language) = language {
|
if let Some(language) = language {
|
||||||
StopCondition::new(self.get_re(language), text)
|
StopCondition::new(self.get_re(language), max_decoding_length, text)
|
||||||
} else {
|
} else {
|
||||||
StopCondition::new(None, text)
|
StopCondition::new(None, max_decoding_length, text)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -45,6 +50,22 @@ impl StopConditionFactory {
|
||||||
re.map(|x| x.value().clone())
|
re.map(|x| x.value().clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn trim_stop_words(&self, language: &'static Language, text: &str) -> Option<String> {
|
||||||
|
let Some(re) = self.get_re(language) else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
let text = reverse(text);
|
||||||
|
|
||||||
|
let text = if let Some(m) = re.find_at(&text, 0) {
|
||||||
|
&text[m.end()..]
|
||||||
|
} else {
|
||||||
|
&text
|
||||||
|
};
|
||||||
|
|
||||||
|
Some(reverse(text))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_stop_regex(stop_words: Vec<String>) -> Regex {
|
fn create_stop_regex(stop_words: Vec<String>) -> Regex {
|
||||||
|
|
@ -60,14 +81,18 @@ fn create_stop_regex(stop_words: Vec<String>) -> Regex {
|
||||||
|
|
||||||
pub struct StopCondition {
|
pub struct StopCondition {
|
||||||
stop_re: Option<Regex>,
|
stop_re: Option<Regex>,
|
||||||
|
max_decoding_length: usize,
|
||||||
reversed_text: String,
|
reversed_text: String,
|
||||||
|
num_decoded: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StopCondition {
|
impl StopCondition {
|
||||||
pub fn new(stop_re: Option<Regex>, text: &str) -> Self {
|
pub fn new(stop_re: Option<Regex>, max_decoding_length: usize, text: &str) -> Self {
|
||||||
Self {
|
Self {
|
||||||
stop_re,
|
stop_re,
|
||||||
|
max_decoding_length,
|
||||||
reversed_text: reverse(text),
|
reversed_text: reverse(text),
|
||||||
|
num_decoded: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -82,7 +107,8 @@ impl StopCondition {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
false
|
self.num_decoded += 1;
|
||||||
|
self.num_decoded >= self.max_decoding_length
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue