diff --git a/.gitmodules b/.gitmodules index c510600..e150ac4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ -[submodule "crates/ctranslate2-bindings/CTranslate2"] - path = crates/ctranslate2-bindings/CTranslate2 - url = https://github.com/OpenNMT/CTranslate2.git [submodule "crates/llama-cpp-bindings/llama.cpp"] path = crates/llama-cpp-bindings/llama.cpp url = https://github.com/TabbyML/llama.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 49c64f1..ae2666f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ * Switch cpu backend to llama.cpp: https://github.com/TabbyML/tabby/pull/638 * add `server.completion_timeout` to control the code completion interface timeout: https://github.com/TabbyML/tabby/pull/637 * Switch cuda backend to llama.cpp: https://github.com/TabbyML/tabby/pull/656 +* Switch tokenizer to llama.cpp, so tabby no longer need to download additional tokenizer file: https://github.com/TabbyML/tabby/pull/683 # v0.4.0 diff --git a/Cargo.lock b/Cargo.lock index 0009ec5..d7c95dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,17 +17,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" -[[package]] -name = "aes" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433cfd6710c9986c576a25ca913c39d66a6474107b406f34f91d4a8923395241" -dependencies = [ - "cfg-if", - "cipher", - "cpufeatures", -] - [[package]] name = "ahash" version = "0.8.3" @@ -39,15 +28,6 @@ dependencies = [ "version_check", ] -[[package]] -name = "aho-corasick" -version = "0.7.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" -dependencies = [ - "memchr", -] - [[package]] name = "aho-corasick" version = "1.1.2" @@ -309,12 +289,6 @@ version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" -[[package]] -name = "base64ct" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" - [[package]] name = "bitflags" version = "1.3.2" @@ -373,27 +347,6 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" -[[package]] -name = "bzip2" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" -dependencies = [ - "bzip2-sys", - "libc", -] - -[[package]] -name = "bzip2-sys" -version = "0.1.11+1.0.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" -dependencies = [ - "cc", - "libc", - "pkg-config", -] - [[package]] name = "cached" version = "0.46.0" @@ -412,28 +365,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "cached-path" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "097968e38f1319207f057d0f4d76452e4f4f847a5de61c5215379f297fa034f3" -dependencies = [ - "flate2", - "fs2", - "glob", - "indicatif 0.16.2", - "log", - "rand", - "reqwest", - "serde", - "serde_json", - "sha2", - "tar", - "tempfile", - "thiserror", - "zip", -] - [[package]] name = "cached_proc_macro" version = "0.18.0" @@ -494,16 +425,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "cipher" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" -dependencies = [ - "crypto-common", - "inout", -] - [[package]] name = "clap" version = "4.3.0" @@ -584,12 +505,6 @@ dependencies = [ "windows-sys 0.45.0", ] -[[package]] -name = "constant_time_eq" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" - [[package]] name = "core-foundation" version = "0.9.3" @@ -694,24 +609,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "ctranslate2-bindings" -version = "0.5.0-dev" -dependencies = [ - "async-stream", - "async-trait", - "cmake", - "cxx", - "cxx-build", - "derive_builder", - "futures", - "rust-cxx-cmake-bridge", - "tabby-inference", - "tokenizers", - "tokio", - "tokio-util", -] - [[package]] name = "cxx" version = "1.0.95" @@ -887,7 +784,6 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", - "subtle", ] [[package]] @@ -967,15 +863,6 @@ dependencies = [ "backtrace", ] -[[package]] -name = "esaxx-rs" -version = "0.1.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f748b253ceca9fed5f42f8b5ceb3851e93102199bc25b64b65369f76e5c0a35" -dependencies = [ - "cc", -] - [[package]] name = "fastdivide" version = "0.4.0" @@ -1010,18 +897,6 @@ dependencies = [ "regex", ] -[[package]] -name = "filetime" -version = "0.2.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cbc844cecaee9d4443931972e1289c8ff485cb4cc2767cb03ca139ed6885153" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall 0.2.16", - "windows-sys 0.48.0", -] - [[package]] name = "fixedbitset" version = "0.4.2" @@ -1068,16 +943,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "fs2" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "fs4" version = "0.6.6" @@ -1217,19 +1082,13 @@ version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4" -[[package]] -name = "glob" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" - [[package]] name = "globset" version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "759c97c1e17c55525b57192c06a267cda0ac5210b222d6b82189a2338fa1c13d" dependencies = [ - "aho-corasick 1.1.2", + "aho-corasick", "bstr", "fnv", "log", @@ -1292,15 +1151,6 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" -[[package]] -name = "hmac" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" -dependencies = [ - "digest", -] - [[package]] name = "htmlescape" version = "0.3.1" @@ -1476,30 +1326,6 @@ dependencies = [ "serde", ] -[[package]] -name = "indicatif" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7baab56125e25686df467fe470785512329883aab42696d661247aca2a2896e4" -dependencies = [ - "console", - "lazy_static", - "number_prefix 0.3.0", - "regex", -] - -[[package]] -name = "indicatif" -version = "0.16.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d207dc617c7a380ab07ff572a6e52fa202a2a8f355860ac9c38e23f8196be1b" -dependencies = [ - "console", - "lazy_static", - "number_prefix 0.4.0", - "regex", -] - [[package]] name = "indicatif" version = "0.17.3" @@ -1507,20 +1333,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cef509aa9bc73864d6756f0d34d35504af3cf0844373afe9b8669a5b8005a729" dependencies = [ "console", - "number_prefix 0.4.0", + "number_prefix", "portable-atomic 0.3.20", "unicode-width", ] -[[package]] -name = "inout" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" -dependencies = [ - "generic-array", -] - [[package]] name = "instant" version = "0.1.12" @@ -1562,24 +1379,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "itertools" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f56a2d0bc861f9165be4eb3442afd3c236d8a98afd426f65d92324ae1091a484" -dependencies = [ - "either", -] - -[[package]] -name = "itertools" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.10.5" @@ -1694,7 +1493,6 @@ dependencies = [ "derive_builder", "futures", "tabby-inference", - "tokenizers", "tokio", "tokio-util", ] @@ -1747,22 +1545,6 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ea9b256699eda7b0387ffbc776dd625e28bde3918446381781245b7a50349d8" -[[package]] -name = "macro_rules_attribute" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf0c9b980bf4f3a37fd7b1c066941dd1b1d0152ce6ee6e8fe8c49b9f6810d862" -dependencies = [ - "macro_rules_attribute-proc_macro", - "paste", -] - -[[package]] -name = "macro_rules_attribute-proc_macro" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d" - [[package]] name = "matchers" version = "0.0.1" @@ -1890,27 +1672,6 @@ dependencies = [ "windows-sys 0.45.0", ] -[[package]] -name = "monostate" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0230b703f1ac35df1e24f6d0d2255472bcccaf657ecdfa4f1fcbcad1ad5bb98a" -dependencies = [ - "monostate-impl", - "serde", -] - -[[package]] -name = "monostate-impl" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.28", -] - [[package]] name = "multimap" version = "0.8.3" @@ -2007,12 +1768,6 @@ dependencies = [ "libc", ] -[[package]] -name = "number_prefix" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a" - [[package]] name = "number_prefix" version = "0.4.0" @@ -2066,28 +1821,6 @@ dependencies = [ "loom", ] -[[package]] -name = "onig" -version = "6.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c4b31c8722ad9171c6d77d3557db078cab2bd50afcc9d09c8b315c59df8ca4f" -dependencies = [ - "bitflags 1.3.2", - "libc", - "once_cell", - "onig_sys", -] - -[[package]] -name = "onig_sys" -version = "69.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b829e3d7e9cc74c7e315ee8edb185bf4190da5acde74afd7fc59c35b1f086e7" -dependencies = [ - "cc", - "pkg-config", -] - [[package]] name = "openssl" version = "0.10.52" @@ -2250,35 +1983,12 @@ dependencies = [ "windows-targets 0.48.0", ] -[[package]] -name = "password-hash" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" -dependencies = [ - "base64ct", - "rand_core", - "subtle", -] - [[package]] name = "paste" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" -[[package]] -name = "pbkdf2" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" -dependencies = [ - "digest", - "hmac", - "password-hash", - "sha2", -] - [[package]] name = "percent-encoding" version = "2.2.0" @@ -2500,17 +2210,6 @@ dependencies = [ "rayon-core", ] -[[package]] -name = "rayon-cond" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd1259362c9065e5ea39a789ef40b1e3fd934c94beb7b5ab3ac6629d3b5e7cb7" -dependencies = [ - "either", - "itertools 0.8.2", - "rayon", -] - [[package]] name = "rayon-core" version = "1.11.0" @@ -2558,7 +2257,7 @@ version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d119d7c7ca818f8a53c300863d4f87566aac09943aef5b355bb83969dae75d87" dependencies = [ - "aho-corasick 1.1.2", + "aho-corasick", "memchr", "regex-automata 0.4.1", "regex-syntax 0.8.1", @@ -2579,7 +2278,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "465c6fc0621e4abc4187a2bda0937bfd4f722c2730b29562e19689ea796c9a4b" dependencies = [ - "aho-corasick 1.1.2", + "aho-corasick", "memchr", "regex-syntax 0.8.1", ] @@ -2590,12 +2289,6 @@ version = "0.6.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" -[[package]] -name = "regex-syntax" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" - [[package]] name = "regex-syntax" version = "0.8.1" @@ -2946,17 +2639,6 @@ dependencies = [ "trackable", ] -[[package]] -name = "sha1" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - [[package]] name = "sha2" version = "0.10.6" @@ -3029,18 +2711,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "spm_precompiled" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" -dependencies = [ - "base64 0.13.1", - "nom 7.1.3", - "serde", - "unicode-segmentation", -] - [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -3093,12 +2763,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "subtle" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" - [[package]] name = "syn" version = "1.0.109" @@ -3215,7 +2879,7 @@ dependencies = [ "async-trait", "cached", "futures-util", - "indicatif 0.17.3", + "indicatif", "reqwest", "serde", "serde_json", @@ -3237,7 +2901,6 @@ dependencies = [ "futures", "regex", "tabby-common", - "tokenizers", ] [[package]] @@ -3273,7 +2936,7 @@ version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1d4675fed6fe2218ce11445374e181e864a8ffd0f28e7e0591ccfc38cd000ae" dependencies = [ - "aho-corasick 1.1.2", + "aho-corasick", "arc-swap", "async-trait", "base64 0.21.2", @@ -3385,7 +3048,7 @@ checksum = "fc0c1bb43e5e8b8e05eb8009610344dbf285f06066c844032fbb3e546b3c71df" dependencies = [ "tantivy-common", "tantivy-fst", - "zstd 0.12.4", + "zstd", ] [[package]] @@ -3407,17 +3070,6 @@ dependencies = [ "serde", ] -[[package]] -name = "tar" -version = "0.4.38" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b55807c0344e1e6c04d7c965f5289c39a8d94ae23ed5c0b57aabac549f871c6" -dependencies = [ - "filetime", - "libc", - "xattr", -] - [[package]] name = "temp_testdir" version = "0.2.3" @@ -3538,42 +3190,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" -[[package]] -name = "tokenizers" -version = "0.13.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aea68938177975ab09da68552b720eac941779ff386baceaf77e0f5f9cea645f" -dependencies = [ - "aho-corasick 0.7.20", - "cached-path", - "clap", - "derive_builder", - "dirs", - "esaxx-rs", - "getrandom", - "indicatif 0.15.0", - "itertools 0.9.0", - "lazy_static", - "log", - "macro_rules_attribute", - "monostate", - "onig", - "paste", - "rand", - "rayon", - "rayon-cond", - "regex", - "regex-syntax 0.7.5", - "reqwest", - "serde", - "serde_json", - "spm_precompiled", - "thiserror", - "unicode-normalization-alignments", - "unicode-segmentation", - "unicode_categories", -] - [[package]] name = "tokio" version = "1.28.2" @@ -4084,33 +3700,12 @@ dependencies = [ "tinyvec", ] -[[package]] -name = "unicode-normalization-alignments" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" -dependencies = [ - "smallvec", -] - -[[package]] -name = "unicode-segmentation" -version = "1.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" - [[package]] name = "unicode-width" version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" -[[package]] -name = "unicode_categories" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" - [[package]] name = "url" version = "2.3.1" @@ -4590,42 +4185,16 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "xattr" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d1526bbe5aaeb5eb06885f4d987bcdfa5e23187055de9b83fe00156a821fabc" -dependencies = [ - "libc", -] - [[package]] name = "zip" version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" dependencies = [ - "aes", "byteorder", - "bzip2", - "constant_time_eq", "crc32fast", "crossbeam-utils", "flate2", - "hmac", - "pbkdf2", - "sha1", - "time 0.3.26", - "zstd 0.11.2+zstd.1.5.2", -] - -[[package]] -name = "zstd" -version = "0.11.2+zstd.1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" -dependencies = [ - "zstd-safe 5.0.2+zstd.1.5.2", ] [[package]] @@ -4634,17 +4203,7 @@ version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" dependencies = [ - "zstd-safe 6.0.6", -] - -[[package]] -name = "zstd-safe" -version = "5.0.2+zstd.1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" -dependencies = [ - "libc", - "zstd-sys", + "zstd-safe", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index c0f91bc..8dbdbe6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ members = [ "crates/tabby-scheduler", "crates/tabby-download", "crates/tabby-inference", - "crates/ctranslate2-bindings", "crates/rust-cxx-cmake-bridge", "crates/llama-cpp-bindings", "crates/http-api-bindings", @@ -33,7 +32,6 @@ tantivy = "0.21.0" async-trait = "0.1.72" reqwest = { version = "0.11.18" } derive_builder = "0.12.0" -tokenizers = "0.13.4-rc3" futures = "0.3.28" async-stream = "0.3.5" regex = "1.10.0" diff --git a/crates/ctranslate2-bindings/.gitignore b/crates/ctranslate2-bindings/.gitignore deleted file mode 100644 index ca58cab..0000000 --- a/crates/ctranslate2-bindings/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -/target -/Cargo.lock -/build diff --git a/crates/ctranslate2-bindings/CMakeLists.txt b/crates/ctranslate2-bindings/CMakeLists.txt deleted file mode 100644 index 344bd6d..0000000 --- a/crates/ctranslate2-bindings/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -cmake_minimum_required(VERSION 3.22) - -project(ctranslate2_bindings) - -add_subdirectory(CTranslate2) - -add_library(dummy - src/dummy.cc -) - -target_link_libraries(dummy - PRIVATE ctranslate2 -) - -include(cmake/export_libs.cmake) -export_all_target_libs(dummy) diff --git a/crates/ctranslate2-bindings/CTranslate2 b/crates/ctranslate2-bindings/CTranslate2 deleted file mode 160000 index 8bcbeb6..0000000 --- a/crates/ctranslate2-bindings/CTranslate2 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 8bcbeb6ff95b6906c9d5f7740fa9491431fa3e30 diff --git a/crates/ctranslate2-bindings/Cargo.toml b/crates/ctranslate2-bindings/Cargo.toml deleted file mode 100644 index 8f3e467..0000000 --- a/crates/ctranslate2-bindings/Cargo.toml +++ /dev/null @@ -1,25 +0,0 @@ -[package] -name = "ctranslate2-bindings" -version = "0.5.0-dev" -edition = "2021" - -[dependencies] -cxx = "1.0" -derive_builder = { workspace = true } -tokenizers = { workspace = true } -tokio = { workspace = true, features = ["rt"] } -tokio-util = { workspace = true } -tabby-inference = { path = "../tabby-inference" } -async-trait = { workspace = true } -futures.workspace = true -async-stream.workspace = true - -[build-dependencies] -cxx-build = "1.0" -cmake = { version = "0.1", optional = true } -rust-cxx-cmake-bridge = { path = "../rust-cxx-cmake-bridge", optional = true } - -[features] -default = ["dep:cmake", "dep:rust-cxx-cmake-bridge"] -link_shared = [] -link_static_cuda = [] diff --git a/crates/ctranslate2-bindings/build.rs b/crates/ctranslate2-bindings/build.rs deleted file mode 100644 index b805e10..0000000 --- a/crates/ctranslate2-bindings/build.rs +++ /dev/null @@ -1,74 +0,0 @@ -use std::{env, path::PathBuf}; - -use cmake::Config; -use rust_cxx_cmake_bridge::read_cmake_generated; - -fn main() { - // Tell cargo to invalidate the built crate whenever the wrapper changes - println!("cargo:rerun-if-changed=include/ctranslate2.h"); - println!("cargo:rerun-if-changed=src/ctranslate2.cc"); - println!("cargo:rerun-if-changed=src/lib.rs"); - - let mut lib = cxx_build::bridge("src/lib.rs"); - lib.file("src/ctranslate2.cc") - .flag_if_supported("-std=c++17"); - - if cfg!(feature = "link_shared") { - let dir = env::var("CTRANSLATE2_ROOT").unwrap(); - println!("cargo:rustc-link-search=native={}/lib", dir); - println!("cargo:rustc-link-lib=ctranslate2"); - lib.flag_if_supported(&format!("-I{}/include", dir)); - } else { - let dst = link_static(); - lib.flag_if_supported(&format!("-I{}", dst.join("include").display())); - } - - lib.compile("cxxbridge"); -} - -fn link_static() -> PathBuf { - let mut config = Config::new("."); - config - .define("CMAKE_BUILD_TYPE", "Release") - .define("BUILD_CLI", "OFF") - .define("CMAKE_INSTALL_RPATH_USE_LINK_PATH", "ON") - .define("BUILD_SHARED_LIBS", "OFF"); - - if cfg!(target_os = "linux") { - config - .define("WITH_MKL", "OFF") - .define("OPENMP_RUNTIME", "NONE"); - - if cfg!(target_feature = "sse4.1") { - config.cxxflag("-msse4.1"); - } - - if cfg!(feature = "link_static_cuda") { - config.define("WITH_CUDA", "ON").define("WITH_CUDNN", "ON"); - - if cfg!(target_arch = "aarch64") { - config.cxxflag("-mcpu=native"); - } - } else { - config.define("WITH_OPENBLAS", "ON"); - } - } else if cfg!(target_os = "macos") { - config - .define("CMAKE_OSX_ARCHITECTURES", "arm64") - .define("WITH_ACCELERATE", "ON") - .define("WITH_MKL", "OFF") - .define("OPENMP_RUNTIME", "NONE") - .define("WITH_RUY", "ON"); - } else { - panic!("Invalid target") - }; - - let dst = config.build(); - - // Read static lib from generated deps. - let cmake_generated_libs_str = - std::fs::read_to_string(format!("/{}/build/cmake_generated_libs", dst.display())).unwrap(); - read_cmake_generated(&cmake_generated_libs_str); - - dst -} diff --git a/crates/ctranslate2-bindings/cmake/debug_cmake.sh b/crates/ctranslate2-bindings/cmake/debug_cmake.sh deleted file mode 100755 index 0f4b5a8..0000000 --- a/crates/ctranslate2-bindings/cmake/debug_cmake.sh +++ /dev/null @@ -1,25 +0,0 @@ -#! /bin/bash - -set -e -set -x - -UNAME="$(uname -s)" -case "${UNAME}" in - Linux*) MACHINE=linux;; - Darwin*) MACHINE=macos;; - *) exit 1;; -esac - -rm -rf build -mkdir build && cd build - -if [[ "$MACHINE" == "macos" ]]; then -CMAKE_EXTRA_OPTIONS='-DCMAKE_OSX_ARCHITECTURES=arm64 -DWITH_ACCELERATE=ON -DWITH_MKL=OFF -DOPENMP_RUNTIME=NONE -DWITH_RUY=ON' -elif [[ "$MACHINE" == "linux" ]]; then -CMAKE_EXTRA_OPTIONS='-DWITH_CUDA=ON -DWITH_CUDNN=ON -DWITH_MKL=ON -DWITH_DNNL=ON -DOPENMP_RUNTIME=COMP -DCUDA_NVCC_FLAGS=-Xfatbin=-compress-all -DCUDA_ARCH_LIST=Common -DCXXFLAGS=-msse4.1' -fi - - -cmake -DBULID_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release -DBUILD_CLI=OFF -DCMAKE_INSTALL_RPATH_USE_LINK_PATH=ON $CMAKE_EXTRA_OPTIONS .. - -"$@" diff --git a/crates/ctranslate2-bindings/cmake/export_libs.cmake b/crates/ctranslate2-bindings/cmake/export_libs.cmake deleted file mode 100644 index 9d30a01..0000000 --- a/crates/ctranslate2-bindings/cmake/export_libs.cmake +++ /dev/null @@ -1,98 +0,0 @@ -################################################################################ - -# WARNING: to list the system libraries(ie IMPORTED) you MUST set: -# set_target_properties(your_lib PROPERTIES IMPORTED_GLOBAL TRUE) -# just after the find_package call -# cf https://gitlab.kitware.com/cmake/cmake/-/issues/17256 -# -# https://stackoverflow.com/questions/32756195/recursive-list-of-link-libraries-in-cmake -# https://stackoverflow.com/questions/32197663/how-can-i-remove-the-the-location-property-may-not-be-read-from-target-error-i -function(_get_link_libraries OUTPUT_LIST TARGET) - list(APPEND VISITED_TARGETS ${TARGET}) - - # DO NOT switch on IMPORTED or not - # An INTERFACE library CAN have LINK_LIBRARIES! - # get_target_property(IMPORTED ${TARGET} IMPORTED) - set(LIBS "") - get_target_property(LIBS_1 ${TARGET} INTERFACE_LINK_LIBRARIES) - get_target_property(LIBS_2 ${TARGET} LINK_LIBRARIES) - list(APPEND LIBS ${LIBS_1} ${LIBS_2}) - - set(LIB_FILES "") - - foreach(LIB ${LIBS}) - if (TARGET ${LIB}) - list(FIND VISITED_TARGETS ${LIB} VISITED) - if (${VISITED} EQUAL -1) - # OLD: get_target_property(LIB_FILE ${LIB} LOCATION) - # NEW: - _get_link_libraries(LINK_LIB_FILES ${LIB}) - set(LIB_FILE ${LIB}) - list(APPEND LIB_FILES ${LINK_LIB_FILES}) - list(APPEND LIB_FILES ${LIB_FILE}) - endif() - elseif(EXISTS ${LIB}) - set(LIB_FILE ${LIB}) - list(APPEND LIB_FILES ${LIB_FILE}) - endif() - endforeach() - - set(VISITED_TARGETS ${VISITED_TARGETS} PARENT_SCOPE) - set(${OUTPUT_LIST} ${LIB_FILES} PARENT_SCOPE) -endfunction() - -################################################################################ - -function(export_all_target_libs TARGET) - # NOTE: get_target_property(CIRCUIT_LIB_LINK_LIBRARIES a_target LINK_LIBRARIES) is NOT transitive - # This function will return eg: "$;$;" - # b/c generator expression are evaluated LATER - # cf https://stackoverflow.com/questions/59226127/cmake-generator-expression-how-to-get-target-file-property-on-list-of-targets - set(ALL_LINK_LIBRARIES "") - _get_link_libraries(ALL_LINK_LIBRARIES ${TARGET}) - - message(STATUS "ALL_LINK_LIBRARIES : ${ALL_LINK_LIBRARIES}") - - set(ALL_LIBS "") - set(ALL_EXTERNAL_LIBS "") - # TODO move that back into get_link_libraries - # NOTE: we MUST do it in 2 steps: - # - collect all the LINK_LIBRARIES recursively - # - loop on those and get their TARGET_FILE (if not INTERFACE_LIBRARY) - # That is b/c in get_link_libraries a INTERFACE_LIBRARY CAN have link_libraries - # but we CAN NOT evaluate generator expressions at this time. - foreach(LIB ${ALL_LINK_LIBRARIES}) - # MUST skip INTERFACE else: - # CMake Error at src/CMakeLists.txt:136 (add_custom_command): - # Error evaluating generator expression: - # $ - # Target "rust_cxx" is not an executable or library. - # SHARED_LIBRARY,INTERFACE_LIBRARY,STATIC_LIBRARY - # - if (TARGET ${LIB}) - get_target_property(LIB_TYPE ${LIB} TYPE) - message(STATUS "LIB_TYPE : ${LIB} = ${LIB_TYPE}") - - if(NOT ${LIB_TYPE} STREQUAL "INTERFACE_LIBRARY") - set(LIB_FILE $) - list(APPEND ALL_LIBS ${LIB_FILE}) - endif() - elseif(EXISTS ${LIB}) - set(LIB_FILE ${LIB}) - message(STATUS "LIB_TYPE : ${LIB} = EXTERNAL") - list(APPEND ALL_LIBS ${LIB_FILE}) - endif() - endforeach() # LIB ${ALL_LIBS} - - message(STATUS "ALL_LIBS : ${ALL_LIBS}") - - # add_custom_command(ie echoing only to stdout) works but more difficult to get from build.rs - # b/c when there is "ninja: no work to do" it will NOT echo on the console - add_custom_command( - TARGET ${TARGET} - POST_BUILD - COMMAND ${CMAKE_COMMAND} -E echo ${ALL_LIBS} > ${CMAKE_CURRENT_BINARY_DIR}/cmake_generated_libs - # OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/cmake_generated_libs - VERBATIM - ) -endfunction(export_all_target_libs) diff --git a/crates/ctranslate2-bindings/include/ctranslate2.h b/crates/ctranslate2-bindings/include/ctranslate2.h deleted file mode 100644 index 473118a..0000000 --- a/crates/ctranslate2-bindings/include/ctranslate2.h +++ /dev/null @@ -1,30 +0,0 @@ -#pragma once - -#include "rust/cxx.h" -#include - -namespace tabby { - -struct InferenceContext; - -typedef rust::Fn InferenceCallback; - -class TextInferenceEngine { - public: - virtual ~TextInferenceEngine(); - virtual rust::Vec inference( - rust::Box context, - InferenceCallback callback, - rust::Slice tokens, - size_t max_decoding_length, - float sampling_temperature - ) const = 0; -}; - -std::shared_ptr create_engine( - rust::Str model_path, - rust::Str model_type, - rust::Str device, - rust::Slice device_indices -); -} // namespace diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc deleted file mode 100644 index b9675a8..0000000 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ /dev/null @@ -1,135 +0,0 @@ -#include "ctranslate2-bindings/include/ctranslate2.h" - -#include "ctranslate2/translator.h" -#include "ctranslate2/generator.h" - -namespace tabby { -TextInferenceEngine::~TextInferenceEngine() {} - -template -class TextInferenceEngineImpl : public TextInferenceEngine { - protected: - struct Options { - size_t max_decoding_length; - float sampling_temperature; - }; - - public: - rust::Vec inference( - rust::Box context, - InferenceCallback callback, - rust::Slice tokens, - size_t max_decoding_length, - float sampling_temperature - ) const { - // Inference. - std::vector input_tokens(tokens.begin(), tokens.end()); - return process( - std::move(context), - std::move(callback), - input_tokens, - Options{max_decoding_length, sampling_temperature} - ); - } - - static std::unique_ptr create(const ctranslate2::models::ModelLoader& loader) { - auto impl = std::make_unique(); - impl->model_ = std::make_unique(loader); - return impl; - } - - protected: - virtual rust::Vec process( - rust::Box context, - InferenceCallback callback, - const std::vector& tokens, - const Options& options) const = 0; - std::unique_ptr model_; -}; - -class EncoderDecoderImpl : public TextInferenceEngineImpl { - protected: - virtual rust::Vec process( - rust::Box context, - InferenceCallback callback, - const std::vector& tokens, - const Options& options) const override { - ctranslate2::TranslationOptions x; - x.max_decoding_length = options.max_decoding_length; - x.sampling_temperature = options.sampling_temperature; - x.beam_size = 1; - rust::Vec output_ids; - x.callback = [&](ctranslate2::GenerationStepResult result) { - bool stop = callback(*context, result.step, result.token_id, result.token); - if (!stop) { - output_ids.push_back(result.token_id); - } else if (result.is_last) { - output_ids.push_back(result.token_id); - } - return stop; - }; - ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0]; - return output_ids; - } -}; - -class DecoderImpl : public TextInferenceEngineImpl { - protected: - virtual rust::Vec process( - rust::Box context, - InferenceCallback callback, - const std::vector& tokens, - const Options& options) const override { - ctranslate2::GenerationOptions x; - x.include_prompt_in_result = false; - x.max_length = options.max_decoding_length; - x.sampling_temperature = options.sampling_temperature; - x.beam_size = 1; - - rust::Vec output_ids; - x.callback = [&](ctranslate2::GenerationStepResult result) { - bool stop = callback(*context, result.step, result.token_id, result.token); - if (!stop) { - output_ids.push_back(result.token_id); - } else if (result.is_last) { - output_ids.push_back(result.token_id); - } - return stop; - }; - ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get(); - return output_ids; - } -}; - -std::shared_ptr create_engine( - rust::Str model_path, - rust::Str model_type, - rust::Str device, - rust::Slice device_indices -) { - std::string model_type_str(model_type); - std::string model_path_str(model_path); - ctranslate2::models::ModelLoader loader(model_path_str); - loader.device = ctranslate2::str_to_device(std::string(device)); - loader.device_indices = std::vector(device_indices.begin(), device_indices.end()); - loader.compute_type = ctranslate2::ComputeType::AUTO; - - const size_t num_cpus = std::thread::hardware_concurrency(); - if (loader.device == ctranslate2::Device::CUDA) { - // When device is cuda, set parallelism to be number of thread, capped to 4 to avoid VRAM oom. - loader.num_replicas_per_device = std::min(num_cpus, 4); - } else if (loader.device == ctranslate2::Device::CPU){ - // When device is cpu, adjust the number based on threads per replica. - // https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77 - loader.num_replicas_per_device = std::max(num_cpus / 4, 1); - } - - if (model_type_str == "AutoModelForCausalLM") { - return DecoderImpl::create(loader); - } else if (model_type_str == "AutoModelForSeq2SeqLM") { - return EncoderDecoderImpl::create(loader); - } else { - return nullptr; - } -} -} // namespace tabby diff --git a/crates/ctranslate2-bindings/src/dummy.cc b/crates/ctranslate2-bindings/src/dummy.cc deleted file mode 100644 index e69de29..0000000 diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs deleted file mode 100644 index 8b8d7ed..0000000 --- a/crates/ctranslate2-bindings/src/lib.rs +++ /dev/null @@ -1,178 +0,0 @@ -use std::sync::Arc; - -use async_stream::stream; -use async_trait::async_trait; -use derive_builder::Builder; -use futures::stream::BoxStream; -use tabby_inference::{ - decoding::{DecodingFactory, IncrementalDecoding}, - helpers, TextGeneration, TextGenerationOptions, -}; -use tokenizers::tokenizer::Tokenizer; -use tokio::sync::mpsc::{channel, Sender}; -use tokio_util::sync::CancellationToken; - -#[cxx::bridge(namespace = "tabby")] -mod ffi { - extern "Rust" { - type InferenceContext; - } - - unsafe extern "C++" { - include!("ctranslate2-bindings/include/ctranslate2.h"); - - type TextInferenceEngine; - - fn create_engine( - model_path: &str, - model_type: &str, - device: &str, - device_indices: &[i32], - ) -> SharedPtr; - - fn inference( - &self, - context: Box, - callback: fn( - &mut InferenceContext, - // step - usize, - // token_id - u32, - // token - String, - ) -> bool, - tokens: &[String], - max_decoding_length: usize, - sampling_temperature: f32, - ) -> Vec; - } -} - -unsafe impl Send for ffi::TextInferenceEngine {} -unsafe impl Sync for ffi::TextInferenceEngine {} - -#[derive(Builder, Debug)] -pub struct CTranslate2EngineOptions { - model_path: String, - - model_type: String, - - tokenizer_path: String, - - device: String, - - device_indices: Vec, -} - -pub struct InferenceContext { - sender: Sender, - decoding: IncrementalDecoding, - cancel: CancellationToken, -} - -impl InferenceContext { - fn new( - sender: Sender, - decoding: IncrementalDecoding, - cancel: CancellationToken, - ) -> Self { - InferenceContext { - sender, - decoding, - cancel, - } - } -} - -pub struct CTranslate2Engine { - engine: cxx::SharedPtr, - decoding_factory: DecodingFactory, - tokenizer: Arc, -} - -impl CTranslate2Engine { - pub fn create(options: CTranslate2EngineOptions) -> Self where { - let engine = ffi::create_engine( - &options.model_path, - &options.model_type, - &options.device, - &options.device_indices, - ); - - return Self { - engine, - decoding_factory: DecodingFactory::default(), - tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()), - }; - } -} - -#[async_trait] -impl TextGeneration for CTranslate2Engine { - async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { - let s = self.generate_stream(prompt, options).await; - helpers::stream_to_string(s).await - } - - async fn generate_stream( - &self, - prompt: &str, - options: TextGenerationOptions, - ) -> BoxStream { - let encoding = self.tokenizer.encode(prompt, true).unwrap(); - let decoding = self.decoding_factory.create_incremental_decoding( - self.tokenizer.clone(), - truncate_tokens(encoding.get_ids(), options.max_input_length), - options.language, - ); - - let cancel = CancellationToken::new(); - let engine = self.engine.clone(); - let (sender, mut receiver) = channel::(8); - let context = InferenceContext::new(sender, decoding, cancel.clone()); - tokio::task::spawn_blocking(move || { - let context = Box::new(context); - engine.inference( - context, - inference_callback, - truncate_tokens(encoding.get_tokens(), options.max_input_length), - options.max_decoding_length, - options.sampling_temperature, - ); - }); - - let s = stream! { - let _guard = cancel.drop_guard(); - while let Some(text) = receiver.recv().await { - yield text; - } - }; - Box::pin(s) - } -} - -fn truncate_tokens(tokens: &[T], max_length: usize) -> &[T] { - if max_length < tokens.len() { - let start = tokens.len() - max_length; - &tokens[start..] - } else { - tokens - } -} - -fn inference_callback( - context: &mut InferenceContext, - _step: usize, - token_id: u32, - _token: String, -) -> bool { - if context.cancel.is_cancelled() { - true - } else if let Some(new_text) = context.decoding.next_token(token_id) { - let _ = context.sender.blocking_send(new_text); - false - } else { - true - } -} diff --git a/crates/llama-cpp-bindings/Cargo.toml b/crates/llama-cpp-bindings/Cargo.toml index d1afeb0..1c706d4 100644 --- a/crates/llama-cpp-bindings/Cargo.toml +++ b/crates/llama-cpp-bindings/Cargo.toml @@ -16,7 +16,6 @@ async-trait = { workspace = true } tokio = { workspace = true, features = ["rt"] } tabby-inference = { path = "../tabby-inference" } derive_builder = { workspace = true } -tokenizers = { workspace = true } tokio-util = { workspace = true } futures.workspace = true async-stream.workspace = true diff --git a/crates/llama-cpp-bindings/include/engine.h b/crates/llama-cpp-bindings/include/engine.h index e6a9c4d..ae75c50 100644 --- a/crates/llama-cpp-bindings/include/engine.h +++ b/crates/llama-cpp-bindings/include/engine.h @@ -4,16 +4,15 @@ #include namespace llama { +struct StepOutput; class TextInferenceEngine { public: virtual ~TextInferenceEngine(); - virtual void add_request(uint32_t request_id, rust::Slice input_token_ids) = 0; + virtual void add_request(uint32_t request_id, rust::Str text, size_t max_input_length) = 0; virtual void stop_request(uint32_t request_id) = 0; - virtual rust::Vec step() = 0; - - virtual uint32_t eos_token_id() const = 0; + virtual rust::Vec step() = 0; }; std::unique_ptr create_engine(bool use_gpu, rust::Str model_path); diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 375553a..5a1f9f7 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -8,6 +8,8 @@ #include #include +#include "llama-cpp-bindings/src/lib.rs.h" + namespace llama { TextInferenceEngine::~TextInferenceEngine() {} @@ -27,20 +29,56 @@ constexpr size_t N_BATCH = 512; // # per batch inference. constexpr size_t N_CTX = 4096; // # max kv history. struct Request { - Request(size_t request_id, rust::Slice input_token_ids) : + Request(size_t request_id, std::vector input_token_ids) : id(request_id), tokens(input_token_ids.begin(), input_token_ids.end()) { } - size_t id = -1; + uint32_t id = -1; llama_seq_id seq_id = -1; std::vector tokens; size_t i_batch = -1; size_t n_past = 0; + + int32_t multibyte_pending = 0; + std::string generated_text; }; +std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) { + std::vector result(8, 0); + const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + + return std::string(result.data(), result.size()); +} + +std::vector llama_tokenize( + const struct llama_model * model, + const rust::Str & text, + bool add_bos, + bool special) { + // upper limit for the number of tokens + int n_tokens = text.length() + add_bos; + std::vector result(n_tokens); + n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + return result; +} + template using owned = std::unique_ptr>; @@ -56,15 +94,20 @@ class TextInferenceEngineImpl : public TextInferenceEngine { llama_batch_free(batch_); } - void add_request(uint32_t request_id, rust::Slice input_token_ids) override { - pending_requests_.push_back(Request(request_id, input_token_ids)); + virtual void add_request(uint32_t request_id, rust::Str text, size_t max_input_length) override { + auto tokens = llama_tokenize(llama_get_model(ctx_.get()), text, false, true); + if (tokens.size() > max_input_length) { + int start = tokens.size() - max_input_length; + tokens = std::vector(tokens.begin() + start, tokens.end()); + } + pending_requests_.push_back(Request(request_id, tokens)); } void stop_request(uint32_t request_id) override { stopped_requests_.insert(request_id); } - rust::Vec step() override { + rust::Vec step() override { auto* ctx = ctx_.get(); auto n_vocab = llama_n_vocab(llama_get_model(ctx)); @@ -123,28 +166,29 @@ class TextInferenceEngineImpl : public TextInferenceEngine { request.i_batch = batch_.n_tokens - 1; } - rust::Vec result; - result.reserve(requests_.size() * 2); + rust::Vec result; + result.reserve(requests_.size()); // Decode tokens in chunks for (size_t i = 0; i < static_cast(batch_.n_tokens); i += N_BATCH) { const int32_t n_tokens = std::min(N_BATCH, batch_.n_tokens - i); - llama_batch batch_view = { - n_tokens, - batch_.token + i, - nullptr, - batch_.pos + i, - batch_.n_seq_id + i, - batch_.seq_id + i, - batch_.logits + i, - 0, 0, 0, // unused - }; + llama_batch batch_view = { + n_tokens, + batch_.token + i, + nullptr, + batch_.pos + i, + batch_.n_seq_id + i, + batch_.seq_id + i, + batch_.logits + i, + 0, 0, 0, // unused + }; - const int ret = llama_decode(ctx, batch_view); + const int ret = llama_decode(ctx, batch_view); if (ret != 0) { throw std::runtime_error("Failed to eval"); } + const auto eos_id = llama_token_eos(llama_get_model(ctx)); for (auto& request : requests_) { if ((request.i_batch < i) || (request.i_batch >= (i + n_tokens))) { continue; @@ -159,18 +203,44 @@ class TextInferenceEngineImpl : public TextInferenceEngine { request.tokens.clear(); request.tokens.push_back(next_token); - result.push_back(request.id); - result.push_back(next_token); + const auto token_str = llama_token_to_piece(ctx, next_token); + request.generated_text += token_str; + + // FIXME: Hack for codellama to simplify tabby's implementation. + const bool is_eos = next_token == eos_id || token_str == " "; + + if (request.multibyte_pending > 0) { + request.multibyte_pending -= token_str.size(); + } else if (token_str.size() == 1) { + const char c = token_str[0]; + // 2-byte characters: 110xxxxx 10xxxxxx + if ((c & 0xE0) == 0xC0) { + request.multibyte_pending = 1; + // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx + } + else if ((c & 0xF0) == 0xE0) { + request.multibyte_pending = 2; + // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + } else if ((c & 0xF8) == 0xF0) { + request.multibyte_pending = 3; + } + else { + request.multibyte_pending = 0; + } + } + + if (request.multibyte_pending == 0) { + rust::String generated_text = is_eos ? "" : request.generated_text; + result.push_back({request.id, generated_text}); + + request.generated_text.clear(); + } } } return result; } - uint32_t eos_token_id() const override { - return llama_token_eos(llama_get_model(ctx_.get())); - } - private: owned model_; owned ctx_; diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index 1e0ad2a..1c96016 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -7,10 +7,9 @@ use derive_builder::Builder; use ffi::create_engine; use futures::{lock::Mutex, stream::BoxStream}; use tabby_inference::{ - decoding::{DecodingFactory, IncrementalDecoding}, + decoding::{StopCondition, StopConditionFactory}, helpers, TextGeneration, TextGenerationOptions, }; -use tokenizers::tokenizer::Tokenizer; use tokio::{ sync::mpsc::{channel, Sender}, task::yield_now, @@ -18,6 +17,11 @@ use tokio::{ #[cxx::bridge(namespace = "llama")] mod ffi { + struct StepOutput { + request_id: u32, + text: String, + } + unsafe extern "C++" { include!("llama-cpp-bindings/include/engine.h"); @@ -28,12 +32,11 @@ mod ffi { fn add_request( self: Pin<&mut TextInferenceEngine>, request_id: u32, - input_token_ids: &[u32], + prompt: &str, + max_input_length: usize, ); fn stop_request(self: Pin<&mut TextInferenceEngine>, request_id: u32); - fn step(self: Pin<&mut TextInferenceEngine>) -> Result>; - - fn eos_token_id(&self) -> u32; + fn step(self: Pin<&mut TextInferenceEngine>) -> Result>; } } @@ -42,26 +45,22 @@ unsafe impl Sync for ffi::TextInferenceEngine {} struct InferenceRequest { tx: Sender, - decoding: IncrementalDecoding, + stop_condition: StopCondition, } struct AsyncTextInferenceEngine { engine: Mutex>, - tokenizer: Arc, - decoding_factory: DecodingFactory, + stop_condition_factory: StopConditionFactory, requests: Mutex>, next_request_id: Mutex, - eos_token_id: u32, } impl AsyncTextInferenceEngine { - fn create(engine: UniquePtr, tokenizer: Tokenizer) -> Self { + fn create(engine: UniquePtr) -> Self { Self { - eos_token_id: engine.eos_token_id(), engine: Mutex::new(engine), - tokenizer: Arc::new(tokenizer), - decoding_factory: DecodingFactory::default(), + stop_condition_factory: StopConditionFactory::default(), requests: Mutex::new(HashMap::new()), next_request_id: Mutex::new(0), } @@ -79,18 +78,15 @@ impl AsyncTextInferenceEngine { panic!("Failed to evaluation"); }; - for i in (0..result.len()).step_by(2) { - let request_id = result[i]; - let token_id = result[i + 1]; - - let InferenceRequest { tx, decoding } = requests.get_mut(&request_id).unwrap(); + for ffi::StepOutput { request_id, text } in result { let mut stopped = false; + let InferenceRequest { tx, stop_condition } = requests.get_mut(&request_id).unwrap(); - if tx.is_closed() || token_id == self.eos_token_id { + if tx.is_closed() || text.is_empty() { // Cancelled by client side or hit eos. stopped = true; - } else if let Some(new_text) = decoding.next_token(token_id) { - match tx.send(new_text).await { + } else if !stop_condition.should_stop(&text) { + match tx.send(text).await { Ok(_) => (), Err(_) => stopped = true, } @@ -111,25 +107,21 @@ impl AsyncTextInferenceEngine { prompt: &str, options: TextGenerationOptions, ) -> BoxStream { - let encoding = self.tokenizer.encode(prompt, true).unwrap(); - let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length); - let decoding = self.decoding_factory.create_incremental_decoding( - self.tokenizer.clone(), - input_token_ids, - options.language, - ); + let stop_condition = self.stop_condition_factory.create(prompt, options.language); let (tx, mut rx) = channel::(4); { let mut engine = self.engine.lock().await; - let engine = engine.as_mut().unwrap(); let mut request_id = self.next_request_id.lock().await; self.requests .lock() .await - .insert(*request_id, InferenceRequest { tx, decoding }); - engine.add_request(*request_id, input_token_ids); + .insert(*request_id, InferenceRequest { tx, stop_condition }); + engine + .as_mut() + .unwrap() + .add_request(*request_id, prompt, options.max_input_length); // 2048 should be large enough to avoid collision. *request_id = (*request_id + 1) % 2048; @@ -155,7 +147,6 @@ impl AsyncTextInferenceEngine { #[derive(Builder, Debug)] pub struct LlamaTextGenerationOptions { model_path: String, - tokenizer_path: String, use_gpu: bool, } @@ -169,9 +160,8 @@ impl LlamaTextGeneration { if engine.is_null() { panic!("Unable to load model: {}", options.model_path); } - let tokenizer = Tokenizer::from_file(&options.tokenizer_path).unwrap(); let ret = LlamaTextGeneration { - engine: Arc::new(AsyncTextInferenceEngine::create(engine, tokenizer)), + engine: Arc::new(AsyncTextInferenceEngine::create(engine)), }; ret.start_background_job(); ret @@ -203,12 +193,3 @@ impl TextGeneration for LlamaTextGeneration { self.engine.generate_stream(prompt, options).await } } - -fn truncate_tokens(tokens: &[u32], max_length: usize) -> &[u32] { - if max_length < tokens.len() { - let start = tokens.len() - max_length; - &tokens[start..] - } else { - tokens - } -} diff --git a/crates/tabby-common/src/path.rs b/crates/tabby-common/src/path.rs index 17717a4..4a14392 100644 --- a/crates/tabby-common/src/path.rs +++ b/crates/tabby-common/src/path.rs @@ -78,14 +78,6 @@ impl ModelDir { self.path_string("tabby.json") } - pub fn tokenizer_file(&self) -> String { - self.path_string("tokenizer.json") - } - - pub fn ctranslate2_dir(&self) -> String { - self.path_string("ctranslate2") - } - pub fn ggml_q8_0_file(&self) -> String { self.path_string("ggml/q8_0.gguf") } diff --git a/crates/tabby-download/src/lib.rs b/crates/tabby-download/src/lib.rs index 16cf31a..99ebc3b 100644 --- a/crates/tabby-download/src/lib.rs +++ b/crates/tabby-download/src/lib.rs @@ -29,27 +29,8 @@ impl Downloader { } } - pub async fn download_ctranslate2_files(&self) -> Result<()> { - let files = vec![ - ("tabby.json", true), - ("tokenizer.json", true), - ("ctranslate2/vocabulary.txt", false), - ("ctranslate2/shared_vocabulary.txt", false), - ("ctranslate2/vocabulary.json", false), - ("ctranslate2/shared_vocabulary.json", false), - ("ctranslate2/config.json", true), - ("ctranslate2/model.bin", true), - ]; - - self.download_files(&files).await - } - pub async fn download_ggml_files(&self) -> Result<()> { - let files = vec![ - ("tabby.json", true), - ("tokenizer.json", true), - ("ggml/q8_0.v2.gguf", true), - ]; + let files = vec![("tabby.json", true), ("ggml/q8_0.v2.gguf", true)]; self.download_files(&files).await } diff --git a/crates/tabby-inference/Cargo.toml b/crates/tabby-inference/Cargo.toml index f0cd8ce..7df2574 100644 --- a/crates/tabby-inference/Cargo.toml +++ b/crates/tabby-inference/Cargo.toml @@ -12,5 +12,4 @@ dashmap = "5.5.3" derive_builder = "0.12.0" futures = { workspace = true } regex.workspace = true -tokenizers.workspace = true tabby-common = { path = "../tabby-common" } diff --git a/crates/tabby-inference/src/decoding.rs b/crates/tabby-inference/src/decoding.rs index 158fab8..5bf202a 100644 --- a/crates/tabby-inference/src/decoding.rs +++ b/crates/tabby-inference/src/decoding.rs @@ -1,11 +1,8 @@ -use std::sync::Arc; - use dashmap::DashMap; use regex::Regex; use tabby_common::languages::Language; -use tokenizers::tokenizer::Tokenizer; -pub struct DecodingFactory { +pub struct StopConditionFactory { stop_regex_cache: DashMap, } @@ -16,7 +13,7 @@ where s.into().chars().rev().collect() } -impl Default for DecodingFactory { +impl Default for StopConditionFactory { fn default() -> Self { Self { stop_regex_cache: DashMap::new(), @@ -24,14 +21,9 @@ impl Default for DecodingFactory { } } -impl DecodingFactory { - pub fn create_incremental_decoding( - &self, - tokenizer: Arc, - input_token_ids: &[u32], - language: &'static Language, - ) -> IncrementalDecoding { - IncrementalDecoding::new(tokenizer, self.get_re(language), input_token_ids) +impl StopConditionFactory { + pub fn create(&self, text: &str, language: &'static Language) -> StopCondition { + StopCondition::new(self.get_re(language), text) } fn get_re(&self, language: &'static Language) -> Option { @@ -62,68 +54,31 @@ fn create_stop_regex(stop_words: Vec) -> Regex { Regex::new(®ex_string).expect("Failed to create regex") } -pub struct IncrementalDecoding { - tokenizer: Arc, +pub struct StopCondition { stop_re: Option, - - token_ids: Vec, - prefix_offset: usize, - read_offset: usize, - reversed_text: String, } -impl IncrementalDecoding { - pub fn new(tokenizer: Arc, stop_re: Option, input_token_ids: &[u32]) -> Self { - let text = tokenizer - .decode(input_token_ids, /* skip_special_token = */ true) - .expect("Cannot decode token from tokenizer."); +impl StopCondition { + pub fn new(stop_re: Option, text: &str) -> Self { Self { - tokenizer, stop_re, - token_ids: input_token_ids.to_owned(), - prefix_offset: 0, - read_offset: input_token_ids.len(), reversed_text: reverse(text), } } - pub fn next_token(&mut self, token_id: u32) -> Option { - let skip_special_token = true; - self.token_ids.push(token_id); - - let prefix_text = self - .tokenizer - .decode( - &self.token_ids[self.prefix_offset..self.read_offset], - skip_special_token, - ) - .expect("Cannot decode token from tokenizer."); - - let new_text = self - .tokenizer - .decode(&self.token_ids[self.prefix_offset..], skip_special_token) - .expect("Cannot decode token from tokenizer."); - - let new_text = if new_text.len() > prefix_text.len() && !new_text.ends_with('�') { - self.prefix_offset = self.read_offset; - self.read_offset = self.token_ids.len(); - &new_text[prefix_text.len()..] - } else { - "" - }; - + pub fn should_stop(&mut self, new_text: &str) -> bool { if !new_text.is_empty() { self.reversed_text = reverse(new_text) + &self.reversed_text; if let Some(re) = &self.stop_re { if re.is_match(&self.reversed_text) { - return None; + return true; } } } - Some(new_text.to_owned()) + false } } diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index 0b89ea5..eff29f0 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -41,7 +41,6 @@ pub struct EngineInfo { fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box { let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() .model_path(model_dir.ggml_q8_0_v2_file()) - .tokenizer_path(model_dir.tokenizer_file()) .use_gpu(device.ggml_use_gpu()) .build() .unwrap();