feat: add stop words encoding offset for ctranslate model config (#371)
* feat: add stop words encoding offset for ctranslate model config * feat: set default suffix to \n * add special treatment for bytefallback tokensrelease-0.0
parent
4da95892ba
commit
65836ee199
|
|
@ -63,6 +63,8 @@ pub struct CTranslate2EngineOptions {
|
|||
num_replicas_per_device: usize,
|
||||
|
||||
compute_type: String,
|
||||
|
||||
stop_words_encoding_offset: Option<usize>,
|
||||
}
|
||||
|
||||
pub struct InferenceContext {
|
||||
|
|
@ -85,6 +87,7 @@ pub struct CTranslate2Engine {
|
|||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||
tokenizer: Tokenizer,
|
||||
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
||||
stop_words_encoding_offset: Option<usize>,
|
||||
}
|
||||
|
||||
impl CTranslate2Engine {
|
||||
|
|
@ -102,6 +105,7 @@ impl CTranslate2Engine {
|
|||
engine,
|
||||
stop_regex_cache: DashMap::new(),
|
||||
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(),
|
||||
stop_words_encoding_offset: options.stop_words_encoding_offset,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
@ -123,7 +127,11 @@ impl TextGeneration for CTranslate2Engine {
|
|||
if re.is_none() {
|
||||
self.stop_regex_cache.insert(
|
||||
options.stop_words,
|
||||
create_stop_regex(&self.tokenizer, options.stop_words),
|
||||
create_stop_regex(
|
||||
&self.tokenizer,
|
||||
options.stop_words,
|
||||
self.stop_words_encoding_offset,
|
||||
),
|
||||
);
|
||||
re = self.stop_regex_cache.get(options.stop_words);
|
||||
}
|
||||
|
|
@ -156,7 +164,7 @@ fn inference_callback(
|
|||
if context.cancel.is_cancelled() {
|
||||
true
|
||||
} else if let Some(re) = &context.stop_re {
|
||||
let mut new_token = reverse(token);
|
||||
let mut new_token = reverse(&token);
|
||||
new_token.push_str(&context.reversed_output_text);
|
||||
context.reversed_output_text = new_token;
|
||||
re.find(&context.reversed_output_text).is_some()
|
||||
|
|
@ -165,19 +173,37 @@ fn inference_callback(
|
|||
}
|
||||
}
|
||||
|
||||
fn reverse(s: String) -> String {
|
||||
s.chars().rev().collect()
|
||||
fn reverse(s: &String) -> String {
|
||||
// Special treatment for byte fallback token.
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/byte_fallback.rs
|
||||
if s.len() == 6 && s.starts_with("<0x") && s.ends_with('>') {
|
||||
// Keep byte fallback tokens like <0x0A> as is, do not reverse it.
|
||||
// This won't really affect stop words regex logic, but brings more readability when
|
||||
// debugging decoding steps.
|
||||
s.to_owned()
|
||||
} else {
|
||||
s.chars().rev().collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &[&str]) -> Regex {
|
||||
fn create_stop_regex(
|
||||
tokenizer: &Tokenizer,
|
||||
stop_words: &[&str],
|
||||
stop_words_encoding_offset: Option<usize>,
|
||||
) -> Regex {
|
||||
let encodings = tokenizer
|
||||
.encode_batch(stop_words.to_owned(), false)
|
||||
.unwrap();
|
||||
let stop_tokens: Vec<String> = encodings
|
||||
.iter()
|
||||
.map(|x| x.get_tokens().join(""))
|
||||
// Reverse for efficient suffix matching.
|
||||
.map(reverse)
|
||||
.map(|x| {
|
||||
x.get_tokens()[stop_words_encoding_offset.unwrap_or(0)..]
|
||||
.iter()
|
||||
.rev()
|
||||
.map(reverse)
|
||||
.collect::<Vec<String>>()
|
||||
.join("")
|
||||
})
|
||||
.collect();
|
||||
|
||||
// (?m) enables multi-line matching mode.
|
||||
|
|
|
|||
|
|
@ -140,6 +140,7 @@ impl CompletionState {
|
|||
.device_indices(args.device_indices.clone())
|
||||
.num_replicas_per_device(args.num_replicas_per_device)
|
||||
.compute_type(compute_type)
|
||||
.stop_words_encoding_offset(metadata.stop_words_encoding_offset)
|
||||
.build()
|
||||
.unwrap();
|
||||
let engine = CTranslate2Engine::create(options);
|
||||
|
|
@ -165,6 +166,7 @@ fn get_model_dir(model: &str) -> ModelDir {
|
|||
struct Metadata {
|
||||
auto_model: String,
|
||||
prompt_template: Option<String>,
|
||||
stop_words_encoding_offset: Option<usize>,
|
||||
}
|
||||
|
||||
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ impl PromptBuilder {
|
|||
if let Some(suffix) = segments.suffix {
|
||||
self.build_prompt(segments.prefix, suffix)
|
||||
} else {
|
||||
self.build_prompt(segments.prefix, "".to_owned())
|
||||
self.build_prompt(segments.prefix, "\n".to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue