Browse Source

!4630 Adding wrapper around CreateFromMemory

Merge pull request !4630 from EricZ/md_tensor_from_mem
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
cc23f1d819
3 changed files with 31 additions and 0 deletions
  1. +13
    -0
      mindspore/ccsrc/minddata/dataset/api/de_tensor.cc
  2. +8
    -0
      mindspore/ccsrc/minddata/dataset/include/de_tensor.h
  3. +10
    -0
      mindspore/lite/test/ut/src/dataset/de_tensor_test.cc

+ 13
- 0
mindspore/ccsrc/minddata/dataset/api/de_tensor.cc View File

@@ -95,6 +95,19 @@ MSTensor *DETensor::CreateTensor(const std::string &path) {
return new DETensor(std::move(t)); return new DETensor(std::move(t));
} }


MSTensor *DETensor::CreateFromMemory(TypeId data_type, const std::vector<int> &shape, void *data) {
std::shared_ptr<dataset::Tensor> t;
// prepare shape info
std::vector<dataset::dsize_t> t_shape;

std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape),
[](int s) -> dataset::dsize_t { return static_cast<dataset::dsize_t>(s); });

(void)dataset::Tensor::CreateFromMemory(dataset::TensorShape(t_shape), MSTypeToDEType(data_type),
static_cast<uchar *>(data), &t);
return new DETensor(std::move(t));
}

DETensor::DETensor(TypeId data_type, const std::vector<int> &shape) { DETensor::DETensor(TypeId data_type, const std::vector<int> &shape) {
std::vector<dataset::dsize_t> t_shape; std::vector<dataset::dsize_t> t_shape;
t_shape.reserve(shape.size()); t_shape.reserve(shape.size());


+ 8
- 0
mindspore/ccsrc/minddata/dataset/include/de_tensor.h View File

@@ -37,6 +37,14 @@ class DETensor : public MSTensor {
/// \return - MSTensor pointer. /// \return - MSTensor pointer.
static MSTensor *CreateTensor(const std::string &path); static MSTensor *CreateTensor(const std::string &path);


/// \brief Create a MSTensor pointer.
/// \note This function returns null_ptr if tensor creation fails.
/// \param[data_type] DataTypeId of tensor to be created.
/// \param[shape] Shape of tensor to be created.
/// \param[data] Data pointer.
/// \return - MSTensor pointer.
static MSTensor *CreateFromMemory(TypeId data_type, const std::vector<int> &shape, void *data);

DETensor(TypeId data_type, const std::vector<int> &shape); DETensor(TypeId data_type, const std::vector<int> &shape);


explicit DETensor(std::shared_ptr<dataset::Tensor> tensor_ptr); explicit DETensor(std::shared_ptr<dataset::Tensor> tensor_ptr);


+ 10
- 0
mindspore/lite/test/ut/src/dataset/de_tensor_test.cc View File

@@ -96,3 +96,13 @@ TEST_F(MindDataTestTensorDE, MSTensorHash) {
auto ms_tensor = std::shared_ptr<MSTensor>(new DETensor(t)); auto ms_tensor = std::shared_ptr<MSTensor>(new DETensor(t));
ASSERT_EQ(ms_tensor->hash() == 11093771382437, true); ASSERT_EQ(ms_tensor->hash() == 11093771382437, true);
} }

TEST_F(MindDataTestTensorDE, MSTensorCreateFromMemory) {
std::vector<float> x = {2.5, 2.5, 2.5, 2.5};
auto mem_tensor = DETensor::CreateFromMemory(mindspore::TypeId::kNumberTypeFloat32, {2, 2}, &x[0]);
std::shared_ptr<Tensor> t;
Tensor::CreateFromVector(x, TensorShape({2, 2}), &t);
auto ms_tensor = std::shared_ptr<MSTensor>(new DETensor(t));
ASSERT_EQ(ms_tensor->hash() == mem_tensor->hash(), true);
}


Loading…
Cancel
Save