2023-08-02 06:12:51 +00:00
|
|
|
use async_trait::async_trait;
|
|
|
|
|
use derive_builder::Builder;
|
|
|
|
|
|
|
|
|
|
#[derive(Builder, Debug)]
|
|
|
|
|
pub struct TextGenerationOptions {
|
2023-09-08 10:01:03 +00:00
|
|
|
#[builder(default = "1024")]
|
|
|
|
|
pub max_input_length: usize,
|
|
|
|
|
|
2023-08-02 06:12:51 +00:00
|
|
|
#[builder(default = "256")]
|
|
|
|
|
pub max_decoding_length: usize,
|
|
|
|
|
|
|
|
|
|
#[builder(default = "1.0")]
|
|
|
|
|
pub sampling_temperature: f32,
|
|
|
|
|
|
2023-09-09 03:59:42 +00:00
|
|
|
#[builder(default = "&EMPTY_STOP_WORDS")]
|
2023-08-02 06:12:51 +00:00
|
|
|
pub stop_words: &'static Vec<&'static str>,
|
|
|
|
|
}
|
|
|
|
|
|
2023-09-09 03:59:42 +00:00
|
|
|
static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
|
|
|
|
|
|
2023-08-02 06:12:51 +00:00
|
|
|
#[async_trait]
|
2023-09-03 01:59:07 +00:00
|
|
|
pub trait TextGeneration: Sync + Send {
|
2023-08-02 06:12:51 +00:00
|
|
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
|
|
|
|
|
}
|