fmt
parent
c3e57147cf
commit
54901515ef
|
|
@ -69,18 +69,22 @@ pub struct TextInferenceOptions {
|
|||
#[builder(default = "1.0")]
|
||||
sampling_temperature: f32,
|
||||
|
||||
stop_words: &'static Vec<&'static str>
|
||||
stop_words: &'static Vec<&'static str>,
|
||||
}
|
||||
|
||||
pub struct InferenceContext {
|
||||
stop_re: Option<Regex>,
|
||||
cancel: CancellationToken,
|
||||
reversed_output_text: String
|
||||
reversed_output_text: String,
|
||||
}
|
||||
|
||||
impl InferenceContext {
|
||||
fn new(stop_re: Option<Regex>, cancel: CancellationToken) -> Self {
|
||||
InferenceContext { stop_re, cancel, reversed_output_text: "".to_owned() }
|
||||
InferenceContext {
|
||||
stop_re,
|
||||
cancel,
|
||||
reversed_output_text: "".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -116,8 +120,11 @@ impl TextInferenceEngine {
|
|||
None
|
||||
} else {
|
||||
// FIXME(meng): consider cache the regexp.
|
||||
let encodings = self.tokenizer.encode_batch(options.stop_words.clone(), false).unwrap();
|
||||
let stop_tokens : Vec<String> = encodings
|
||||
let encodings = self
|
||||
.tokenizer
|
||||
.encode_batch(options.stop_words.clone(), false)
|
||||
.unwrap();
|
||||
let stop_tokens: Vec<String> = encodings
|
||||
.iter()
|
||||
.map(|x| x.get_tokens().join(""))
|
||||
// Reverse for efficient suffix matching.
|
||||
|
|
@ -146,18 +153,21 @@ impl TextInferenceEngine {
|
|||
}
|
||||
}
|
||||
|
||||
fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool {
|
||||
fn inference_callback(
|
||||
context: &mut InferenceContext,
|
||||
_step: usize,
|
||||
_token_id: u32,
|
||||
token: String,
|
||||
) -> bool {
|
||||
if context.cancel.is_cancelled() {
|
||||
true
|
||||
} else if let Some(re) = &context.stop_re {
|
||||
let mut new_token = reverse(token);
|
||||
new_token.push_str(&context.reversed_output_text);
|
||||
context.reversed_output_text = new_token;
|
||||
re.find(&context.reversed_output_text).is_some()
|
||||
} else {
|
||||
if let Some(re) = &context.stop_re {
|
||||
let mut new_token = reverse(token);
|
||||
new_token.push_str(&context.reversed_output_text);
|
||||
context.reversed_output_text = new_token;
|
||||
re.find(&context.reversed_output_text).is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -100,10 +100,7 @@ pub async fn completion(
|
|||
|
||||
Json(CompletionResponse {
|
||||
id: completion_id,
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
text,
|
||||
}],
|
||||
choices: vec![Choice { index: 0, text }],
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,10 +6,7 @@ lazy_static! {
|
|||
static ref DEFAULT: Vec<&'static str> = vec!("\n\n");
|
||||
static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = {
|
||||
let mut map = HashMap::new();
|
||||
map.insert(
|
||||
"python",
|
||||
vec!("\n\n", "\ndef", "\n#", "\nfrom", "\nclass")
|
||||
);
|
||||
map.insert("python", vec!["\n\n", "\ndef", "\n#", "\nfrom", "\nclass"]);
|
||||
map
|
||||
};
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue