fix format

Meng Zhang 2023-10-09 16:35:12 -07:00
parent 90e85c79c2
commit 6923d1b90f
1 changed files with 10 additions and 8 deletions

View File

@ -3,7 +3,10 @@ use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use derive_builder::Builder; use derive_builder::Builder;
use futures::stream::BoxStream; use futures::stream::BoxStream;
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions, decoding::{DecodingFactory, IncrementalDecoding}}; use tabby_inference::{
decoding::{DecodingFactory, IncrementalDecoding},
helpers, TextGeneration, TextGenerationOptions,
};
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
@ -73,10 +76,7 @@ pub struct InferenceContext {
impl InferenceContext { impl InferenceContext {
fn new(decoding: IncrementalDecoding, cancel: CancellationToken) -> Self { fn new(decoding: IncrementalDecoding, cancel: CancellationToken) -> Self {
InferenceContext { InferenceContext { decoding, cancel }
decoding,
cancel,
}
} }
} }
@ -115,9 +115,11 @@ impl TextGeneration for CTranslate2Engine {
let cancel_for_inference = cancel.clone(); let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard(); let _guard = cancel.drop_guard();
let decoding = self let decoding = self.decoding_factory.create_incremental_decoding(
.decoding_factory self.tokenizer.clone(),
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words); truncate_tokens(encoding.get_ids(), options.max_input_length),
options.stop_words,
);
let context = InferenceContext::new(decoding, cancel_for_inference); let context = InferenceContext::new(decoding, cancel_for_inference);
let output_ids = tokio::task::spawn_blocking(move || { let output_ids = tokio::task::spawn_blocking(move || {