tabby/crates/ctranslate2-bindings/ctranslate2/examples/wngt2020/main.cc

76 lines
2.7 KiB
C++

#include <sentencepiece_processor.h>
#include <ctranslate2/translator_pool.h>
#include <ctranslate2/models/sequence_to_sequence.h>
#include <regex>
static std::vector<std::string> get_vocabulary_tokens(const ctranslate2::Vocabulary& vocabulary) {
std::vector<std::string> tokens;
const size_t size = vocabulary.size();
tokens.reserve(size);
for (size_t i = 0; i < size; ++i)
tokens.emplace_back(vocabulary.to_token(i));
return tokens;
}
int main(int, char* argv[]) {
const std::string in_file = argv[1];
const std::string out_file = argv[2];
const int num_cores = std::stoi(std::string(argv[3]));
const std::string model_path = "/model";
const std::string sp_model_path = model_path + "/sp.model";
const ctranslate2::Device device = ctranslate2::str_to_device("auto");
ctranslate2::ComputeType compute_type = ctranslate2::ComputeType::INT8;
size_t num_threads_per_replica = 1;
size_t num_replicas = 0;
size_t max_batch_size = 0;
if (device == ctranslate2::Device::CUDA) {
num_replicas = 1;
max_batch_size = 6000;
compute_type = ctranslate2::ComputeType::FLOAT16;
} else if (num_cores == 1) {
num_replicas = 1;
max_batch_size = 512;
} else {
num_replicas = num_cores / 2;
max_batch_size = 256;
}
const auto model = ctranslate2::models::Model::load(model_path, device, 0, compute_type);
ctranslate2::TranslatorPool pool(num_replicas, num_threads_per_replica, model);
sentencepiece::SentencePieceProcessor sp_processor;
auto status = sp_processor.Load(sp_model_path);
if (!status.ok())
throw std::invalid_argument("Unable to open SentencePiece model " + sp_model_path);
const auto* seq2seq_model = dynamic_cast<const ctranslate2::models::SequenceToSequenceModel*>(model.get());
status = sp_processor.SetVocabulary(get_vocabulary_tokens(seq2seq_model->get_source_vocabulary()));
if (!status.ok())
throw std::runtime_error("Failed to set the SentencePiece vocabulary");
auto tokenizer = [&sp_processor](const std::string& text) {
std::vector<std::string> tokens;
sp_processor.Encode(text, &tokens);
return tokens;
};
auto detokenizer = [&sp_processor](const std::vector<std::string>& tokens) {
std::string text;
sp_processor.Decode(tokens, &text);
return std::regex_replace(text, std::regex("<unk>"), "UNK");
};
ctranslate2::TranslationOptions options;
options.beam_size = 1;
options.max_decoding_length = 150;
options.max_batch_size = max_batch_size;
options.batch_type = ctranslate2::BatchType::Tokens;
options.use_vmap = true;
options.return_scores = false;
const size_t read_batch_size = max_batch_size * 16;
pool.consume_raw_text_file(in_file, out_file, tokenizer, detokenizer, read_batch_size, options);
return 0;
}