Merge pull request !4630 from EricZ/md_tensor_from_memtags/v0.7.0-beta
| @@ -95,6 +95,19 @@ MSTensor *DETensor::CreateTensor(const std::string &path) { | |||||
| return new DETensor(std::move(t)); | return new DETensor(std::move(t)); | ||||
| } | } | ||||
| MSTensor *DETensor::CreateFromMemory(TypeId data_type, const std::vector<int> &shape, void *data) { | |||||
| std::shared_ptr<dataset::Tensor> t; | |||||
| // prepare shape info | |||||
| std::vector<dataset::dsize_t> t_shape; | |||||
| std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape), | |||||
| [](int s) -> dataset::dsize_t { return static_cast<dataset::dsize_t>(s); }); | |||||
| (void)dataset::Tensor::CreateFromMemory(dataset::TensorShape(t_shape), MSTypeToDEType(data_type), | |||||
| static_cast<uchar *>(data), &t); | |||||
| return new DETensor(std::move(t)); | |||||
| } | |||||
| DETensor::DETensor(TypeId data_type, const std::vector<int> &shape) { | DETensor::DETensor(TypeId data_type, const std::vector<int> &shape) { | ||||
| std::vector<dataset::dsize_t> t_shape; | std::vector<dataset::dsize_t> t_shape; | ||||
| t_shape.reserve(shape.size()); | t_shape.reserve(shape.size()); | ||||
| @@ -37,6 +37,14 @@ class DETensor : public MSTensor { | |||||
| /// \return - MSTensor pointer. | /// \return - MSTensor pointer. | ||||
| static MSTensor *CreateTensor(const std::string &path); | static MSTensor *CreateTensor(const std::string &path); | ||||
| /// \brief Create a MSTensor pointer. | |||||
| /// \note This function returns null_ptr if tensor creation fails. | |||||
| /// \param[data_type] DataTypeId of tensor to be created. | |||||
| /// \param[shape] Shape of tensor to be created. | |||||
| /// \param[data] Data pointer. | |||||
| /// \return - MSTensor pointer. | |||||
| static MSTensor *CreateFromMemory(TypeId data_type, const std::vector<int> &shape, void *data); | |||||
| DETensor(TypeId data_type, const std::vector<int> &shape); | DETensor(TypeId data_type, const std::vector<int> &shape); | ||||
| explicit DETensor(std::shared_ptr<dataset::Tensor> tensor_ptr); | explicit DETensor(std::shared_ptr<dataset::Tensor> tensor_ptr); | ||||
| @@ -96,3 +96,13 @@ TEST_F(MindDataTestTensorDE, MSTensorHash) { | |||||
| auto ms_tensor = std::shared_ptr<MSTensor>(new DETensor(t)); | auto ms_tensor = std::shared_ptr<MSTensor>(new DETensor(t)); | ||||
| ASSERT_EQ(ms_tensor->hash() == 11093771382437, true); | ASSERT_EQ(ms_tensor->hash() == 11093771382437, true); | ||||
| } | } | ||||
| TEST_F(MindDataTestTensorDE, MSTensorCreateFromMemory) { | |||||
| std::vector<float> x = {2.5, 2.5, 2.5, 2.5}; | |||||
| auto mem_tensor = DETensor::CreateFromMemory(mindspore::TypeId::kNumberTypeFloat32, {2, 2}, &x[0]); | |||||
| std::shared_ptr<Tensor> t; | |||||
| Tensor::CreateFromVector(x, TensorShape({2, 2}), &t); | |||||
| auto ms_tensor = std::shared_ptr<MSTensor>(new DETensor(t)); | |||||
| ASSERT_EQ(ms_tensor->hash() == mem_tensor->hash(), true); | |||||
| } | |||||