feat: support request level stop words (#492)

release-0.2
Meng Zhang 2023-09-29 11:21:57 -07:00 committed by GitHub
parent 486e507079
commit 0d6840e372
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 60 additions and 58 deletions

View File

@ -137,7 +137,7 @@ impl TextGeneration for CTranslate2Engine {
let decoding = self
.decoding_factory
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);
.create(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), &options.stop_words, options.static_stop_words);
let (sender, mut receiver) = channel::<String>(8);
let context = InferenceContext::new(sender, decoding, cancel_for_inference);

View File

@ -58,8 +58,11 @@ impl FastChatEngine {
#[async_trait]
impl TextGeneration for FastChatEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let _stop_sequences: Vec<String> =
options.stop_words.iter().map(|x| x.to_string()).collect();
let _stop_sequences: Vec<String> = options
.static_stop_words
.iter()
.map(|x| x.to_string())
.collect();
let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request {

View File

@ -67,7 +67,7 @@ impl VertexAIEngine {
impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> = options
.stop_words
.static_stop_words
.iter()
.map(|x| x.to_string())
// vertex supports at most 5 stop sequence.

View File

@ -15,29 +15,6 @@ static size_t N_BATCH = 512;
template<class T>
using owned = std::unique_ptr<T, std::function<void(T*)>>;
std::vector<llama_token> tokenize(struct llama_context * ctx, const std::string & text, size_t max_input_length, bool add_bos) {
const struct llama_model* model = llama_get_model(ctx);
// upper limit for the number of tokens
int n_tokens = max_input_length;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
GGML_ASSERT(check == -n_tokens);
int start = check - max_input_length;
GGML_ASSERT(start >= 0);
result = std::vector<llama_token>(result.begin() + start, result.end());
if (add_bos) {
result[0] = llama_token_bos(ctx);
}
} else {
result.resize(n_tokens);
}
return result;
}
class TextInferenceEngineImpl : public TextInferenceEngine {
public:
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :

View File

@ -70,7 +70,7 @@ impl TextGeneration for LlamaEngine {
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
engine.start(input_token_ids);
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words);
let mut n_remains = options.max_decoding_length ;
while n_remains > 0 {
let next_token_id = engine.step();

View File

@ -24,16 +24,35 @@ impl Default for DecodingFactory {
}
impl DecodingFactory {
pub fn create_incremental_decoding(
pub fn create(
&self,
tokenizer: Arc<Tokenizer>,
input_token_ids: &[u32],
stop_words: &'static Vec<&'static str>,
stop_words: &Vec<String>,
static_stop_words: &'static Vec<&'static str>,
) -> IncrementalDecoding {
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
IncrementalDecoding::new(
tokenizer,
vec![
self.get_static_re(static_stop_words),
self.get_re(stop_words),
]
.into_iter()
.flatten()
.collect(),
input_token_ids,
)
}
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
fn get_re(&self, stop_words: &Vec<String>) -> Option<Regex> {
if !stop_words.is_empty() {
Some(create_stop_regex(stop_words))
} else {
None
}
}
fn get_static_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
if stop_words.is_empty() {
None
} else {
@ -48,8 +67,8 @@ impl DecodingFactory {
}
}
fn create_stop_regex(stop_words: &[&str]) -> Regex {
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(*x)).collect();
fn create_stop_regex<T: AsRef<str>>(stop_words: &[T]) -> Regex {
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(x.as_ref())).collect();
// (?m) enables multi-line matching mode.
// \A means absolute begins of string.
@ -59,7 +78,7 @@ fn create_stop_regex(stop_words: &[&str]) -> Regex {
pub struct IncrementalDecoding {
tokenizer: Arc<Tokenizer>,
stop_re: Option<Regex>,
stop_re: Vec<Regex>,
token_ids: Vec<u32>,
prefix_offset: usize,
@ -69,7 +88,7 @@ pub struct IncrementalDecoding {
}
impl IncrementalDecoding {
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Vec<Regex>, input_token_ids: &[u32]) -> Self {
let text = tokenizer
.decode(input_token_ids, /* skip_special_token = */ true)
.expect("Cannot decode token from tokenizer.");
@ -110,8 +129,7 @@ impl IncrementalDecoding {
if !new_text.is_empty() {
self.reversed_text = reverse(new_text) + &self.reversed_text;
if let Some(re) = &self.stop_re {
for re in &self.stop_re {
if re.find(&self.reversed_text).is_some() {
return None;
}

View File

@ -16,7 +16,10 @@ pub struct TextGenerationOptions {
pub sampling_temperature: f32,
#[builder(default = "&EMPTY_STOP_WORDS")]
pub stop_words: &'static Vec<&'static str>,
pub static_stop_words: &'static Vec<&'static str>,
#[builder(default = "vec![]")]
pub stop_words: Vec<String>,
}
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];

View File

@ -80,7 +80,7 @@ pub async fn completion(
.max_input_length(1024 + 512)
.max_decoding_length(128)
.sampling_temperature(0.1)
.stop_words(get_stop_words(&language))
.static_stop_words(get_stop_words(&language))
.build()
.unwrap();

View File

@ -3,7 +3,6 @@ use std::sync::Arc;
use async_stream::stream;
use axum::{extract::State, response::IntoResponse, Json};
use axum_streams::StreamBodyAs;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::instrument;
@ -20,11 +19,12 @@ impl GenerateState {
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"prompt": "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n",
}))]
pub struct GenerateRequest {
#[schema(
example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n"
)]
prompt: String,
stop_words: Option<Vec<String>>,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
@ -47,9 +47,9 @@ pub async fn generate(
State(state): State<Arc<GenerateState>>,
Json(request): Json<GenerateRequest>,
) -> impl IntoResponse {
let options = build_options(&request);
let (prompt, options) = parse_request(request);
Json(GenerateResponse {
text: state.engine.generate(&request.prompt, options).await,
text: state.engine.generate(&prompt, options).await,
})
}
@ -68,9 +68,9 @@ pub async fn generate_stream(
State(state): State<Arc<GenerateState>>,
Json(request): Json<GenerateRequest>,
) -> impl IntoResponse {
let options = build_options(&request);
let (prompt, options) = parse_request(request);
let s = stream! {
for await text in state.engine.generate_stream(&request.prompt, options).await {
for await text in state.engine.generate_stream(&prompt, options).await {
yield GenerateResponse { text }
}
};
@ -78,16 +78,17 @@ pub async fn generate_stream(
StreamBodyAs::json_nl(s)
}
lazy_static! {
static ref STOP_WORDS: Vec<&'static str> = vec!["\n\n",];
}
fn parse_request(request: GenerateRequest) -> (String, TextGenerationOptions) {
let mut builder = TextGenerationOptionsBuilder::default();
fn build_options(_request: &GenerateRequest) -> TextGenerationOptions {
TextGenerationOptionsBuilder::default()
builder
.max_input_length(1024)
.max_decoding_length(1024)
.sampling_temperature(0.1)
.stop_words(&STOP_WORDS)
.build()
.unwrap()
.max_decoding_length(968)
.sampling_temperature(0.1);
if let Some(stop_words) = request.stop_words {
builder.stop_words(stop_words);
};
(request.prompt, builder.build().unwrap())
}