feat: implement input truncation with options.max_input_length (#415)

release-0.2
Meng Zhang 2023-09-08 18:01:03 +08:00 committed by GitHub
parent acfdba68fa
commit 87b6b34120
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 11 additions and 2 deletions

View File

@ -124,7 +124,7 @@ impl TextGeneration for CTranslate2Engine {
engine.inference(
context,
inference_callback,
encoding.get_tokens(),
truncate_tokens(encoding.get_tokens(), options.max_input_length),
options.max_decoding_length,
options.sampling_temperature,
)
@ -135,6 +135,11 @@ impl TextGeneration for CTranslate2Engine {
}
}
fn truncate_tokens(tokens: &[String], max_length: usize) -> &[String] {
let start = std::cmp::max(tokens.len() - max_length, 0);
&tokens[start..]
}
fn inference_callback(
context: &mut InferenceContext,
_step: usize,

View File

@ -3,6 +3,9 @@ use derive_builder::Builder;
#[derive(Builder, Debug)]
pub struct TextGenerationOptions {
#[builder(default = "1024")]
pub max_input_length: usize,
#[builder(default = "256")]
pub max_decoding_length: usize,

View File

@ -79,6 +79,7 @@ pub async fn completion(
) -> Result<Json<CompletionResponse>, StatusCode> {
let language = request.language.unwrap_or("unknown".to_string());
let options = TextGenerationOptionsBuilder::default()
.max_input_length(1024 + 512)
.max_decoding_length(128)
.sampling_temperature(0.1)
.stop_words(get_stop_words(&language))

View File

@ -13,7 +13,7 @@ use super::Segments;
static MAX_SNIPPETS_TO_FETCH: usize = 20;
static MAX_SNIPPET_PER_NAME: u32 = 1;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 1024;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512;
pub struct PromptBuilder {
prompt_template: Option<String>,