@@ -120,51 +120,24 @@ static bool IsCContiguous(const py::array &input) {
// TensorDataNumpy implements TensorData using numpy array.
class TensorDataNumpy : public TensorData {
public:
explicit TensorDataNumpy(const py::array &input) : data_(input) {
if (!IsCContiguous(data_)) {
// Call numpy.ascontiguousarray() to convert data to C contiguous if it is not.
auto np = py::module::import("numpy");
auto convert = np.attr("ascontiguousarray");
data_ = convert(data_);
}
}
explicit TensorDataNumpy(py::buffer_info &&buffer) : buffer_(std::move(buffer)) {}
/// Total number of elements.
ssize_t size() const override { return data_.size() ; }
ssize_t size() const override { return buffer_.size; }
/// Byte size of a single element.
ssize_t itemsize() const override { return data_.itemsize() ; }
ssize_t itemsize() const override { return buffer_.itemsize; }
/// Total number of bytes.
ssize_t nbytes() const override { return data_.nbytes() ; }
ssize_t nbytes() const override { return buffer_.itemsize * buffer_.size ; }
/// Number of dimensions.
ssize_t ndim() const override { return data_.ndim() ; }
ssize_t ndim() const override { return buffer_.ndim ; }
/// Data pointer.
void *data() override { return data_.request().ptr; }
const void *const_data() const override { return data_.request().ptr; }
/// Is data equals.
bool equals(const TensorData &other) const override {
auto ptr = dynamic_cast<const TensorDataNumpy *>(&other);
if (ptr == nullptr) {
// Not same type, compare data byte by byte.
return TensorData::equals(other);
}
return NumpyEquals(*ptr);
}
void *data() override { return buffer_.ptr; }
bool NumpyEquals(const TensorDataNumpy &other) const {
auto all_data_equal = [&other, this]() -> bool {
auto np = py::module::import("numpy");
auto equal = np.attr("equal")(data_, other.data_);
auto all_equal = np.attr("all")(equal);
return all_equal.cast<bool>();
};
return this == &other || data_.is(other.data_) || all_data_equal();
}
const void *const_data() const override { return buffer_.ptr; }
/// To string.
std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const override {
@@ -174,17 +147,21 @@ class TensorDataNumpy : public TensorData {
kwargs["separator"] = ", ";
auto np = py::module::import("numpy");
auto array2string = np.attr("array2string");
return py::str(array2string(data_ , **kwargs));
return py::str(array2string(py_array() , **kwargs));
}
// without comma.
return py::str(data_ );
return py::str(py_array() );
}
/// py::array object.
py::array py_array() const { return data_; }
py::array py_array() const {
// Use dummy owner to avoid copy data.
py::str dummyOwner;
return py::array(py::dtype(buffer_), buffer_.shape, buffer_.strides, buffer_.ptr, dummyOwner);
}
private:
mutable py::array data_;
py::buffer_info buffer _;
};
TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) {
@@ -226,6 +203,10 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
/// Creates a Tensor from a numpy array without copy
TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
// Check format.
if (!IsCContiguous(input)) {
MS_LOG(EXCEPTION) << "Array should be C contiguous.";
}
// Get input buffer info.
py::buffer_info buf = input.request();
// Get tensor dtype and check it.
@@ -236,7 +217,7 @@ TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
// Get tensor shape.
ShapeVector shape(buf.shape.begin(), buf.shape.end());
// Make a tensor with shared data with numpy array.
auto tensor_data = std::make_shared<TensorDataNumpy>(input );
auto tensor_data = std::make_shared<TensorDataNumpy>(std::move(buf) );
return std::make_shared<Tensor>(dtype, shape, tensor_data);
}