| @@ -185,14 +185,6 @@ bool Tensor::operator==(const Tensor &tensor) const { | |||||
| return (MetaTensor::operator==(tensor) && data_ == tensor.data_); | return (MetaTensor::operator==(tensor) && data_ == tensor.data_); | ||||
| } | } | ||||
| bool Tensor::ValueEqualPy(const py::object &other) const { | |||||
| if (!py::isinstance<Tensor>(other)) { | |||||
| MS_LOG(WARNING) << "compare other not a tensor"; | |||||
| return false; | |||||
| } | |||||
| return ValueEqual(py::cast<Tensor>(other)); | |||||
| } | |||||
| bool Tensor::ValueEqual(const Tensor &other) const { | bool Tensor::ValueEqual(const Tensor &other) const { | ||||
| auto equal = [&other, this]() -> bool { | auto equal = [&other, this]() -> bool { | ||||
| auto np = py::module::import("numpy"); | auto np = py::module::import("numpy"); | ||||
| @@ -542,7 +534,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||||
| )mydelimiter") | )mydelimiter") | ||||
| .def("__str__", &Tensor::ToString) | .def("__str__", &Tensor::ToString) | ||||
| .def("__repr__", &Tensor::ToStringRepr) | .def("__repr__", &Tensor::ToStringRepr) | ||||
| .def("__eq__", &Tensor::ValueEqualPy) | |||||
| .def(py::pickle( | .def(py::pickle( | ||||
| [](const Tensor &t) { // __getstate__ | [](const Tensor &t) { // __getstate__ | ||||
| /* Return a tuple that fully encodes the state of the object */ | /* Return a tuple that fully encodes the state of the object */ | ||||
| @@ -329,9 +329,6 @@ class Tensor : public MetaTensor { | |||||
| // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. | // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. | ||||
| bool ValueEqual(const Tensor &other) const; | bool ValueEqual(const Tensor &other) const; | ||||
| // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. | |||||
| bool ValueEqualPy(const py::object &other) const; | |||||
| bool operator==(const Value &other) const override { | bool operator==(const Value &other) const override { | ||||
| if (other.isa<Tensor>()) { | if (other.isa<Tensor>()) { | ||||
| auto other_ = static_cast<const Tensor &>(other); | auto other_ = static_cast<const Tensor &>(other); | ||||
| @@ -74,6 +74,17 @@ class Tensor(Tensor_): | |||||
| out = tensor_operator_registry.get('__add__')(self, other) | out = tensor_operator_registry.get('__add__')(self, other) | ||||
| return out | return out | ||||
| def __eq__(self, other): | |||||
| if not isinstance(other, Tensor): | |||||
| return False | |||||
| x = self.asnumpy() | |||||
| y = other.asnumpy() | |||||
| out = np.equal(x, y) | |||||
| return Tensor(np.array(out)) | |||||
| def __hash__(self): | |||||
| return hash(id(self)) | |||||
| def __mul__(self, other): | def __mul__(self, other): | ||||
| check_type('tensor input_data', other, (Tensor, float, int)) | check_type('tensor input_data', other, (Tensor, float, int)) | ||||
| out = tensor_operator_registry.get('__mul__')(self, other) | out = tensor_operator_registry.get('__mul__')(self, other) | ||||
| @@ -144,3 +144,5 @@ stop_gradient = Primitive("stop_gradient") | |||||
| tensor_operator_registry.register('__add__', tensor_add) | tensor_operator_registry.register('__add__', tensor_add) | ||||
| tensor_operator_registry.register('__mul__', tensor_mul) | tensor_operator_registry.register('__mul__', tensor_mul) | ||||
| tensor_operator_registry.register('__div__', tensor_div) | tensor_operator_registry.register('__div__', tensor_div) | ||||
| #ms cannot support Tensor(True) compare | |||||
| tensor_operator_registry.register('__eq__', equal) | |||||
| @@ -172,7 +172,7 @@ def vm_impl_equal(self): | |||||
| x = x.asnumpy() | x = x.asnumpy() | ||||
| y = y.asnumpy() | y = y.asnumpy() | ||||
| out = vm.equal(x, y) | out = vm.equal(x, y) | ||||
| return Tensor(out) | |||||
| return Tensor(np.array(out)) | |||||
| return vm_impl | return vm_impl | ||||
| @@ -183,7 +183,7 @@ def vm_impl_not_equal(self): | |||||
| x = x.asnumpy() | x = x.asnumpy() | ||||
| y = y.asnumpy() | y = y.asnumpy() | ||||
| out = vm.not_equal(x, y) | out = vm.not_equal(x, y) | ||||
| return Tensor(out) | |||||
| return Tensor(np.array(out)) | |||||
| return vm_impl | return vm_impl | ||||
| @@ -194,7 +194,7 @@ def vm_impl_greater(self): | |||||
| x = x.asnumpy() | x = x.asnumpy() | ||||
| y = y.asnumpy() | y = y.asnumpy() | ||||
| out = vm.greater(x, y) | out = vm.greater(x, y) | ||||
| return Tensor(out) | |||||
| return Tensor(np.array(out)) | |||||
| return vm_impl | return vm_impl | ||||
| @vm_impl_getters.register(P.Maximum) | @vm_impl_getters.register(P.Maximum) | ||||
| @@ -219,17 +219,17 @@ def vm_impl_minimum(self): | |||||
| return vm_impl | return vm_impl | ||||
| @vm_impl_getters.register(P.Less) | @vm_impl_getters.register(P.Less) | ||||
| def vm_impl_greater(self): | |||||
| def vm_impl_less(self): | |||||
| """Generate vm_impl function for Less""" | """Generate vm_impl function for Less""" | ||||
| def vm_impl(x, y): | def vm_impl(x, y): | ||||
| x = x.asnumpy() | x = x.asnumpy() | ||||
| y = y.asnumpy() | y = y.asnumpy() | ||||
| out = vm.less(x, y) | out = vm.less(x, y) | ||||
| return Tensor(out) | |||||
| return Tensor(np.array(out)) | |||||
| return vm_impl | return vm_impl | ||||
| @vm_impl_getters.register(P.ScalarCast) | @vm_impl_getters.register(P.ScalarCast) | ||||
| def vm_impl_greater(self): | |||||
| def vm_impl_scalar_cast(self): | |||||
| """Generate vm_impl function for ScalarCast""" | """Generate vm_impl function for ScalarCast""" | ||||
| def vm_impl(x, t): | def vm_impl(x, t): | ||||
| np_type = dtype_to_nptype(t) | np_type = dtype_to_nptype(t) | ||||