feat: supports FIM inference [TAB-46] (#183)
* Add prefix / suffix * update * feat: support segments in inference * chore: add tabby.json in model repository to store prompt_template * make prompt_template optional. * download tabby.json in downloadersupport-coreml
parent
950a7a795f
commit
2779da3cba
|
|
@ -2051,6 +2051,12 @@ dependencies = [
|
|||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strfmt"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a8348af2d9fc3258c8733b8d9d8db2e56f54b2363a4b5b81585c7875ed65e65"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.10.0"
|
||||
|
|
@ -2134,6 +2140,7 @@ dependencies = [
|
|||
"serde",
|
||||
"serde_json",
|
||||
"serdeconv",
|
||||
"strfmt",
|
||||
"strum",
|
||||
"tabby-common",
|
||||
"tokio",
|
||||
|
|
|
|||
|
|
@ -32,8 +32,12 @@ impl ModelDir {
|
|||
self.0.join(name).display().to_string()
|
||||
}
|
||||
|
||||
pub fn cache_info_file(&self) -> String {
|
||||
self.path_string(".cache_info.json")
|
||||
}
|
||||
|
||||
pub fn metadata_file(&self) -> String {
|
||||
self.path_string("metadata.json")
|
||||
self.path_string("tabby.json")
|
||||
}
|
||||
|
||||
pub fn tokenizer_file(&self) -> String {
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ indicatif = "0.17.3"
|
|||
futures-util = "0.3.28"
|
||||
tabby-common = { path = "../tabby-common" }
|
||||
anyhow = "1.0.71"
|
||||
strfmt = "0.2.4"
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1.3.3"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use tabby_common::path::ModelDir;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct CacheInfo {
|
||||
etags: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl CacheInfo {
|
||||
pub async fn from(model_id: &str) -> CacheInfo {
|
||||
if let Some(cache_info) = Self::from_local(model_id) {
|
||||
cache_info
|
||||
} else {
|
||||
CacheInfo {
|
||||
etags: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from_local(model_id: &str) -> Option<CacheInfo> {
|
||||
let cache_info_file = ModelDir::new(model_id).cache_info_file();
|
||||
if fs::metadata(&cache_info_file).is_ok() {
|
||||
serdeconv::from_json_file(cache_info_file).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn local_cache_key(&self, path: &str) -> Option<&str> {
|
||||
self.etags.get(path).map(|x| x.as_str())
|
||||
}
|
||||
|
||||
pub fn remote_cache_key(res: &reqwest::Response) -> &str {
|
||||
res.headers()
|
||||
.get("etag")
|
||||
.unwrap_or_else(|| panic!("Failed to GET ETAG header from '{}'", res.url()))
|
||||
.to_str()
|
||||
.unwrap_or_else(|_| panic!("Failed to convert ETAG header into string '{}'", res.url()))
|
||||
}
|
||||
|
||||
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
|
||||
self.etags.insert(path.to_string(), cache_key.to_string());
|
||||
}
|
||||
|
||||
pub fn save(&self, model_id: &str) -> Result<()> {
|
||||
let cache_info_file = ModelDir::new(model_id).cache_info_file();
|
||||
let cache_info_file_path = Path::new(&cache_info_file);
|
||||
serdeconv::to_json_file(self, cache_info_file_path)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hf() {
|
||||
let hf_metadata = HFCacheInfo::from("TabbyML/J-350M").await;
|
||||
assert_eq!(
|
||||
hf_metadata.transformers_info.auto_model,
|
||||
"AutoModelForCausalLM"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,96 +0,0 @@
|
|||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use tabby_common::path::ModelDir;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct HFTransformersInfo {
|
||||
auto_model: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct HFMetadata {
|
||||
transformers_info: HFTransformersInfo,
|
||||
}
|
||||
|
||||
impl HFMetadata {
|
||||
async fn from(model_id: &str) -> HFMetadata {
|
||||
let api_url = format!("https://huggingface.co/api/models/{}", model_id);
|
||||
reqwest::get(&api_url)
|
||||
.await
|
||||
.unwrap_or_else(|_| panic!("Failed to GET url '{}'", api_url))
|
||||
.json::<HFMetadata>()
|
||||
.await
|
||||
.unwrap_or_else(|_| panic!("Failed to parse HFMetadata '{}'", api_url))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Metadata {
|
||||
auto_model: String,
|
||||
etags: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Metadata {
|
||||
pub async fn from(model_id: &str) -> Metadata {
|
||||
if let Some(metadata) = Self::from_local(model_id) {
|
||||
metadata
|
||||
} else {
|
||||
let hf_metadata = HFMetadata::from(model_id).await;
|
||||
Metadata {
|
||||
auto_model: hf_metadata.transformers_info.auto_model,
|
||||
etags: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from_local(model_id: &str) -> Option<Metadata> {
|
||||
let metadata_file = ModelDir::new(model_id).metadata_file();
|
||||
if fs::metadata(&metadata_file).is_ok() {
|
||||
let metadata = serdeconv::from_json_file(metadata_file);
|
||||
metadata.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn local_cache_key(&self, path: &str) -> Option<&str> {
|
||||
self.etags.get(path).map(|x| x.as_str())
|
||||
}
|
||||
|
||||
pub fn remote_cache_key(res: &reqwest::Response) -> &str {
|
||||
res.headers()
|
||||
.get("etag")
|
||||
.unwrap_or_else(|| panic!("Failed to GET ETAG header from '{}'", res.url()))
|
||||
.to_str()
|
||||
.unwrap_or_else(|_| panic!("Failed to convert ETAG header into string '{}'", res.url()))
|
||||
}
|
||||
|
||||
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
|
||||
self.etags.insert(path.to_string(), cache_key.to_string());
|
||||
}
|
||||
|
||||
pub fn save(&self, model_id: &str) -> Result<()> {
|
||||
let metadata_file = ModelDir::new(model_id).metadata_file();
|
||||
let metadata_file_path = Path::new(&metadata_file);
|
||||
serdeconv::to_json_file(self, metadata_file_path)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hf() {
|
||||
let hf_metadata = HFMetadata::from("TabbyML/J-350M").await;
|
||||
assert_eq!(
|
||||
hf_metadata.transformers_info.auto_model,
|
||||
"AutoModelForCausalLM"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
mod metadata;
|
||||
mod cache_info;
|
||||
|
||||
use std::cmp;
|
||||
use std::fs;
|
||||
|
|
@ -10,6 +10,8 @@ use futures_util::StreamExt;
|
|||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use tabby_common::path::ModelDir;
|
||||
|
||||
use cache_info::CacheInfo;
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct DownloadArgs {
|
||||
/// model id to fetch.
|
||||
|
|
@ -26,7 +28,7 @@ pub async fn main(args: &DownloadArgs) {
|
|||
println!("model '{}' is ready", args.model);
|
||||
}
|
||||
|
||||
impl metadata::Metadata {
|
||||
impl CacheInfo {
|
||||
async fn download(&mut self, model_id: &str, path: &str, prefer_local_file: bool) {
|
||||
// Create url.
|
||||
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
|
||||
|
|
@ -56,28 +58,31 @@ pub async fn download_model(model_id: &str, prefer_local_file: bool) {
|
|||
return;
|
||||
}
|
||||
|
||||
let mut metadata = metadata::Metadata::from(model_id).await;
|
||||
let mut cache_info = CacheInfo::from(model_id).await;
|
||||
|
||||
metadata
|
||||
cache_info
|
||||
.download(model_id, "tabby.json", prefer_local_file)
|
||||
.await;
|
||||
cache_info
|
||||
.download(model_id, "tokenizer.json", prefer_local_file)
|
||||
.await;
|
||||
metadata
|
||||
cache_info
|
||||
.download(model_id, "ctranslate2/config.json", prefer_local_file)
|
||||
.await;
|
||||
metadata
|
||||
cache_info
|
||||
.download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file)
|
||||
.await;
|
||||
metadata
|
||||
cache_info
|
||||
.download(
|
||||
model_id,
|
||||
"ctranslate2/shared_vocabulary.txt",
|
||||
prefer_local_file,
|
||||
)
|
||||
.await;
|
||||
metadata
|
||||
cache_info
|
||||
.download(model_id, "ctranslate2/model.bin", prefer_local_file)
|
||||
.await;
|
||||
metadata
|
||||
cache_info
|
||||
.save(model_id)
|
||||
.unwrap_or_else(|_| panic!("Failed to save model_id '{}'", model_id));
|
||||
}
|
||||
|
|
@ -91,7 +96,7 @@ async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) ->
|
|||
.await
|
||||
.unwrap_or_else(|_| panic!("Failed to GET from '{}'", url));
|
||||
|
||||
let remote_cache_key = metadata::Metadata::remote_cache_key(&res).to_string();
|
||||
let remote_cache_key = CacheInfo::remote_cache_key(&res).to_string();
|
||||
if let Some(local_cache_key) = local_cache_key {
|
||||
if local_cache_key == remote_cache_key {
|
||||
return remote_cache_key;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use ctranslate2_bindings::{
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use strfmt::{strfmt, strfmt_builder};
|
||||
use tabby_common::{events, path::ModelDir};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
|
|
@ -12,12 +13,27 @@ mod languages;
|
|||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct CompletionRequest {
|
||||
#[schema(example = "def fib(n):")]
|
||||
#[deprecated]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// https://code.visualstudio.com/docs/languages/identifiers
|
||||
#[schema(example = "python")]
|
||||
language: Option<String>,
|
||||
|
||||
#[schema(example = "def fib(n):")]
|
||||
prompt: String,
|
||||
/// When segments are set, the `prompt` is ignored during the inference.
|
||||
segments: Option<Segments>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct Segments {
|
||||
/// Content that appears before the cursor in the editor window.
|
||||
#[schema(example = "def fib(n):\n ")]
|
||||
prefix: String,
|
||||
|
||||
/// Content that appears after the cursor in the editor window.
|
||||
#[schema(example = "\n return fib(n - 1) + fib(n - 2)")]
|
||||
suffix: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
|
|
@ -46,7 +62,20 @@ pub async fn completion(
|
|||
.sampling_temperature(0.2)
|
||||
.build()
|
||||
.expect("Invalid TextInferenceOptions");
|
||||
let text = state.engine.inference(&request.prompt, options);
|
||||
|
||||
let prompt = if let Some(Segments { prefix, suffix }) = request.segments {
|
||||
if let Some(prompt_template) = &state.prompt_template {
|
||||
strfmt!(prompt_template, prefix => prefix, suffix => suffix)
|
||||
.expect("Failed to format prompt")
|
||||
} else {
|
||||
// If there's no prompt template, just use prefix.
|
||||
prefix
|
||||
}
|
||||
} else {
|
||||
request.prompt.expect("No prompt is set")
|
||||
};
|
||||
|
||||
let text = state.engine.inference(&prompt, options);
|
||||
let language = request.language.unwrap_or("unknown".into());
|
||||
let filtered_text = languages::remove_stop_words(&language, &text);
|
||||
|
||||
|
|
@ -61,7 +90,7 @@ pub async fn completion(
|
|||
events::Event::Completion {
|
||||
completion_id: &response.id,
|
||||
language: &language,
|
||||
prompt: &request.prompt,
|
||||
prompt: &prompt,
|
||||
choices: vec![events::Choice {
|
||||
index: 0,
|
||||
text: filtered_text,
|
||||
|
|
@ -74,6 +103,7 @@ pub async fn completion(
|
|||
|
||||
pub struct CompletionState {
|
||||
engine: TextInferenceEngine,
|
||||
prompt_template: Option<String>,
|
||||
}
|
||||
|
||||
impl CompletionState {
|
||||
|
|
@ -92,7 +122,10 @@ impl CompletionState {
|
|||
.build()
|
||||
.expect("Invalid TextInferenceEngineCreateOptions");
|
||||
let engine = TextInferenceEngine::create(options);
|
||||
Self { engine }
|
||||
Self {
|
||||
engine,
|
||||
prompt_template: metadata.prompt_template,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -107,6 +140,7 @@ fn get_model_dir(model: &str) -> ModelDir {
|
|||
#[derive(Deserialize)]
|
||||
struct Metadata {
|
||||
auto_model: String,
|
||||
prompt_template: Option<String>,
|
||||
}
|
||||
|
||||
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ use utoipa_swagger_ui::SwaggerUi;
|
|||
events::LogEventRequest,
|
||||
completions::CompletionRequest,
|
||||
completions::CompletionResponse,
|
||||
completions::Segments,
|
||||
completions::Choice
|
||||
))
|
||||
)]
|
||||
|
|
|
|||
Loading…
Reference in New Issue