| @@ -20,7 +20,8 @@ | |||
| #include "minddata/dataset/core/data_type.h" | |||
| #include "mindspore/core/ir/dtype/type_id.h" | |||
| #include "utils/hashing.h" | |||
| #include "mindspore/lite/src/ir/tensor.h" | |||
| #include "mindspore/lite/internal/include/ms_tensor.h" | |||
| #include "mindspore/core/utils/convert_utils_base.h" | |||
| namespace mindspore { | |||
| namespace tensor { | |||
| @@ -59,7 +60,7 @@ DETensor::DETensor(std::shared_ptr<dataset::Tensor> tensor_ptr) { this->tensor_i | |||
| MSTensor *DETensor::ConvertToLiteTensor() { | |||
| // static MSTensor::CreateTensor is only for the LiteTensor | |||
| MSTensor *tensor = MSTensor::CreateTensor(this->data_type(), this->shape()); | |||
| MSTensor *tensor = CreateTensor(this->data_type(), this->shape()); | |||
| MS_ASSERT(tensor->Size() == this->Size()); | |||
| memcpy_s(tensor->MutableData(), tensor->Size(), this->MutableData(), this->Size()); | |||
| return tensor; | |||
| @@ -141,7 +142,7 @@ size_t DETensor::Size() const { | |||
| return this->tensor_impl_->SizeInBytes(); | |||
| } | |||
| void *DETensor::MutableData() const { | |||
| void *DETensor::MutableData() { | |||
| MS_ASSERT(this->tensor_impl_ != nullptr); | |||
| return this->tensor_impl_->GetMutableBuffer(); | |||
| } | |||
| @@ -24,7 +24,7 @@ | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace tensor { | |||
| class DETensor : public MSTensor { | |||
| class DETensor : public mindspore::tensor::MSTensor { | |||
| public: | |||
| /// \brief Create a MSTensor pointer. | |||
| /// \param[in] data_type DataTypeId of tensor to be created | |||
| @@ -58,21 +58,21 @@ class DETensor : public MSTensor { | |||
| TypeId data_type() const override; | |||
| TypeId set_data_type(const TypeId data_type) override; | |||
| TypeId set_data_type(const TypeId data_type); | |||
| std::vector<int> shape() const override; | |||
| size_t set_shape(const std::vector<int> &shape) override; | |||
| size_t set_shape(const std::vector<int> &shape); | |||
| int DimensionSize(size_t index) const override; | |||
| int ElementsNum() const override; | |||
| std::size_t hash() const override; | |||
| std::size_t hash() const; | |||
| size_t Size() const override; | |||
| void *MutableData() const override; | |||
| void *MutableData() override; | |||
| protected: | |||
| std::shared_ptr<dataset::Tensor> tensor_impl_; | |||
| @@ -65,7 +65,7 @@ extern "C" JNIEXPORT void JNICALL Java_com_example_mindsporepredict_MainActivity | |||
| // Create a Cifar10 Dataset | |||
| std::string folder_path = env->GetStringUTFChars(path, 0); | |||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); | |||
| // Create an iterator over the result of the above dataset | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| @@ -32,7 +32,7 @@ int main() { | |||
| // Create a Cifar10 Dataset | |||
| std::string folder_path = "./testCifar10Data/"; | |||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); | |||
| // Create an iterator over the result of the above dataset | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| @@ -25,7 +25,7 @@ | |||
| using MSTensor = mindspore::tensor::MSTensor; | |||
| using DETensor = mindspore::tensor::DETensor; | |||
| using LiteTensor = mindspore::lite::tensor::LiteTensor; | |||
| using LiteTensor = mindspore::lite::Tensor; | |||
| using Tensor = mindspore::dataset::Tensor; | |||
| using DataType = mindspore::dataset::DataType; | |||
| using TensorShape = mindspore::dataset::TensorShape; | |||
| @@ -56,11 +56,6 @@ TEST_F(MindDataTestTensorDE, MSTensorShape) { | |||
| auto ms_tensor = std::shared_ptr<MSTensor>(new DETensor(t)); | |||
| ASSERT_EQ(ms_tensor->DimensionSize(0) == 2, true); | |||
| ASSERT_EQ(ms_tensor->DimensionSize(1) == 3, true); | |||
| ms_tensor->set_shape(std::vector<int>{3, 2}); | |||
| ASSERT_EQ(ms_tensor->DimensionSize(0) == 3, true); | |||
| ASSERT_EQ(ms_tensor->DimensionSize(1) == 2, true); | |||
| ms_tensor->set_shape(std::vector<int>{6}); | |||
| ASSERT_EQ(ms_tensor->DimensionSize(0) == 6, true); | |||
| } | |||
| TEST_F(MindDataTestTensorDE, MSTensorSize) { | |||
| @@ -74,9 +69,6 @@ TEST_F(MindDataTestTensorDE, MSTensorDataType) { | |||
| std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 3}), DataType(DataType::DE_FLOAT32)); | |||
| auto ms_tensor = std::shared_ptr<MSTensor>(new DETensor(t)); | |||
| ASSERT_EQ(ms_tensor->data_type() == mindspore::TypeId::kNumberTypeFloat32, true); | |||
| ms_tensor->set_data_type(mindspore::TypeId::kNumberTypeInt32); | |||
| ASSERT_EQ(ms_tensor->data_type() == mindspore::TypeId::kNumberTypeInt32, true); | |||
| ASSERT_EQ(std::dynamic_pointer_cast<DETensor>(ms_tensor)->tensor()->type() == DataType::DE_INT32, true); | |||
| } | |||
| TEST_F(MindDataTestTensorDE, MSTensorMutableData) { | |||
| @@ -89,19 +81,8 @@ TEST_F(MindDataTestTensorDE, MSTensorMutableData) { | |||
| ASSERT_EQ(x == tensor_vec, true); | |||
| } | |||
| TEST_F(MindDataTestTensorDE, MSTensorHash) { | |||
| std::vector<float> x = {2.5, 2.5, 2.5, 2.5}; | |||
| std::shared_ptr<Tensor> t; | |||
| Tensor::CreateFromVector(x, TensorShape({2, 2}), &t); | |||
| auto ms_tensor = std::shared_ptr<MSTensor>(new DETensor(t)); | |||
| ASSERT_EQ(ms_tensor->hash() == 11093771382437, true); | |||
| } | |||
| TEST_F(MindDataTestTensorDE, MSTensorCreateFromMemory) { | |||
| std::vector<float> x = {2.5, 2.5, 2.5, 2.5}; | |||
| auto mem_tensor = DETensor::CreateFromMemory(mindspore::TypeId::kNumberTypeFloat32, {2, 2}, &x[0]); | |||
| std::shared_ptr<Tensor> t; | |||
| Tensor::CreateFromVector(x, TensorShape({2, 2}), &t); | |||
| auto ms_tensor = std::shared_ptr<MSTensor>(new DETensor(t)); | |||
| ASSERT_EQ(ms_tensor->hash() == mem_tensor->hash(), true); | |||
| ASSERT_EQ(mem_tensor->data_type() == mindspore::TypeId::kNumberTypeFloat32, true); | |||
| } | |||