#include "ctranslate2-bindings/include/ctranslate2.h" #include "ctranslate2/translator.h" #include "ctranslate2/generator.h" namespace tabby { TextInferenceEngine::~TextInferenceEngine() {} template class TextInferenceEngineImpl : public TextInferenceEngine { protected: struct Options { size_t max_decoding_length; float sampling_temperature; }; public: rust::Vec inference( rust::Box context, InferenceCallback callback, rust::Slice tokens, size_t max_decoding_length, float sampling_temperature ) const { // Inference. std::vector input_tokens(tokens.begin(), tokens.end()); return process( std::move(context), std::move(callback), input_tokens, Options{max_decoding_length, sampling_temperature} ); } static std::unique_ptr create(const ctranslate2::models::ModelLoader& loader) { auto impl = std::make_unique(); impl->model_ = std::make_unique(loader); return impl; } protected: virtual rust::Vec process( rust::Box context, InferenceCallback callback, const std::vector& tokens, const Options& options) const = 0; std::unique_ptr model_; }; class EncoderDecoderImpl : public TextInferenceEngineImpl { protected: virtual rust::Vec process( rust::Box context, InferenceCallback callback, const std::vector& tokens, const Options& options) const override { ctranslate2::TranslationOptions x; x.max_decoding_length = options.max_decoding_length; x.sampling_temperature = options.sampling_temperature; x.beam_size = 1; rust::Vec output_ids; x.callback = [&](ctranslate2::GenerationStepResult result) { bool stop = callback(*context, result.step, result.token_id, result.token); if (!stop) { output_ids.push_back(result.token_id); } else if (result.is_last) { output_ids.push_back(result.token_id); } return stop; }; ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0]; return output_ids; } }; class DecoderImpl : public TextInferenceEngineImpl { protected: virtual rust::Vec process( rust::Box context, InferenceCallback callback, const std::vector& tokens, const Options& options) const override { ctranslate2::GenerationOptions x; x.include_prompt_in_result = false; x.max_length = options.max_decoding_length; x.sampling_temperature = options.sampling_temperature; x.beam_size = 1; rust::Vec output_ids; x.callback = [&](ctranslate2::GenerationStepResult result) { bool stop = callback(*context, result.step, result.token_id, result.token); if (!stop) { output_ids.push_back(result.token_id); } else if (result.is_last) { output_ids.push_back(result.token_id); } return stop; }; ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get(); return output_ids; } }; std::shared_ptr create_engine( rust::Str model_path, rust::Str model_type, rust::Str device, rust::Str compute_type, rust::Slice device_indices, size_t num_replicas_per_device ) { std::string model_type_str(model_type); std::string model_path_str(model_path); ctranslate2::models::ModelLoader loader(model_path_str); loader.device = ctranslate2::str_to_device(std::string(device)); loader.device_indices = std::vector(device_indices.begin(), device_indices.end()); loader.num_replicas_per_device = num_replicas_per_device; std::string compute_type_str(compute_type); if (compute_type_str == "auto") { if (loader.device == ctranslate2::Device::CPU) { loader.compute_type = ctranslate2::ComputeType::INT8; } else if (loader.device == ctranslate2::Device::CUDA) { loader.compute_type = ctranslate2::ComputeType::INT8_FLOAT16; } } else { loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str); } if (model_type_str == "AutoModelForCausalLM") { return DecoderImpl::create(loader); } else if (model_type_str == "AutoModelForSeq2SeqLM") { return EncoderDecoderImpl::create(loader); } else { return nullptr; } } } // namespace tabby