feat: implement input truncation with options.max_input_length (#415)
parent
acfdba68fa
commit
87b6b34120
|
|
@ -124,7 +124,7 @@ impl TextGeneration for CTranslate2Engine {
|
|||
engine.inference(
|
||||
context,
|
||||
inference_callback,
|
||||
encoding.get_tokens(),
|
||||
truncate_tokens(encoding.get_tokens(), options.max_input_length),
|
||||
options.max_decoding_length,
|
||||
options.sampling_temperature,
|
||||
)
|
||||
|
|
@ -135,6 +135,11 @@ impl TextGeneration for CTranslate2Engine {
|
|||
}
|
||||
}
|
||||
|
||||
fn truncate_tokens(tokens: &[String], max_length: usize) -> &[String] {
|
||||
let start = std::cmp::max(tokens.len() - max_length, 0);
|
||||
&tokens[start..]
|
||||
}
|
||||
|
||||
fn inference_callback(
|
||||
context: &mut InferenceContext,
|
||||
_step: usize,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@ use derive_builder::Builder;
|
|||
|
||||
#[derive(Builder, Debug)]
|
||||
pub struct TextGenerationOptions {
|
||||
#[builder(default = "1024")]
|
||||
pub max_input_length: usize,
|
||||
|
||||
#[builder(default = "256")]
|
||||
pub max_decoding_length: usize,
|
||||
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ pub async fn completion(
|
|||
) -> Result<Json<CompletionResponse>, StatusCode> {
|
||||
let language = request.language.unwrap_or("unknown".to_string());
|
||||
let options = TextGenerationOptionsBuilder::default()
|
||||
.max_input_length(1024 + 512)
|
||||
.max_decoding_length(128)
|
||||
.sampling_temperature(0.1)
|
||||
.stop_words(get_stop_words(&language))
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ use super::Segments;
|
|||
|
||||
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
||||
static MAX_SNIPPET_PER_NAME: u32 = 1;
|
||||
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 1024;
|
||||
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512;
|
||||
|
||||
pub struct PromptBuilder {
|
||||
prompt_template: Option<String>,
|
||||
|
|
|
|||
Loading…
Reference in New Issue