diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 0391cc4..f04f8ad 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -9,7 +9,7 @@ use tabby_inference::{ helpers, TextGeneration, TextGenerationOptions, }; use tokenizers::tokenizer::Tokenizer; -use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::sync::mpsc::{channel, Sender}; use tokio_util::sync::CancellationToken; #[cxx::bridge(namespace = "tabby")] @@ -72,14 +72,14 @@ pub struct CTranslate2EngineOptions { } pub struct InferenceContext { - sender: UnboundedSender, + sender: Sender, decoding: IncrementalDecoding, cancel: CancellationToken, } impl InferenceContext { fn new( - sender: UnboundedSender, + sender: Sender, decoding: IncrementalDecoding, cancel: CancellationToken, ) -> Self { @@ -137,9 +137,9 @@ impl TextGeneration for CTranslate2Engine { let cancel = CancellationToken::new(); let engine = self.engine.clone(); - let (sender, mut receiver) = unbounded_channel(); + let (sender, mut receiver) = channel::(8); let context = InferenceContext::new(sender, decoding, cancel.clone()); - tokio::task::spawn(async move { + tokio::task::spawn_blocking(move || { let context = Box::new(context); engine.inference( context, @@ -178,7 +178,7 @@ fn inference_callback( if context.cancel.is_cancelled() { true } else if let Some(new_text) = context.decoding.next_token(token_id) { - let _ = context.sender.send(new_text); + let _ = context.sender.blocking_send(new_text); false } else { true diff --git a/tests/default.loadtest.js b/tests/default.loadtest.js index fe465f5..3280376 100644 --- a/tests/default.loadtest.js +++ b/tests/default.loadtest.js @@ -20,7 +20,7 @@ export default () => { prompt: "def binarySearch(arr, left, right, x):\n mid = (left +", }); const headers = { "Content-Type": "application/json" }; - const res = http.post("https://tabbyml-tabby-template-space.hf.space/v1/completions", payload, { + const res = http.post("http://api.tabbyml.com/v1/completions", payload, { headers, }); check(res, { success: (r) => r.status === 200 });