Browse Source

change tensor equal bug

tags/v0.3.0-alpha
kpy 5 years ago
parent
commit
e64c755ad6
5 changed files with 19 additions and 18 deletions
  1. +0
    -9
      mindspore/ccsrc/ir/meta_tensor.cc
  2. +0
    -3
      mindspore/ccsrc/ir/meta_tensor.h
  3. +11
    -0
      mindspore/common/tensor.py
  4. +2
    -0
      mindspore/ops/functional.py
  5. +6
    -6
      tests/vm_impl/math_ops_vm_impl.py

+ 0
- 9
mindspore/ccsrc/ir/meta_tensor.cc View File

@@ -185,14 +185,6 @@ bool Tensor::operator==(const Tensor &tensor) const {
return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
}

bool Tensor::ValueEqualPy(const py::object &other) const {
if (!py::isinstance<Tensor>(other)) {
MS_LOG(WARNING) << "compare other not a tensor";
return false;
}
return ValueEqual(py::cast<Tensor>(other));
}

bool Tensor::ValueEqual(const Tensor &other) const {
auto equal = [&other, this]() -> bool {
auto np = py::module::import("numpy");
@@ -542,7 +534,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
)mydelimiter")
.def("__str__", &Tensor::ToString)
.def("__repr__", &Tensor::ToStringRepr)
.def("__eq__", &Tensor::ValueEqualPy)
.def(py::pickle(
[](const Tensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */


+ 0
- 3
mindspore/ccsrc/ir/meta_tensor.h View File

@@ -329,9 +329,6 @@ class Tensor : public MetaTensor {
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool ValueEqual(const Tensor &other) const;

// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool ValueEqualPy(const py::object &other) const;

bool operator==(const Value &other) const override {
if (other.isa<Tensor>()) {
auto other_ = static_cast<const Tensor &>(other);


+ 11
- 0
mindspore/common/tensor.py View File

@@ -74,6 +74,17 @@ class Tensor(Tensor_):
out = tensor_operator_registry.get('__add__')(self, other)
return out

def __eq__(self, other):
if not isinstance(other, Tensor):
return False
x = self.asnumpy()
y = other.asnumpy()
out = np.equal(x, y)
return Tensor(np.array(out))

def __hash__(self):
return hash(id(self))

def __mul__(self, other):
check_type('tensor input_data', other, (Tensor, float, int))
out = tensor_operator_registry.get('__mul__')(self, other)


+ 2
- 0
mindspore/ops/functional.py View File

@@ -144,3 +144,5 @@ stop_gradient = Primitive("stop_gradient")
tensor_operator_registry.register('__add__', tensor_add)
tensor_operator_registry.register('__mul__', tensor_mul)
tensor_operator_registry.register('__div__', tensor_div)
#ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)

+ 6
- 6
tests/vm_impl/math_ops_vm_impl.py View File

@@ -172,7 +172,7 @@ def vm_impl_equal(self):
x = x.asnumpy()
y = y.asnumpy()
out = vm.equal(x, y)
return Tensor(out)
return Tensor(np.array(out))
return vm_impl


@@ -183,7 +183,7 @@ def vm_impl_not_equal(self):
x = x.asnumpy()
y = y.asnumpy()
out = vm.not_equal(x, y)
return Tensor(out)
return Tensor(np.array(out))
return vm_impl


@@ -194,7 +194,7 @@ def vm_impl_greater(self):
x = x.asnumpy()
y = y.asnumpy()
out = vm.greater(x, y)
return Tensor(out)
return Tensor(np.array(out))
return vm_impl

@vm_impl_getters.register(P.Maximum)
@@ -219,17 +219,17 @@ def vm_impl_minimum(self):
return vm_impl

@vm_impl_getters.register(P.Less)
def vm_impl_greater(self):
def vm_impl_less(self):
"""Generate vm_impl function for Less"""
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
out = vm.less(x, y)
return Tensor(out)
return Tensor(np.array(out))
return vm_impl

@vm_impl_getters.register(P.ScalarCast)
def vm_impl_greater(self):
def vm_impl_scalar_cast(self):
"""Generate vm_impl function for ScalarCast"""
def vm_impl(x, t):
np_type = dtype_to_nptype(t)


Loading…
Cancel
Save