Merge pull request !4688 from amongo/FixBoolParsetags/v0.7.0-beta
| @@ -737,8 +737,7 @@ AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object | |||
| return block->func_graph()->NewCNode({op_node, left_node, right_node}); | |||
| } | |||
| AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, | |||
| const py::object &op) { | |||
| AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) { | |||
| // if there is only one bool op now | |||
| if (value_list.size() == 1) { | |||
| AnfNodePtr first_node = ParseExprNode(block, value_list[0]); | |||
| @@ -749,11 +748,41 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p | |||
| for (size_t i = 1; i < value_list.size(); i++) { | |||
| rest.append(value_list[i]); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(block); | |||
| TraceManager::DebugTrace(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info())); | |||
| FunctionBlockPtr true_block = MakeFunctionBlock(*this); | |||
| TraceManager::EndTrace(); | |||
| TraceManager::DebugTrace(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info())); | |||
| FunctionBlockPtr false_block = MakeFunctionBlock(*this); | |||
| TraceManager::EndTrace(); | |||
| MakeConditionBlocks(block, true_block, false_block); | |||
| FunctionBlockPtr b1, b2; | |||
| // if it is and, we need to process the rest nodes; | |||
| // if it is or, we continue to next | |||
| if (mode == AST_SUB_TYPE_AND) { | |||
| b1 = true_block; | |||
| b2 = false_block; | |||
| } else if (mode == AST_SUB_TYPE_OR) { | |||
| b2 = true_block; | |||
| b1 = false_block; | |||
| } else { | |||
| MS_LOG(ERROR) << "Not supported mode: " << mode; | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr test_node = ParseExprNode(block, first); | |||
| AnfNodePtr rest_node = ProcessBoolOpValueList(b1, rest, mode); | |||
| b1->func_graph()->set_output(rest_node); | |||
| b2->func_graph()->set_output(test_node); | |||
| auto cond_node = block->ForceToBoolNode(test_node); | |||
| auto switch_app = | |||
| block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()), | |||
| NewValueNode(false_block->func_graph())}); | |||
| AnfNodePtr first_node = ParseExprNode(block, first); | |||
| AnfNodePtr rest_node = ProcessBoolOpValueList(block, rest, op); | |||
| auto op_node = block->MakeResolveAstOp(op); | |||
| return block->func_graph()->NewCNode({op_node, first_node, rest_node}); | |||
| std::vector<AnfNodePtr> call_graph_nodes{switch_app}; | |||
| auto switch_app_call = block->func_graph()->NewCNode(call_graph_nodes); | |||
| return switch_app_call; | |||
| } | |||
| } | |||
| @@ -761,8 +790,13 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p | |||
| AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) { | |||
| MS_LOG(DEBUG) << "Process ast BoolOp"; | |||
| py::object op_node = python_adapter::GetPyObjAttr(node, "op"); | |||
| AstSubType op_type = ast_->GetOpType(op_node); | |||
| if (op_type == AST_SUB_TYPE_UNKNOWN) { | |||
| MS_LOG(WARNING) << "ProcessBoolOp, got unkown op type"; | |||
| return nullptr; | |||
| } | |||
| py::list op_values = python_adapter::GetPyObjAttr(node, "values"); | |||
| return ProcessBoolOpValueList(block, op_values, op_node); | |||
| return ProcessBoolOpValueList(block, op_values, op_type); | |||
| } | |||
| // Process a function def | |||
| @@ -206,7 +206,7 @@ class Parser { | |||
| void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); | |||
| // process a bool operation value list | |||
| AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, const py::object &op); | |||
| AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode); | |||
| CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node, | |||
| const AnfNodePtr &op_iter); | |||
| @@ -45,7 +45,9 @@ def _logical_not_tensor(x): | |||
| Returns: | |||
| Tensor, Return logical not operation result of x. | |||
| """ | |||
| return F.logical_not(x) | |||
| if F.isconstant(x): | |||
| return F.bool_not(x.__bool__()) | |||
| return F.logical_not(x.__bool__()) | |||
| @logical_not.register("Tuple") | |||
| @@ -61,8 +61,7 @@ class ControlSimpleIfWithAssign(nn.Cell): | |||
| class ControlIfinIf(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| """pass""" | |||
| def construct(self, x, y): | |||
| if x > y: | |||
| @@ -151,6 +150,40 @@ class ControlMixedWhileIf(nn.Cell): | |||
| return out | |||
| class AndOperation(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.reduce_sum = op.ReduceSum() | |||
| def construct(self, x, y): | |||
| x_sum = self.reduce_sum(x) | |||
| y_sum = self.reduce_sum(y) | |||
| out = x_sum and y_sum | |||
| return out | |||
| class OrOperation(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.reduce_sum = op.ReduceSum() | |||
| def construct(self, x, y): | |||
| x_sum = self.reduce_sum(x) | |||
| y_sum = self.reduce_sum(y) | |||
| out = x_sum or y_sum | |||
| return out | |||
| class NotOperation(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.reduce_sum = op.ReduceSum() | |||
| def construct(self, x): | |||
| x_sum = self.reduce_sum(x) | |||
| return not x_sum | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @@ -248,3 +281,27 @@ def test_mixed_while_if(): | |||
| output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4) | |||
| expect = np.array(3318).astype(np.int32) | |||
| assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_and_or_operation(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| x = np.array([0, 1]).astype(np.float32) | |||
| y = np.array([0, 0]).astype(np.float32) | |||
| net = AndOperation() | |||
| output = net(Tensor(x), Tensor(y)) | |||
| expect = np.sum(x) and np.sum(y) | |||
| assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) | |||
| net = OrOperation() | |||
| output = net(Tensor(x), Tensor(y)) | |||
| expect = np.sum(x) or np.sum(y) | |||
| assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) | |||
| net = NotOperation() | |||
| output = net(Tensor(x)) | |||
| expect = not np.sum(x) | |||
| assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) | |||
| @@ -103,15 +103,15 @@ class LogicalTensorOpsNet(nn.Cell): | |||
| self.const_true = Tensor(True, dtype=mstype.bool_) | |||
| def construct(self, x, y): | |||
| ret = x and y and (y or self.const_true) and (not self.const_true) | |||
| ret = x and y and (y or self.const_true) and (not y) | |||
| return ret | |||
| test_case_ops = [ | |||
| ('CompareOpsNet', { | |||
| 'block': ComparisonOpsNet(), | |||
| 'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32), | |||
| Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), | |||
| 'desc_inputs': [Tensor(1.0, dtype=mstype.float32), | |||
| Tensor(1.0, dtype=mstype.float32)]}), | |||
| ('MathOpsNet', { | |||
| 'block': MathOpsNet(), | |||
| 'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32), | |||
| @@ -126,8 +126,8 @@ test_case_ops = [ | |||
| Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), | |||
| ('LogicalTensorOps', { | |||
| 'block': LogicalTensorOpsNet(), | |||
| 'desc_inputs': [Tensor(np.ones([6, 9, 10]).astype(np.bool_), dtype=mstype.bool_), | |||
| Tensor(np.zeros([6, 9, 10]).astype(np.bool_), dtype=mstype.bool_)]}), | |||
| 'desc_inputs': [Tensor(True, dtype=mstype.bool_), | |||
| Tensor(False, dtype=mstype.bool_)]}), | |||
| ] | |||
| test_case_lists = [test_case_ops] | |||
| @@ -41,10 +41,12 @@ def vm_impl_tensor_add(self): | |||
| # pylint: disable=used-before-assignment | |||
| @vm_impl_getters.register(P.LogicalNot) | |||
| def vm_impl_logical_not(self): | |||
| x = x.asnumpy() | |||
| out = vm.logical_not(x) | |||
| return Tensor(out) | |||
| def vm_impl(x): | |||
| x = x.asnumpy() | |||
| out = vm.logical_not(x) | |||
| return Tensor(out) | |||
| return vm_impl | |||
| @vm_impl_getters.register(P.MatMul) | |||
| def vm_impl_mat_mul(self): | |||