diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 25ce843..ffc07ea 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -137,7 +137,7 @@ impl TextGeneration for CTranslate2Engine { let decoding = self .decoding_factory - .create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words); + .create(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), &options.stop_words, options.static_stop_words); let (sender, mut receiver) = channel::(8); let context = InferenceContext::new(sender, decoding, cancel_for_inference); diff --git a/crates/http-api-bindings/src/fastchat.rs b/crates/http-api-bindings/src/fastchat.rs index f71e048..c08ad01 100644 --- a/crates/http-api-bindings/src/fastchat.rs +++ b/crates/http-api-bindings/src/fastchat.rs @@ -58,8 +58,11 @@ impl FastChatEngine { #[async_trait] impl TextGeneration for FastChatEngine { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { - let _stop_sequences: Vec = - options.stop_words.iter().map(|x| x.to_string()).collect(); + let _stop_sequences: Vec = options + .static_stop_words + .iter() + .map(|x| x.to_string()) + .collect(); let tokens: Vec<&str> = prompt.split("").collect(); let request = Request { diff --git a/crates/http-api-bindings/src/vertex_ai.rs b/crates/http-api-bindings/src/vertex_ai.rs index 1d74b59..6d83210 100644 --- a/crates/http-api-bindings/src/vertex_ai.rs +++ b/crates/http-api-bindings/src/vertex_ai.rs @@ -67,7 +67,7 @@ impl VertexAIEngine { impl TextGeneration for VertexAIEngine { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { let stop_sequences: Vec = options - .stop_words + .static_stop_words .iter() .map(|x| x.to_string()) // vertex supports at most 5 stop sequence. diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index e433c41..0f1a8cd 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -15,29 +15,6 @@ static size_t N_BATCH = 512; template using owned = std::unique_ptr>; -std::vector tokenize(struct llama_context * ctx, const std::string & text, size_t max_input_length, bool add_bos) { - const struct llama_model* model = llama_get_model(ctx); - // upper limit for the number of tokens - int n_tokens = max_input_length; - std::vector result(n_tokens); - n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos); - if (n_tokens < 0) { - result.resize(-n_tokens); - int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos); - GGML_ASSERT(check == -n_tokens); - - int start = check - max_input_length; - GGML_ASSERT(start >= 0); - result = std::vector(result.begin() + start, result.end()); - if (add_bos) { - result[0] = llama_token_bos(ctx); - } - } else { - result.resize(n_tokens); - } - return result; -} - class TextInferenceEngineImpl : public TextInferenceEngine { public: TextInferenceEngineImpl(owned model, owned ctx) : diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index ec42066..02a171f 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -70,7 +70,7 @@ impl TextGeneration for LlamaEngine { let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length); engine.start(input_token_ids); - let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words); + let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words); let mut n_remains = options.max_decoding_length ; while n_remains > 0 { let next_token_id = engine.step(); diff --git a/crates/tabby-inference/src/decoding.rs b/crates/tabby-inference/src/decoding.rs index 78cb1a7..ac7a60a 100644 --- a/crates/tabby-inference/src/decoding.rs +++ b/crates/tabby-inference/src/decoding.rs @@ -24,16 +24,35 @@ impl Default for DecodingFactory { } impl DecodingFactory { - pub fn create_incremental_decoding( + pub fn create( &self, tokenizer: Arc, input_token_ids: &[u32], - stop_words: &'static Vec<&'static str>, + stop_words: &Vec, + static_stop_words: &'static Vec<&'static str>, ) -> IncrementalDecoding { - IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids) + IncrementalDecoding::new( + tokenizer, + vec![ + self.get_static_re(static_stop_words), + self.get_re(stop_words), + ] + .into_iter() + .flatten() + .collect(), + input_token_ids, + ) } - fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option { + fn get_re(&self, stop_words: &Vec) -> Option { + if !stop_words.is_empty() { + Some(create_stop_regex(stop_words)) + } else { + None + } + } + + fn get_static_re(&self, stop_words: &'static Vec<&'static str>) -> Option { if stop_words.is_empty() { None } else { @@ -48,8 +67,8 @@ impl DecodingFactory { } } -fn create_stop_regex(stop_words: &[&str]) -> Regex { - let tokens: Vec = stop_words.iter().map(|x| reverse(*x)).collect(); +fn create_stop_regex>(stop_words: &[T]) -> Regex { + let tokens: Vec = stop_words.iter().map(|x| reverse(x.as_ref())).collect(); // (?m) enables multi-line matching mode. // \A means absolute begins of string. @@ -59,7 +78,7 @@ fn create_stop_regex(stop_words: &[&str]) -> Regex { pub struct IncrementalDecoding { tokenizer: Arc, - stop_re: Option, + stop_re: Vec, token_ids: Vec, prefix_offset: usize, @@ -69,7 +88,7 @@ pub struct IncrementalDecoding { } impl IncrementalDecoding { - pub fn new(tokenizer: Arc, stop_re: Option, input_token_ids: &[u32]) -> Self { + pub fn new(tokenizer: Arc, stop_re: Vec, input_token_ids: &[u32]) -> Self { let text = tokenizer .decode(input_token_ids, /* skip_special_token = */ true) .expect("Cannot decode token from tokenizer."); @@ -110,8 +129,7 @@ impl IncrementalDecoding { if !new_text.is_empty() { self.reversed_text = reverse(new_text) + &self.reversed_text; - - if let Some(re) = &self.stop_re { + for re in &self.stop_re { if re.find(&self.reversed_text).is_some() { return None; } diff --git a/crates/tabby-inference/src/lib.rs b/crates/tabby-inference/src/lib.rs index 495785e..a36f4ad 100644 --- a/crates/tabby-inference/src/lib.rs +++ b/crates/tabby-inference/src/lib.rs @@ -16,7 +16,10 @@ pub struct TextGenerationOptions { pub sampling_temperature: f32, #[builder(default = "&EMPTY_STOP_WORDS")] - pub stop_words: &'static Vec<&'static str>, + pub static_stop_words: &'static Vec<&'static str>, + + #[builder(default = "vec![]")] + pub stop_words: Vec, } static EMPTY_STOP_WORDS: Vec<&'static str> = vec![]; diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 73113e3..26bcfe7 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -80,7 +80,7 @@ pub async fn completion( .max_input_length(1024 + 512) .max_decoding_length(128) .sampling_temperature(0.1) - .stop_words(get_stop_words(&language)) + .static_stop_words(get_stop_words(&language)) .build() .unwrap(); diff --git a/crates/tabby/src/serve/generate.rs b/crates/tabby/src/serve/generate.rs index cfebe9a..1ec3f1c 100644 --- a/crates/tabby/src/serve/generate.rs +++ b/crates/tabby/src/serve/generate.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use async_stream::stream; use axum::{extract::State, response::IntoResponse, Json}; use axum_streams::StreamBodyAs; -use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; use tracing::instrument; @@ -20,11 +19,12 @@ impl GenerateState { } #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +#[schema(example=json!({ + "prompt": "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n", +}))] pub struct GenerateRequest { - #[schema( - example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n" - )] prompt: String, + stop_words: Option>, } #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] @@ -47,9 +47,9 @@ pub async fn generate( State(state): State>, Json(request): Json, ) -> impl IntoResponse { - let options = build_options(&request); + let (prompt, options) = parse_request(request); Json(GenerateResponse { - text: state.engine.generate(&request.prompt, options).await, + text: state.engine.generate(&prompt, options).await, }) } @@ -68,9 +68,9 @@ pub async fn generate_stream( State(state): State>, Json(request): Json, ) -> impl IntoResponse { - let options = build_options(&request); + let (prompt, options) = parse_request(request); let s = stream! { - for await text in state.engine.generate_stream(&request.prompt, options).await { + for await text in state.engine.generate_stream(&prompt, options).await { yield GenerateResponse { text } } }; @@ -78,16 +78,17 @@ pub async fn generate_stream( StreamBodyAs::json_nl(s) } -lazy_static! { - static ref STOP_WORDS: Vec<&'static str> = vec!["\n\n",]; -} +fn parse_request(request: GenerateRequest) -> (String, TextGenerationOptions) { + let mut builder = TextGenerationOptionsBuilder::default(); -fn build_options(_request: &GenerateRequest) -> TextGenerationOptions { - TextGenerationOptionsBuilder::default() + builder .max_input_length(1024) - .max_decoding_length(1024) - .sampling_temperature(0.1) - .stop_words(&STOP_WORDS) - .build() - .unwrap() + .max_decoding_length(968) + .sampling_temperature(0.1); + + if let Some(stop_words) = request.stop_words { + builder.stop_words(stop_words); + }; + + (request.prompt, builder.build().unwrap()) }