feat: add stop words encoding offset for ctranslate model config (#371)

* feat: add stop words encoding offset for ctranslate model config

* feat: set default suffix to \n

* add special treatment for bytefallback tokens
release-0.0
Meng Zhang 2023-08-28 14:07:01 +08:00 committed by GitHub
parent 4da95892ba
commit 65836ee199
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 9 deletions

View File

@ -63,6 +63,8 @@ pub struct CTranslate2EngineOptions {
num_replicas_per_device: usize,
compute_type: String,
stop_words_encoding_offset: Option<usize>,
}
pub struct InferenceContext {
@ -85,6 +87,7 @@ pub struct CTranslate2Engine {
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
tokenizer: Tokenizer,
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
stop_words_encoding_offset: Option<usize>,
}
impl CTranslate2Engine {
@ -102,6 +105,7 @@ impl CTranslate2Engine {
engine,
stop_regex_cache: DashMap::new(),
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(),
stop_words_encoding_offset: options.stop_words_encoding_offset,
};
}
}
@ -123,7 +127,11 @@ impl TextGeneration for CTranslate2Engine {
if re.is_none() {
self.stop_regex_cache.insert(
options.stop_words,
create_stop_regex(&self.tokenizer, options.stop_words),
create_stop_regex(
&self.tokenizer,
options.stop_words,
self.stop_words_encoding_offset,
),
);
re = self.stop_regex_cache.get(options.stop_words);
}
@ -156,7 +164,7 @@ fn inference_callback(
if context.cancel.is_cancelled() {
true
} else if let Some(re) = &context.stop_re {
let mut new_token = reverse(token);
let mut new_token = reverse(&token);
new_token.push_str(&context.reversed_output_text);
context.reversed_output_text = new_token;
re.find(&context.reversed_output_text).is_some()
@ -165,19 +173,37 @@ fn inference_callback(
}
}
fn reverse(s: String) -> String {
s.chars().rev().collect()
fn reverse(s: &String) -> String {
// Special treatment for byte fallback token.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/byte_fallback.rs
if s.len() == 6 && s.starts_with("<0x") && s.ends_with('>') {
// Keep byte fallback tokens like <0x0A> as is, do not reverse it.
// This won't really affect stop words regex logic, but brings more readability when
// debugging decoding steps.
s.to_owned()
} else {
s.chars().rev().collect()
}
}
fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &[&str]) -> Regex {
fn create_stop_regex(
tokenizer: &Tokenizer,
stop_words: &[&str],
stop_words_encoding_offset: Option<usize>,
) -> Regex {
let encodings = tokenizer
.encode_batch(stop_words.to_owned(), false)
.unwrap();
let stop_tokens: Vec<String> = encodings
.iter()
.map(|x| x.get_tokens().join(""))
// Reverse for efficient suffix matching.
.map(reverse)
.map(|x| {
x.get_tokens()[stop_words_encoding_offset.unwrap_or(0)..]
.iter()
.rev()
.map(reverse)
.collect::<Vec<String>>()
.join("")
})
.collect();
// (?m) enables multi-line matching mode.

View File

@ -140,6 +140,7 @@ impl CompletionState {
.device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device)
.compute_type(compute_type)
.stop_words_encoding_offset(metadata.stop_words_encoding_offset)
.build()
.unwrap();
let engine = CTranslate2Engine::create(options);
@ -165,6 +166,7 @@ fn get_model_dir(model: &str) -> ModelDir {
struct Metadata {
auto_model: String,
prompt_template: Option<String>,
stop_words_encoding_offset: Option<usize>,
}
fn read_metadata(model_dir: &ModelDir) -> Metadata {

View File

@ -52,7 +52,7 @@ impl PromptBuilder {
if let Some(suffix) = segments.suffix {
self.build_prompt(segments.prefix, suffix)
} else {
self.build_prompt(segments.prefix, "".to_owned())
self.build_prompt(segments.prefix, "\n".to_owned())
}
}