feat: add stop words encoding offset for ctranslate model config (#371)
* feat: add stop words encoding offset for ctranslate model config * feat: set default suffix to \n * add special treatment for bytefallback tokensrelease-0.0
parent
4da95892ba
commit
65836ee199
|
|
@ -63,6 +63,8 @@ pub struct CTranslate2EngineOptions {
|
||||||
num_replicas_per_device: usize,
|
num_replicas_per_device: usize,
|
||||||
|
|
||||||
compute_type: String,
|
compute_type: String,
|
||||||
|
|
||||||
|
stop_words_encoding_offset: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct InferenceContext {
|
pub struct InferenceContext {
|
||||||
|
|
@ -85,6 +87,7 @@ pub struct CTranslate2Engine {
|
||||||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
||||||
|
stop_words_encoding_offset: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CTranslate2Engine {
|
impl CTranslate2Engine {
|
||||||
|
|
@ -102,6 +105,7 @@ impl CTranslate2Engine {
|
||||||
engine,
|
engine,
|
||||||
stop_regex_cache: DashMap::new(),
|
stop_regex_cache: DashMap::new(),
|
||||||
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(),
|
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(),
|
||||||
|
stop_words_encoding_offset: options.stop_words_encoding_offset,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -123,7 +127,11 @@ impl TextGeneration for CTranslate2Engine {
|
||||||
if re.is_none() {
|
if re.is_none() {
|
||||||
self.stop_regex_cache.insert(
|
self.stop_regex_cache.insert(
|
||||||
options.stop_words,
|
options.stop_words,
|
||||||
create_stop_regex(&self.tokenizer, options.stop_words),
|
create_stop_regex(
|
||||||
|
&self.tokenizer,
|
||||||
|
options.stop_words,
|
||||||
|
self.stop_words_encoding_offset,
|
||||||
|
),
|
||||||
);
|
);
|
||||||
re = self.stop_regex_cache.get(options.stop_words);
|
re = self.stop_regex_cache.get(options.stop_words);
|
||||||
}
|
}
|
||||||
|
|
@ -156,7 +164,7 @@ fn inference_callback(
|
||||||
if context.cancel.is_cancelled() {
|
if context.cancel.is_cancelled() {
|
||||||
true
|
true
|
||||||
} else if let Some(re) = &context.stop_re {
|
} else if let Some(re) = &context.stop_re {
|
||||||
let mut new_token = reverse(token);
|
let mut new_token = reverse(&token);
|
||||||
new_token.push_str(&context.reversed_output_text);
|
new_token.push_str(&context.reversed_output_text);
|
||||||
context.reversed_output_text = new_token;
|
context.reversed_output_text = new_token;
|
||||||
re.find(&context.reversed_output_text).is_some()
|
re.find(&context.reversed_output_text).is_some()
|
||||||
|
|
@ -165,19 +173,37 @@ fn inference_callback(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reverse(s: String) -> String {
|
fn reverse(s: &String) -> String {
|
||||||
|
// Special treatment for byte fallback token.
|
||||||
|
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/byte_fallback.rs
|
||||||
|
if s.len() == 6 && s.starts_with("<0x") && s.ends_with('>') {
|
||||||
|
// Keep byte fallback tokens like <0x0A> as is, do not reverse it.
|
||||||
|
// This won't really affect stop words regex logic, but brings more readability when
|
||||||
|
// debugging decoding steps.
|
||||||
|
s.to_owned()
|
||||||
|
} else {
|
||||||
s.chars().rev().collect()
|
s.chars().rev().collect()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &[&str]) -> Regex {
|
fn create_stop_regex(
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
stop_words: &[&str],
|
||||||
|
stop_words_encoding_offset: Option<usize>,
|
||||||
|
) -> Regex {
|
||||||
let encodings = tokenizer
|
let encodings = tokenizer
|
||||||
.encode_batch(stop_words.to_owned(), false)
|
.encode_batch(stop_words.to_owned(), false)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let stop_tokens: Vec<String> = encodings
|
let stop_tokens: Vec<String> = encodings
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| x.get_tokens().join(""))
|
.map(|x| {
|
||||||
// Reverse for efficient suffix matching.
|
x.get_tokens()[stop_words_encoding_offset.unwrap_or(0)..]
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
.map(reverse)
|
.map(reverse)
|
||||||
|
.collect::<Vec<String>>()
|
||||||
|
.join("")
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// (?m) enables multi-line matching mode.
|
// (?m) enables multi-line matching mode.
|
||||||
|
|
|
||||||
|
|
@ -140,6 +140,7 @@ impl CompletionState {
|
||||||
.device_indices(args.device_indices.clone())
|
.device_indices(args.device_indices.clone())
|
||||||
.num_replicas_per_device(args.num_replicas_per_device)
|
.num_replicas_per_device(args.num_replicas_per_device)
|
||||||
.compute_type(compute_type)
|
.compute_type(compute_type)
|
||||||
|
.stop_words_encoding_offset(metadata.stop_words_encoding_offset)
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let engine = CTranslate2Engine::create(options);
|
let engine = CTranslate2Engine::create(options);
|
||||||
|
|
@ -165,6 +166,7 @@ fn get_model_dir(model: &str) -> ModelDir {
|
||||||
struct Metadata {
|
struct Metadata {
|
||||||
auto_model: String,
|
auto_model: String,
|
||||||
prompt_template: Option<String>,
|
prompt_template: Option<String>,
|
||||||
|
stop_words_encoding_offset: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ impl PromptBuilder {
|
||||||
if let Some(suffix) = segments.suffix {
|
if let Some(suffix) = segments.suffix {
|
||||||
self.build_prompt(segments.prefix, suffix)
|
self.build_prompt(segments.prefix, suffix)
|
||||||
} else {
|
} else {
|
||||||
self.build_prompt(segments.prefix, "".to_owned())
|
self.build_prompt(segments.prefix, "\n".to_owned())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue