From 87b6b34120706ee050072c9687a2ea7cbfba824e Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 8 Sep 2023 18:01:03 +0800 Subject: [PATCH] feat: implement input truncation with options.max_input_length (#415) --- crates/ctranslate2-bindings/src/lib.rs | 7 ++++++- crates/tabby-inference/src/lib.rs | 3 +++ crates/tabby/src/serve/completions.rs | 1 + crates/tabby/src/serve/completions/prompt.rs | 2 +- 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index ab49996..7afb7f9 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -124,7 +124,7 @@ impl TextGeneration for CTranslate2Engine { engine.inference( context, inference_callback, - encoding.get_tokens(), + truncate_tokens(encoding.get_tokens(), options.max_input_length), options.max_decoding_length, options.sampling_temperature, ) @@ -135,6 +135,11 @@ impl TextGeneration for CTranslate2Engine { } } +fn truncate_tokens(tokens: &[String], max_length: usize) -> &[String] { + let start = std::cmp::max(tokens.len() - max_length, 0); + &tokens[start..] +} + fn inference_callback( context: &mut InferenceContext, _step: usize, diff --git a/crates/tabby-inference/src/lib.rs b/crates/tabby-inference/src/lib.rs index 4d1befa..c822034 100644 --- a/crates/tabby-inference/src/lib.rs +++ b/crates/tabby-inference/src/lib.rs @@ -3,6 +3,9 @@ use derive_builder::Builder; #[derive(Builder, Debug)] pub struct TextGenerationOptions { + #[builder(default = "1024")] + pub max_input_length: usize, + #[builder(default = "256")] pub max_decoding_length: usize, diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 10755ac..64255d7 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -79,6 +79,7 @@ pub async fn completion( ) -> Result, StatusCode> { let language = request.language.unwrap_or("unknown".to_string()); let options = TextGenerationOptionsBuilder::default() + .max_input_length(1024 + 512) .max_decoding_length(128) .sampling_temperature(0.1) .stop_words(get_stop_words(&language)) diff --git a/crates/tabby/src/serve/completions/prompt.rs b/crates/tabby/src/serve/completions/prompt.rs index 5b94f7d..998bff2 100644 --- a/crates/tabby/src/serve/completions/prompt.rs +++ b/crates/tabby/src/serve/completions/prompt.rs @@ -13,7 +13,7 @@ use super::Segments; static MAX_SNIPPETS_TO_FETCH: usize = 20; static MAX_SNIPPET_PER_NAME: u32 = 1; -static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 1024; +static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512; pub struct PromptBuilder { prompt_template: Option,