|
|
|
@@ -25,7 +25,7 @@ |
|
|
|
|
|
|
|
using namespace mindspore::dataset; |
|
|
|
|
|
|
|
class MindDataTestTensorDE : public UT::Common { |
|
|
|
class MindDataTestTensorDE : public mindspore::Common { |
|
|
|
public: |
|
|
|
MindDataTestTensorDE() {} |
|
|
|
}; |
|
|
|
@@ -42,7 +42,7 @@ TEST_F(MindDataTestTensorDE, MSTensorConvertToLiteTensor) { |
|
|
|
std::shared_ptr<mindspore::tensor::MSTensor> lite_ms_tensor = std::shared_ptr<mindspore::tensor::MSTensor>( |
|
|
|
std::dynamic_pointer_cast<mindspore::tensor::DETensor>(ms_tensor)->ConvertToLiteTensor()); |
|
|
|
// check if the lite_ms_tensor is the derived LiteTensor |
|
|
|
mindspore::tensor::LiteTensor * lite_tensor = static_cast<mindspore::tensor::LiteTensor *>(lite_ms_tensor.get()); |
|
|
|
mindspore::lite::tensor::LiteTensor * lite_tensor = static_cast<mindspore::lite::tensor::LiteTensor *>(lite_ms_tensor.get()); |
|
|
|
ASSERT_EQ(lite_tensor != nullptr, true); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -77,7 +77,7 @@ TEST_F(MindDataTestTensorDE, MSTensorDataType) { |
|
|
|
TEST_F(MindDataTestTensorDE, MSTensorMutableData) { |
|
|
|
std::vector<float> x = {2.5, 2.5, 2.5, 2.5}; |
|
|
|
std::shared_ptr<Tensor> t; |
|
|
|
Tensor::CreateTensor(&t, x, TensorShape({2, 2})); |
|
|
|
Tensor::CreateFromVector(x, TensorShape({2, 2}), &t); |
|
|
|
auto ms_tensor = std::shared_ptr<mindspore::tensor::MSTensor>(new mindspore::tensor::DETensor(t)); |
|
|
|
float *data = static_cast<float*>(ms_tensor->MutableData()); |
|
|
|
std::vector<float> tensor_vec(data, data + ms_tensor->ElementsNum()); |
|
|
|
@@ -88,7 +88,7 @@ TEST_F(MindDataTestTensorDE, MSTensorMutableData) { |
|
|
|
TEST_F(MindDataTestTensorDE, MSTensorHash) { |
|
|
|
std::vector<float> x = {2.5, 2.5, 2.5, 2.5}; |
|
|
|
std::shared_ptr<Tensor> t; |
|
|
|
Tensor::CreateTensor(&t, x, TensorShape({2, 2})); |
|
|
|
Tensor::CreateFromVector(x, TensorShape({2, 2}), &t); |
|
|
|
auto ms_tensor = std::shared_ptr<mindspore::tensor::MSTensor>(new mindspore::tensor::DETensor(t)); |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
ASSERT_EQ(ms_tensor->hash() == 11093771382437, true); // arm64 |