feat: add ggml fp16 / q8_0 files (#407)

* feat: add ggml fp16 / q8_0 files

* add q8_0.gguf to optional download files

* add download options to split ctranslate2 files and ggml files
release-0.2
Meng Zhang 2023-09-07 01:12:29 +08:00 committed by GitHub
parent 007951b550
commit e780031ed6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 67 additions and 34 deletions

View File

@ -86,7 +86,7 @@ impl ModelDir {
self.path_string("ctranslate2") self.path_string("ctranslate2")
} }
pub fn ggml_model_file(&self) -> String { pub fn ggml_q8_0_file(&self) -> String {
self.path_string("ggml/default.gguf") self.path_string("ggml/q8_0.gguf")
} }
} }

View File

@ -57,7 +57,12 @@ impl CacheInfo {
} }
} }
pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<()> { pub async fn download_model(
model_id: &str,
download_ctranslate2_files: bool,
download_ggml_files: bool,
prefer_local_file: bool,
) -> Result<()> {
if fs::metadata(model_id).is_ok() { if fs::metadata(model_id).is_ok() {
// Local path, no need for downloading. // Local path, no need for downloading.
return Ok(()); return Ok(());
@ -67,12 +72,18 @@ pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<(
let mut cache_info = CacheInfo::from(model_id).await; let mut cache_info = CacheInfo::from(model_id).await;
let optional_files = vec![ let mut optional_files = vec![];
"ctranslate2/vocabulary.txt", if download_ctranslate2_files {
"ctranslate2/shared_vocabulary.txt", optional_files.push("ctranslate2/vocabulary.txt");
"ctranslate2/vocabulary.json", optional_files.push("ctranslate2/shared_vocabulary.txt");
"ctranslate2/shared_vocabulary.json", optional_files.push("ctranslate2/vocabulary.json");
]; optional_files.push("ctranslate2/shared_vocabulary.json");
}
if download_ggml_files {
optional_files.push("ggml/q8_0.gguf");
}
for path in optional_files { for path in optional_files {
cache_info cache_info
.download( .download(
@ -84,12 +95,13 @@ pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<(
.await?; .await?;
} }
let required_files = vec![ let mut required_files = vec!["tabby.json", "tokenizer.json"];
"tabby.json",
"tokenizer.json", if download_ctranslate2_files {
"ctranslate2/config.json", required_files.push("ctranslate2/config.json");
"ctranslate2/model.bin", required_files.push("ctranslate2/model.bin");
]; }
for path in required_files { for path in required_files {
cache_info cache_info
.download( .download(

View File

@ -15,14 +15,19 @@ pub struct DownloadArgs {
} }
pub async fn main(args: &DownloadArgs) { pub async fn main(args: &DownloadArgs) {
tabby_download::download_model(&args.model, args.prefer_local_file) tabby_download::download_model(
.await &args.model,
.unwrap_or_else(|err| { /* download_ctranslate2_files= */ true,
fatal!( /* download_ggml_files= */ true,
"Failed to fetch model due to '{}', is '{}' a valid model id?", args.prefer_local_file,
err, )
args.model .await
) .unwrap_or_else(|err| {
}); fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
)
});
info!("model '{}' is ready", args.model); info!("model '{}' is ready", args.model);
} }

View File

@ -186,7 +186,7 @@ fn create_ctranslate2_engine(
#[cfg(all(target_os = "macos", target_arch = "aarch64"))] #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn create_llama_engine(model_dir: &ModelDir) -> Box<dyn TextGeneration> { fn create_llama_engine(model_dir: &ModelDir) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default() let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
.model_path(model_dir.ggml_model_file()) .model_path(model_dir.ggml_q8_0_file())
.tokenizer_path(model_dir.tokenizer_file()) .tokenizer_path(model_dir.tokenizer_file())
.build() .build()
.unwrap(); .unwrap();

View File

@ -116,19 +116,35 @@ pub struct ServeArgs {
compute_type: ComputeType, compute_type: ComputeType,
} }
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn should_download_ggml_files(device: &Device) -> bool {
false
}
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn should_download_ggml_files(device: &Device) -> bool {
*device == Device::Metal
}
pub async fn main(config: &Config, args: &ServeArgs) { pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args); valid_args(args);
// Ensure model exists. // Ensure model exists.
tabby_download::download_model(&args.model, true) tabby_download::download_model(
.await &args.model,
.unwrap_or_else(|err| { /* download_ctranslate2_files= */
fatal!( !should_download_ggml_files(&args.device),
"Failed to fetch model due to '{}', is '{}' a valid model id?", /* download_ggml_files= */ should_download_ggml_files(&args.device),
err, /* prefer_local_file= */ true,
args.model )
) .await
}); .unwrap_or_else(|err| {
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
)
});
info!("Starting server, this might takes a few minutes..."); info!("Starting server, this might takes a few minutes...");
let app = Router::new() let app = Router::new()