|
|
|
@@ -122,6 +122,8 @@ class TensorDataNumpy : public TensorData { |
|
|
|
public: |
|
|
|
explicit TensorDataNumpy(py::buffer_info &&buffer) : buffer_(std::move(buffer)) {} |
|
|
|
|
|
|
|
~TensorDataNumpy() override = default; |
|
|
|
|
|
|
|
/// Total number of elements. |
|
|
|
ssize_t size() const override { return buffer_.size; } |
|
|
|
|
|
|
|
@@ -160,7 +162,7 @@ class TensorDataNumpy : public TensorData { |
|
|
|
return py::array(py::dtype(buffer_), buffer_.shape, buffer_.strides, buffer_.ptr, dummyOwner); |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
// The internal buffer. |
|
|
|
py::buffer_info buffer_; |
|
|
|
}; |
|
|
|
|
|
|
|
@@ -258,7 +260,7 @@ py::array TensorPy::SyncAsNumpy(const Tensor &tensor) { |
|
|
|
|
|
|
|
py::array TensorPy::AsNumpy(const Tensor &tensor) { |
|
|
|
auto data_numpy = dynamic_cast<const TensorDataNumpy *>(&tensor.data()); |
|
|
|
if (data_numpy) { |
|
|
|
if (data_numpy != nullptr) { |
|
|
|
// Return internal numpy array if tensor data is implemented base on it. |
|
|
|
return data_numpy->py_array(); |
|
|
|
} |
|
|
|
|