| @@ -93,7 +93,7 @@ convert_object_map = { | |||||
| T.xor: NO_IMPLEMENT, | T.xor: NO_IMPLEMENT, | ||||
| T.pos: multitype_ops.uadd, | T.pos: multitype_ops.uadd, | ||||
| T.neg: multitype_ops.negative, | T.neg: multitype_ops.negative, | ||||
| T.invert: NO_IMPLEMENT, | |||||
| T.invert: F.logical_not, | |||||
| T.not_: multitype_ops.logical_not, | T.not_: multitype_ops.logical_not, | ||||
| T.eq: multitype_ops.equal, | T.eq: multitype_ops.equal, | ||||
| T.ne: multitype_ops.not_equal, | T.ne: multitype_ops.not_equal, | ||||
| @@ -153,6 +153,10 @@ class Tensor(Tensor_): | |||||
| out = tensor_operator_registry.get('__neg__')(self) | out = tensor_operator_registry.get('__neg__')(self) | ||||
| return out | return out | ||||
| def __invert__(self): | |||||
| out = tensor_operator_registry.get('__logical_not__')(self) | |||||
| return out | |||||
| def __bool__(self): | def __bool__(self): | ||||
| data = self.asnumpy() | data = self.asnumpy() | ||||
| if data.shape == (): | if data.shape == (): | ||||
| @@ -223,6 +223,7 @@ tensor_operator_registry.register('__lt__', tensor_lt) | |||||
| tensor_operator_registry.register('__le__', tensor_le) | tensor_operator_registry.register('__le__', tensor_le) | ||||
| tensor_operator_registry.register('__gt__', tensor_gt) | tensor_operator_registry.register('__gt__', tensor_gt) | ||||
| tensor_operator_registry.register('__ge__', tensor_ge) | tensor_operator_registry.register('__ge__', tensor_ge) | ||||
| tensor_operator_registry.register('__logical_not__', logical_not) | |||||
| tensor_operator_registry.register('shape', shape) | tensor_operator_registry.register('shape', shape) | ||||
| tensor_operator_registry.register('squeeze', squeeze) | tensor_operator_registry.register('squeeze', squeeze) | ||||
| # support GE backend for no compare operators | # support GE backend for no compare operators | ||||
| @@ -3105,9 +3105,15 @@ class LogicalNot(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_dtype_valid("x", x_dtype, [mstype.bool_], self.name) | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, [mstype.bool_], self.name + " or '~' operator") | |||||
| return mstype.tensor_type(mstype.bool_) | return mstype.tensor_type(mstype.bool_) | ||||
| def infer_value(self, x): | |||||
| if x is not None: | |||||
| x = x.asnumpy() | |||||
| return Tensor(np.logical_not(x)) | |||||
| return None | |||||
| class LogicalAnd(_LogicBinaryOp): | class LogicalAnd(_LogicBinaryOp): | ||||
| """ | """ | ||||
| @@ -3146,6 +3152,14 @@ class LogicalAnd(_LogicBinaryOp): | |||||
| def infer_dtype(self, x_dtype, y_dtype): | def infer_dtype(self, x_dtype, y_dtype): | ||||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name) | return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name) | ||||
| def infer_value(self, x, y): | |||||
| if x is not None and y is not None: | |||||
| x = x.asnumpy() | |||||
| y = y.asnumpy() | |||||
| out = np.array(np.logical_and(x, y)) | |||||
| return Tensor(out) | |||||
| return None | |||||
| class LogicalOr(_LogicBinaryOp): | class LogicalOr(_LogicBinaryOp): | ||||
| """ | """ | ||||
| @@ -3184,6 +3198,14 @@ class LogicalOr(_LogicBinaryOp): | |||||
| def infer_dtype(self, x_dtype, y_dtype): | def infer_dtype(self, x_dtype, y_dtype): | ||||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name) | return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name) | ||||
| def infer_value(self, x, y): | |||||
| if x is not None and y is not None: | |||||
| x = x.asnumpy() | |||||
| y = y.asnumpy() | |||||
| out = np.array(np.logical_or(x, y)) | |||||
| return Tensor(out) | |||||
| return None | |||||
| class IsNan(PrimitiveWithInfer): | class IsNan(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -0,0 +1,63 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test '~' """ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| class InvertNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(InvertNet, self).__init__() | |||||
| self.t = Tensor(np.array([True, False, True])) | |||||
| def construct(self, x): | |||||
| invert_t = ~self.t | |||||
| invert_x = ~x | |||||
| ret = (invert_t, invert_x) | |||||
| return ret | |||||
| def test_invert_bool_tensor(): | |||||
| net = InvertNet() | |||||
| input_x = Tensor(np.array([False, True, False])) | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| ret = net(input_x) | |||||
| assert (ret[0].asnumpy() == np.array([False, True, False])).all() | |||||
| assert (ret[1].asnumpy() == np.array([True, False, True])).all() | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| net(input_x) | |||||
| def test_invert_int_tensor(): | |||||
| net = InvertNet() | |||||
| input_x = Tensor(np.array([1, 2, 3], np.int32)) | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| with pytest.raises(TypeError) as err: | |||||
| net(input_x) | |||||
| assert "For 'LogicalNot or '~' operator', the type of `x` should be subclass of Tensor[Bool], " \ | |||||
| "but got Tensor[Int32]" in str(err.value) | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| with pytest.raises(TypeError) as err: | |||||
| net(input_x) | |||||
| assert "For 'LogicalNot or '~' operator', the type of `x` should be subclass of Tensor[Bool], " \ | |||||
| "but got Tensor[Int32]" in str(err.value) | |||||