2023-05-25 21:05:28 +00:00
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include "rust/cxx.h"
|
2023-05-26 01:18:22 +00:00
|
|
|
#include <memory>
|
2023-05-25 21:05:28 +00:00
|
|
|
|
|
|
|
|
namespace tabby {
|
|
|
|
|
|
|
|
|
|
class TextInferenceEngine {
|
|
|
|
|
public:
|
|
|
|
|
virtual ~TextInferenceEngine();
|
|
|
|
|
virtual rust::Vec<rust::String> inference(
|
|
|
|
|
rust::Slice<const rust::String> tokens,
|
|
|
|
|
size_t max_decoding_length,
|
|
|
|
|
float sampling_temperature,
|
|
|
|
|
size_t beam_size
|
|
|
|
|
) const = 0;
|
|
|
|
|
};
|
|
|
|
|
|
2023-05-26 06:23:07 +00:00
|
|
|
std::unique_ptr<TextInferenceEngine> create_engine(
|
|
|
|
|
rust::Str model_path,
|
2023-05-27 08:26:33 +00:00
|
|
|
rust::Str model_type,
|
2023-05-26 06:23:07 +00:00
|
|
|
rust::Str device,
|
|
|
|
|
rust::Slice<const int32_t> device_indices,
|
|
|
|
|
size_t num_replicas_per_device
|
|
|
|
|
);
|
2023-05-25 21:05:28 +00:00
|
|
|
} // namespace
|