feat: add server.completion_timeout to control timeout of /v1/completion (#637)

* feat: add server.completion_timeout to control timeout of /v1/completion

* Update config.rs
add-llama-model-converter
Meng Zhang 2023-10-25 15:05:23 -07:00 committed by GitHub
parent d6296bb121
commit 21ec60eddf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 16 deletions

View File

@ -11,7 +11,10 @@ use crate::path::{config_file, repositories_dir};
#[derive(Serialize, Deserialize, Default)] #[derive(Serialize, Deserialize, Default)]
pub struct Config { pub struct Config {
#[serde(default)] #[serde(default)]
pub repositories: Vec<Repository>, pub repositories: Vec<RepositoryConfig>,
#[serde(default)]
pub server: ServerConfig,
} }
impl Config { impl Config {
@ -37,11 +40,11 @@ impl Config {
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct Repository { pub struct RepositoryConfig {
pub git_url: String, pub git_url: String,
} }
impl Repository { impl RepositoryConfig {
pub fn dir(&self) -> PathBuf { pub fn dir(&self) -> PathBuf {
if self.is_local_dir() { if self.is_local_dir() {
let path = self.git_url.strip_prefix("file://").unwrap(); let path = self.git_url.strip_prefix("file://").unwrap();
@ -56,9 +59,23 @@ impl Repository {
} }
} }
#[derive(Serialize, Deserialize)]
pub struct ServerConfig {
/// The timeout in seconds for the /v1/completion api.
pub completion_timeout: u64,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
completion_timeout: 30,
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{Config, Repository}; use super::{Config, RepositoryConfig};
#[test] #[test]
fn it_parses_empty_config() { fn it_parses_empty_config() {
@ -68,13 +85,13 @@ mod tests {
#[test] #[test]
fn it_parses_local_dir() { fn it_parses_local_dir() {
let repo = Repository { let repo = RepositoryConfig {
git_url: "file:///home/user".to_owned(), git_url: "file:///home/user".to_owned(),
}; };
assert!(repo.is_local_dir()); assert!(repo.is_local_dir());
assert_eq!(repo.dir().display().to_string(), "/home/user"); assert_eq!(repo.dir().display().to_string(), "/home/user");
let repo = Repository { let repo = RepositoryConfig {
git_url: "https://github.com/TabbyML/tabby".to_owned(), git_url: "https://github.com/TabbyML/tabby".to_owned(),
}; };
assert!(!repo.is_local_dir()); assert!(!repo.is_local_dir());

View File

@ -11,7 +11,7 @@ use ignore::{DirEntry, Walk};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use serde_jsonlines::WriteExt; use serde_jsonlines::WriteExt;
use tabby_common::{ use tabby_common::{
config::{Config, Repository}, config::{Config, RepositoryConfig},
path::dataset_dir, path::dataset_dir,
SourceFile, SourceFile,
}; };
@ -22,7 +22,7 @@ trait RepositoryExt {
fn create_dataset(&self, writer: &mut impl Write) -> Result<()>; fn create_dataset(&self, writer: &mut impl Write) -> Result<()>;
} }
impl RepositoryExt for Repository { impl RepositoryExt for RepositoryConfig {
fn create_dataset(&self, writer: &mut impl Write) -> Result<()> { fn create_dataset(&self, writer: &mut impl Write) -> Result<()> {
let dir = self.dir(); let dir = self.dir();

View File

@ -1,7 +1,7 @@
use std::process::Command; use std::process::Command;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use tabby_common::config::{Config, Repository}; use tabby_common::config::{Config, RepositoryConfig};
trait ConfigExt { trait ConfigExt {
fn sync_repositories(&self) -> Result<()>; fn sync_repositories(&self) -> Result<()>;
@ -27,7 +27,7 @@ trait RepositoryExt {
fn sync(&self) -> Result<()>; fn sync(&self) -> Result<()>;
} }
impl RepositoryExt for Repository { impl RepositoryExt for RepositoryConfig {
fn sync(&self) -> Result<()> { fn sync(&self) -> Result<()> {
let dir = self.dir(); let dir = self.dir();
let dir_string = dir.display().to_string(); let dir_string = dir.display().to_string();

View File

@ -3,7 +3,7 @@ mod tests {
use std::fs::create_dir_all; use std::fs::create_dir_all;
use tabby_common::{ use tabby_common::{
config::{Config, Repository}, config::{Config, RepositoryConfig, ServerConfig},
path::set_tabby_root, path::set_tabby_root,
}; };
use temp_testdir::*; use temp_testdir::*;
@ -17,9 +17,10 @@ mod tests {
set_tabby_root(root.to_path_buf()); set_tabby_root(root.to_path_buf());
let config = Config { let config = Config {
repositories: vec![Repository { repositories: vec![RepositoryConfig {
git_url: "https://github.com/TabbyML/interview-questions".to_owned(), git_url: "https://github.com/TabbyML/interview-questions".to_owned(),
}], }],
server: ServerConfig::default(),
}; };
config.save(); config.save();

View File

@ -125,7 +125,7 @@ fn should_download_ggml_files(device: &Device) -> bool {
*device == Device::Metal *device == Device::Metal
} }
pub async fn main(_config: &Config, args: &ServeArgs) { pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args); valid_args(args);
if args.device != Device::ExperimentalHttp { if args.device != Device::ExperimentalHttp {
@ -146,7 +146,7 @@ pub async fn main(_config: &Config, args: &ServeArgs) {
.route("/", routing::get(playground::handler)) .route("/", routing::get(playground::handler))
.route("/index.txt", routing::get(playground::handler)) .route("/index.txt", routing::get(playground::handler))
.route("/_next/*path", routing::get(playground::handler)) .route("/_next/*path", routing::get(playground::handler))
.merge(api_router(args)) .merge(api_router(args, config))
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc)); .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc));
let app = if args.chat_model.is_some() { let app = if args.chat_model.is_some() {
@ -166,7 +166,7 @@ pub async fn main(_config: &Config, args: &ServeArgs) {
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err)) .unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
} }
fn api_router(args: &ServeArgs) -> Router { fn api_router(args: &ServeArgs, config: &Config) -> Router {
let index_server = Arc::new(IndexServer::new()); let index_server = Arc::new(IndexServer::new());
let completion_state = { let completion_state = {
let ( let (
@ -218,7 +218,9 @@ fn api_router(args: &ServeArgs) -> Router {
"/v1/completions", "/v1/completions",
routing::post(completions::completions).with_state(completion_state), routing::post(completions::completions).with_state(completion_state),
) )
.layer(TimeoutLayer::new(Duration::from_secs(3))) .layer(TimeoutLayer::new(Duration::from_secs(
config.server.completion_timeout,
)))
}); });
if let Some(chat_state) = chat_state { if let Some(chat_state) = chat_state {