feat: support request level stop words (#492)

release-0.2
Meng Zhang 2023-09-29 11:21:57 -07:00 committed by GitHub
parent 486e507079
commit 0d6840e372
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 60 additions and 58 deletions

View File

@ -137,7 +137,7 @@ impl TextGeneration for CTranslate2Engine {
let decoding = self let decoding = self
.decoding_factory .decoding_factory
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words); .create(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), &options.stop_words, options.static_stop_words);
let (sender, mut receiver) = channel::<String>(8); let (sender, mut receiver) = channel::<String>(8);
let context = InferenceContext::new(sender, decoding, cancel_for_inference); let context = InferenceContext::new(sender, decoding, cancel_for_inference);

View File

@ -58,8 +58,11 @@ impl FastChatEngine {
#[async_trait] #[async_trait]
impl TextGeneration for FastChatEngine { impl TextGeneration for FastChatEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let _stop_sequences: Vec<String> = let _stop_sequences: Vec<String> = options
options.stop_words.iter().map(|x| x.to_string()).collect(); .static_stop_words
.iter()
.map(|x| x.to_string())
.collect();
let tokens: Vec<&str> = prompt.split("<MID>").collect(); let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request { let request = Request {

View File

@ -67,7 +67,7 @@ impl VertexAIEngine {
impl TextGeneration for VertexAIEngine { impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> = options let stop_sequences: Vec<String> = options
.stop_words .static_stop_words
.iter() .iter()
.map(|x| x.to_string()) .map(|x| x.to_string())
// vertex supports at most 5 stop sequence. // vertex supports at most 5 stop sequence.

View File

@ -15,29 +15,6 @@ static size_t N_BATCH = 512;
template<class T> template<class T>
using owned = std::unique_ptr<T, std::function<void(T*)>>; using owned = std::unique_ptr<T, std::function<void(T*)>>;
std::vector<llama_token> tokenize(struct llama_context * ctx, const std::string & text, size_t max_input_length, bool add_bos) {
const struct llama_model* model = llama_get_model(ctx);
// upper limit for the number of tokens
int n_tokens = max_input_length;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
GGML_ASSERT(check == -n_tokens);
int start = check - max_input_length;
GGML_ASSERT(start >= 0);
result = std::vector<llama_token>(result.begin() + start, result.end());
if (add_bos) {
result[0] = llama_token_bos(ctx);
}
} else {
result.resize(n_tokens);
}
return result;
}
class TextInferenceEngineImpl : public TextInferenceEngine { class TextInferenceEngineImpl : public TextInferenceEngine {
public: public:
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) : TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :

View File

@ -70,7 +70,7 @@ impl TextGeneration for LlamaEngine {
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length); let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
engine.start(input_token_ids); engine.start(input_token_ids);
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words); let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words);
let mut n_remains = options.max_decoding_length ; let mut n_remains = options.max_decoding_length ;
while n_remains > 0 { while n_remains > 0 {
let next_token_id = engine.step(); let next_token_id = engine.step();

View File

@ -24,16 +24,35 @@ impl Default for DecodingFactory {
} }
impl DecodingFactory { impl DecodingFactory {
pub fn create_incremental_decoding( pub fn create(
&self, &self,
tokenizer: Arc<Tokenizer>, tokenizer: Arc<Tokenizer>,
input_token_ids: &[u32], input_token_ids: &[u32],
stop_words: &'static Vec<&'static str>, stop_words: &Vec<String>,
static_stop_words: &'static Vec<&'static str>,
) -> IncrementalDecoding { ) -> IncrementalDecoding {
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids) IncrementalDecoding::new(
tokenizer,
vec![
self.get_static_re(static_stop_words),
self.get_re(stop_words),
]
.into_iter()
.flatten()
.collect(),
input_token_ids,
)
} }
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> { fn get_re(&self, stop_words: &Vec<String>) -> Option<Regex> {
if !stop_words.is_empty() {
Some(create_stop_regex(stop_words))
} else {
None
}
}
fn get_static_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
if stop_words.is_empty() { if stop_words.is_empty() {
None None
} else { } else {
@ -48,8 +67,8 @@ impl DecodingFactory {
} }
} }
fn create_stop_regex(stop_words: &[&str]) -> Regex { fn create_stop_regex<T: AsRef<str>>(stop_words: &[T]) -> Regex {
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(*x)).collect(); let tokens: Vec<String> = stop_words.iter().map(|x| reverse(x.as_ref())).collect();
// (?m) enables multi-line matching mode. // (?m) enables multi-line matching mode.
// \A means absolute begins of string. // \A means absolute begins of string.
@ -59,7 +78,7 @@ fn create_stop_regex(stop_words: &[&str]) -> Regex {
pub struct IncrementalDecoding { pub struct IncrementalDecoding {
tokenizer: Arc<Tokenizer>, tokenizer: Arc<Tokenizer>,
stop_re: Option<Regex>, stop_re: Vec<Regex>,
token_ids: Vec<u32>, token_ids: Vec<u32>,
prefix_offset: usize, prefix_offset: usize,
@ -69,7 +88,7 @@ pub struct IncrementalDecoding {
} }
impl IncrementalDecoding { impl IncrementalDecoding {
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self { pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Vec<Regex>, input_token_ids: &[u32]) -> Self {
let text = tokenizer let text = tokenizer
.decode(input_token_ids, /* skip_special_token = */ true) .decode(input_token_ids, /* skip_special_token = */ true)
.expect("Cannot decode token from tokenizer."); .expect("Cannot decode token from tokenizer.");
@ -110,8 +129,7 @@ impl IncrementalDecoding {
if !new_text.is_empty() { if !new_text.is_empty() {
self.reversed_text = reverse(new_text) + &self.reversed_text; self.reversed_text = reverse(new_text) + &self.reversed_text;
for re in &self.stop_re {
if let Some(re) = &self.stop_re {
if re.find(&self.reversed_text).is_some() { if re.find(&self.reversed_text).is_some() {
return None; return None;
} }

View File

@ -16,7 +16,10 @@ pub struct TextGenerationOptions {
pub sampling_temperature: f32, pub sampling_temperature: f32,
#[builder(default = "&EMPTY_STOP_WORDS")] #[builder(default = "&EMPTY_STOP_WORDS")]
pub stop_words: &'static Vec<&'static str>, pub static_stop_words: &'static Vec<&'static str>,
#[builder(default = "vec![]")]
pub stop_words: Vec<String>,
} }
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![]; static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];

View File

@ -80,7 +80,7 @@ pub async fn completion(
.max_input_length(1024 + 512) .max_input_length(1024 + 512)
.max_decoding_length(128) .max_decoding_length(128)
.sampling_temperature(0.1) .sampling_temperature(0.1)
.stop_words(get_stop_words(&language)) .static_stop_words(get_stop_words(&language))
.build() .build()
.unwrap(); .unwrap();

View File

@ -3,7 +3,6 @@ use std::sync::Arc;
use async_stream::stream; use async_stream::stream;
use axum::{extract::State, response::IntoResponse, Json}; use axum::{extract::State, response::IntoResponse, Json};
use axum_streams::StreamBodyAs; use axum_streams::StreamBodyAs;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::instrument; use tracing::instrument;
@ -20,11 +19,12 @@ impl GenerateState {
} }
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"prompt": "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n",
}))]
pub struct GenerateRequest { pub struct GenerateRequest {
#[schema(
example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n"
)]
prompt: String, prompt: String,
stop_words: Option<Vec<String>>,
} }
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
@ -47,9 +47,9 @@ pub async fn generate(
State(state): State<Arc<GenerateState>>, State(state): State<Arc<GenerateState>>,
Json(request): Json<GenerateRequest>, Json(request): Json<GenerateRequest>,
) -> impl IntoResponse { ) -> impl IntoResponse {
let options = build_options(&request); let (prompt, options) = parse_request(request);
Json(GenerateResponse { Json(GenerateResponse {
text: state.engine.generate(&request.prompt, options).await, text: state.engine.generate(&prompt, options).await,
}) })
} }
@ -68,9 +68,9 @@ pub async fn generate_stream(
State(state): State<Arc<GenerateState>>, State(state): State<Arc<GenerateState>>,
Json(request): Json<GenerateRequest>, Json(request): Json<GenerateRequest>,
) -> impl IntoResponse { ) -> impl IntoResponse {
let options = build_options(&request); let (prompt, options) = parse_request(request);
let s = stream! { let s = stream! {
for await text in state.engine.generate_stream(&request.prompt, options).await { for await text in state.engine.generate_stream(&prompt, options).await {
yield GenerateResponse { text } yield GenerateResponse { text }
} }
}; };
@ -78,16 +78,17 @@ pub async fn generate_stream(
StreamBodyAs::json_nl(s) StreamBodyAs::json_nl(s)
} }
lazy_static! { fn parse_request(request: GenerateRequest) -> (String, TextGenerationOptions) {
static ref STOP_WORDS: Vec<&'static str> = vec!["\n\n",]; let mut builder = TextGenerationOptionsBuilder::default();
}
fn build_options(_request: &GenerateRequest) -> TextGenerationOptions { builder
TextGenerationOptionsBuilder::default()
.max_input_length(1024) .max_input_length(1024)
.max_decoding_length(1024) .max_decoding_length(968)
.sampling_temperature(0.1) .sampling_temperature(0.1);
.stop_words(&STOP_WORDS)
.build() if let Some(stop_words) = request.stop_words {
.unwrap() builder.stop_words(stop_words);
};
(request.prompt, builder.build().unwrap())
} }