refactor: add experimental-http feature (#750)
* add experimental-http feature, update code * refactor: add experimental-http featureextract-routes
parent
f2ea57bdd6
commit
22592374c1
|
|
@ -5,6 +5,7 @@ edition = "2021"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
cuda = ["llama-cpp-bindings/cuda"]
|
cuda = ["llama-cpp-bindings/cuda"]
|
||||||
|
experimental-http = ["dep:http-api-bindings"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tabby-common = { path = "../tabby-common" }
|
tabby-common = { path = "../tabby-common" }
|
||||||
|
|
@ -36,7 +37,7 @@ tantivy = { workspace = true }
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
sysinfo = "0.29.8"
|
sysinfo = "0.29.8"
|
||||||
nvml-wrapper = "0.9.0"
|
nvml-wrapper = "0.9.0"
|
||||||
http-api-bindings = { path = "../http-api-bindings" }
|
http-api-bindings = { path = "../http-api-bindings", optional = true } # included when build with `experimental-http` feature
|
||||||
async-stream = { workspace = true }
|
async-stream = { workspace = true }
|
||||||
axum-streams = { version = "0.9.1", features = ["json"] }
|
axum-streams = { version = "0.9.1", features = ["json"] }
|
||||||
minijinja = { version = "1.0.8", features = ["loader"] }
|
minijinja = { version = "1.0.8", features = ["loader"] }
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,18 @@ pub async fn create_engine(
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
args: &crate::serve::ServeArgs,
|
args: &crate::serve::ServeArgs,
|
||||||
) -> (Box<dyn TextGeneration>, EngineInfo) {
|
) -> (Box<dyn TextGeneration>, EngineInfo) {
|
||||||
if args.device != super::Device::ExperimentalHttp {
|
#[cfg(feature = "experimental-http")]
|
||||||
|
if args.device == crate::serve::Device::ExperimentalHttp {
|
||||||
|
let (engine, prompt_template) = http_api_bindings::create(model_id);
|
||||||
|
return (
|
||||||
|
engine,
|
||||||
|
EngineInfo {
|
||||||
|
prompt_template: Some(prompt_template),
|
||||||
|
chat_template: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if fs::metadata(model_id).is_ok() {
|
if fs::metadata(model_id).is_ok() {
|
||||||
let path = PathBuf::from(model_id);
|
let path = PathBuf::from(model_id);
|
||||||
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
|
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
|
||||||
|
|
@ -35,16 +46,6 @@ pub async fn create_engine(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
let (engine, prompt_template) = http_api_bindings::create(model_id);
|
|
||||||
(
|
|
||||||
engine,
|
|
||||||
EngineInfo {
|
|
||||||
prompt_template: Some(prompt_template),
|
|
||||||
chat_template: None,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ use tabby_common::{
|
||||||
use tabby_download::download_model;
|
use tabby_download::download_model;
|
||||||
use tokio::time::sleep;
|
use tokio::time::sleep;
|
||||||
use tower_http::{cors::CorsLayer, timeout::TimeoutLayer};
|
use tower_http::{cors::CorsLayer, timeout::TimeoutLayer};
|
||||||
use tracing::{info, warn};
|
use tracing::info;
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
|
|
@ -86,6 +86,7 @@ pub enum Device {
|
||||||
#[strum(serialize = "metal")]
|
#[strum(serialize = "metal")]
|
||||||
Metal,
|
Metal,
|
||||||
|
|
||||||
|
#[cfg(feature = "experimental-http")]
|
||||||
#[strum(serialize = "experimental_http")]
|
#[strum(serialize = "experimental_http")]
|
||||||
ExperimentalHttp,
|
ExperimentalHttp,
|
||||||
}
|
}
|
||||||
|
|
@ -131,18 +132,14 @@ pub struct ServeArgs {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn main(config: &Config, args: &ServeArgs) {
|
pub async fn main(config: &Config, args: &ServeArgs) {
|
||||||
if args.device != Device::ExperimentalHttp {
|
#[cfg(feature = "experimental-http")]
|
||||||
if fs::metadata(&args.model).is_ok() {
|
if args.device == Device::ExperimentalHttp {
|
||||||
info!("Loading model from local path {}", &args.model);
|
tracing::warn!("HTTP device is unstable and does not comply with semver expectations.");
|
||||||
} else {
|
} else {
|
||||||
download_model(&args.model, true).await;
|
load_model(args).await;
|
||||||
if let Some(chat_model) = &args.chat_model {
|
|
||||||
download_model(chat_model, true).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
warn!("HTTP device is unstable and does not comply with semver expectations.")
|
|
||||||
}
|
}
|
||||||
|
#[cfg(not(feature = "experimental-http"))]
|
||||||
|
load_model(args).await;
|
||||||
|
|
||||||
info!("Starting server, this might takes a few minutes...");
|
info!("Starting server, this might takes a few minutes...");
|
||||||
|
|
||||||
|
|
@ -172,6 +169,17 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
||||||
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
|
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn load_model(args: &ServeArgs) {
|
||||||
|
if fs::metadata(&args.model).is_ok() {
|
||||||
|
info!("Loading model from local path {}", &args.model);
|
||||||
|
} else {
|
||||||
|
download_model(&args.model, true).await;
|
||||||
|
if let Some(chat_model) = &args.chat_model {
|
||||||
|
download_model(chat_model, true).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
let code = Arc::new(create_code_search());
|
let code = Arc::new(create_code_search());
|
||||||
let completion_state = {
|
let completion_state = {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue