Browse Source

!4072 Fix a bug in Tensor::equals()

Merge pull request !4072 from hewei/fixbug_tensor_equal
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
3b7df4e512
2 changed files with 30 additions and 8 deletions
  1. +9
    -8
      mindspore/core/ir/tensor.cc
  2. +21
    -0
      tests/ut/cpp/ir/meta_tensor_test.cc

+ 9
- 8
mindspore/core/ir/tensor.cc View File

@@ -176,13 +176,8 @@ class TensorDataImpl : public TensorData {
ssize_t ndim() const override { return static_cast<ssize_t>(ndim_); }

void *data() override {
static T empty_data = static_cast<T>(0);
if (data_size_ == 0) {
// Prevent null pointer for empty shape.
return &empty_data;
}
// Lazy allocation.
if (data_ == nullptr) {
// Lazy allocation.
data_ = std::make_unique<T[]>(data_size_);
}
return data_.get();
@@ -193,8 +188,14 @@ class TensorDataImpl : public TensorData {
if (ptr == nullptr) {
return false;
}
return (ptr == this) || ((ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) &&
(std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get())));
if (ptr == this) {
return true;
}
if (data_ == nullptr || ptr->data_ == nullptr) {
return false;
}
return (ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) &&
std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get());
}

std::string ToString(const TypeId type, const std::vector<int> &shape) const override {


+ 21
- 0
tests/ut/cpp/ir/meta_tensor_test.cc View File

@@ -225,6 +225,27 @@ TEST_F(TestTensor, EqualTest) {
ASSERT_EQ(TypeId::kNumberTypeFloat64, tensor_float64->data_type_c());
}

TEST_F(TestTensor, ValueEqualTest) {
py::tuple tuple = py::make_tuple(1, 2, 3, 4, 5, 6);
TensorPtr t1 = TensorPy::MakeTensor(py::array(tuple), kInt32);
TensorPtr t2 = TensorPy::MakeTensor(py::array(tuple), kInt32);
ASSERT_TRUE(t1->ValueEqual(*t1));
ASSERT_TRUE(t1->ValueEqual(*t2));

std::vector<int> shape = {6};
TensorPtr t3 = std::make_shared<Tensor>(kInt32->type_id(), shape);
TensorPtr t4 = std::make_shared<Tensor>(kInt32->type_id(), shape);
ASSERT_TRUE(t3->ValueEqual(*t3));
ASSERT_FALSE(t3->ValueEqual(*t4));
ASSERT_FALSE(t3->ValueEqual(*t1));
ASSERT_FALSE(t1->ValueEqual(*t3));

memcpy_s(t3->data_c(), t3->data().nbytes(), t1->data_c(), t1->data().nbytes());
ASSERT_TRUE(t1->ValueEqual(*t3));
ASSERT_FALSE(t3->ValueEqual(*t4));
ASSERT_FALSE(t4->ValueEqual(*t3));
}

TEST_F(TestTensor, PyArrayTest) {
py::array_t<float, py::array::c_style> input({2, 3});
auto array = input.mutable_unchecked();


Loading…
Cancel
Save