2023-05-25 21:05:28 +00:00
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include "rust/cxx.h"
|
2023-05-26 01:18:22 +00:00
|
|
|
#include <memory>
|
2023-05-25 21:05:28 +00:00
|
|
|
|
|
|
|
|
namespace tabby {
|
|
|
|
|
|
2023-06-04 22:28:39 +00:00
|
|
|
struct InferenceContext;
|
|
|
|
|
|
2023-06-06 23:28:58 +00:00
|
|
|
typedef rust::Fn<bool(InferenceContext&, size_t, uint32_t, rust::String)> InferenceCallback;
|
|
|
|
|
|
2023-05-25 21:05:28 +00:00
|
|
|
class TextInferenceEngine {
|
|
|
|
|
public:
|
|
|
|
|
virtual ~TextInferenceEngine();
|
2023-06-06 23:28:58 +00:00
|
|
|
virtual rust::Vec<uint32_t> inference(
|
2023-06-04 22:28:39 +00:00
|
|
|
rust::Box<InferenceContext> context,
|
2023-06-06 23:28:58 +00:00
|
|
|
InferenceCallback callback,
|
2023-05-25 21:05:28 +00:00
|
|
|
rust::Slice<const rust::String> tokens,
|
|
|
|
|
size_t max_decoding_length,
|
2023-06-06 12:46:17 +00:00
|
|
|
float sampling_temperature
|
2023-05-25 21:05:28 +00:00
|
|
|
) const = 0;
|
|
|
|
|
};
|
|
|
|
|
|
2023-06-04 06:23:31 +00:00
|
|
|
std::shared_ptr<TextInferenceEngine> create_engine(
|
2023-05-26 06:23:07 +00:00
|
|
|
rust::Str model_path,
|
2023-05-27 08:26:33 +00:00
|
|
|
rust::Str model_type,
|
2023-05-26 06:23:07 +00:00
|
|
|
rust::Str device,
|
|
|
|
|
rust::Slice<const int32_t> device_indices,
|
|
|
|
|
size_t num_replicas_per_device
|
|
|
|
|
);
|
2023-05-25 21:05:28 +00:00
|
|
|
} // namespace
|