feat: support request level stop words (#492)
parent
486e507079
commit
0d6840e372
|
|
@ -137,7 +137,7 @@ impl TextGeneration for CTranslate2Engine {
|
|||
|
||||
let decoding = self
|
||||
.decoding_factory
|
||||
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);
|
||||
.create(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), &options.stop_words, options.static_stop_words);
|
||||
|
||||
let (sender, mut receiver) = channel::<String>(8);
|
||||
let context = InferenceContext::new(sender, decoding, cancel_for_inference);
|
||||
|
|
|
|||
|
|
@ -58,8 +58,11 @@ impl FastChatEngine {
|
|||
#[async_trait]
|
||||
impl TextGeneration for FastChatEngine {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let _stop_sequences: Vec<String> =
|
||||
options.stop_words.iter().map(|x| x.to_string()).collect();
|
||||
let _stop_sequences: Vec<String> = options
|
||||
.static_stop_words
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect();
|
||||
|
||||
let tokens: Vec<&str> = prompt.split("<MID>").collect();
|
||||
let request = Request {
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ impl VertexAIEngine {
|
|||
impl TextGeneration for VertexAIEngine {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let stop_sequences: Vec<String> = options
|
||||
.stop_words
|
||||
.static_stop_words
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
// vertex supports at most 5 stop sequence.
|
||||
|
|
|
|||
|
|
@ -15,29 +15,6 @@ static size_t N_BATCH = 512;
|
|||
template<class T>
|
||||
using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
||||
|
||||
std::vector<llama_token> tokenize(struct llama_context * ctx, const std::string & text, size_t max_input_length, bool add_bos) {
|
||||
const struct llama_model* model = llama_get_model(ctx);
|
||||
// upper limit for the number of tokens
|
||||
int n_tokens = max_input_length;
|
||||
std::vector<llama_token> result(n_tokens);
|
||||
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
|
||||
GGML_ASSERT(check == -n_tokens);
|
||||
|
||||
int start = check - max_input_length;
|
||||
GGML_ASSERT(start >= 0);
|
||||
result = std::vector<llama_token>(result.begin() + start, result.end());
|
||||
if (add_bos) {
|
||||
result[0] = llama_token_bos(ctx);
|
||||
}
|
||||
} else {
|
||||
result.resize(n_tokens);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||
public:
|
||||
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ impl TextGeneration for LlamaEngine {
|
|||
|
||||
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
||||
engine.start(input_token_ids);
|
||||
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
|
||||
let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words);
|
||||
let mut n_remains = options.max_decoding_length ;
|
||||
while n_remains > 0 {
|
||||
let next_token_id = engine.step();
|
||||
|
|
|
|||
|
|
@ -24,16 +24,35 @@ impl Default for DecodingFactory {
|
|||
}
|
||||
|
||||
impl DecodingFactory {
|
||||
pub fn create_incremental_decoding(
|
||||
pub fn create(
|
||||
&self,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
input_token_ids: &[u32],
|
||||
stop_words: &'static Vec<&'static str>,
|
||||
stop_words: &Vec<String>,
|
||||
static_stop_words: &'static Vec<&'static str>,
|
||||
) -> IncrementalDecoding {
|
||||
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
|
||||
IncrementalDecoding::new(
|
||||
tokenizer,
|
||||
vec![
|
||||
self.get_static_re(static_stop_words),
|
||||
self.get_re(stop_words),
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
input_token_ids,
|
||||
)
|
||||
}
|
||||
|
||||
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
|
||||
fn get_re(&self, stop_words: &Vec<String>) -> Option<Regex> {
|
||||
if !stop_words.is_empty() {
|
||||
Some(create_stop_regex(stop_words))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_static_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
|
||||
if stop_words.is_empty() {
|
||||
None
|
||||
} else {
|
||||
|
|
@ -48,8 +67,8 @@ impl DecodingFactory {
|
|||
}
|
||||
}
|
||||
|
||||
fn create_stop_regex(stop_words: &[&str]) -> Regex {
|
||||
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(*x)).collect();
|
||||
fn create_stop_regex<T: AsRef<str>>(stop_words: &[T]) -> Regex {
|
||||
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(x.as_ref())).collect();
|
||||
|
||||
// (?m) enables multi-line matching mode.
|
||||
// \A means absolute begins of string.
|
||||
|
|
@ -59,7 +78,7 @@ fn create_stop_regex(stop_words: &[&str]) -> Regex {
|
|||
|
||||
pub struct IncrementalDecoding {
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
stop_re: Option<Regex>,
|
||||
stop_re: Vec<Regex>,
|
||||
|
||||
token_ids: Vec<u32>,
|
||||
prefix_offset: usize,
|
||||
|
|
@ -69,7 +88,7 @@ pub struct IncrementalDecoding {
|
|||
}
|
||||
|
||||
impl IncrementalDecoding {
|
||||
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
|
||||
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Vec<Regex>, input_token_ids: &[u32]) -> Self {
|
||||
let text = tokenizer
|
||||
.decode(input_token_ids, /* skip_special_token = */ true)
|
||||
.expect("Cannot decode token from tokenizer.");
|
||||
|
|
@ -110,8 +129,7 @@ impl IncrementalDecoding {
|
|||
|
||||
if !new_text.is_empty() {
|
||||
self.reversed_text = reverse(new_text) + &self.reversed_text;
|
||||
|
||||
if let Some(re) = &self.stop_re {
|
||||
for re in &self.stop_re {
|
||||
if re.find(&self.reversed_text).is_some() {
|
||||
return None;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,10 @@ pub struct TextGenerationOptions {
|
|||
pub sampling_temperature: f32,
|
||||
|
||||
#[builder(default = "&EMPTY_STOP_WORDS")]
|
||||
pub stop_words: &'static Vec<&'static str>,
|
||||
pub static_stop_words: &'static Vec<&'static str>,
|
||||
|
||||
#[builder(default = "vec![]")]
|
||||
pub stop_words: Vec<String>,
|
||||
}
|
||||
|
||||
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ pub async fn completion(
|
|||
.max_input_length(1024 + 512)
|
||||
.max_decoding_length(128)
|
||||
.sampling_temperature(0.1)
|
||||
.stop_words(get_stop_words(&language))
|
||||
.static_stop_words(get_stop_words(&language))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ use std::sync::Arc;
|
|||
use async_stream::stream;
|
||||
use axum::{extract::State, response::IntoResponse, Json};
|
||||
use axum_streams::StreamBodyAs;
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||
use tracing::instrument;
|
||||
|
|
@ -20,11 +19,12 @@ impl GenerateState {
|
|||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
#[schema(example=json!({
|
||||
"prompt": "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n",
|
||||
}))]
|
||||
pub struct GenerateRequest {
|
||||
#[schema(
|
||||
example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n"
|
||||
)]
|
||||
prompt: String,
|
||||
stop_words: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
|
|
@ -47,9 +47,9 @@ pub async fn generate(
|
|||
State(state): State<Arc<GenerateState>>,
|
||||
Json(request): Json<GenerateRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let options = build_options(&request);
|
||||
let (prompt, options) = parse_request(request);
|
||||
Json(GenerateResponse {
|
||||
text: state.engine.generate(&request.prompt, options).await,
|
||||
text: state.engine.generate(&prompt, options).await,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -68,9 +68,9 @@ pub async fn generate_stream(
|
|||
State(state): State<Arc<GenerateState>>,
|
||||
Json(request): Json<GenerateRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let options = build_options(&request);
|
||||
let (prompt, options) = parse_request(request);
|
||||
let s = stream! {
|
||||
for await text in state.engine.generate_stream(&request.prompt, options).await {
|
||||
for await text in state.engine.generate_stream(&prompt, options).await {
|
||||
yield GenerateResponse { text }
|
||||
}
|
||||
};
|
||||
|
|
@ -78,16 +78,17 @@ pub async fn generate_stream(
|
|||
StreamBodyAs::json_nl(s)
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref STOP_WORDS: Vec<&'static str> = vec!["\n\n",];
|
||||
}
|
||||
fn parse_request(request: GenerateRequest) -> (String, TextGenerationOptions) {
|
||||
let mut builder = TextGenerationOptionsBuilder::default();
|
||||
|
||||
fn build_options(_request: &GenerateRequest) -> TextGenerationOptions {
|
||||
TextGenerationOptionsBuilder::default()
|
||||
builder
|
||||
.max_input_length(1024)
|
||||
.max_decoding_length(1024)
|
||||
.sampling_temperature(0.1)
|
||||
.stop_words(&STOP_WORDS)
|
||||
.build()
|
||||
.unwrap()
|
||||
.max_decoding_length(968)
|
||||
.sampling_temperature(0.1);
|
||||
|
||||
if let Some(stop_words) = request.stop_words {
|
||||
builder.stop_words(stop_words);
|
||||
};
|
||||
|
||||
(request.prompt, builder.build().unwrap())
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue