tabby/crates/ctranslate2-bindings/ctranslate2/tests/decoding_test.cc

17 lines
449 B
C++
Raw Normal View History

#include <ctranslate2/decoding.h>
#include "test_utils.h"
TEST(DecodingTest, DisableTokens) {
StorageView input({2, 5}, std::vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
StorageView expected({2, 5}, std::vector<float>{1, 0, 0, 4, 5, 6, 7, 0, 9, 0});
DisableTokens disable_tokens(input, 0);
disable_tokens.add(2);
disable_tokens.add(0, 1);
disable_tokens.add(1, 4);
disable_tokens.apply();
expect_storage_eq(input, expected);
}