feat: improve error handling and messages [TAB-58] (#213)
* add fatal macro * switch expect to fatal * improve error handling of serve * improve error handling on download module * improve error handling in scheduler * improve error handling * fmt * fmtimprove-workflow
parent
c0106ad774
commit
4cb672ec39
|
|
@ -172,8 +172,10 @@ fn reverse(s: String) -> String {
|
||||||
s.chars().rev().collect()
|
s.chars().rev().collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &Vec<&str>) -> Regex {
|
fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &[&str]) -> Regex {
|
||||||
let encodings = tokenizer.encode_batch(stop_words.clone(), false).unwrap();
|
let encodings = tokenizer
|
||||||
|
.encode_batch(stop_words.to_owned(), false)
|
||||||
|
.unwrap();
|
||||||
let stop_tokens: Vec<String> = encodings
|
let stop_tokens: Vec<String> = encodings
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| x.get_tokens().join(""))
|
.map(|x| x.get_tokens().join(""))
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
use std::path::PathBuf;
|
use std::{
|
||||||
|
io::{Error, ErrorKind},
|
||||||
|
path::PathBuf,
|
||||||
|
};
|
||||||
|
|
||||||
use filenamify::filenamify;
|
use filenamify::filenamify;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
use crate::path::repositories_dir;
|
use crate::path::{config_file, repositories_dir};
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
|
|
@ -11,8 +14,14 @@ pub struct Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
pub fn load() -> Result<Self, serdeconv::Error> {
|
pub fn load() -> Result<Self, Error> {
|
||||||
serdeconv::from_toml_file(crate::path::config_file().as_path())
|
let file = serdeconv::from_toml_file(crate::path::config_file().as_path());
|
||||||
|
file.map_err(|_| {
|
||||||
|
Error::new(
|
||||||
|
ErrorKind::InvalidData,
|
||||||
|
format!("Config {:?} doesn't exist or is not valid", config_file()),
|
||||||
|
)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
use std::{collections::HashMap, fs, path::Path};
|
use std::{collections::HashMap, fs, path::Path};
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::{anyhow, Result};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tabby_common::path::ModelDir;
|
use tabby_common::path::ModelDir;
|
||||||
|
|
||||||
|
|
@ -33,12 +33,13 @@ impl CacheInfo {
|
||||||
self.etags.get(path).map(|x| x.as_str())
|
self.etags.get(path).map(|x| x.as_str())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn remote_cache_key(res: &reqwest::Response) -> &str {
|
pub fn remote_cache_key(res: &reqwest::Response) -> Result<&str> {
|
||||||
res.headers()
|
let key = res
|
||||||
|
.headers()
|
||||||
.get("etag")
|
.get("etag")
|
||||||
.unwrap_or_else(|| panic!("Failed to GET ETAG header from '{}'", res.url()))
|
.ok_or(anyhow!("etag key missing"))?
|
||||||
.to_str()
|
.to_str()?;
|
||||||
.unwrap_or_else(|_| panic!("Failed to convert ETAG header into string '{}'", res.url()))
|
Ok(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
|
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,19 @@ mod cache_info;
|
||||||
|
|
||||||
use std::{cmp, fs, io::Write, path::Path};
|
use std::{cmp, fs, io::Write, path::Path};
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
use cache_info::CacheInfo;
|
use cache_info::CacheInfo;
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use indicatif::{ProgressBar, ProgressStyle};
|
use indicatif::{ProgressBar, ProgressStyle};
|
||||||
use tabby_common::path::ModelDir;
|
use tabby_common::path::ModelDir;
|
||||||
|
|
||||||
impl CacheInfo {
|
impl CacheInfo {
|
||||||
async fn download(&mut self, model_id: &str, path: &str, prefer_local_file: bool) {
|
async fn download(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
path: &str,
|
||||||
|
prefer_local_file: bool,
|
||||||
|
) -> Result<()> {
|
||||||
// Create url.
|
// Create url.
|
||||||
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
|
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
|
||||||
|
|
||||||
|
|
@ -25,89 +31,89 @@ impl CacheInfo {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !local_file_ready {
|
if !local_file_ready {
|
||||||
let etag = download_file(&url, &filepath, local_cache_key).await;
|
let etag = download_file(&url, &filepath, local_cache_key).await?;
|
||||||
self.set_local_cache_key(path, &etag).await
|
self.set_local_cache_key(path, &etag).await;
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn download_model(model_id: &str, prefer_local_file: bool) {
|
pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<()> {
|
||||||
if fs::metadata(model_id).is_ok() {
|
if fs::metadata(model_id).is_ok() {
|
||||||
// Local path, no need for downloading.
|
// Local path, no need for downloading.
|
||||||
return;
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut cache_info = CacheInfo::from(model_id).await;
|
let mut cache_info = CacheInfo::from(model_id).await;
|
||||||
|
|
||||||
cache_info
|
cache_info
|
||||||
.download(model_id, "tabby.json", prefer_local_file)
|
.download(model_id, "tabby.json", prefer_local_file)
|
||||||
.await;
|
.await?;
|
||||||
cache_info
|
cache_info
|
||||||
.download(model_id, "tokenizer.json", prefer_local_file)
|
.download(model_id, "tokenizer.json", prefer_local_file)
|
||||||
.await;
|
.await?;
|
||||||
cache_info
|
cache_info
|
||||||
.download(model_id, "ctranslate2/config.json", prefer_local_file)
|
.download(model_id, "ctranslate2/config.json", prefer_local_file)
|
||||||
.await;
|
.await?;
|
||||||
cache_info
|
cache_info
|
||||||
.download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file)
|
.download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file)
|
||||||
.await;
|
.await?;
|
||||||
cache_info
|
cache_info
|
||||||
.download(
|
.download(
|
||||||
model_id,
|
model_id,
|
||||||
"ctranslate2/shared_vocabulary.txt",
|
"ctranslate2/shared_vocabulary.txt",
|
||||||
prefer_local_file,
|
prefer_local_file,
|
||||||
)
|
)
|
||||||
.await;
|
.await?;
|
||||||
cache_info
|
cache_info
|
||||||
.download(model_id, "ctranslate2/model.bin", prefer_local_file)
|
.download(model_id, "ctranslate2/model.bin", prefer_local_file)
|
||||||
.await;
|
.await?;
|
||||||
cache_info
|
cache_info.save(model_id)?;
|
||||||
.save(model_id)
|
|
||||||
.unwrap_or_else(|_| panic!("Failed to save model_id '{}'", model_id));
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) -> String {
|
async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) -> Result<String> {
|
||||||
fs::create_dir_all(Path::new(path).parent().unwrap())
|
fs::create_dir_all(Path::new(path).parent().unwrap())?;
|
||||||
.unwrap_or_else(|_| panic!("Failed to create path '{}'", path));
|
|
||||||
|
|
||||||
// Reqwest setup
|
// Reqwest setup
|
||||||
let res = reqwest::get(url)
|
let res = reqwest::get(url).await?;
|
||||||
.await
|
|
||||||
.unwrap_or_else(|_| panic!("Failed to GET from '{}'", url));
|
|
||||||
|
|
||||||
let remote_cache_key = CacheInfo::remote_cache_key(&res).to_string();
|
if !res.status().is_success() {
|
||||||
|
return Err(anyhow!(format!("Invalid url: {}", url)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let remote_cache_key = CacheInfo::remote_cache_key(&res)?.to_string();
|
||||||
if let Some(local_cache_key) = local_cache_key {
|
if let Some(local_cache_key) = local_cache_key {
|
||||||
if local_cache_key == remote_cache_key {
|
if local_cache_key == remote_cache_key {
|
||||||
return remote_cache_key;
|
return Ok(remote_cache_key);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let total_size = res
|
let total_size = res
|
||||||
.content_length()
|
.content_length()
|
||||||
.unwrap_or_else(|| panic!("Failed to get content length from '{}'", url));
|
.ok_or(anyhow!("No content length in headers"))?;
|
||||||
|
|
||||||
// Indicatif setup
|
// Indicatif setup
|
||||||
let pb = ProgressBar::new(total_size);
|
let pb = ProgressBar::new(total_size);
|
||||||
pb.set_style(ProgressStyle::default_bar()
|
pb.set_style(ProgressStyle::default_bar()
|
||||||
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")
|
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
|
||||||
.expect("Invalid progress style")
|
|
||||||
.progress_chars("#>-"));
|
.progress_chars("#>-"));
|
||||||
pb.set_message(format!("Downloading {}", path));
|
pb.set_message(format!("Downloading {}", path));
|
||||||
|
|
||||||
// download chunks
|
// download chunks
|
||||||
let mut file =
|
let mut file = fs::File::create(path)?;
|
||||||
fs::File::create(path).unwrap_or_else(|_| panic!("Failed to create file '{}'", path));
|
|
||||||
let mut downloaded: u64 = 0;
|
let mut downloaded: u64 = 0;
|
||||||
let mut stream = res.bytes_stream();
|
let mut stream = res.bytes_stream();
|
||||||
|
|
||||||
while let Some(item) = stream.next().await {
|
while let Some(item) = stream.next().await {
|
||||||
let chunk = item.expect("Error while downloading file");
|
let chunk = item?;
|
||||||
file.write_all(&chunk).expect("Error while writing to file");
|
file.write_all(&chunk)?;
|
||||||
let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
|
let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
|
||||||
downloaded = new;
|
downloaded = new;
|
||||||
pb.set_position(new);
|
pb.set_position(new);
|
||||||
}
|
}
|
||||||
|
|
||||||
pb.finish_with_message(format!("Downloaded {}", path));
|
pb.finish_with_message(format!("Downloaded {}", path));
|
||||||
remote_cache_key
|
Ok(remote_cache_key)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,11 @@ use tracing::{info, warn};
|
||||||
use walkdir::{DirEntry, WalkDir};
|
use walkdir::{DirEntry, WalkDir};
|
||||||
|
|
||||||
trait RepositoryExt {
|
trait RepositoryExt {
|
||||||
fn index(&self, schema: &Schema, writer: &mut IndexWriter);
|
fn index(&self, schema: &Schema, writer: &mut IndexWriter) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RepositoryExt for Repository {
|
impl RepositoryExt for Repository {
|
||||||
fn index(&self, schema: &Schema, writer: &mut IndexWriter) {
|
fn index(&self, schema: &Schema, writer: &mut IndexWriter) -> Result<()> {
|
||||||
let git_url = schema.get_field("git_url").unwrap();
|
let git_url = schema.get_field("git_url").unwrap();
|
||||||
let filepath = schema.get_field("filepath").unwrap();
|
let filepath = schema.get_field("filepath").unwrap();
|
||||||
let content = schema.get_field("content").unwrap();
|
let content = schema.get_field("content").unwrap();
|
||||||
|
|
@ -36,17 +36,17 @@ impl RepositoryExt for Repository {
|
||||||
let relative_path = entry.path().strip_prefix(dir.as_path()).unwrap();
|
let relative_path = entry.path().strip_prefix(dir.as_path()).unwrap();
|
||||||
if let Ok(file_content) = read_to_string(entry.path()) {
|
if let Ok(file_content) = read_to_string(entry.path()) {
|
||||||
info!("Indexing {:?}", relative_path);
|
info!("Indexing {:?}", relative_path);
|
||||||
writer
|
writer.add_document(doc!(
|
||||||
.add_document(doc!(
|
git_url => self.git_url.clone(),
|
||||||
git_url => self.git_url.clone(),
|
filepath => relative_path.display().to_string(),
|
||||||
filepath => relative_path.display().to_string(),
|
content => file_content,
|
||||||
content => file_content,
|
))?;
|
||||||
))
|
|
||||||
.unwrap();
|
|
||||||
} else {
|
} else {
|
||||||
warn!("Skip {:?}", relative_path);
|
warn!("Skip {:?}", relative_path);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -66,18 +66,20 @@ fn create_schema() -> Schema {
|
||||||
builder.build()
|
builder.build()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn index_repositories(config: &Config) {
|
pub fn index_repositories(config: &Config) -> Result<()> {
|
||||||
let schema = create_schema();
|
let schema = create_schema();
|
||||||
|
|
||||||
fs::create_dir_all(index_dir()).unwrap();
|
fs::create_dir_all(index_dir())?;
|
||||||
let directory = MmapDirectory::open(index_dir()).unwrap();
|
let directory = MmapDirectory::open(index_dir())?;
|
||||||
let index = Index::open_or_create(directory, schema.clone()).unwrap();
|
let index = Index::open_or_create(directory, schema.clone())?;
|
||||||
let mut writer = index.writer(10_000_000).unwrap();
|
let mut writer = index.writer(10_000_000)?;
|
||||||
|
|
||||||
writer.delete_all_documents().unwrap();
|
writer.delete_all_documents()?;
|
||||||
for repository in config.repositories.as_slice() {
|
for repository in config.repositories.as_slice() {
|
||||||
repository.index(&schema, &mut writer);
|
repository.index(&schema, &mut writer)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
writer.commit().unwrap();
|
writer.commit()?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,30 +1,32 @@
|
||||||
mod index;
|
mod index;
|
||||||
mod repository;
|
mod repository;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
use job_scheduler::{Job, JobScheduler};
|
use job_scheduler::{Job, JobScheduler};
|
||||||
use tabby_common::config::Config;
|
use tabby_common::config::Config;
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
|
|
||||||
pub fn scheduler(now: bool) {
|
pub async fn scheduler(now: bool) -> Result<()> {
|
||||||
let config = Config::load();
|
let config = Config::load()?;
|
||||||
if config.is_err() {
|
|
||||||
error!("Please create config.toml before using scheduler");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let config = config.unwrap();
|
|
||||||
let mut scheduler = JobScheduler::new();
|
let mut scheduler = JobScheduler::new();
|
||||||
|
|
||||||
let job = || {
|
let job = || {
|
||||||
info!("Syncing repositories...");
|
info!("Syncing repositories...");
|
||||||
repository::sync_repositories(&config);
|
let ret = repository::sync_repositories(&config);
|
||||||
|
if let Err(err) = ret {
|
||||||
|
error!("Failed to sync repositories, err: '{}'", err);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
info!("Indexing repositories...");
|
info!("Indexing repositories...");
|
||||||
index::index_repositories(&config);
|
let ret = index::index_repositories(&config);
|
||||||
|
if let Err(err) = ret {
|
||||||
|
error!("Failed to index repositories, err: '{}'", err);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if now {
|
if now {
|
||||||
job()
|
job();
|
||||||
} else {
|
} else {
|
||||||
// Every 5 hours.
|
// Every 5 hours.
|
||||||
scheduler.add(Job::new("0 0 1/5 * * * *".parse().unwrap(), job));
|
scheduler.add(Job::new("0 0 1/5 * * * *".parse().unwrap(), job));
|
||||||
|
|
@ -37,6 +39,8 @@ pub fn scheduler(now: bool) {
|
||||||
std::thread::sleep(duration);
|
std::thread::sleep(duration);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -61,7 +65,7 @@ mod tests {
|
||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
|
|
||||||
repository::sync_repositories(&config);
|
repository::sync_repositories(&config).unwrap();
|
||||||
index::index_repositories(&config);
|
index::index_repositories(&config).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,33 +1,32 @@
|
||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
use tabby_common::config::{Config, Repository};
|
use tabby_common::config::{Config, Repository};
|
||||||
|
|
||||||
trait ConfigExt {
|
trait ConfigExt {
|
||||||
fn sync_repositories(&self);
|
fn sync_repositories(&self) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConfigExt for Config {
|
impl ConfigExt for Config {
|
||||||
fn sync_repositories(&self) {
|
fn sync_repositories(&self) -> Result<()> {
|
||||||
for repository in self.repositories.iter() {
|
for repository in self.repositories.iter() {
|
||||||
repository.sync()
|
repository.sync()?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trait RepositoryExt {
|
trait RepositoryExt {
|
||||||
fn sync(&self);
|
fn sync(&self) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RepositoryExt for Repository {
|
impl RepositoryExt for Repository {
|
||||||
fn sync(&self) {
|
fn sync(&self) -> Result<()> {
|
||||||
let dir = self.dir();
|
let dir = self.dir();
|
||||||
let dir_string = dir.display().to_string();
|
let dir_string = dir.display().to_string();
|
||||||
let status = if dir.exists() {
|
let status = if dir.exists() {
|
||||||
Command::new("git")
|
Command::new("git").current_dir(&dir).arg("pull").status()
|
||||||
.current_dir(&dir)
|
|
||||||
.arg("pull")
|
|
||||||
.status()
|
|
||||||
.expect("git could not be executed")
|
|
||||||
} else {
|
} else {
|
||||||
std::fs::create_dir_all(&dir)
|
std::fs::create_dir_all(&dir)
|
||||||
.unwrap_or_else(|_| panic!("Failed to create dir {}", dir_string));
|
.unwrap_or_else(|_| panic!("Failed to create dir {}", dir_string));
|
||||||
|
|
@ -39,20 +38,22 @@ impl RepositoryExt for Repository {
|
||||||
.arg(&self.git_url)
|
.arg(&self.git_url)
|
||||||
.arg(dir)
|
.arg(dir)
|
||||||
.status()
|
.status()
|
||||||
.expect("git could not be executed")
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(code) = status.code() {
|
if let Some(code) = status?.code() {
|
||||||
if code != 0 {
|
if code != 0 {
|
||||||
panic!(
|
return Err(anyhow!(
|
||||||
"Failed to pull remote '{}'\nConsider remove dir '{}' and retry",
|
"Failed to pull remote '{}'. Consider remove dir '{}' and retry",
|
||||||
&self.git_url, &dir_string
|
&self.git_url,
|
||||||
);
|
&dir_string
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sync_repositories(config: &Config) {
|
pub fn sync_repositories(config: &Config) -> Result<()> {
|
||||||
config.sync_repositories();
|
config.sync_repositories()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
use clap::Args;
|
use clap::Args;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
|
use crate::fatal;
|
||||||
|
|
||||||
#[derive(Args)]
|
#[derive(Args)]
|
||||||
pub struct DownloadArgs {
|
pub struct DownloadArgs {
|
||||||
/// model id to fetch.
|
/// model id to fetch.
|
||||||
|
|
@ -13,6 +15,14 @@ pub struct DownloadArgs {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn main(args: &DownloadArgs) {
|
pub async fn main(args: &DownloadArgs) {
|
||||||
tabby_download::download_model(&args.model, args.prefer_local_file).await;
|
tabby_download::download_model(&args.model, args.prefer_local_file)
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|err| {
|
||||||
|
fatal!(
|
||||||
|
"Failed to fetch model due to '{}', is '{}' a valid model id?",
|
||||||
|
err,
|
||||||
|
args.model
|
||||||
|
)
|
||||||
|
});
|
||||||
info!("model '{}' is ready", args.model);
|
info!("model '{}' is ready", args.model);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,13 +14,13 @@ struct Cli {
|
||||||
|
|
||||||
#[derive(Subcommand)]
|
#[derive(Subcommand)]
|
||||||
pub enum Commands {
|
pub enum Commands {
|
||||||
/// Serve the model
|
/// Starts the api endpoint for IDE / Editor extensions.
|
||||||
Serve(serve::ServeArgs),
|
Serve(serve::ServeArgs),
|
||||||
|
|
||||||
/// Download the model
|
/// Download the language model for serving.
|
||||||
Download(download::DownloadArgs),
|
Download(download::DownloadArgs),
|
||||||
|
|
||||||
/// Starts the scheduler process.
|
/// Run scheduler progress for cron jobs integrating external code repositories.
|
||||||
Scheduler(SchedulerArgs),
|
Scheduler(SchedulerArgs),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -42,6 +42,25 @@ async fn main() {
|
||||||
match &cli.command {
|
match &cli.command {
|
||||||
Commands::Serve(args) => serve::main(args).await,
|
Commands::Serve(args) => serve::main(args).await,
|
||||||
Commands::Download(args) => download::main(args).await,
|
Commands::Download(args) => download::main(args).await,
|
||||||
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now),
|
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! fatal {
|
||||||
|
($msg:expr) => {
|
||||||
|
({
|
||||||
|
tracing::error!($msg);
|
||||||
|
std::process::exit(1);
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
($fmt:expr, $($arg:tt)*) => {
|
||||||
|
({
|
||||||
|
tracing::error!($fmt, $($arg)*);
|
||||||
|
std::process::exit(1);
|
||||||
|
})
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ use axum::{
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::fatal;
|
||||||
|
|
||||||
#[derive(rust_embed::RustEmbed)]
|
#[derive(rust_embed::RustEmbed)]
|
||||||
#[folder = "../tabby-admin/dist/"]
|
#[folder = "../tabby-admin/dist/"]
|
||||||
struct AdminAssets;
|
struct AdminAssets;
|
||||||
|
|
@ -26,12 +28,12 @@ where
|
||||||
Response::builder()
|
Response::builder()
|
||||||
.header(header::CONTENT_TYPE, mime.as_ref())
|
.header(header::CONTENT_TYPE, mime.as_ref())
|
||||||
.body(body)
|
.body(body)
|
||||||
.expect("Invalid response")
|
.unwrap_or_else(|_| fatal!("Invalid response"))
|
||||||
}
|
}
|
||||||
None => Response::builder()
|
None => Response::builder()
|
||||||
.status(StatusCode::NOT_FOUND)
|
.status(StatusCode::NOT_FOUND)
|
||||||
.body(boxed(Full::from("404")))
|
.body(boxed(Full::from("404")))
|
||||||
.expect("Invalid response"),
|
.unwrap_or_else(|_| fatal!("Invalid response")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,14 @@ use axum::{extract::State, Json};
|
||||||
use ctranslate2_bindings::{
|
use ctranslate2_bindings::{
|
||||||
TextInferenceEngine, TextInferenceEngineCreateOptionsBuilder, TextInferenceOptionsBuilder,
|
TextInferenceEngine, TextInferenceEngineCreateOptionsBuilder, TextInferenceOptionsBuilder,
|
||||||
};
|
};
|
||||||
|
use hyper::StatusCode;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use strfmt::{strfmt, strfmt_builder};
|
use strfmt::{strfmt, strfmt_builder};
|
||||||
use tabby_common::{events, path::ModelDir};
|
use tabby_common::{events, path::ModelDir};
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
use self::languages::get_stop_words;
|
use self::languages::get_stop_words;
|
||||||
|
use crate::fatal;
|
||||||
|
|
||||||
mod languages;
|
mod languages;
|
||||||
|
|
||||||
|
|
@ -58,20 +60,19 @@ pub struct CompletionResponse {
|
||||||
pub async fn completion(
|
pub async fn completion(
|
||||||
State(state): State<Arc<CompletionState>>,
|
State(state): State<Arc<CompletionState>>,
|
||||||
Json(request): Json<CompletionRequest>,
|
Json(request): Json<CompletionRequest>,
|
||||||
) -> Json<CompletionResponse> {
|
) -> Result<Json<CompletionResponse>, StatusCode> {
|
||||||
let language = request.language.unwrap_or("unknown".into());
|
let language = request.language.unwrap_or("unknown".to_string());
|
||||||
let options = TextInferenceOptionsBuilder::default()
|
let options = TextInferenceOptionsBuilder::default()
|
||||||
.max_decoding_length(128)
|
.max_decoding_length(128)
|
||||||
.sampling_temperature(0.1)
|
.sampling_temperature(0.1)
|
||||||
.stop_words(get_stop_words(&language))
|
.stop_words(get_stop_words(&language))
|
||||||
.build()
|
.build()
|
||||||
.expect("Invalid TextInferenceOptions");
|
.unwrap();
|
||||||
|
|
||||||
let prompt = if let Some(Segments { prefix, suffix }) = request.segments {
|
let prompt = if let Some(Segments { prefix, suffix }) = request.segments {
|
||||||
if let Some(prompt_template) = &state.prompt_template {
|
if let Some(prompt_template) = &state.prompt_template {
|
||||||
if let Some(suffix) = suffix {
|
if let Some(suffix) = suffix {
|
||||||
strfmt!(prompt_template, prefix => prefix, suffix => suffix)
|
strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap()
|
||||||
.expect("Failed to format prompt")
|
|
||||||
} else {
|
} else {
|
||||||
// If suffix is empty, just returns prefix.
|
// If suffix is empty, just returns prefix.
|
||||||
prefix
|
prefix
|
||||||
|
|
@ -81,7 +82,7 @@ pub async fn completion(
|
||||||
prefix
|
prefix
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
request.prompt.expect("No prompt is set")
|
return Err(StatusCode::BAD_REQUEST);
|
||||||
};
|
};
|
||||||
|
|
||||||
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
|
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
|
||||||
|
|
@ -98,10 +99,10 @@ pub async fn completion(
|
||||||
}
|
}
|
||||||
.log();
|
.log();
|
||||||
|
|
||||||
Json(CompletionResponse {
|
Ok(Json(CompletionResponse {
|
||||||
id: completion_id,
|
id: completion_id,
|
||||||
choices: vec![Choice { index: 0, text }],
|
choices: vec![Choice { index: 0, text }],
|
||||||
})
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CompletionState {
|
pub struct CompletionState {
|
||||||
|
|
@ -123,7 +124,7 @@ impl CompletionState {
|
||||||
.device_indices(args.device_indices.clone())
|
.device_indices(args.device_indices.clone())
|
||||||
.num_replicas_per_device(args.num_replicas_per_device)
|
.num_replicas_per_device(args.num_replicas_per_device)
|
||||||
.build()
|
.build()
|
||||||
.expect("Invalid TextInferenceEngineCreateOptions");
|
.unwrap();
|
||||||
let engine = TextInferenceEngine::create(options);
|
let engine = TextInferenceEngine::create(options);
|
||||||
Self {
|
Self {
|
||||||
engine,
|
engine,
|
||||||
|
|
@ -147,5 +148,6 @@ struct Metadata {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
||||||
serdeconv::from_json_file(model_dir.metadata_file()).expect("Invalid metadata")
|
serdeconv::from_json_file(model_dir.metadata_file())
|
||||||
|
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", model_dir.metadata_file()))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,13 +8,13 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use axum::{routing, Router, Server};
|
use axum::{routing, Router, Server};
|
||||||
use clap::{error::ErrorKind, Args, CommandFactory};
|
use clap::Args;
|
||||||
use tower_http::cors::CorsLayer;
|
use tower_http::cors::CorsLayer;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
use crate::Cli;
|
use crate::fatal;
|
||||||
|
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
#[openapi(
|
#[openapi(
|
||||||
|
|
@ -68,7 +68,15 @@ pub async fn main(args: &ServeArgs) {
|
||||||
valid_args(args);
|
valid_args(args);
|
||||||
|
|
||||||
// Ensure model exists.
|
// Ensure model exists.
|
||||||
tabby_download::download_model(&args.model, true).await;
|
tabby_download::download_model(&args.model, true)
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|err| {
|
||||||
|
fatal!(
|
||||||
|
"Failed to fetch model due to '{}', is '{}' a valid model id?",
|
||||||
|
err,
|
||||||
|
args.model
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
||||||
|
|
@ -81,7 +89,7 @@ pub async fn main(args: &ServeArgs) {
|
||||||
Server::bind(&address)
|
Server::bind(&address)
|
||||||
.serve(app.into_make_service())
|
.serve(app.into_make_service())
|
||||||
.await
|
.await
|
||||||
.expect("Error happends during model serving")
|
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn api_router(args: &ServeArgs) -> Router {
|
fn api_router(args: &ServeArgs) -> Router {
|
||||||
|
|
@ -104,21 +112,11 @@ fn fallback(experimental_admin_panel: bool) -> routing::MethodRouter {
|
||||||
|
|
||||||
fn valid_args(args: &ServeArgs) {
|
fn valid_args(args: &ServeArgs) {
|
||||||
if args.device == Device::Cuda && args.num_replicas_per_device != 1 {
|
if args.device == Device::Cuda && args.num_replicas_per_device != 1 {
|
||||||
Cli::command()
|
fatal!("CUDA device only supports 1 replicas per device");
|
||||||
.error(
|
|
||||||
ErrorKind::ValueValidation,
|
|
||||||
"CUDA device only supports 1 replicas per device",
|
|
||||||
)
|
|
||||||
.exit();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
|
if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
|
||||||
{
|
{
|
||||||
Cli::command()
|
fatal!("CPU device only supports device indices = [0]");
|
||||||
.error(
|
|
||||||
ErrorKind::ValueValidation,
|
|
||||||
"CPU device only supports device indices = [0]",
|
|
||||||
)
|
|
||||||
.exit();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue