From 30d2ba5f008acfdcce5c8985ea4ee4902cfd23cc Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Thu, 30 Nov 2023 16:26:55 +0800 Subject: [PATCH] test: add chat golden test (#919) * test: add chat golden test * update --- Cargo.lock | 15 +++- crates/tabby/Cargo.toml | 1 + crates/tabby/src/serve.rs | 2 +- crates/tabby/tests/golden_chat.json | 24 ++++++ crates/tabby/tests/goldentests_chat.rs | 111 +++++++++++++++++++++++++ 5 files changed, 150 insertions(+), 3 deletions(-) create mode 100644 crates/tabby/tests/golden_chat.json create mode 100644 crates/tabby/tests/goldentests_chat.rs diff --git a/Cargo.lock b/Cargo.lock index bc46ca7..4ae0434 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4127,6 +4127,16 @@ dependencies = [ "serde_json", ] +[[package]] +name = "serde-jsonlines" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e228faf5f94badfe42723177b62cfb9b187351994cb4e852cd4a6a4c96dbeea8" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "serde_derive" version = "1.0.171" @@ -4578,6 +4588,7 @@ dependencies = [ "regex", "reqwest", "serde", + "serde-jsonlines 0.5.0", "serde_json", "serdeconv", "strfmt", @@ -4613,7 +4624,7 @@ dependencies = [ "lazy_static", "reqwest", "serde", - "serde-jsonlines", + "serde-jsonlines 0.4.0", "serde_json", "serdeconv", "tantivy", @@ -4659,7 +4670,7 @@ dependencies = [ "kdam", "lazy_static", "requirements", - "serde-jsonlines", + "serde-jsonlines 0.4.0", "serde_json", "serdeconv", "tabby-common", diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 9854933..d74b8ad 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -66,3 +66,4 @@ vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } [dev-dependencies] assert-json-diff = "2.0.2" reqwest.workspace = true +serde-jsonlines = "0.5.0" diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index 6cea454..fa54cb4 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -45,7 +45,7 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi servers( (url = "/", description = "Server"), ), - paths(routes::log_event, routes::completions, routes::completions, routes::health, routes::search), + paths(routes::log_event, routes::completions, routes::chat_completions, routes::health, routes::search), components(schemas( api::event::LogEventRequest, completion::CompletionRequest, diff --git a/crates/tabby/tests/golden_chat.json b/crates/tabby/tests/golden_chat.json new file mode 100644 index 0000000..e129c80 --- /dev/null +++ b/crates/tabby/tests/golden_chat.json @@ -0,0 +1,24 @@ +[ + { + "request": { + "messages": [ + { + "role": "user", + "content": "How to convert a list of string to numbers in python" + } + ] + }, + "expected": " In Python, you can convert a list of strings to numbers using the `map()` function and the `int()` function. Here's an example:\n```\nstrings = ['1', '2', '3', '4', '5']\nnumbers = list(map(int, strings))\nprint(numbers)\n```\nThis will output:\n```\n[1, 2, 3, 4, 5]\n```\nIn this example, the `map()` function applies the `int()` function to each element of the `strings` list, converting each string to an integer and returning a new list of integers. The `list()` function is used to convert the resulting iterator to a list." + }, + { + "request": { + "messages": [ + { + "role": "user", + "content": "How to parse email address with regex" + } + ] + }, + "expected": " To parse an email address with regex, you can use the following pattern:\n```\n^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$\n```\nThis pattern matches email addresses in the following format:\n\n* `^`: start of the string\n* `[a-zA-Z0-9._%+-]+`: matches one or more characters that are either letters, numbers, periods, underscores, percent signs, plus signs, or hyphens\n* `@`: matches the `@` symbol\n* `[a-zA-Z0-9.-]+`: matches one or more characters that are either letters, numbers, periods, or hyphens\n* `\\.`: matches the `.` symbol\n* `[a-zA-Z]{2,}`: matches two or more characters that are letters\n* `$`: end of the string\n\nYou can use this pattern in a programming language that supports regex, such as Python, JavaScript, or Java, to extract the email address from a string. For example, in Python, you can use the `re` module to find all email addresses in a string:\n```\nimport re\n\nstring = \"Please send your feedback to john.doe@example.com or jane_doe@example.co.uk.\"\n\nemails = re.findall(r\"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}\", string)\n\nprint(emails) # Output: ['john.doe@example.com', 'jane_doe@example.co.uk']\n```\nIn this example, the `re.findall()` function is used to find all occurrences of the email address pattern in the string. The `findall()` function returns a list of all non-overlapping matches." + } +] \ No newline at end of file diff --git a/crates/tabby/tests/goldentests_chat.rs b/crates/tabby/tests/goldentests_chat.rs new file mode 100644 index 0000000..1e464a9 --- /dev/null +++ b/crates/tabby/tests/goldentests_chat.rs @@ -0,0 +1,111 @@ +use std::path::PathBuf; + +use assert_json_diff::assert_json_include; +use lazy_static::lazy_static; +use serde::Deserialize; +use serde_json::json; +use serde_jsonlines::BufReadExt; +use tokio::{ + process::Command, + time::{sleep, Duration}, +}; + +#[derive(Deserialize)] +pub struct ChatCompletionChunk { + content: String, +} + +lazy_static! { + static ref SERVER: bool = { + let mut cmd = Command::new(tabby_path()); + cmd.arg("serve") + .arg("--chat-model") + .arg("TabbyML/Mistral-7B") + .arg("--port") + .arg("9090") + .arg("--device") + .arg("metal") + .kill_on_drop(true); + tokio::task::spawn(async move { + cmd.spawn() + .expect("Failed to start server") + .wait() + .await + .unwrap(); + }); + true + }; + static ref CLIENT: reqwest::Client = reqwest::Client::new(); +} + +fn workspace_dir() -> PathBuf { + let output = std::process::Command::new(env!("CARGO")) + .arg("locate-project") + .arg("--workspace") + .arg("--message-format=plain") + .output() + .unwrap() + .stdout; + let cargo_path = std::path::Path::new(std::str::from_utf8(&output).unwrap().trim()); + cargo_path.parent().unwrap().to_path_buf() +} + +fn tabby_path() -> PathBuf { + workspace_dir().join("target/debug/tabby") +} + +fn golden_path() -> PathBuf { + workspace_dir().join("crates/tabby/tests/golden_chat.json") +} + +async fn wait_for_server() { + lazy_static::initialize(&SERVER); + + loop { + println!("Waiting for server to start..."); + let is_ok = reqwest::get("http://localhost:9090/v1/health") + .await + .is_ok(); + if is_ok { + break; + } else { + sleep(Duration::from_secs(5)).await; + } + } +} + +async fn golden_test(body: serde_json::Value, expected: serde_json::Value) { + let bytes = CLIENT + .post("http://localhost:9090/v1beta/chat/completions") + .json(&body) + .send() + .await + .unwrap() + .bytes() + .await + .unwrap(); + + let lines = bytes.json_lines::(); + let mut actual = "".to_owned(); + for x in lines { + actual += &x.unwrap().content; + } + + assert_json_include!(actual: actual, expected: expected); +} + +#[derive(Deserialize)] +struct TestCase { + request: serde_json::Value, + expected: serde_json::Value, +} + +#[tokio::test] +async fn run_chat_golden_tests() { + wait_for_server().await; + + let cases: Vec = serdeconv::from_json_file(golden_path()).unwrap(); + for case in cases { + golden_test(case.request, case.expected).await; + } +}