fix: use spawn_blocking for sub task without await

Meng Zhang 2023-10-09 19:12:22 -07:00
parent e466c1d6cb
commit 43009235c6
2 changed files with 7 additions and 7 deletions

View File

@ -9,7 +9,7 @@ use tabby_inference::{
helpers, TextGeneration, TextGenerationOptions, helpers, TextGeneration, TextGenerationOptions,
}; };
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::sync::mpsc::{channel, Sender};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
#[cxx::bridge(namespace = "tabby")] #[cxx::bridge(namespace = "tabby")]
@ -72,14 +72,14 @@ pub struct CTranslate2EngineOptions {
} }
pub struct InferenceContext { pub struct InferenceContext {
sender: UnboundedSender<String>, sender: Sender<String>,
decoding: IncrementalDecoding, decoding: IncrementalDecoding,
cancel: CancellationToken, cancel: CancellationToken,
} }
impl InferenceContext { impl InferenceContext {
fn new( fn new(
sender: UnboundedSender<String>, sender: Sender<String>,
decoding: IncrementalDecoding, decoding: IncrementalDecoding,
cancel: CancellationToken, cancel: CancellationToken,
) -> Self { ) -> Self {
@ -137,9 +137,9 @@ impl TextGeneration for CTranslate2Engine {
let cancel = CancellationToken::new(); let cancel = CancellationToken::new();
let engine = self.engine.clone(); let engine = self.engine.clone();
let (sender, mut receiver) = unbounded_channel(); let (sender, mut receiver) = channel::<String>(8);
let context = InferenceContext::new(sender, decoding, cancel.clone()); let context = InferenceContext::new(sender, decoding, cancel.clone());
tokio::task::spawn(async move { tokio::task::spawn_blocking(move || {
let context = Box::new(context); let context = Box::new(context);
engine.inference( engine.inference(
context, context,
@ -178,7 +178,7 @@ fn inference_callback(
if context.cancel.is_cancelled() { if context.cancel.is_cancelled() {
true true
} else if let Some(new_text) = context.decoding.next_token(token_id) { } else if let Some(new_text) = context.decoding.next_token(token_id) {
let _ = context.sender.send(new_text); let _ = context.sender.blocking_send(new_text);
false false
} else { } else {
true true

View File

@ -20,7 +20,7 @@ export default () => {
prompt: "def binarySearch(arr, left, right, x):\n mid = (left +", prompt: "def binarySearch(arr, left, right, x):\n mid = (left +",
}); });
const headers = { "Content-Type": "application/json" }; const headers = { "Content-Type": "application/json" };
const res = http.post("https://tabbyml-tabby-template-space.hf.space/v1/completions", payload, { const res = http.post("http://api.tabbyml.com/v1/completions", payload, {
headers, headers,
}); });
check(res, { success: (r) => r.status === 200 }); check(res, { success: (r) => r.status === 200 });