Merge pull request !4688 from amongo/FixBoolParsetags/v0.7.0-beta
| @@ -737,8 +737,7 @@ AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object | |||||
| return block->func_graph()->NewCNode({op_node, left_node, right_node}); | return block->func_graph()->NewCNode({op_node, left_node, right_node}); | ||||
| } | } | ||||
| AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, | |||||
| const py::object &op) { | |||||
| AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) { | |||||
| // if there is only one bool op now | // if there is only one bool op now | ||||
| if (value_list.size() == 1) { | if (value_list.size() == 1) { | ||||
| AnfNodePtr first_node = ParseExprNode(block, value_list[0]); | AnfNodePtr first_node = ParseExprNode(block, value_list[0]); | ||||
| @@ -749,11 +748,41 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p | |||||
| for (size_t i = 1; i < value_list.size(); i++) { | for (size_t i = 1; i < value_list.size(); i++) { | ||||
| rest.append(value_list[i]); | rest.append(value_list[i]); | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(block); | |||||
| TraceManager::DebugTrace(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info())); | |||||
| FunctionBlockPtr true_block = MakeFunctionBlock(*this); | |||||
| TraceManager::EndTrace(); | |||||
| TraceManager::DebugTrace(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info())); | |||||
| FunctionBlockPtr false_block = MakeFunctionBlock(*this); | |||||
| TraceManager::EndTrace(); | |||||
| MakeConditionBlocks(block, true_block, false_block); | |||||
| FunctionBlockPtr b1, b2; | |||||
| // if it is and, we need to process the rest nodes; | |||||
| // if it is or, we continue to next | |||||
| if (mode == AST_SUB_TYPE_AND) { | |||||
| b1 = true_block; | |||||
| b2 = false_block; | |||||
| } else if (mode == AST_SUB_TYPE_OR) { | |||||
| b2 = true_block; | |||||
| b1 = false_block; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Not supported mode: " << mode; | |||||
| return nullptr; | |||||
| } | |||||
| AnfNodePtr test_node = ParseExprNode(block, first); | |||||
| AnfNodePtr rest_node = ProcessBoolOpValueList(b1, rest, mode); | |||||
| b1->func_graph()->set_output(rest_node); | |||||
| b2->func_graph()->set_output(test_node); | |||||
| auto cond_node = block->ForceToBoolNode(test_node); | |||||
| auto switch_app = | |||||
| block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()), | |||||
| NewValueNode(false_block->func_graph())}); | |||||
| AnfNodePtr first_node = ParseExprNode(block, first); | |||||
| AnfNodePtr rest_node = ProcessBoolOpValueList(block, rest, op); | |||||
| auto op_node = block->MakeResolveAstOp(op); | |||||
| return block->func_graph()->NewCNode({op_node, first_node, rest_node}); | |||||
| std::vector<AnfNodePtr> call_graph_nodes{switch_app}; | |||||
| auto switch_app_call = block->func_graph()->NewCNode(call_graph_nodes); | |||||
| return switch_app_call; | |||||
| } | } | ||||
| } | } | ||||
| @@ -761,8 +790,13 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p | |||||
| AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) { | AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) { | ||||
| MS_LOG(DEBUG) << "Process ast BoolOp"; | MS_LOG(DEBUG) << "Process ast BoolOp"; | ||||
| py::object op_node = python_adapter::GetPyObjAttr(node, "op"); | py::object op_node = python_adapter::GetPyObjAttr(node, "op"); | ||||
| AstSubType op_type = ast_->GetOpType(op_node); | |||||
| if (op_type == AST_SUB_TYPE_UNKNOWN) { | |||||
| MS_LOG(WARNING) << "ProcessBoolOp, got unkown op type"; | |||||
| return nullptr; | |||||
| } | |||||
| py::list op_values = python_adapter::GetPyObjAttr(node, "values"); | py::list op_values = python_adapter::GetPyObjAttr(node, "values"); | ||||
| return ProcessBoolOpValueList(block, op_values, op_node); | |||||
| return ProcessBoolOpValueList(block, op_values, op_type); | |||||
| } | } | ||||
| // Process a function def | // Process a function def | ||||
| @@ -206,7 +206,7 @@ class Parser { | |||||
| void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); | void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); | ||||
| // process a bool operation value list | // process a bool operation value list | ||||
| AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, const py::object &op); | |||||
| AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode); | |||||
| CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node, | CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node, | ||||
| const AnfNodePtr &op_iter); | const AnfNodePtr &op_iter); | ||||
| @@ -45,7 +45,9 @@ def _logical_not_tensor(x): | |||||
| Returns: | Returns: | ||||
| Tensor, Return logical not operation result of x. | Tensor, Return logical not operation result of x. | ||||
| """ | """ | ||||
| return F.logical_not(x) | |||||
| if F.isconstant(x): | |||||
| return F.bool_not(x.__bool__()) | |||||
| return F.logical_not(x.__bool__()) | |||||
| @logical_not.register("Tuple") | @logical_not.register("Tuple") | ||||
| @@ -61,8 +61,7 @@ class ControlSimpleIfWithAssign(nn.Cell): | |||||
| class ControlIfinIf(nn.Cell): | class ControlIfinIf(nn.Cell): | ||||
| def __init__(self): | |||||
| super().__init__() | |||||
| """pass""" | |||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| if x > y: | if x > y: | ||||
| @@ -151,6 +150,40 @@ class ControlMixedWhileIf(nn.Cell): | |||||
| return out | return out | ||||
| class AndOperation(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.reduce_sum = op.ReduceSum() | |||||
| def construct(self, x, y): | |||||
| x_sum = self.reduce_sum(x) | |||||
| y_sum = self.reduce_sum(y) | |||||
| out = x_sum and y_sum | |||||
| return out | |||||
| class OrOperation(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.reduce_sum = op.ReduceSum() | |||||
| def construct(self, x, y): | |||||
| x_sum = self.reduce_sum(x) | |||||
| y_sum = self.reduce_sum(y) | |||||
| out = x_sum or y_sum | |||||
| return out | |||||
| class NotOperation(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.reduce_sum = op.ReduceSum() | |||||
| def construct(self, x): | |||||
| x_sum = self.reduce_sum(x) | |||||
| return not x_sum | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||
| @pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||
| @@ -248,3 +281,27 @@ def test_mixed_while_if(): | |||||
| output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4) | output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4) | ||||
| expect = np.array(3318).astype(np.int32) | expect = np.array(3318).astype(np.int32) | ||||
| assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) | assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_and_or_operation(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| x = np.array([0, 1]).astype(np.float32) | |||||
| y = np.array([0, 0]).astype(np.float32) | |||||
| net = AndOperation() | |||||
| output = net(Tensor(x), Tensor(y)) | |||||
| expect = np.sum(x) and np.sum(y) | |||||
| assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) | |||||
| net = OrOperation() | |||||
| output = net(Tensor(x), Tensor(y)) | |||||
| expect = np.sum(x) or np.sum(y) | |||||
| assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) | |||||
| net = NotOperation() | |||||
| output = net(Tensor(x)) | |||||
| expect = not np.sum(x) | |||||
| assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) | |||||
| @@ -103,15 +103,15 @@ class LogicalTensorOpsNet(nn.Cell): | |||||
| self.const_true = Tensor(True, dtype=mstype.bool_) | self.const_true = Tensor(True, dtype=mstype.bool_) | ||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| ret = x and y and (y or self.const_true) and (not self.const_true) | |||||
| ret = x and y and (y or self.const_true) and (not y) | |||||
| return ret | return ret | ||||
| test_case_ops = [ | test_case_ops = [ | ||||
| ('CompareOpsNet', { | ('CompareOpsNet', { | ||||
| 'block': ComparisonOpsNet(), | 'block': ComparisonOpsNet(), | ||||
| 'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32), | |||||
| Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), | |||||
| 'desc_inputs': [Tensor(1.0, dtype=mstype.float32), | |||||
| Tensor(1.0, dtype=mstype.float32)]}), | |||||
| ('MathOpsNet', { | ('MathOpsNet', { | ||||
| 'block': MathOpsNet(), | 'block': MathOpsNet(), | ||||
| 'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32), | 'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32), | ||||
| @@ -126,8 +126,8 @@ test_case_ops = [ | |||||
| Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), | Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), | ||||
| ('LogicalTensorOps', { | ('LogicalTensorOps', { | ||||
| 'block': LogicalTensorOpsNet(), | 'block': LogicalTensorOpsNet(), | ||||
| 'desc_inputs': [Tensor(np.ones([6, 9, 10]).astype(np.bool_), dtype=mstype.bool_), | |||||
| Tensor(np.zeros([6, 9, 10]).astype(np.bool_), dtype=mstype.bool_)]}), | |||||
| 'desc_inputs': [Tensor(True, dtype=mstype.bool_), | |||||
| Tensor(False, dtype=mstype.bool_)]}), | |||||
| ] | ] | ||||
| test_case_lists = [test_case_ops] | test_case_lists = [test_case_ops] | ||||
| @@ -41,10 +41,12 @@ def vm_impl_tensor_add(self): | |||||
| # pylint: disable=used-before-assignment | # pylint: disable=used-before-assignment | ||||
| @vm_impl_getters.register(P.LogicalNot) | @vm_impl_getters.register(P.LogicalNot) | ||||
| def vm_impl_logical_not(self): | def vm_impl_logical_not(self): | ||||
| x = x.asnumpy() | |||||
| out = vm.logical_not(x) | |||||
| return Tensor(out) | |||||
| def vm_impl(x): | |||||
| x = x.asnumpy() | |||||
| out = vm.logical_not(x) | |||||
| return Tensor(out) | |||||
| return vm_impl | |||||
| @vm_impl_getters.register(P.MatMul) | @vm_impl_getters.register(P.MatMul) | ||||
| def vm_impl_mat_mul(self): | def vm_impl_mat_mul(self): | ||||