From 2779da3cbab0bd90e275aa553489946e0748cebe Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 2 Jun 2023 16:47:48 -0700 Subject: [PATCH] feat: supports FIM inference [TAB-46] (#183) * Add prefix / suffix * update * feat: support segments in inference * chore: add tabby.json in model repository to store prompt_template * make prompt_template optional. * download tabby.json in downloader --- Cargo.lock | 7 ++ crates/tabby-common/src/path.rs | 6 +- crates/tabby/Cargo.toml | 1 + crates/tabby/src/download/cache_info.rs | 69 ++++++++++++++++++ crates/tabby/src/download/metadata.rs | 96 ------------------------- crates/tabby/src/download/mod.rs | 25 ++++--- crates/tabby/src/serve/completions.rs | 44 ++++++++++-- crates/tabby/src/serve/mod.rs | 1 + 8 files changed, 137 insertions(+), 112 deletions(-) create mode 100644 crates/tabby/src/download/cache_info.rs delete mode 100644 crates/tabby/src/download/metadata.rs diff --git a/Cargo.lock b/Cargo.lock index 4018898..73dc841 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2051,6 +2051,12 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "strfmt" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a8348af2d9fc3258c8733b8d9d8db2e56f54b2363a4b5b81585c7875ed65e65" + [[package]] name = "strsim" version = "0.10.0" @@ -2134,6 +2140,7 @@ dependencies = [ "serde", "serde_json", "serdeconv", + "strfmt", "strum", "tabby-common", "tokio", diff --git a/crates/tabby-common/src/path.rs b/crates/tabby-common/src/path.rs index 67459ca..4306ff2 100644 --- a/crates/tabby-common/src/path.rs +++ b/crates/tabby-common/src/path.rs @@ -32,8 +32,12 @@ impl ModelDir { self.0.join(name).display().to_string() } + pub fn cache_info_file(&self) -> String { + self.path_string(".cache_info.json") + } + pub fn metadata_file(&self) -> String { - self.path_string("metadata.json") + self.path_string("tabby.json") } pub fn tokenizer_file(&self) -> String { diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 26d25e3..ea0118f 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -28,6 +28,7 @@ indicatif = "0.17.3" futures-util = "0.3.28" tabby-common = { path = "../tabby-common" } anyhow = "1.0.71" +strfmt = "0.2.4" [dependencies.uuid] version = "1.3.3" diff --git a/crates/tabby/src/download/cache_info.rs b/crates/tabby/src/download/cache_info.rs new file mode 100644 index 0000000..d746b9e --- /dev/null +++ b/crates/tabby/src/download/cache_info.rs @@ -0,0 +1,69 @@ +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fs; +use std::path::Path; +use tabby_common::path::ModelDir; + +#[derive(Serialize, Deserialize)] +pub struct CacheInfo { + etags: HashMap, +} + +impl CacheInfo { + pub async fn from(model_id: &str) -> CacheInfo { + if let Some(cache_info) = Self::from_local(model_id) { + cache_info + } else { + CacheInfo { + etags: HashMap::new(), + } + } + } + + fn from_local(model_id: &str) -> Option { + let cache_info_file = ModelDir::new(model_id).cache_info_file(); + if fs::metadata(&cache_info_file).is_ok() { + serdeconv::from_json_file(cache_info_file).ok() + } else { + None + } + } + + pub fn local_cache_key(&self, path: &str) -> Option<&str> { + self.etags.get(path).map(|x| x.as_str()) + } + + pub fn remote_cache_key(res: &reqwest::Response) -> &str { + res.headers() + .get("etag") + .unwrap_or_else(|| panic!("Failed to GET ETAG header from '{}'", res.url())) + .to_str() + .unwrap_or_else(|_| panic!("Failed to convert ETAG header into string '{}'", res.url())) + } + + pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) { + self.etags.insert(path.to_string(), cache_key.to_string()); + } + + pub fn save(&self, model_id: &str) -> Result<()> { + let cache_info_file = ModelDir::new(model_id).cache_info_file(); + let cache_info_file_path = Path::new(&cache_info_file); + serdeconv::to_json_file(self, cache_info_file_path)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_hf() { + let hf_metadata = HFCacheInfo::from("TabbyML/J-350M").await; + assert_eq!( + hf_metadata.transformers_info.auto_model, + "AutoModelForCausalLM" + ); + } +} diff --git a/crates/tabby/src/download/metadata.rs b/crates/tabby/src/download/metadata.rs deleted file mode 100644 index d46d3b8..0000000 --- a/crates/tabby/src/download/metadata.rs +++ /dev/null @@ -1,96 +0,0 @@ -use anyhow::Result; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::fs; -use std::path::Path; -use tabby_common::path::ModelDir; - -#[derive(Deserialize)] -struct HFTransformersInfo { - auto_model: String, -} - -#[derive(Deserialize)] -#[serde(rename_all = "camelCase")] -struct HFMetadata { - transformers_info: HFTransformersInfo, -} - -impl HFMetadata { - async fn from(model_id: &str) -> HFMetadata { - let api_url = format!("https://huggingface.co/api/models/{}", model_id); - reqwest::get(&api_url) - .await - .unwrap_or_else(|_| panic!("Failed to GET url '{}'", api_url)) - .json::() - .await - .unwrap_or_else(|_| panic!("Failed to parse HFMetadata '{}'", api_url)) - } -} - -#[derive(Serialize, Deserialize)] -pub struct Metadata { - auto_model: String, - etags: HashMap, -} - -impl Metadata { - pub async fn from(model_id: &str) -> Metadata { - if let Some(metadata) = Self::from_local(model_id) { - metadata - } else { - let hf_metadata = HFMetadata::from(model_id).await; - Metadata { - auto_model: hf_metadata.transformers_info.auto_model, - etags: HashMap::new(), - } - } - } - - fn from_local(model_id: &str) -> Option { - let metadata_file = ModelDir::new(model_id).metadata_file(); - if fs::metadata(&metadata_file).is_ok() { - let metadata = serdeconv::from_json_file(metadata_file); - metadata.ok() - } else { - None - } - } - - pub fn local_cache_key(&self, path: &str) -> Option<&str> { - self.etags.get(path).map(|x| x.as_str()) - } - - pub fn remote_cache_key(res: &reqwest::Response) -> &str { - res.headers() - .get("etag") - .unwrap_or_else(|| panic!("Failed to GET ETAG header from '{}'", res.url())) - .to_str() - .unwrap_or_else(|_| panic!("Failed to convert ETAG header into string '{}'", res.url())) - } - - pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) { - self.etags.insert(path.to_string(), cache_key.to_string()); - } - - pub fn save(&self, model_id: &str) -> Result<()> { - let metadata_file = ModelDir::new(model_id).metadata_file(); - let metadata_file_path = Path::new(&metadata_file); - serdeconv::to_json_file(self, metadata_file_path)?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_hf() { - let hf_metadata = HFMetadata::from("TabbyML/J-350M").await; - assert_eq!( - hf_metadata.transformers_info.auto_model, - "AutoModelForCausalLM" - ); - } -} diff --git a/crates/tabby/src/download/mod.rs b/crates/tabby/src/download/mod.rs index bc0a7ac..87c995e 100644 --- a/crates/tabby/src/download/mod.rs +++ b/crates/tabby/src/download/mod.rs @@ -1,4 +1,4 @@ -mod metadata; +mod cache_info; use std::cmp; use std::fs; @@ -10,6 +10,8 @@ use futures_util::StreamExt; use indicatif::{ProgressBar, ProgressStyle}; use tabby_common::path::ModelDir; +use cache_info::CacheInfo; + #[derive(Args)] pub struct DownloadArgs { /// model id to fetch. @@ -26,7 +28,7 @@ pub async fn main(args: &DownloadArgs) { println!("model '{}' is ready", args.model); } -impl metadata::Metadata { +impl CacheInfo { async fn download(&mut self, model_id: &str, path: &str, prefer_local_file: bool) { // Create url. let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path); @@ -56,28 +58,31 @@ pub async fn download_model(model_id: &str, prefer_local_file: bool) { return; } - let mut metadata = metadata::Metadata::from(model_id).await; + let mut cache_info = CacheInfo::from(model_id).await; - metadata + cache_info + .download(model_id, "tabby.json", prefer_local_file) + .await; + cache_info .download(model_id, "tokenizer.json", prefer_local_file) .await; - metadata + cache_info .download(model_id, "ctranslate2/config.json", prefer_local_file) .await; - metadata + cache_info .download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file) .await; - metadata + cache_info .download( model_id, "ctranslate2/shared_vocabulary.txt", prefer_local_file, ) .await; - metadata + cache_info .download(model_id, "ctranslate2/model.bin", prefer_local_file) .await; - metadata + cache_info .save(model_id) .unwrap_or_else(|_| panic!("Failed to save model_id '{}'", model_id)); } @@ -91,7 +96,7 @@ async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) -> .await .unwrap_or_else(|_| panic!("Failed to GET from '{}'", url)); - let remote_cache_key = metadata::Metadata::remote_cache_key(&res).to_string(); + let remote_cache_key = CacheInfo::remote_cache_key(&res).to_string(); if let Some(local_cache_key) = local_cache_key { if local_cache_key == remote_cache_key { return remote_cache_key; diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index fe0ffea..2c5b5c2 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -5,6 +5,7 @@ use ctranslate2_bindings::{ use serde::{Deserialize, Serialize}; use std::path::Path; use std::sync::Arc; +use strfmt::{strfmt, strfmt_builder}; use tabby_common::{events, path::ModelDir}; use utoipa::ToSchema; @@ -12,12 +13,27 @@ mod languages; #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] pub struct CompletionRequest { + #[schema(example = "def fib(n):")] + #[deprecated] + prompt: Option, + /// https://code.visualstudio.com/docs/languages/identifiers #[schema(example = "python")] language: Option, - #[schema(example = "def fib(n):")] - prompt: String, + /// When segments are set, the `prompt` is ignored during the inference. + segments: Option, +} + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +pub struct Segments { + /// Content that appears before the cursor in the editor window. + #[schema(example = "def fib(n):\n ")] + prefix: String, + + /// Content that appears after the cursor in the editor window. + #[schema(example = "\n return fib(n - 1) + fib(n - 2)")] + suffix: String, } #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] @@ -46,7 +62,20 @@ pub async fn completion( .sampling_temperature(0.2) .build() .expect("Invalid TextInferenceOptions"); - let text = state.engine.inference(&request.prompt, options); + + let prompt = if let Some(Segments { prefix, suffix }) = request.segments { + if let Some(prompt_template) = &state.prompt_template { + strfmt!(prompt_template, prefix => prefix, suffix => suffix) + .expect("Failed to format prompt") + } else { + // If there's no prompt template, just use prefix. + prefix + } + } else { + request.prompt.expect("No prompt is set") + }; + + let text = state.engine.inference(&prompt, options); let language = request.language.unwrap_or("unknown".into()); let filtered_text = languages::remove_stop_words(&language, &text); @@ -61,7 +90,7 @@ pub async fn completion( events::Event::Completion { completion_id: &response.id, language: &language, - prompt: &request.prompt, + prompt: &prompt, choices: vec![events::Choice { index: 0, text: filtered_text, @@ -74,6 +103,7 @@ pub async fn completion( pub struct CompletionState { engine: TextInferenceEngine, + prompt_template: Option, } impl CompletionState { @@ -92,7 +122,10 @@ impl CompletionState { .build() .expect("Invalid TextInferenceEngineCreateOptions"); let engine = TextInferenceEngine::create(options); - Self { engine } + Self { + engine, + prompt_template: metadata.prompt_template, + } } } @@ -107,6 +140,7 @@ fn get_model_dir(model: &str) -> ModelDir { #[derive(Deserialize)] struct Metadata { auto_model: String, + prompt_template: Option, } fn read_metadata(model_dir: &ModelDir) -> Metadata { diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 895a133..81e0ddb 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -20,6 +20,7 @@ use utoipa_swagger_ui::SwaggerUi; events::LogEventRequest, completions::CompletionRequest, completions::CompletionResponse, + completions::Segments, completions::Choice )) )]