tabby/crates/tabby-inference/src/lib.rs

55 lines
1.3 KiB
Rust
Raw Normal View History

use async_trait::async_trait;
use derive_builder::Builder;
use futures::stream::BoxStream;
#[derive(Builder, Debug)]
pub struct TextGenerationOptions {
#[builder(default = "1024")]
pub max_input_length: usize,
#[builder(default = "256")]
pub max_decoding_length: usize,
#[builder(default = "1.0")]
pub sampling_temperature: f32,
#[builder(default = "&EMPTY_STOP_WORDS")]
pub stop_words: &'static Vec<&'static str>,
}
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
#[async_trait]
pub trait TextGeneration: Sync + Send {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
async fn generate_stream(
&self,
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String>;
}
pub mod helpers {
use async_stream::stream;
use futures::{pin_mut, stream::BoxStream, Stream, StreamExt};
pub async fn stream_to_string(s: impl Stream<Item = String>) -> String {
pin_mut!(s);
let mut text = "".to_owned();
while let Some(value) = s.next().await {
text += &value;
}
text
}
pub async fn string_to_stream(s: String) -> BoxStream<'static, String> {
let stream = stream! {
yield s
};
Box::pin(stream)
}
}