fmt
parent
c3e57147cf
commit
54901515ef
|
|
@ -69,18 +69,22 @@ pub struct TextInferenceOptions {
|
||||||
#[builder(default = "1.0")]
|
#[builder(default = "1.0")]
|
||||||
sampling_temperature: f32,
|
sampling_temperature: f32,
|
||||||
|
|
||||||
stop_words: &'static Vec<&'static str>
|
stop_words: &'static Vec<&'static str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct InferenceContext {
|
pub struct InferenceContext {
|
||||||
stop_re: Option<Regex>,
|
stop_re: Option<Regex>,
|
||||||
cancel: CancellationToken,
|
cancel: CancellationToken,
|
||||||
reversed_output_text: String
|
reversed_output_text: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InferenceContext {
|
impl InferenceContext {
|
||||||
fn new(stop_re: Option<Regex>, cancel: CancellationToken) -> Self {
|
fn new(stop_re: Option<Regex>, cancel: CancellationToken) -> Self {
|
||||||
InferenceContext { stop_re, cancel, reversed_output_text: "".to_owned() }
|
InferenceContext {
|
||||||
|
stop_re,
|
||||||
|
cancel,
|
||||||
|
reversed_output_text: "".to_owned(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -116,8 +120,11 @@ impl TextInferenceEngine {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
// FIXME(meng): consider cache the regexp.
|
// FIXME(meng): consider cache the regexp.
|
||||||
let encodings = self.tokenizer.encode_batch(options.stop_words.clone(), false).unwrap();
|
let encodings = self
|
||||||
let stop_tokens : Vec<String> = encodings
|
.tokenizer
|
||||||
|
.encode_batch(options.stop_words.clone(), false)
|
||||||
|
.unwrap();
|
||||||
|
let stop_tokens: Vec<String> = encodings
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| x.get_tokens().join(""))
|
.map(|x| x.get_tokens().join(""))
|
||||||
// Reverse for efficient suffix matching.
|
// Reverse for efficient suffix matching.
|
||||||
|
|
@ -146,18 +153,21 @@ impl TextInferenceEngine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool {
|
fn inference_callback(
|
||||||
|
context: &mut InferenceContext,
|
||||||
|
_step: usize,
|
||||||
|
_token_id: u32,
|
||||||
|
token: String,
|
||||||
|
) -> bool {
|
||||||
if context.cancel.is_cancelled() {
|
if context.cancel.is_cancelled() {
|
||||||
true
|
true
|
||||||
|
} else if let Some(re) = &context.stop_re {
|
||||||
|
let mut new_token = reverse(token);
|
||||||
|
new_token.push_str(&context.reversed_output_text);
|
||||||
|
context.reversed_output_text = new_token;
|
||||||
|
re.find(&context.reversed_output_text).is_some()
|
||||||
} else {
|
} else {
|
||||||
if let Some(re) = &context.stop_re {
|
false
|
||||||
let mut new_token = reverse(token);
|
|
||||||
new_token.push_str(&context.reversed_output_text);
|
|
||||||
context.reversed_output_text = new_token;
|
|
||||||
re.find(&context.reversed_output_text).is_some()
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -100,10 +100,7 @@ pub async fn completion(
|
||||||
|
|
||||||
Json(CompletionResponse {
|
Json(CompletionResponse {
|
||||||
id: completion_id,
|
id: completion_id,
|
||||||
choices: vec![Choice {
|
choices: vec![Choice { index: 0, text }],
|
||||||
index: 0,
|
|
||||||
text,
|
|
||||||
}],
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,7 @@ lazy_static! {
|
||||||
static ref DEFAULT: Vec<&'static str> = vec!("\n\n");
|
static ref DEFAULT: Vec<&'static str> = vec!("\n\n");
|
||||||
static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = {
|
static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = {
|
||||||
let mut map = HashMap::new();
|
let mut map = HashMap::new();
|
||||||
map.insert(
|
map.insert("python", vec!["\n\n", "\ndef", "\n#", "\nfrom", "\nclass"]);
|
||||||
"python",
|
|
||||||
vec!("\n\n", "\ndef", "\n#", "\nfrom", "\nclass")
|
|
||||||
);
|
|
||||||
map
|
map
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue