Browse Source

!2946 Added empty tensor support

Merge pull request !2946 from EricZ/emtpy-tensor
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
5355deca04
3 changed files with 27 additions and 0 deletions
  1. +9
    -0
      mindspore/ccsrc/dataset/core/tensor.cc
  2. +4
    -0
      mindspore/ccsrc/dataset/core/tensor.h
  3. +14
    -0
      tests/ut/cpp/dataset/tensor_test.cc

+ 9
- 0
mindspore/ccsrc/dataset/core/tensor.cc View File

@@ -513,6 +513,15 @@ const unsigned char *Tensor::GetBuffer() const {
return data_; return data_;
} }


// check for empty
bool Tensor::HasData() const {
if (data_ == nullptr) {
return true;
} else {
return false;
}
}

unsigned char *Tensor::GetMutableBuffer() { unsigned char *Tensor::GetMutableBuffer() {
if (!shape_.known() || type_ == DataType::DE_UNKNOWN) { if (!shape_.known() || type_ == DataType::DE_UNKNOWN) {
return nullptr; return nullptr;


+ 4
- 0
mindspore/ccsrc/dataset/core/tensor.h View File

@@ -277,6 +277,10 @@ class Tensor {
// @return // @return
const TensorShape &shape() const { return shape_; } const TensorShape &shape() const { return shape_; }


/// Check if tensor has data
/// \return bool - true if tensor is empty
bool HasData() const;

// Reshape the tensor. The given shape should have the same number of elements in the Tensor // Reshape the tensor. The given shape should have the same number of elements in the Tensor
// @param shape // @param shape
virtual Status Reshape(const TensorShape &shape); virtual Status Reshape(const TensorShape &shape);


+ 14
- 0
tests/ut/cpp/dataset/tensor_test.cc View File

@@ -432,3 +432,17 @@ TEST_F(MindDataTestTensorDE, TensorConcatenate) {
s = t1->Concatenate({5}, t2); s = t1->Concatenate({5}, t2);
EXPECT_FALSE(s.IsOk()); EXPECT_FALSE(s.IsOk());
} }

TEST_F(MindDataTestTensorDE, TensorEmpty) {
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 3}), DataType(DataType::DE_UINT64));
ASSERT_TRUE(t->HasData());
}

TEST_F(MindDataTestTensorDE, TensorEmptyInvalidate) {
std::vector<uint32_t> values1 = {1, 2, 3, 0, 0, 0};
std::shared_ptr<Tensor> t;
Tensor::CreateTensor(&t, values1);
t->Invalidate();
ASSERT_TRUE(t->HasData());
}


Loading…
Cancel
Save