2023-05-25 21:05:28 +00:00
|
|
|
#include "ctranslate2-bindings/include/ctranslate2.h"
|
|
|
|
|
|
|
|
|
|
#include "ctranslate2/translator.h"
|
2023-05-27 08:26:33 +00:00
|
|
|
#include "ctranslate2/generator.h"
|
2023-05-25 21:05:28 +00:00
|
|
|
|
|
|
|
|
namespace tabby {
|
|
|
|
|
TextInferenceEngine::~TextInferenceEngine() {}
|
|
|
|
|
|
2023-06-04 22:28:39 +00:00
|
|
|
template <class Model, class Child>
|
|
|
|
|
class TextInferenceEngineImpl : public TextInferenceEngine {
|
|
|
|
|
protected:
|
|
|
|
|
struct Options {
|
|
|
|
|
size_t max_decoding_length;
|
|
|
|
|
float sampling_temperature;
|
|
|
|
|
};
|
|
|
|
|
|
2023-05-25 21:05:28 +00:00
|
|
|
public:
|
2023-06-06 23:28:58 +00:00
|
|
|
rust::Vec<uint32_t> inference(
|
2023-06-04 22:28:39 +00:00
|
|
|
rust::Box<InferenceContext> context,
|
2023-06-06 23:28:58 +00:00
|
|
|
InferenceCallback callback,
|
2023-05-25 21:05:28 +00:00
|
|
|
rust::Slice<const rust::String> tokens,
|
|
|
|
|
size_t max_decoding_length,
|
2023-06-06 12:46:17 +00:00
|
|
|
float sampling_temperature
|
2023-05-25 21:05:28 +00:00
|
|
|
) const {
|
|
|
|
|
// Inference.
|
|
|
|
|
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
2023-06-06 23:28:58 +00:00
|
|
|
return process(
|
2023-06-06 12:46:17 +00:00
|
|
|
std::move(context),
|
2023-06-06 23:28:58 +00:00
|
|
|
std::move(callback),
|
2023-06-06 12:46:17 +00:00
|
|
|
input_tokens,
|
|
|
|
|
Options{max_decoding_length, sampling_temperature}
|
|
|
|
|
);
|
2023-05-25 21:05:28 +00:00
|
|
|
}
|
2023-05-27 08:26:33 +00:00
|
|
|
|
|
|
|
|
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
2023-06-04 22:28:39 +00:00
|
|
|
auto impl = std::make_unique<Child>();
|
|
|
|
|
impl->model_ = std::make_unique<Model>(loader);
|
2023-05-27 08:26:33 +00:00
|
|
|
return impl;
|
|
|
|
|
}
|
2023-05-25 21:05:28 +00:00
|
|
|
|
2023-06-04 22:28:39 +00:00
|
|
|
protected:
|
2023-06-06 23:28:58 +00:00
|
|
|
virtual rust::Vec<uint32_t> process(
|
2023-06-06 12:46:17 +00:00
|
|
|
rust::Box<InferenceContext> context,
|
2023-06-06 23:28:58 +00:00
|
|
|
InferenceCallback callback,
|
2023-06-06 12:46:17 +00:00
|
|
|
const std::vector<std::string>& tokens,
|
|
|
|
|
const Options& options) const = 0;
|
2023-06-04 22:28:39 +00:00
|
|
|
std::unique_ptr<Model> model_;
|
|
|
|
|
};
|
2023-05-27 08:26:33 +00:00
|
|
|
|
2023-06-04 22:28:39 +00:00
|
|
|
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
|
|
|
|
|
protected:
|
2023-06-06 23:28:58 +00:00
|
|
|
virtual rust::Vec<uint32_t> process(
|
2023-06-06 12:46:17 +00:00
|
|
|
rust::Box<InferenceContext> context,
|
2023-06-06 23:28:58 +00:00
|
|
|
InferenceCallback callback,
|
2023-06-06 12:46:17 +00:00
|
|
|
const std::vector<std::string>& tokens,
|
|
|
|
|
const Options& options) const override {
|
2023-06-04 22:28:39 +00:00
|
|
|
ctranslate2::TranslationOptions x;
|
|
|
|
|
x.max_decoding_length = options.max_decoding_length;
|
|
|
|
|
x.sampling_temperature = options.sampling_temperature;
|
2023-06-06 12:46:17 +00:00
|
|
|
x.beam_size = 1;
|
2023-06-06 23:28:58 +00:00
|
|
|
rust::Vec<uint32_t> output_ids;
|
2023-06-06 12:46:17 +00:00
|
|
|
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
2023-06-06 23:28:58 +00:00
|
|
|
bool stop = callback(*context, result.step, result.token_id, result.token);
|
|
|
|
|
if (!stop) {
|
|
|
|
|
output_ids.push_back(result.token_id);
|
|
|
|
|
} else if (result.is_last) {
|
|
|
|
|
output_ids.push_back(result.token_id);
|
|
|
|
|
}
|
|
|
|
|
return stop;
|
2023-06-06 12:46:17 +00:00
|
|
|
};
|
|
|
|
|
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
|
2023-06-06 23:28:58 +00:00
|
|
|
return output_ids;
|
2023-05-27 08:26:33 +00:00
|
|
|
}
|
2023-06-04 22:28:39 +00:00
|
|
|
};
|
2023-05-27 08:26:33 +00:00
|
|
|
|
2023-06-04 22:28:39 +00:00
|
|
|
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
|
|
|
|
|
protected:
|
2023-06-06 23:28:58 +00:00
|
|
|
virtual rust::Vec<uint32_t> process(
|
2023-06-06 12:46:17 +00:00
|
|
|
rust::Box<InferenceContext> context,
|
2023-06-06 23:28:58 +00:00
|
|
|
InferenceCallback callback,
|
2023-06-06 12:46:17 +00:00
|
|
|
const std::vector<std::string>& tokens,
|
|
|
|
|
const Options& options) const override {
|
2023-06-04 22:28:39 +00:00
|
|
|
ctranslate2::GenerationOptions x;
|
|
|
|
|
x.include_prompt_in_result = false;
|
|
|
|
|
x.max_length = options.max_decoding_length;
|
|
|
|
|
x.sampling_temperature = options.sampling_temperature;
|
2023-06-06 12:46:17 +00:00
|
|
|
x.beam_size = 1;
|
2023-06-06 23:28:58 +00:00
|
|
|
|
|
|
|
|
rust::Vec<uint32_t> output_ids;
|
2023-06-06 12:46:17 +00:00
|
|
|
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
2023-06-06 23:28:58 +00:00
|
|
|
bool stop = callback(*context, result.step, result.token_id, result.token);
|
|
|
|
|
if (!stop) {
|
|
|
|
|
output_ids.push_back(result.token_id);
|
|
|
|
|
} else if (result.is_last) {
|
|
|
|
|
output_ids.push_back(result.token_id);
|
|
|
|
|
}
|
|
|
|
|
return stop;
|
2023-06-06 12:46:17 +00:00
|
|
|
};
|
2023-06-04 22:28:39 +00:00
|
|
|
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
2023-06-06 23:28:58 +00:00
|
|
|
return output_ids;
|
2023-05-27 08:26:33 +00:00
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2023-06-04 06:23:31 +00:00
|
|
|
std::shared_ptr<TextInferenceEngine> create_engine(
|
2023-05-26 06:23:07 +00:00
|
|
|
rust::Str model_path,
|
2023-05-27 08:26:33 +00:00
|
|
|
rust::Str model_type,
|
2023-05-26 06:23:07 +00:00
|
|
|
rust::Str device,
|
2023-06-13 19:04:07 +00:00
|
|
|
rust::Str compute_type,
|
2023-05-26 06:23:07 +00:00
|
|
|
rust::Slice<const int32_t> device_indices,
|
|
|
|
|
size_t num_replicas_per_device
|
|
|
|
|
) {
|
2023-05-27 08:26:33 +00:00
|
|
|
std::string model_type_str(model_type);
|
|
|
|
|
std::string model_path_str(model_path);
|
|
|
|
|
ctranslate2::models::ModelLoader loader(model_path_str);
|
|
|
|
|
loader.device = ctranslate2::str_to_device(std::string(device));
|
|
|
|
|
loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
|
2023-05-26 06:23:07 +00:00
|
|
|
loader.num_replicas_per_device = num_replicas_per_device;
|
|
|
|
|
|
2023-06-13 19:04:07 +00:00
|
|
|
std::string compute_type_str(compute_type);
|
|
|
|
|
if (compute_type_str == "auto") {
|
|
|
|
|
if (loader.device == ctranslate2::Device::CPU) {
|
|
|
|
|
loader.compute_type = ctranslate2::ComputeType::INT8;
|
|
|
|
|
} else if (loader.device == ctranslate2::Device::CUDA) {
|
|
|
|
|
loader.compute_type = ctranslate2::ComputeType::INT8_FLOAT16;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str);
|
2023-05-28 21:36:11 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (model_type_str == "AutoModelForCausalLM") {
|
2023-05-27 08:26:33 +00:00
|
|
|
return DecoderImpl::create(loader);
|
2023-05-28 21:36:11 +00:00
|
|
|
} else if (model_type_str == "AutoModelForSeq2SeqLM") {
|
2023-05-27 08:26:33 +00:00
|
|
|
return EncoderDecoderImpl::create(loader);
|
|
|
|
|
} else {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
2023-05-25 21:05:28 +00:00
|
|
|
}
|
|
|
|
|
} // namespace tabby
|