feat: add ggml fp16 / q8_0 files (#407)

* feat: add ggml fp16 / q8_0 files

* add q8_0.gguf to optional download files

* add download options to split ctranslate2 files and ggml files
release-0.2
Meng Zhang 2023-09-07 01:12:29 +08:00 committed by GitHub
parent 007951b550
commit e780031ed6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 67 additions and 34 deletions

View File

@ -86,7 +86,7 @@ impl ModelDir {
self.path_string("ctranslate2")
}
pub fn ggml_model_file(&self) -> String {
self.path_string("ggml/default.gguf")
pub fn ggml_q8_0_file(&self) -> String {
self.path_string("ggml/q8_0.gguf")
}
}

View File

@ -57,7 +57,12 @@ impl CacheInfo {
}
}
pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<()> {
pub async fn download_model(
model_id: &str,
download_ctranslate2_files: bool,
download_ggml_files: bool,
prefer_local_file: bool,
) -> Result<()> {
if fs::metadata(model_id).is_ok() {
// Local path, no need for downloading.
return Ok(());
@ -67,12 +72,18 @@ pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<(
let mut cache_info = CacheInfo::from(model_id).await;
let optional_files = vec![
"ctranslate2/vocabulary.txt",
"ctranslate2/shared_vocabulary.txt",
"ctranslate2/vocabulary.json",
"ctranslate2/shared_vocabulary.json",
];
let mut optional_files = vec![];
if download_ctranslate2_files {
optional_files.push("ctranslate2/vocabulary.txt");
optional_files.push("ctranslate2/shared_vocabulary.txt");
optional_files.push("ctranslate2/vocabulary.json");
optional_files.push("ctranslate2/shared_vocabulary.json");
}
if download_ggml_files {
optional_files.push("ggml/q8_0.gguf");
}
for path in optional_files {
cache_info
.download(
@ -84,12 +95,13 @@ pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<(
.await?;
}
let required_files = vec![
"tabby.json",
"tokenizer.json",
"ctranslate2/config.json",
"ctranslate2/model.bin",
];
let mut required_files = vec!["tabby.json", "tokenizer.json"];
if download_ctranslate2_files {
required_files.push("ctranslate2/config.json");
required_files.push("ctranslate2/model.bin");
}
for path in required_files {
cache_info
.download(

View File

@ -15,14 +15,19 @@ pub struct DownloadArgs {
}
pub async fn main(args: &DownloadArgs) {
tabby_download::download_model(&args.model, args.prefer_local_file)
.await
.unwrap_or_else(|err| {
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
)
});
tabby_download::download_model(
&args.model,
/* download_ctranslate2_files= */ true,
/* download_ggml_files= */ true,
args.prefer_local_file,
)
.await
.unwrap_or_else(|err| {
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
)
});
info!("model '{}' is ready", args.model);
}

View File

@ -186,7 +186,7 @@ fn create_ctranslate2_engine(
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn create_llama_engine(model_dir: &ModelDir) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
.model_path(model_dir.ggml_model_file())
.model_path(model_dir.ggml_q8_0_file())
.tokenizer_path(model_dir.tokenizer_file())
.build()
.unwrap();

View File

@ -116,19 +116,35 @@ pub struct ServeArgs {
compute_type: ComputeType,
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn should_download_ggml_files(device: &Device) -> bool {
false
}
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn should_download_ggml_files(device: &Device) -> bool {
*device == Device::Metal
}
pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args);
// Ensure model exists.
tabby_download::download_model(&args.model, true)
.await
.unwrap_or_else(|err| {
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
)
});
tabby_download::download_model(
&args.model,
/* download_ctranslate2_files= */
!should_download_ggml_files(&args.device),
/* download_ggml_files= */ should_download_ggml_files(&args.device),
/* prefer_local_file= */ true,
)
.await
.unwrap_or_else(|err| {
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
)
});
info!("Starting server, this might takes a few minutes...");
let app = Router::new()