feat: supports FIM inference [TAB-46] (#183)
* Add prefix / suffix * update * feat: support segments in inference * chore: add tabby.json in model repository to store prompt_template * make prompt_template optional. * download tabby.json in downloadersupport-coreml
parent
950a7a795f
commit
2779da3cba
|
|
@ -2051,6 +2051,12 @@ dependencies = [
|
||||||
"unicode-segmentation",
|
"unicode-segmentation",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "strfmt"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7a8348af2d9fc3258c8733b8d9d8db2e56f54b2363a4b5b81585c7875ed65e65"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "strsim"
|
name = "strsim"
|
||||||
version = "0.10.0"
|
version = "0.10.0"
|
||||||
|
|
@ -2134,6 +2140,7 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serdeconv",
|
"serdeconv",
|
||||||
|
"strfmt",
|
||||||
"strum",
|
"strum",
|
||||||
"tabby-common",
|
"tabby-common",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,12 @@ impl ModelDir {
|
||||||
self.0.join(name).display().to_string()
|
self.0.join(name).display().to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn cache_info_file(&self) -> String {
|
||||||
|
self.path_string(".cache_info.json")
|
||||||
|
}
|
||||||
|
|
||||||
pub fn metadata_file(&self) -> String {
|
pub fn metadata_file(&self) -> String {
|
||||||
self.path_string("metadata.json")
|
self.path_string("tabby.json")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tokenizer_file(&self) -> String {
|
pub fn tokenizer_file(&self) -> String {
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ indicatif = "0.17.3"
|
||||||
futures-util = "0.3.28"
|
futures-util = "0.3.28"
|
||||||
tabby-common = { path = "../tabby-common" }
|
tabby-common = { path = "../tabby-common" }
|
||||||
anyhow = "1.0.71"
|
anyhow = "1.0.71"
|
||||||
|
strfmt = "0.2.4"
|
||||||
|
|
||||||
[dependencies.uuid]
|
[dependencies.uuid]
|
||||||
version = "1.3.3"
|
version = "1.3.3"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
use anyhow::Result;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::Path;
|
||||||
|
use tabby_common::path::ModelDir;
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct CacheInfo {
|
||||||
|
etags: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CacheInfo {
|
||||||
|
pub async fn from(model_id: &str) -> CacheInfo {
|
||||||
|
if let Some(cache_info) = Self::from_local(model_id) {
|
||||||
|
cache_info
|
||||||
|
} else {
|
||||||
|
CacheInfo {
|
||||||
|
etags: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_local(model_id: &str) -> Option<CacheInfo> {
|
||||||
|
let cache_info_file = ModelDir::new(model_id).cache_info_file();
|
||||||
|
if fs::metadata(&cache_info_file).is_ok() {
|
||||||
|
serdeconv::from_json_file(cache_info_file).ok()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn local_cache_key(&self, path: &str) -> Option<&str> {
|
||||||
|
self.etags.get(path).map(|x| x.as_str())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remote_cache_key(res: &reqwest::Response) -> &str {
|
||||||
|
res.headers()
|
||||||
|
.get("etag")
|
||||||
|
.unwrap_or_else(|| panic!("Failed to GET ETAG header from '{}'", res.url()))
|
||||||
|
.to_str()
|
||||||
|
.unwrap_or_else(|_| panic!("Failed to convert ETAG header into string '{}'", res.url()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
|
||||||
|
self.etags.insert(path.to_string(), cache_key.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save(&self, model_id: &str) -> Result<()> {
|
||||||
|
let cache_info_file = ModelDir::new(model_id).cache_info_file();
|
||||||
|
let cache_info_file_path = Path::new(&cache_info_file);
|
||||||
|
serdeconv::to_json_file(self, cache_info_file_path)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_hf() {
|
||||||
|
let hf_metadata = HFCacheInfo::from("TabbyML/J-350M").await;
|
||||||
|
assert_eq!(
|
||||||
|
hf_metadata.transformers_info.auto_model,
|
||||||
|
"AutoModelForCausalLM"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,96 +0,0 @@
|
||||||
use anyhow::Result;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::fs;
|
|
||||||
use std::path::Path;
|
|
||||||
use tabby_common::path::ModelDir;
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct HFTransformersInfo {
|
|
||||||
auto_model: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
struct HFMetadata {
|
|
||||||
transformers_info: HFTransformersInfo,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl HFMetadata {
|
|
||||||
async fn from(model_id: &str) -> HFMetadata {
|
|
||||||
let api_url = format!("https://huggingface.co/api/models/{}", model_id);
|
|
||||||
reqwest::get(&api_url)
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|_| panic!("Failed to GET url '{}'", api_url))
|
|
||||||
.json::<HFMetadata>()
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|_| panic!("Failed to parse HFMetadata '{}'", api_url))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
pub struct Metadata {
|
|
||||||
auto_model: String,
|
|
||||||
etags: HashMap<String, String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Metadata {
|
|
||||||
pub async fn from(model_id: &str) -> Metadata {
|
|
||||||
if let Some(metadata) = Self::from_local(model_id) {
|
|
||||||
metadata
|
|
||||||
} else {
|
|
||||||
let hf_metadata = HFMetadata::from(model_id).await;
|
|
||||||
Metadata {
|
|
||||||
auto_model: hf_metadata.transformers_info.auto_model,
|
|
||||||
etags: HashMap::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn from_local(model_id: &str) -> Option<Metadata> {
|
|
||||||
let metadata_file = ModelDir::new(model_id).metadata_file();
|
|
||||||
if fs::metadata(&metadata_file).is_ok() {
|
|
||||||
let metadata = serdeconv::from_json_file(metadata_file);
|
|
||||||
metadata.ok()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn local_cache_key(&self, path: &str) -> Option<&str> {
|
|
||||||
self.etags.get(path).map(|x| x.as_str())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn remote_cache_key(res: &reqwest::Response) -> &str {
|
|
||||||
res.headers()
|
|
||||||
.get("etag")
|
|
||||||
.unwrap_or_else(|| panic!("Failed to GET ETAG header from '{}'", res.url()))
|
|
||||||
.to_str()
|
|
||||||
.unwrap_or_else(|_| panic!("Failed to convert ETAG header into string '{}'", res.url()))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
|
|
||||||
self.etags.insert(path.to_string(), cache_key.to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn save(&self, model_id: &str) -> Result<()> {
|
|
||||||
let metadata_file = ModelDir::new(model_id).metadata_file();
|
|
||||||
let metadata_file_path = Path::new(&metadata_file);
|
|
||||||
serdeconv::to_json_file(self, metadata_file_path)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_hf() {
|
|
||||||
let hf_metadata = HFMetadata::from("TabbyML/J-350M").await;
|
|
||||||
assert_eq!(
|
|
||||||
hf_metadata.transformers_info.auto_model,
|
|
||||||
"AutoModelForCausalLM"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
mod metadata;
|
mod cache_info;
|
||||||
|
|
||||||
use std::cmp;
|
use std::cmp;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
|
@ -10,6 +10,8 @@ use futures_util::StreamExt;
|
||||||
use indicatif::{ProgressBar, ProgressStyle};
|
use indicatif::{ProgressBar, ProgressStyle};
|
||||||
use tabby_common::path::ModelDir;
|
use tabby_common::path::ModelDir;
|
||||||
|
|
||||||
|
use cache_info::CacheInfo;
|
||||||
|
|
||||||
#[derive(Args)]
|
#[derive(Args)]
|
||||||
pub struct DownloadArgs {
|
pub struct DownloadArgs {
|
||||||
/// model id to fetch.
|
/// model id to fetch.
|
||||||
|
|
@ -26,7 +28,7 @@ pub async fn main(args: &DownloadArgs) {
|
||||||
println!("model '{}' is ready", args.model);
|
println!("model '{}' is ready", args.model);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl metadata::Metadata {
|
impl CacheInfo {
|
||||||
async fn download(&mut self, model_id: &str, path: &str, prefer_local_file: bool) {
|
async fn download(&mut self, model_id: &str, path: &str, prefer_local_file: bool) {
|
||||||
// Create url.
|
// Create url.
|
||||||
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
|
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
|
||||||
|
|
@ -56,28 +58,31 @@ pub async fn download_model(model_id: &str, prefer_local_file: bool) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut metadata = metadata::Metadata::from(model_id).await;
|
let mut cache_info = CacheInfo::from(model_id).await;
|
||||||
|
|
||||||
metadata
|
cache_info
|
||||||
|
.download(model_id, "tabby.json", prefer_local_file)
|
||||||
|
.await;
|
||||||
|
cache_info
|
||||||
.download(model_id, "tokenizer.json", prefer_local_file)
|
.download(model_id, "tokenizer.json", prefer_local_file)
|
||||||
.await;
|
.await;
|
||||||
metadata
|
cache_info
|
||||||
.download(model_id, "ctranslate2/config.json", prefer_local_file)
|
.download(model_id, "ctranslate2/config.json", prefer_local_file)
|
||||||
.await;
|
.await;
|
||||||
metadata
|
cache_info
|
||||||
.download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file)
|
.download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file)
|
||||||
.await;
|
.await;
|
||||||
metadata
|
cache_info
|
||||||
.download(
|
.download(
|
||||||
model_id,
|
model_id,
|
||||||
"ctranslate2/shared_vocabulary.txt",
|
"ctranslate2/shared_vocabulary.txt",
|
||||||
prefer_local_file,
|
prefer_local_file,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
metadata
|
cache_info
|
||||||
.download(model_id, "ctranslate2/model.bin", prefer_local_file)
|
.download(model_id, "ctranslate2/model.bin", prefer_local_file)
|
||||||
.await;
|
.await;
|
||||||
metadata
|
cache_info
|
||||||
.save(model_id)
|
.save(model_id)
|
||||||
.unwrap_or_else(|_| panic!("Failed to save model_id '{}'", model_id));
|
.unwrap_or_else(|_| panic!("Failed to save model_id '{}'", model_id));
|
||||||
}
|
}
|
||||||
|
|
@ -91,7 +96,7 @@ async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) ->
|
||||||
.await
|
.await
|
||||||
.unwrap_or_else(|_| panic!("Failed to GET from '{}'", url));
|
.unwrap_or_else(|_| panic!("Failed to GET from '{}'", url));
|
||||||
|
|
||||||
let remote_cache_key = metadata::Metadata::remote_cache_key(&res).to_string();
|
let remote_cache_key = CacheInfo::remote_cache_key(&res).to_string();
|
||||||
if let Some(local_cache_key) = local_cache_key {
|
if let Some(local_cache_key) = local_cache_key {
|
||||||
if local_cache_key == remote_cache_key {
|
if local_cache_key == remote_cache_key {
|
||||||
return remote_cache_key;
|
return remote_cache_key;
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ use ctranslate2_bindings::{
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use strfmt::{strfmt, strfmt_builder};
|
||||||
use tabby_common::{events, path::ModelDir};
|
use tabby_common::{events, path::ModelDir};
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
|
@ -12,12 +13,27 @@ mod languages;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
pub struct CompletionRequest {
|
pub struct CompletionRequest {
|
||||||
|
#[schema(example = "def fib(n):")]
|
||||||
|
#[deprecated]
|
||||||
|
prompt: Option<String>,
|
||||||
|
|
||||||
/// https://code.visualstudio.com/docs/languages/identifiers
|
/// https://code.visualstudio.com/docs/languages/identifiers
|
||||||
#[schema(example = "python")]
|
#[schema(example = "python")]
|
||||||
language: Option<String>,
|
language: Option<String>,
|
||||||
|
|
||||||
#[schema(example = "def fib(n):")]
|
/// When segments are set, the `prompt` is ignored during the inference.
|
||||||
prompt: String,
|
segments: Option<Segments>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct Segments {
|
||||||
|
/// Content that appears before the cursor in the editor window.
|
||||||
|
#[schema(example = "def fib(n):\n ")]
|
||||||
|
prefix: String,
|
||||||
|
|
||||||
|
/// Content that appears after the cursor in the editor window.
|
||||||
|
#[schema(example = "\n return fib(n - 1) + fib(n - 2)")]
|
||||||
|
suffix: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
|
@ -46,7 +62,20 @@ pub async fn completion(
|
||||||
.sampling_temperature(0.2)
|
.sampling_temperature(0.2)
|
||||||
.build()
|
.build()
|
||||||
.expect("Invalid TextInferenceOptions");
|
.expect("Invalid TextInferenceOptions");
|
||||||
let text = state.engine.inference(&request.prompt, options);
|
|
||||||
|
let prompt = if let Some(Segments { prefix, suffix }) = request.segments {
|
||||||
|
if let Some(prompt_template) = &state.prompt_template {
|
||||||
|
strfmt!(prompt_template, prefix => prefix, suffix => suffix)
|
||||||
|
.expect("Failed to format prompt")
|
||||||
|
} else {
|
||||||
|
// If there's no prompt template, just use prefix.
|
||||||
|
prefix
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
request.prompt.expect("No prompt is set")
|
||||||
|
};
|
||||||
|
|
||||||
|
let text = state.engine.inference(&prompt, options);
|
||||||
let language = request.language.unwrap_or("unknown".into());
|
let language = request.language.unwrap_or("unknown".into());
|
||||||
let filtered_text = languages::remove_stop_words(&language, &text);
|
let filtered_text = languages::remove_stop_words(&language, &text);
|
||||||
|
|
||||||
|
|
@ -61,7 +90,7 @@ pub async fn completion(
|
||||||
events::Event::Completion {
|
events::Event::Completion {
|
||||||
completion_id: &response.id,
|
completion_id: &response.id,
|
||||||
language: &language,
|
language: &language,
|
||||||
prompt: &request.prompt,
|
prompt: &prompt,
|
||||||
choices: vec![events::Choice {
|
choices: vec![events::Choice {
|
||||||
index: 0,
|
index: 0,
|
||||||
text: filtered_text,
|
text: filtered_text,
|
||||||
|
|
@ -74,6 +103,7 @@ pub async fn completion(
|
||||||
|
|
||||||
pub struct CompletionState {
|
pub struct CompletionState {
|
||||||
engine: TextInferenceEngine,
|
engine: TextInferenceEngine,
|
||||||
|
prompt_template: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CompletionState {
|
impl CompletionState {
|
||||||
|
|
@ -92,7 +122,10 @@ impl CompletionState {
|
||||||
.build()
|
.build()
|
||||||
.expect("Invalid TextInferenceEngineCreateOptions");
|
.expect("Invalid TextInferenceEngineCreateOptions");
|
||||||
let engine = TextInferenceEngine::create(options);
|
let engine = TextInferenceEngine::create(options);
|
||||||
Self { engine }
|
Self {
|
||||||
|
engine,
|
||||||
|
prompt_template: metadata.prompt_template,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -107,6 +140,7 @@ fn get_model_dir(model: &str) -> ModelDir {
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Metadata {
|
struct Metadata {
|
||||||
auto_model: String,
|
auto_model: String,
|
||||||
|
prompt_template: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ use utoipa_swagger_ui::SwaggerUi;
|
||||||
events::LogEventRequest,
|
events::LogEventRequest,
|
||||||
completions::CompletionRequest,
|
completions::CompletionRequest,
|
||||||
completions::CompletionResponse,
|
completions::CompletionResponse,
|
||||||
|
completions::Segments,
|
||||||
completions::Choice
|
completions::Choice
|
||||||
))
|
))
|
||||||
)]
|
)]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue