#include "ctranslate2-bindings/include/ctranslate2.h" #include "ctranslate2/translator.h" namespace tabby { TextInferenceEngine::~TextInferenceEngine() {} class TextInferenceEngineImpl : public TextInferenceEngine { public: TextInferenceEngineImpl(std::unique_ptr translator) : translator_(std::move(translator)) {} ~TextInferenceEngineImpl() {} rust::Vec inference( rust::Slice tokens, size_t max_decoding_length, float sampling_temperature, size_t beam_size ) const { // Create options. ctranslate2::TranslationOptions options; options.max_decoding_length = max_decoding_length; options.sampling_temperature = sampling_temperature; options.beam_size = beam_size; // Inference. std::vector input_tokens(tokens.begin(), tokens.end()); ctranslate2::TranslationResult result = translator_->translate_batch({ input_tokens }, options)[0]; const auto& output_tokens = result.output(); // Convert to rust vec. rust::Vec output; output.reserve(output_tokens.size()); std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output)); return output; } private: std::unique_ptr translator_; }; std::unique_ptr create_engine( rust::Str model_path, rust::Str device, rust::Slice device_indices, size_t num_replicas_per_device ) { // model_path. std::string model_path_string(model_path); ctranslate2::models::ModelLoader loader(model_path_string); // device. std::string device_string(device); if (device_string == "cuda") { loader.device = ctranslate2::Device::CUDA; } else if (device_string == "cpu") { loader.device = ctranslate2::Device::CPU; } // device_indices loader.device_indices.clear(); std::copy(device_indices.begin(), device_indices.end(), std::back_inserter(loader.device_indices)); // num_replicas_per_device loader.num_replicas_per_device = num_replicas_per_device; auto translator = std::make_unique(loader); return std::make_unique(std::move(translator)); } } // namespace tabby