feat: validate token during worker registration (#803)

* feat: validate token during worker registration

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* resolve comments

* reslove comments

* format file, update schema file

* resolve comment

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
release-fix-intellij-update-support-version-range
Eric 2023-11-17 15:05:39 +08:00 committed by GitHub
parent 97f4989905
commit ce338c7436
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 318 additions and 34 deletions

125
Cargo.lock generated
View File

@ -1415,6 +1415,18 @@ version = "2.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0"
[[package]]
name = "fallible-iterator"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
[[package]]
name = "fallible-streaming-iterator"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
[[package]]
name = "fastdivide"
version = "0.4.0"
@ -1771,6 +1783,15 @@ dependencies = [
"allocator-api2",
]
[[package]]
name = "hashlink"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7"
dependencies = [
"hashbrown 0.14.0",
]
[[package]]
name = "headers"
version = "0.3.8"
@ -1966,9 +1987,9 @@ dependencies = [
[[package]]
name = "iana-time-zone"
version = "0.1.56"
version = "0.1.57"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0722cd7114b7de04316e7ea5456a0bbb20e4adb46fd27a3697adb812cff0f37c"
checksum = "2fad5b825842d2b38bd206f3e81d6957625fd7f0a361e345c30e01a0ae2dd613"
dependencies = [
"android_system_properties",
"core-foundation-sys",
@ -2141,9 +2162,9 @@ dependencies = [
[[package]]
name = "js-sys"
version = "0.3.63"
version = "0.3.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f37a4a5928311ac501dee68b3c7613a1037d0edb30c8e5427bd832d55d1b790"
checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a"
dependencies = [
"wasm-bindgen",
]
@ -2262,6 +2283,17 @@ dependencies = [
"winapi",
]
[[package]]
name = "libsqlite3-sys"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afc22eff61b133b115c6e8c74e818c628d6d5e7a502afea6f64dee076dd94326"
dependencies = [
"cc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "libssh2-sys"
version = "0.3.0"
@ -2311,9 +2343,9 @@ checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519"
[[package]]
name = "linux-raw-sys"
version = "0.4.8"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3852614a3bd9ca9804678ba6be5e3b8ce76dfc902cae004e3e0c44051b6e88db"
checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f"
[[package]]
name = "llama-cpp-bindings"
@ -3503,6 +3535,32 @@ dependencies = [
"serde",
]
[[package]]
name = "rusqlite"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2"
dependencies = [
"bitflags 2.4.0",
"fallible-iterator",
"fallible-streaming-iterator",
"hashlink",
"libsqlite3-sys",
"smallvec",
]
[[package]]
name = "rusqlite_migration"
version = "1.1.0-alpha.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ef119690ca6bac53498f4478badf364840780248132ab5097891d0cfdf42eda"
dependencies = [
"log",
"rusqlite",
"tokio",
"tokio-rusqlite",
]
[[package]]
name = "rust-embed"
version = "6.6.1"
@ -3661,7 +3719,7 @@ dependencies = [
"bitflags 2.4.0",
"errno",
"libc",
"linux-raw-sys 0.4.8",
"linux-raw-sys 0.4.10",
"windows-sys 0.48.0",
]
@ -3773,9 +3831,9 @@ dependencies = [
[[package]]
name = "security-framework"
version = "2.9.1"
version = "2.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8"
checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de"
dependencies = [
"bitflags 1.3.2",
"core-foundation",
@ -3786,9 +3844,9 @@ dependencies = [
[[package]]
name = "security-framework-sys"
version = "2.9.0"
version = "2.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7"
checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a"
dependencies = [
"core-foundation-sys",
"libc",
@ -4406,6 +4464,7 @@ dependencies = [
"anyhow",
"axum",
"bincode",
"chrono",
"futures",
"hyper",
"juniper",
@ -4413,14 +4472,19 @@ dependencies = [
"lazy_static",
"mime_guess",
"pin-project",
"rusqlite",
"rusqlite_migration",
"rust-embed 8.0.0",
"serde",
"tabby-common",
"tarpc",
"thiserror",
"tokio",
"tokio-rusqlite",
"tokio-tungstenite",
"tracing",
"unicase",
"uuid 1.4.1",
]
[[package]]
@ -4824,6 +4888,17 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-rusqlite"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7aa66395f5ff117faee90c9458232c936405f9227ad902038000b74b3bc1feac"
dependencies = [
"crossbeam-channel",
"rusqlite",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.24.1"
@ -5670,9 +5745,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasm-bindgen"
version = "0.2.86"
version = "0.2.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5bba0e8cb82ba49ff4e229459ff22a191bbe9a1cb3a341610c9c33efc27ddf73"
checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342"
dependencies = [
"cfg-if",
"wasm-bindgen-macro",
@ -5680,9 +5755,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.86"
version = "0.2.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b04bc93f9d6bdee709f6bd2118f57dd6679cf1176a1af464fca3ab0d66d8fb"
checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd"
dependencies = [
"bumpalo",
"log",
@ -5695,9 +5770,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.36"
version = "0.4.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d1985d03709c53167ce907ff394f5316aa22cb4e12761295c5dc57dacb6297e"
checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03"
dependencies = [
"cfg-if",
"js-sys",
@ -5707,9 +5782,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.86"
version = "0.2.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14d6b024f1a526bb0234f52840389927257beb670610081360e5a03c5df9c258"
checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
@ -5717,9 +5792,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.86"
version = "0.2.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e128beba882dd1eb6200e1dc92ae6c5dbaa4311aa7bb211ca035779e5efc39f8"
checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b"
dependencies = [
"proc-macro2",
"quote",
@ -5730,9 +5805,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.86"
version = "0.2.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed9d5b4305409d1fc9482fee2d7f9bcbf24b3972bf59817ef757e23982242a93"
checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1"
[[package]]
name = "wasm-streams"
@ -5749,9 +5824,9 @@ dependencies = [
[[package]]
name = "web-sys"
version = "0.3.63"
version = "0.3.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3bdd9ef4e984da1187bf8110c5cf5b845fbc87a23602cdf912386a76fcd3a7c2"
checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b"
dependencies = [
"js-sys",
"wasm-bindgen",

View File

@ -18,7 +18,7 @@ pub fn set_tabby_root(path: PathBuf) {
cell.replace(path);
}
fn tabby_root() -> PathBuf {
pub fn tabby_root() -> PathBuf {
let mut cell = TABBY_ROOT.lock().unwrap();
cell.get_mut().clone()
}

View File

@ -33,6 +33,10 @@ pub struct WorkerArgs {
#[clap(long, default_value_t = 8080)]
port: u16,
/// Server token to register this worker to.
#[clap(long)]
token: String,
/// Model id
#[clap(long, help_heading=Some("Model Options"))]
model: String,
@ -99,6 +103,7 @@ async fn request_register(kind: WorkerKind, args: &WorkerArgs) {
args.port,
args.model.to_owned(),
args.device.to_string(),
args.token.clone(),
)
.await
{
@ -112,6 +117,7 @@ async fn request_register_impl(
port: u16,
name: String,
device: String,
token: String,
) -> Result<()> {
let client = tabby_webserver::api::create_client(url).await;
let (cpu_info, cpu_count) = read_cpu_info();
@ -127,6 +133,7 @@ async fn request_register_impl(
cpu_info,
cpu_count as i32,
cuda_devices,
token,
)
.await??;

View File

@ -9,6 +9,7 @@ homepage.workspace = true
anyhow.workspace = true
axum = { workspace = true, features = ["ws"] }
bincode = "1.3.3"
chrono = "0.4"
futures.workspace = true
hyper = { workspace = true, features=["client"]}
juniper.workspace = true
@ -16,14 +17,27 @@ juniper-axum = { path = "../../crates/juniper-axum" }
lazy_static = "1.4.0"
mime_guess = "2.0.4"
pin-project = "1.1.3"
rusqlite = { version = "0.29.0", features = ["bundled"] }
# `async-tokio-rusqlite` is only available from 1.1.0-alpha.2, will bump up version when it's stable
rusqlite_migration = { version = "1.1.0-alpha.2", features = ["async-tokio-rusqlite"] }
rust-embed = "8.0.0"
serde.workspace = true
tabby-common = { path = "../../crates/tabby-common" }
tarpc = { version = "0.33.0", features = ["serde-transport"] }
thiserror.workspace = true
tokio.workspace = true
tokio-rusqlite = "0.4.0"
tokio-tungstenite = "0.20.1"
tracing.workspace = true
unicase = "2.7.0"
[dependencies.uuid]
version = "1.3.3"
features = [
"v4", # Lets you generate random UUIDs
"fast-rng", # Use a faster (but still sufficiently random) RNG
"macro-diagnostics", # Enable better diagnostics for compile-time UUIDs
]
[dev-dependencies]
tokio = { workspace = true, features = ["macros"] }

View File

@ -3,6 +3,10 @@ enum WorkerKind {
CHAT
}
type Mutation {
resetRegistrationToken: String!
}
type Query {
workers: [Worker!]!
}
@ -20,4 +24,5 @@ type Worker {
schema {
query: Query
mutation: Mutation
}

View File

@ -25,7 +25,7 @@ pub struct Worker {
#[derive(Serialize, Deserialize, Error, Debug)]
pub enum HubError {
#[error("Invalid worker token")]
#[error("Invalid token")]
InvalidToken(String),
#[error("Feature requires enterprise license")]
@ -43,6 +43,7 @@ pub trait Hub {
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
token: String,
) -> Result<Worker, HubError>;
}

View File

@ -0,0 +1,127 @@
use std::{path::PathBuf, sync::Arc};
use anyhow::Result;
use lazy_static::lazy_static;
use rusqlite::params;
use rusqlite_migration::{AsyncMigrations, M};
use tabby_common::path::tabby_root;
use tokio_rusqlite::Connection;
lazy_static! {
static ref MIGRATIONS: AsyncMigrations = AsyncMigrations::new(vec![M::up(
r#"
CREATE TABLE IF NOT EXISTS registration_token (
id INTEGER PRIMARY KEY AUTOINCREMENT,
token VARCHAR(255) NOT NULL,
created_at TIMESTAMP DEFAULT (DATETIME('now')),
updated_at TIMESTAMP DEFAULT (DATETIME('now')),
CONSTRAINT `idx_token` UNIQUE (`token`)
);
"#
),]);
}
fn db_file() -> PathBuf {
tabby_root().join("db.sqlite3")
}
pub struct DbConn {
conn: Arc<Connection>,
}
impl DbConn {
pub async fn new() -> Result<Self> {
let conn = Connection::open(db_file()).await?;
Self::init_db(conn).await
}
/// Initialize database, create tables and insert first token if not exist
async fn init_db(mut conn: Connection) -> Result<Self> {
MIGRATIONS.to_latest(&mut conn).await?;
let token = uuid::Uuid::new_v4().to_string();
conn.call(move |c| {
c.execute(
r#"INSERT OR IGNORE INTO registration_token (id, token) VALUES (1, ?)"#,
params![token],
)
})
.await?;
Ok(Self {
conn: Arc::new(conn),
})
}
/// Query token from database.
/// Since token is global unique for each tabby server, by right there's only one row in the table.
pub async fn read_registration_token(&self) -> Result<String> {
let token = self
.conn
.call(|conn| {
conn.query_row(
r#"SELECT token FROM registration_token WHERE id = 1"#,
[],
|row| row.get(0),
)
})
.await?;
Ok(token)
}
/// Update token in database.
pub async fn reset_registration_token(&self) -> Result<String> {
let token = uuid::Uuid::new_v4().to_string();
let result = token.clone();
let updated_at = chrono::Utc::now().timestamp() as u32;
let res = self
.conn
.call(move |conn| {
conn.execute(
r#"UPDATE registration_token SET token = ?, updated_at = ? WHERE id = 1"#,
params![token, updated_at],
)
})
.await?;
if res != 1 {
return Err(anyhow::anyhow!("failed to update token"));
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn new_in_memory() -> Result<DbConn> {
let conn = Connection::open_in_memory().await?;
DbConn::init_db(conn).await
}
#[tokio::test]
async fn migrations_test() {
assert!(MIGRATIONS.validate().await.is_ok());
}
#[tokio::test]
async fn test_token() {
let conn = new_in_memory().await.unwrap();
let token = conn.read_registration_token().await.unwrap();
assert_eq!(token.len(), 36);
}
#[tokio::test]
async fn test_update_token() {
let conn = new_in_memory().await.unwrap();
let old_token = conn.read_registration_token().await.unwrap();
conn.reset_registration_token().await.unwrap();
let new_token = conn.read_registration_token().await.unwrap();
assert_eq!(new_token.len(), 36);
assert_ne!(old_token, new_token);
}
}

View File

@ -2,8 +2,10 @@ pub mod api;
mod schema;
pub use schema::create_schema;
use tracing::error;
use websocket::WebSocketTransport;
mod db;
mod server;
mod ui;
mod websocket;
@ -25,7 +27,8 @@ use server::ServerContext;
use tarpc::server::{BaseChannel, Channel};
pub async fn attach_webserver(router: Router) -> Router {
let ctx = Arc::new(ServerContext::default());
let conn = db::DbConn::new().await.unwrap();
let ctx = Arc::new(ServerContext::new(conn));
let schema = Arc::new(create_schema());
let app = Router::new()
@ -91,7 +94,24 @@ impl Hub for Arc<HubImpl> {
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
token: String,
) -> Result<Worker, HubError> {
if token.is_empty() {
return Err(HubError::InvalidToken("Empty worker token".to_string()));
}
let server_token = match self.ctx.read_registration_token().await {
Ok(t) => t,
Err(err) => {
error!("fetch server token: {}", err.to_string());
return Err(HubError::InvalidToken(
"Failed to fetch server token".to_string(),
));
}
};
if server_token != token {
return Err(HubError::InvalidToken("Token mismatch".to_string()));
}
let worker = Worker {
name,
kind,

View File

@ -1,4 +1,4 @@
use juniper::{graphql_object, EmptyMutation, EmptySubscription, RootNode};
use juniper::{graphql_object, EmptySubscription, FieldResult, RootNode};
use crate::{api::Worker, server::ServerContext};
@ -15,9 +15,19 @@ impl Query {
}
}
pub type Schema =
RootNode<'static, Query, EmptyMutation<ServerContext>, EmptySubscription<ServerContext>>;
#[derive(Default)]
pub struct Mutation;
#[graphql_object(context = ServerContext)]
impl Mutation {
async fn reset_registration_token(ctx: &ServerContext) -> FieldResult<String> {
let token = ctx.reset_registration_token().await?;
Ok(token)
}
}
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<ServerContext>>;
pub fn create_schema() -> Schema {
Schema::new(Query, EmptyMutation::new(), EmptySubscription::new())
Schema::new(Query, Mutation, EmptySubscription::new())
}

View File

@ -3,19 +3,44 @@ mod worker;
use std::net::SocketAddr;
use anyhow::Result;
use axum::{http::Request, middleware::Next, response::IntoResponse};
use hyper::{client::HttpConnector, Body, Client, StatusCode};
use tracing::{info, warn};
use crate::api::{HubError, Worker, WorkerKind};
#[derive(Default)]
use crate::{
api::{HubError, Worker, WorkerKind},
db::DbConn,
};
pub struct ServerContext {
client: Client<HttpConnector>,
completion: worker::WorkerGroup,
chat: worker::WorkerGroup,
db_conn: DbConn,
}
impl ServerContext {
pub fn new(db_conn: DbConn) -> Self {
Self {
client: Client::default(),
completion: worker::WorkerGroup::default(),
chat: worker::WorkerGroup::default(),
db_conn,
}
}
/// Query current token from the database.
pub async fn read_registration_token(&self) -> Result<String> {
self.db_conn.read_registration_token().await
}
/// Generate new token, and update it in the database.
/// Return new token after update is done
pub async fn reset_registration_token(&self) -> Result<String> {
self.db_conn.reset_registration_token().await
}
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, HubError> {
let worker = match worker.kind {
WorkerKind::Completion => self.completion.register(worker).await,