168 lines
5.5 KiB
C++
168 lines
5.5 KiB
C++
#include "module.h"
|
|
|
|
#include <sstream>
|
|
|
|
#include <ctranslate2/storage_view.h>
|
|
|
|
#include "utils.h"
|
|
|
|
using namespace pybind11::literals;
|
|
|
|
namespace ctranslate2 {
|
|
namespace python {
|
|
|
|
static DataType typestr_to_dtype(const std::string& typestr) {
|
|
const auto type_code = typestr[1];
|
|
const auto num_bytes = typestr[2];
|
|
|
|
if (type_code == 'i') {
|
|
if (num_bytes == '1')
|
|
return DataType::INT8;
|
|
if (num_bytes == '2')
|
|
return DataType::INT16;
|
|
if (num_bytes == '4')
|
|
return DataType::INT32;
|
|
|
|
} else if (type_code == 'f') {
|
|
if (num_bytes == '2')
|
|
return DataType::FLOAT16;
|
|
if (num_bytes == '4')
|
|
return DataType::FLOAT32;
|
|
}
|
|
|
|
throw std::invalid_argument("Unsupported type: " + typestr);
|
|
}
|
|
|
|
static std::string dtype_to_typestr(const DataType dtype) {
|
|
// Assume little-endian.
|
|
|
|
switch (dtype) {
|
|
case DataType::FLOAT32:
|
|
return "<f4";
|
|
case DataType::FLOAT16:
|
|
return "<f2";
|
|
case DataType::INT8:
|
|
return "|i1";
|
|
case DataType::INT16:
|
|
return "<i2";
|
|
case DataType::INT32:
|
|
return "<i4";
|
|
default:
|
|
return "";
|
|
}
|
|
}
|
|
|
|
static StorageView create_view_from_array(py::object array) {
|
|
auto device = Device::CPU;
|
|
|
|
py::object interface_obj = py::getattr(array, "__array_interface__", py::none());
|
|
if (interface_obj.is_none()) {
|
|
interface_obj = py::getattr(array, "__cuda_array_interface__", py::none());
|
|
if (interface_obj.is_none())
|
|
throw std::invalid_argument("Object does not implement the array interface");
|
|
device = Device::CUDA;
|
|
}
|
|
|
|
py::dict interface = interface_obj.cast<py::dict>();
|
|
if (interface_obj.contains("strides") && !interface_obj["strides"].is_none())
|
|
throw std::invalid_argument("StorageView does not support arrays with non contiguous memory");
|
|
|
|
auto shape = interface["shape"].cast<Shape>();
|
|
auto dtype = typestr_to_dtype(interface["typestr"].cast<std::string>());
|
|
auto data = interface["data"].cast<py::tuple>();
|
|
auto ptr = data[0].cast<uintptr_t>();
|
|
auto read_only = data[1].cast<bool>();
|
|
|
|
if (read_only)
|
|
throw std::invalid_argument("StorageView does not support read-only arrays");
|
|
|
|
StorageView view(dtype, device);
|
|
view.view((void*)ptr, std::move(shape));
|
|
return view;
|
|
}
|
|
|
|
static py::dict get_array_interface(const StorageView& view) {
|
|
py::tuple shape(view.rank());
|
|
for (size_t i = 0; i < shape.size(); ++i)
|
|
shape[i] = view.dim(i);
|
|
|
|
return py::dict(
|
|
"shape"_a=shape,
|
|
"typestr"_a=dtype_to_typestr(view.dtype()),
|
|
"data"_a=py::make_tuple((uintptr_t)view.buffer(), false),
|
|
"version"_a=3);
|
|
}
|
|
|
|
void register_storage_view(py::module& m) {
|
|
py::class_<StorageView>(
|
|
m, "StorageView",
|
|
R"pbdoc(
|
|
An allocated buffer with shape information.
|
|
|
|
The object implements the
|
|
`Array Interface <https://numpy.org/doc/stable/reference/arrays.interface.html>`_
|
|
and the
|
|
`CUDA Array Interface <https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html>`_
|
|
so that it can be passed to Numpy or PyTorch without copy.
|
|
|
|
Example:
|
|
|
|
>>> x = np.ones((2, 4), dtype=np.int32)
|
|
>>> y = ctranslate2.StorageView.from_array(x)
|
|
>>> print(y)
|
|
1 1 1 ... 1 1 1
|
|
[cpu:0 int32 storage viewed as 2x4]
|
|
>>> z = np.array(y)
|
|
...
|
|
>>> x = torch.ones((2, 4), dtype=torch.int32, device="cuda")
|
|
>>> y = ctranslate2.StorageView.from_array(x)
|
|
>>> print(y)
|
|
1 1 1 ... 1 1 1
|
|
[cuda:0 int32 storage viewed as 2x4]
|
|
>>> z = torch.as_tensor(y, device="cuda")
|
|
|
|
)pbdoc")
|
|
|
|
.def_static("from_array", &create_view_from_array, py::arg("array"),
|
|
py::keep_alive<0, 1>(),
|
|
R"pbdoc(
|
|
Creates a ``StorageView`` from an object implementing the array interface.
|
|
|
|
Arguments:
|
|
array: An object implementing the array interface (e.g. a Numpy array
|
|
or a PyTorch Tensor).
|
|
|
|
Returns:
|
|
A new ``StorageView`` instance sharing the same data as the input array.
|
|
|
|
Raises:
|
|
ValueError: if the object does not implement the array interface or
|
|
uses an unsupported array specification.
|
|
)pbdoc")
|
|
|
|
.def_property_readonly("__array_interface__", [](const StorageView& view) {
|
|
if (view.device() == Device::CUDA)
|
|
throw py::attribute_error("Cannot get __array_interface__ when the StorageView "
|
|
"is viewing a CUDA array");
|
|
return get_array_interface(view);
|
|
})
|
|
|
|
.def_property_readonly("__cuda_array_interface__", [](const StorageView& view) {
|
|
if (view.device() == Device::CPU)
|
|
throw py::attribute_error("Cannot get __cuda_array_interface__ when the StorageView "
|
|
"is viewing a CPU array");
|
|
return get_array_interface(view);
|
|
})
|
|
|
|
.def("__str__", [](const StorageView& view) {
|
|
std::ostringstream stream;
|
|
stream << view;
|
|
return stream.str();
|
|
})
|
|
|
|
;
|
|
}
|
|
|
|
}
|
|
}
|