feat: support request level stop words (#492)
parent
486e507079
commit
0d6840e372
|
|
@ -137,7 +137,7 @@ impl TextGeneration for CTranslate2Engine {
|
||||||
|
|
||||||
let decoding = self
|
let decoding = self
|
||||||
.decoding_factory
|
.decoding_factory
|
||||||
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);
|
.create(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), &options.stop_words, options.static_stop_words);
|
||||||
|
|
||||||
let (sender, mut receiver) = channel::<String>(8);
|
let (sender, mut receiver) = channel::<String>(8);
|
||||||
let context = InferenceContext::new(sender, decoding, cancel_for_inference);
|
let context = InferenceContext::new(sender, decoding, cancel_for_inference);
|
||||||
|
|
|
||||||
|
|
@ -58,8 +58,11 @@ impl FastChatEngine {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl TextGeneration for FastChatEngine {
|
impl TextGeneration for FastChatEngine {
|
||||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||||
let _stop_sequences: Vec<String> =
|
let _stop_sequences: Vec<String> = options
|
||||||
options.stop_words.iter().map(|x| x.to_string()).collect();
|
.static_stop_words
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.to_string())
|
||||||
|
.collect();
|
||||||
|
|
||||||
let tokens: Vec<&str> = prompt.split("<MID>").collect();
|
let tokens: Vec<&str> = prompt.split("<MID>").collect();
|
||||||
let request = Request {
|
let request = Request {
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ impl VertexAIEngine {
|
||||||
impl TextGeneration for VertexAIEngine {
|
impl TextGeneration for VertexAIEngine {
|
||||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||||
let stop_sequences: Vec<String> = options
|
let stop_sequences: Vec<String> = options
|
||||||
.stop_words
|
.static_stop_words
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| x.to_string())
|
.map(|x| x.to_string())
|
||||||
// vertex supports at most 5 stop sequence.
|
// vertex supports at most 5 stop sequence.
|
||||||
|
|
|
||||||
|
|
@ -15,29 +15,6 @@ static size_t N_BATCH = 512;
|
||||||
template<class T>
|
template<class T>
|
||||||
using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
||||||
|
|
||||||
std::vector<llama_token> tokenize(struct llama_context * ctx, const std::string & text, size_t max_input_length, bool add_bos) {
|
|
||||||
const struct llama_model* model = llama_get_model(ctx);
|
|
||||||
// upper limit for the number of tokens
|
|
||||||
int n_tokens = max_input_length;
|
|
||||||
std::vector<llama_token> result(n_tokens);
|
|
||||||
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
|
|
||||||
if (n_tokens < 0) {
|
|
||||||
result.resize(-n_tokens);
|
|
||||||
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
|
|
||||||
GGML_ASSERT(check == -n_tokens);
|
|
||||||
|
|
||||||
int start = check - max_input_length;
|
|
||||||
GGML_ASSERT(start >= 0);
|
|
||||||
result = std::vector<llama_token>(result.begin() + start, result.end());
|
|
||||||
if (add_bos) {
|
|
||||||
result[0] = llama_token_bos(ctx);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
result.resize(n_tokens);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
class TextInferenceEngineImpl : public TextInferenceEngine {
|
class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
public:
|
public:
|
||||||
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
|
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ impl TextGeneration for LlamaEngine {
|
||||||
|
|
||||||
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
||||||
engine.start(input_token_ids);
|
engine.start(input_token_ids);
|
||||||
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
|
let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words);
|
||||||
let mut n_remains = options.max_decoding_length ;
|
let mut n_remains = options.max_decoding_length ;
|
||||||
while n_remains > 0 {
|
while n_remains > 0 {
|
||||||
let next_token_id = engine.step();
|
let next_token_id = engine.step();
|
||||||
|
|
|
||||||
|
|
@ -24,16 +24,35 @@ impl Default for DecodingFactory {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DecodingFactory {
|
impl DecodingFactory {
|
||||||
pub fn create_incremental_decoding(
|
pub fn create(
|
||||||
&self,
|
&self,
|
||||||
tokenizer: Arc<Tokenizer>,
|
tokenizer: Arc<Tokenizer>,
|
||||||
input_token_ids: &[u32],
|
input_token_ids: &[u32],
|
||||||
stop_words: &'static Vec<&'static str>,
|
stop_words: &Vec<String>,
|
||||||
|
static_stop_words: &'static Vec<&'static str>,
|
||||||
) -> IncrementalDecoding {
|
) -> IncrementalDecoding {
|
||||||
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
|
IncrementalDecoding::new(
|
||||||
|
tokenizer,
|
||||||
|
vec![
|
||||||
|
self.get_static_re(static_stop_words),
|
||||||
|
self.get_re(stop_words),
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
|
.collect(),
|
||||||
|
input_token_ids,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
|
fn get_re(&self, stop_words: &Vec<String>) -> Option<Regex> {
|
||||||
|
if !stop_words.is_empty() {
|
||||||
|
Some(create_stop_regex(stop_words))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_static_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
|
||||||
if stop_words.is_empty() {
|
if stop_words.is_empty() {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -48,8 +67,8 @@ impl DecodingFactory {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_stop_regex(stop_words: &[&str]) -> Regex {
|
fn create_stop_regex<T: AsRef<str>>(stop_words: &[T]) -> Regex {
|
||||||
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(*x)).collect();
|
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(x.as_ref())).collect();
|
||||||
|
|
||||||
// (?m) enables multi-line matching mode.
|
// (?m) enables multi-line matching mode.
|
||||||
// \A means absolute begins of string.
|
// \A means absolute begins of string.
|
||||||
|
|
@ -59,7 +78,7 @@ fn create_stop_regex(stop_words: &[&str]) -> Regex {
|
||||||
|
|
||||||
pub struct IncrementalDecoding {
|
pub struct IncrementalDecoding {
|
||||||
tokenizer: Arc<Tokenizer>,
|
tokenizer: Arc<Tokenizer>,
|
||||||
stop_re: Option<Regex>,
|
stop_re: Vec<Regex>,
|
||||||
|
|
||||||
token_ids: Vec<u32>,
|
token_ids: Vec<u32>,
|
||||||
prefix_offset: usize,
|
prefix_offset: usize,
|
||||||
|
|
@ -69,7 +88,7 @@ pub struct IncrementalDecoding {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IncrementalDecoding {
|
impl IncrementalDecoding {
|
||||||
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
|
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Vec<Regex>, input_token_ids: &[u32]) -> Self {
|
||||||
let text = tokenizer
|
let text = tokenizer
|
||||||
.decode(input_token_ids, /* skip_special_token = */ true)
|
.decode(input_token_ids, /* skip_special_token = */ true)
|
||||||
.expect("Cannot decode token from tokenizer.");
|
.expect("Cannot decode token from tokenizer.");
|
||||||
|
|
@ -110,8 +129,7 @@ impl IncrementalDecoding {
|
||||||
|
|
||||||
if !new_text.is_empty() {
|
if !new_text.is_empty() {
|
||||||
self.reversed_text = reverse(new_text) + &self.reversed_text;
|
self.reversed_text = reverse(new_text) + &self.reversed_text;
|
||||||
|
for re in &self.stop_re {
|
||||||
if let Some(re) = &self.stop_re {
|
|
||||||
if re.find(&self.reversed_text).is_some() {
|
if re.find(&self.reversed_text).is_some() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,10 @@ pub struct TextGenerationOptions {
|
||||||
pub sampling_temperature: f32,
|
pub sampling_temperature: f32,
|
||||||
|
|
||||||
#[builder(default = "&EMPTY_STOP_WORDS")]
|
#[builder(default = "&EMPTY_STOP_WORDS")]
|
||||||
pub stop_words: &'static Vec<&'static str>,
|
pub static_stop_words: &'static Vec<&'static str>,
|
||||||
|
|
||||||
|
#[builder(default = "vec![]")]
|
||||||
|
pub stop_words: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
|
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ pub async fn completion(
|
||||||
.max_input_length(1024 + 512)
|
.max_input_length(1024 + 512)
|
||||||
.max_decoding_length(128)
|
.max_decoding_length(128)
|
||||||
.sampling_temperature(0.1)
|
.sampling_temperature(0.1)
|
||||||
.stop_words(get_stop_words(&language))
|
.static_stop_words(get_stop_words(&language))
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ use std::sync::Arc;
|
||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
use axum::{extract::State, response::IntoResponse, Json};
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
use axum_streams::StreamBodyAs;
|
use axum_streams::StreamBodyAs;
|
||||||
use lazy_static::lazy_static;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
@ -20,11 +19,12 @@ impl GenerateState {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
#[schema(example=json!({
|
||||||
|
"prompt": "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n",
|
||||||
|
}))]
|
||||||
pub struct GenerateRequest {
|
pub struct GenerateRequest {
|
||||||
#[schema(
|
|
||||||
example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n"
|
|
||||||
)]
|
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
stop_words: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
|
@ -47,9 +47,9 @@ pub async fn generate(
|
||||||
State(state): State<Arc<GenerateState>>,
|
State(state): State<Arc<GenerateState>>,
|
||||||
Json(request): Json<GenerateRequest>,
|
Json(request): Json<GenerateRequest>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let options = build_options(&request);
|
let (prompt, options) = parse_request(request);
|
||||||
Json(GenerateResponse {
|
Json(GenerateResponse {
|
||||||
text: state.engine.generate(&request.prompt, options).await,
|
text: state.engine.generate(&prompt, options).await,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -68,9 +68,9 @@ pub async fn generate_stream(
|
||||||
State(state): State<Arc<GenerateState>>,
|
State(state): State<Arc<GenerateState>>,
|
||||||
Json(request): Json<GenerateRequest>,
|
Json(request): Json<GenerateRequest>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let options = build_options(&request);
|
let (prompt, options) = parse_request(request);
|
||||||
let s = stream! {
|
let s = stream! {
|
||||||
for await text in state.engine.generate_stream(&request.prompt, options).await {
|
for await text in state.engine.generate_stream(&prompt, options).await {
|
||||||
yield GenerateResponse { text }
|
yield GenerateResponse { text }
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -78,16 +78,17 @@ pub async fn generate_stream(
|
||||||
StreamBodyAs::json_nl(s)
|
StreamBodyAs::json_nl(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
lazy_static! {
|
fn parse_request(request: GenerateRequest) -> (String, TextGenerationOptions) {
|
||||||
static ref STOP_WORDS: Vec<&'static str> = vec!["\n\n",];
|
let mut builder = TextGenerationOptionsBuilder::default();
|
||||||
}
|
|
||||||
|
|
||||||
fn build_options(_request: &GenerateRequest) -> TextGenerationOptions {
|
builder
|
||||||
TextGenerationOptionsBuilder::default()
|
|
||||||
.max_input_length(1024)
|
.max_input_length(1024)
|
||||||
.max_decoding_length(1024)
|
.max_decoding_length(968)
|
||||||
.sampling_temperature(0.1)
|
.sampling_temperature(0.1);
|
||||||
.stop_words(&STOP_WORDS)
|
|
||||||
.build()
|
if let Some(stop_words) = request.stop_words {
|
||||||
.unwrap()
|
builder.stop_words(stop_words);
|
||||||
|
};
|
||||||
|
|
||||||
|
(request.prompt, builder.build().unwrap())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue