feat: supports FIM inference [TAB-46] (#183)

* Add prefix / suffix

* update

* feat: support segments in inference

* chore: add tabby.json in model repository to store prompt_template

* make prompt_template optional.

* download tabby.json in downloader
support-coreml
Meng Zhang 2023-06-02 16:47:48 -07:00 committed by GitHub
parent 950a7a795f
commit 2779da3cba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 137 additions and 112 deletions

7
Cargo.lock generated
View File

@ -2051,6 +2051,12 @@ dependencies = [
"unicode-segmentation",
]
[[package]]
name = "strfmt"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a8348af2d9fc3258c8733b8d9d8db2e56f54b2363a4b5b81585c7875ed65e65"
[[package]]
name = "strsim"
version = "0.10.0"
@ -2134,6 +2140,7 @@ dependencies = [
"serde",
"serde_json",
"serdeconv",
"strfmt",
"strum",
"tabby-common",
"tokio",

View File

@ -32,8 +32,12 @@ impl ModelDir {
self.0.join(name).display().to_string()
}
pub fn cache_info_file(&self) -> String {
self.path_string(".cache_info.json")
}
pub fn metadata_file(&self) -> String {
self.path_string("metadata.json")
self.path_string("tabby.json")
}
pub fn tokenizer_file(&self) -> String {

View File

@ -28,6 +28,7 @@ indicatif = "0.17.3"
futures-util = "0.3.28"
tabby-common = { path = "../tabby-common" }
anyhow = "1.0.71"
strfmt = "0.2.4"
[dependencies.uuid]
version = "1.3.3"

View File

@ -0,0 +1,69 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use tabby_common::path::ModelDir;
#[derive(Serialize, Deserialize)]
pub struct CacheInfo {
etags: HashMap<String, String>,
}
impl CacheInfo {
pub async fn from(model_id: &str) -> CacheInfo {
if let Some(cache_info) = Self::from_local(model_id) {
cache_info
} else {
CacheInfo {
etags: HashMap::new(),
}
}
}
fn from_local(model_id: &str) -> Option<CacheInfo> {
let cache_info_file = ModelDir::new(model_id).cache_info_file();
if fs::metadata(&cache_info_file).is_ok() {
serdeconv::from_json_file(cache_info_file).ok()
} else {
None
}
}
pub fn local_cache_key(&self, path: &str) -> Option<&str> {
self.etags.get(path).map(|x| x.as_str())
}
pub fn remote_cache_key(res: &reqwest::Response) -> &str {
res.headers()
.get("etag")
.unwrap_or_else(|| panic!("Failed to GET ETAG header from '{}'", res.url()))
.to_str()
.unwrap_or_else(|_| panic!("Failed to convert ETAG header into string '{}'", res.url()))
}
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
self.etags.insert(path.to_string(), cache_key.to_string());
}
pub fn save(&self, model_id: &str) -> Result<()> {
let cache_info_file = ModelDir::new(model_id).cache_info_file();
let cache_info_file_path = Path::new(&cache_info_file);
serdeconv::to_json_file(self, cache_info_file_path)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_hf() {
let hf_metadata = HFCacheInfo::from("TabbyML/J-350M").await;
assert_eq!(
hf_metadata.transformers_info.auto_model,
"AutoModelForCausalLM"
);
}
}

View File

@ -1,96 +0,0 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use tabby_common::path::ModelDir;
#[derive(Deserialize)]
struct HFTransformersInfo {
auto_model: String,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct HFMetadata {
transformers_info: HFTransformersInfo,
}
impl HFMetadata {
async fn from(model_id: &str) -> HFMetadata {
let api_url = format!("https://huggingface.co/api/models/{}", model_id);
reqwest::get(&api_url)
.await
.unwrap_or_else(|_| panic!("Failed to GET url '{}'", api_url))
.json::<HFMetadata>()
.await
.unwrap_or_else(|_| panic!("Failed to parse HFMetadata '{}'", api_url))
}
}
#[derive(Serialize, Deserialize)]
pub struct Metadata {
auto_model: String,
etags: HashMap<String, String>,
}
impl Metadata {
pub async fn from(model_id: &str) -> Metadata {
if let Some(metadata) = Self::from_local(model_id) {
metadata
} else {
let hf_metadata = HFMetadata::from(model_id).await;
Metadata {
auto_model: hf_metadata.transformers_info.auto_model,
etags: HashMap::new(),
}
}
}
fn from_local(model_id: &str) -> Option<Metadata> {
let metadata_file = ModelDir::new(model_id).metadata_file();
if fs::metadata(&metadata_file).is_ok() {
let metadata = serdeconv::from_json_file(metadata_file);
metadata.ok()
} else {
None
}
}
pub fn local_cache_key(&self, path: &str) -> Option<&str> {
self.etags.get(path).map(|x| x.as_str())
}
pub fn remote_cache_key(res: &reqwest::Response) -> &str {
res.headers()
.get("etag")
.unwrap_or_else(|| panic!("Failed to GET ETAG header from '{}'", res.url()))
.to_str()
.unwrap_or_else(|_| panic!("Failed to convert ETAG header into string '{}'", res.url()))
}
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
self.etags.insert(path.to_string(), cache_key.to_string());
}
pub fn save(&self, model_id: &str) -> Result<()> {
let metadata_file = ModelDir::new(model_id).metadata_file();
let metadata_file_path = Path::new(&metadata_file);
serdeconv::to_json_file(self, metadata_file_path)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_hf() {
let hf_metadata = HFMetadata::from("TabbyML/J-350M").await;
assert_eq!(
hf_metadata.transformers_info.auto_model,
"AutoModelForCausalLM"
);
}
}

View File

@ -1,4 +1,4 @@
mod metadata;
mod cache_info;
use std::cmp;
use std::fs;
@ -10,6 +10,8 @@ use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use tabby_common::path::ModelDir;
use cache_info::CacheInfo;
#[derive(Args)]
pub struct DownloadArgs {
/// model id to fetch.
@ -26,7 +28,7 @@ pub async fn main(args: &DownloadArgs) {
println!("model '{}' is ready", args.model);
}
impl metadata::Metadata {
impl CacheInfo {
async fn download(&mut self, model_id: &str, path: &str, prefer_local_file: bool) {
// Create url.
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
@ -56,28 +58,31 @@ pub async fn download_model(model_id: &str, prefer_local_file: bool) {
return;
}
let mut metadata = metadata::Metadata::from(model_id).await;
let mut cache_info = CacheInfo::from(model_id).await;
metadata
cache_info
.download(model_id, "tabby.json", prefer_local_file)
.await;
cache_info
.download(model_id, "tokenizer.json", prefer_local_file)
.await;
metadata
cache_info
.download(model_id, "ctranslate2/config.json", prefer_local_file)
.await;
metadata
cache_info
.download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file)
.await;
metadata
cache_info
.download(
model_id,
"ctranslate2/shared_vocabulary.txt",
prefer_local_file,
)
.await;
metadata
cache_info
.download(model_id, "ctranslate2/model.bin", prefer_local_file)
.await;
metadata
cache_info
.save(model_id)
.unwrap_or_else(|_| panic!("Failed to save model_id '{}'", model_id));
}
@ -91,7 +96,7 @@ async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) ->
.await
.unwrap_or_else(|_| panic!("Failed to GET from '{}'", url));
let remote_cache_key = metadata::Metadata::remote_cache_key(&res).to_string();
let remote_cache_key = CacheInfo::remote_cache_key(&res).to_string();
if let Some(local_cache_key) = local_cache_key {
if local_cache_key == remote_cache_key {
return remote_cache_key;

View File

@ -5,6 +5,7 @@ use ctranslate2_bindings::{
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::sync::Arc;
use strfmt::{strfmt, strfmt_builder};
use tabby_common::{events, path::ModelDir};
use utoipa::ToSchema;
@ -12,12 +13,27 @@ mod languages;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct CompletionRequest {
#[schema(example = "def fib(n):")]
#[deprecated]
prompt: Option<String>,
/// https://code.visualstudio.com/docs/languages/identifiers
#[schema(example = "python")]
language: Option<String>,
#[schema(example = "def fib(n):")]
prompt: String,
/// When segments are set, the `prompt` is ignored during the inference.
segments: Option<Segments>,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Segments {
/// Content that appears before the cursor in the editor window.
#[schema(example = "def fib(n):\n ")]
prefix: String,
/// Content that appears after the cursor in the editor window.
#[schema(example = "\n return fib(n - 1) + fib(n - 2)")]
suffix: String,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
@ -46,7 +62,20 @@ pub async fn completion(
.sampling_temperature(0.2)
.build()
.expect("Invalid TextInferenceOptions");
let text = state.engine.inference(&request.prompt, options);
let prompt = if let Some(Segments { prefix, suffix }) = request.segments {
if let Some(prompt_template) = &state.prompt_template {
strfmt!(prompt_template, prefix => prefix, suffix => suffix)
.expect("Failed to format prompt")
} else {
// If there's no prompt template, just use prefix.
prefix
}
} else {
request.prompt.expect("No prompt is set")
};
let text = state.engine.inference(&prompt, options);
let language = request.language.unwrap_or("unknown".into());
let filtered_text = languages::remove_stop_words(&language, &text);
@ -61,7 +90,7 @@ pub async fn completion(
events::Event::Completion {
completion_id: &response.id,
language: &language,
prompt: &request.prompt,
prompt: &prompt,
choices: vec![events::Choice {
index: 0,
text: filtered_text,
@ -74,6 +103,7 @@ pub async fn completion(
pub struct CompletionState {
engine: TextInferenceEngine,
prompt_template: Option<String>,
}
impl CompletionState {
@ -92,7 +122,10 @@ impl CompletionState {
.build()
.expect("Invalid TextInferenceEngineCreateOptions");
let engine = TextInferenceEngine::create(options);
Self { engine }
Self {
engine,
prompt_template: metadata.prompt_template,
}
}
}
@ -107,6 +140,7 @@ fn get_model_dir(model: &str) -> ModelDir {
#[derive(Deserialize)]
struct Metadata {
auto_model: String,
prompt_template: Option<String>,
}
fn read_metadata(model_dir: &ModelDir) -> Metadata {

View File

@ -20,6 +20,7 @@ use utoipa_swagger_ui::SwaggerUi;
events::LogEventRequest,
completions::CompletionRequest,
completions::CompletionResponse,
completions::Segments,
completions::Choice
))
)]