2023-09-29 13:06:47 +00:00
|
|
|
pub mod decoding;
|
|
|
|
|
|
2023-08-02 06:12:51 +00:00
|
|
|
use async_trait::async_trait;
|
|
|
|
|
use derive_builder::Builder;
|
2023-09-28 17:20:50 +00:00
|
|
|
use futures::stream::BoxStream;
|
2023-10-16 00:24:44 +00:00
|
|
|
use tabby_common::languages::Language;
|
2023-08-02 06:12:51 +00:00
|
|
|
|
|
|
|
|
#[derive(Builder, Debug)]
|
|
|
|
|
pub struct TextGenerationOptions {
|
2023-09-08 10:01:03 +00:00
|
|
|
#[builder(default = "1024")]
|
|
|
|
|
pub max_input_length: usize,
|
|
|
|
|
|
2023-08-02 06:12:51 +00:00
|
|
|
#[builder(default = "256")]
|
|
|
|
|
pub max_decoding_length: usize,
|
|
|
|
|
|
|
|
|
|
#[builder(default = "1.0")]
|
|
|
|
|
pub sampling_temperature: f32,
|
|
|
|
|
|
2023-10-16 00:24:44 +00:00
|
|
|
#[builder(default = "&tabby_common::languages::UNKNOWN_LANGUAGE")]
|
|
|
|
|
pub language: &'static Language,
|
2023-08-02 06:12:51 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[async_trait]
|
2023-09-03 01:59:07 +00:00
|
|
|
pub trait TextGeneration: Sync + Send {
|
2023-08-02 06:12:51 +00:00
|
|
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
|
2023-09-28 17:20:50 +00:00
|
|
|
async fn generate_stream(
|
|
|
|
|
&self,
|
|
|
|
|
prompt: &str,
|
|
|
|
|
options: TextGenerationOptions,
|
|
|
|
|
) -> BoxStream<String>;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub mod helpers {
|
|
|
|
|
use async_stream::stream;
|
|
|
|
|
use futures::{pin_mut, stream::BoxStream, Stream, StreamExt};
|
|
|
|
|
|
|
|
|
|
pub async fn stream_to_string(s: impl Stream<Item = String>) -> String {
|
|
|
|
|
pin_mut!(s);
|
|
|
|
|
|
|
|
|
|
let mut text = "".to_owned();
|
|
|
|
|
while let Some(value) = s.next().await {
|
|
|
|
|
text += &value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
text
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub async fn string_to_stream(s: String) -> BoxStream<'static, String> {
|
|
|
|
|
let stream = stream! {
|
|
|
|
|
yield s
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Box::pin(stream)
|
|
|
|
|
}
|
2023-08-02 06:12:51 +00:00
|
|
|
}
|