Merge pull request !1082 from hewei/support_cont_breaktags/v0.3.0-alpha
| @@ -193,6 +193,14 @@ class TraceForAfter : public TraceInfo { | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceForAfter>(*shared_from_base<TraceForAfter>()); } | |||
| }; | |||
| class TraceLoopEnd : public TraceInfo { | |||
| public: | |||
| explicit TraceLoopEnd(const DebugInfoPtr &info) : TraceInfo(info, "loop_end", "↓↓") {} | |||
| MS_DECLARE_PARENT(TraceLoopEnd, TraceInfo); | |||
| ~TraceLoopEnd() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceLoopEnd>(*shared_from_base<TraceLoopEnd>()); } | |||
| }; | |||
| class TraceEquiv : public TraceInfo { | |||
| public: | |||
| explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {} | |||
| @@ -89,6 +89,9 @@ void Parser::BuildMethodMap() { | |||
| stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef; | |||
| stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign; | |||
| stmt_method_map_["Global"] = &Parser::ParseGlobal; | |||
| stmt_method_map_["Break"] = &Parser::ParseBreak; | |||
| stmt_method_map_["Continue"] = &Parser::ParseContinue; | |||
| stmt_method_map_["Pass"] = &Parser::ParsePass; | |||
| expr_method_map_["NoneType"] = &Parser::ParseNone; | |||
| expr_method_map_["BinOp"] = &Parser::ParseBinOp; | |||
| expr_method_map_["Name"] = &Parser::ParseName; | |||
| @@ -270,6 +273,8 @@ FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr fn_block, const py::ob | |||
| // insert appropriate depended items for the function block if it has a return node | |||
| if (fn_block->func_graph()->get_return() != nullptr) { | |||
| fn_block->InsertDependItemsBeforeReturn(); | |||
| // Skip statements after 'return' (or 'break', 'continue'). | |||
| break; | |||
| } | |||
| } | |||
| return fn_block; | |||
| @@ -966,13 +971,24 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj | |||
| body_block->Mature(); | |||
| header_block->ConditionalJump(condition_node, body_block, after_block); | |||
| // Parse loop body statements with loop context. | |||
| LoopContext loop_context{&loops_, header_block, nullptr}; | |||
| py::object body_node = python_adapter::GetPyObjAttr(node, "body"); | |||
| FunctionBlockPtr after_body = ParseStatements(body_block, body_node); | |||
| if (after_body->func_graph()->get_return() == nullptr) { | |||
| after_body->Jump(header_block, nullptr); | |||
| } | |||
| header_block->Mature(); | |||
| after_block->Mature(); | |||
| auto &end_block = loop_context.EndBlock(); | |||
| if (end_block) { | |||
| // end_block exists if we encounter 'break' in loop body. | |||
| after_block->Jump(end_block, nullptr); | |||
| end_block->Mature(); | |||
| return end_block; | |||
| } | |||
| // No 'break', no end_block. | |||
| return after_block; | |||
| } | |||
| @@ -1049,13 +1065,24 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec | |||
| body_block->Mature(); | |||
| header_block->ConditionalJump(cond_apply, body_block, after_block); | |||
| // Parse loop body statements with loop context. | |||
| LoopContext loop_context{&loops_, header_block, iter2_app}; | |||
| py::object body_node = python_adapter::GetPyObjAttr(node, "body"); | |||
| FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); | |||
| if (after_body_block->func_graph()->get_return() == nullptr) { | |||
| after_body_block->Jump(header_block, iter2_app); | |||
| } | |||
| header_block->Mature(); | |||
| after_block->Mature(); | |||
| auto &end_block = loop_context.EndBlock(); | |||
| if (end_block) { | |||
| // end_block exists if we encounter 'break' in loop body. | |||
| after_block->Jump(end_block, nullptr); | |||
| end_block->Mature(); | |||
| return end_block; | |||
| } | |||
| // No 'break', no end_block. | |||
| return after_block; | |||
| } | |||
| AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) { | |||
| @@ -1222,6 +1249,52 @@ FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::ob | |||
| return block; | |||
| } | |||
| FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) { | |||
| if (loops_.empty()) { | |||
| // Report error if loop context not set for the 'break' statement. | |||
| py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); | |||
| if (location.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "List size should not be less than 2."; | |||
| } | |||
| auto filename = location[0].cast<std::string>(); | |||
| auto line_no = location[1].cast<int>(); | |||
| MS_LOG(EXCEPTION) << "Unexpected 'break' at " << filename << ":" << line_no; | |||
| } | |||
| // Get current loop. | |||
| Loop &loop = loops_.top(); | |||
| if (loop.end == nullptr) { | |||
| // Create end_block if it is not existed. | |||
| TraceManager::DebugTrace(std::make_shared<TraceLoopEnd>(block->func_graph()->debug_info())); | |||
| loop.end = MakeFunctionBlock(*this); | |||
| TraceManager::EndTrace(); | |||
| } | |||
| // Jump to the end_block. | |||
| block->Jump(loop.end, nullptr); | |||
| return block; | |||
| } | |||
| FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) { | |||
| if (loops_.empty()) { | |||
| // Report error if loop context not set for the 'continue' statement. | |||
| py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); | |||
| if (location.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "List size should not be less than 2."; | |||
| } | |||
| auto filename = location[0].cast<std::string>(); | |||
| auto line_no = location[1].cast<int>(); | |||
| MS_LOG(EXCEPTION) << "Unexpected 'continue' at " << filename << ":" << line_no; | |||
| } | |||
| // Jump to the header of the loop with iterator called. | |||
| Loop &loop = loops_.top(); | |||
| block->Jump(loop.header, loop.iterator); | |||
| return block; | |||
| } | |||
| FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) { | |||
| // We just bypass 'pass' statement. | |||
| return block; | |||
| } | |||
| void Parser::RemoveUnnecessaryPhis() { | |||
| // merge all removable phis to one map; | |||
| std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis; | |||
| @@ -23,6 +23,7 @@ | |||
| #include <string> | |||
| #include <map> | |||
| #include <set> | |||
| #include <stack> | |||
| #include <memory> | |||
| #include "utils/misc.h" | |||
| #include "ir/anf.h" | |||
| @@ -50,6 +51,33 @@ enum ParseStatusCode : int { | |||
| class AstNodeType; | |||
| class ParseAst; | |||
| // Save loop info for 'continue' and 'break' statements. | |||
| struct Loop { | |||
| // Loop header block. | |||
| FunctionBlockPtr header; | |||
| // Loop iterator node, used in 'for loop'. | |||
| AnfNodePtr iterator; | |||
| // Loop end block. | |||
| FunctionBlockPtr end; | |||
| Loop(const FunctionBlockPtr &header, const AnfNodePtr &iterator, const FunctionBlockPtr &end) | |||
| : header(header), iterator(iterator), end(end) {} | |||
| ~Loop() = default; | |||
| }; | |||
| // Loop context for loop stack management. | |||
| class LoopContext { | |||
| public: | |||
| LoopContext(std::stack<Loop> *loops, const FunctionBlockPtr &header, const AnfNodePtr &iterator) : loops_(loops) { | |||
| loops_->emplace(header, iterator, nullptr); | |||
| } | |||
| ~LoopContext() { loops_->pop(); } | |||
| const FunctionBlockPtr &EndBlock() const { return loops_->top().end; } | |||
| private: | |||
| std::stack<Loop> *loops_; | |||
| }; | |||
| // Parser to parse python function | |||
| class Parser { | |||
| public: | |||
| @@ -86,6 +114,12 @@ class Parser { | |||
| FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node); | |||
| // process assign statement | |||
| FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node); | |||
| // process break statement | |||
| FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node); | |||
| // process continue statement | |||
| FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node); | |||
| // process pass statement | |||
| FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node); | |||
| // process the expr and slice node method list | |||
| AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node); | |||
| // process a variable name | |||
| @@ -216,6 +250,8 @@ class Parser { | |||
| std::map<std::string, pStmtFunc> stmt_method_map_; | |||
| // define the function map to parse ast expression | |||
| std::map<std::string, pExprFunc> expr_method_map_; | |||
| // Save current loops to support 'continue', 'break' statement. | |||
| std::stack<Loop> loops_; | |||
| }; | |||
| // AST node type define code to ast | |||
| @@ -0,0 +1,162 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test_cont_break """ | |||
| import pytest | |||
| import numpy as np | |||
| from mindspore.nn import Cell | |||
| from mindspore import Tensor, Model, context | |||
| def run_test(netclass, count, dev): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=dev) | |||
| net = netclass() | |||
| model = Model(net) | |||
| for _ in range(count): | |||
| input_np = np.random.randn(2, 3).astype(np.float32) | |||
| input_ms = Tensor(input_np) | |||
| output_np = net.construct(input_np) # run python | |||
| output_ms = model.predict(input_ms) # run graph | |||
| np.testing.assert_array_almost_equal(output_np, output_ms.asnumpy(), decimal=3) | |||
| class for_loop_with_break(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| for i in range(8): | |||
| if i > 5: | |||
| x *= 3 | |||
| break | |||
| x = x * 2 | |||
| pass | |||
| return x | |||
| class for_loop_with_continue(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| for i in range(8): | |||
| if i > 5: | |||
| x *= 3 | |||
| continue | |||
| x = x * 2 | |||
| return x | |||
| class for_loop_with_cont_break(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| for i in range(8): | |||
| if i < 3: | |||
| i *= 2 | |||
| continue | |||
| if i > 5: | |||
| x *= 3 | |||
| break | |||
| x *= 2 | |||
| x = x * 2 | |||
| pass | |||
| return x | |||
| class for_nested_loop_with_break(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| for i in range(3): | |||
| for j in range(5): | |||
| if j > 3: | |||
| x *= 2 | |||
| break | |||
| x = x * 1.5 | |||
| return x | |||
| class while_with_break(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| i = 0 | |||
| while i < 5: | |||
| if i > 3: | |||
| x *= 2 | |||
| break | |||
| x = x * 1.5 | |||
| i += 1 | |||
| return x | |||
| class while_with_continue(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| i = 0 | |||
| while i < 5: | |||
| if i > 3: | |||
| x *= 2 | |||
| i += 1 | |||
| continue | |||
| x = x * 1.5 | |||
| i += 1 | |||
| return x | |||
| class while_for_nested(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| i = 0 | |||
| while i < 5: | |||
| if i > 3: | |||
| for j in range(3): | |||
| if j > 1: | |||
| break | |||
| x *= 2 | |||
| i += 1 | |||
| continue | |||
| x = x * 1.5 | |||
| i += 1 | |||
| return x | |||
| class pass_branch(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| i = 0 | |||
| while i < 5: | |||
| if i > 3: | |||
| pass | |||
| else: | |||
| x = x * 1.5 | |||
| i += 1 | |||
| return x | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_cont_break(): | |||
| count = 20 | |||
| dev = 'CPU' | |||
| run_test(for_loop_with_break, count, dev) | |||
| run_test(for_loop_with_continue, count, dev) | |||
| run_test(for_loop_with_cont_break, count, dev) | |||
| run_test(for_nested_loop_with_break, count, dev) | |||
| run_test(while_with_break, count, dev) | |||
| run_test(while_with_continue, count, dev) | |||
| run_test(while_for_nested, count, dev) | |||
| run_test(pass_branch, count, dev) | |||
| @@ -0,0 +1,180 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test_cont_break """ | |||
| import numpy as np | |||
| from mindspore.nn import Cell | |||
| from mindspore import Tensor, Model, context | |||
| from ...ut_filter import non_graph_engine | |||
| def run_test(netclass, count): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net = netclass() | |||
| model = Model(net) | |||
| for _ in range(count): | |||
| input_np = np.random.randn(2, 3).astype(np.float32) | |||
| input_ms = Tensor(input_np) | |||
| output_np = net.construct(input_np) # run python | |||
| output_ms = model.predict(input_ms) # run graph | |||
| assert np.shape(output_np) == np.shape(output_ms.asnumpy()) | |||
| # Disable equal assert because UT in CI use fake backend. | |||
| # np.testing.assert_array_almost_equal(output_np, output_ms.asnumpy(), decimal=3) | |||
| class for_loop_with_break(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| for i in range(8): | |||
| if i > 5: | |||
| x *= 3 | |||
| break | |||
| x = x * 2 | |||
| pass | |||
| return x | |||
| @non_graph_engine | |||
| def test_for_loop_with_break(): | |||
| run_test(for_loop_with_break, 10) | |||
| class for_loop_with_continue(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| for i in range(8): | |||
| if i > 5: | |||
| x *= 3 | |||
| continue | |||
| x = x * 2 | |||
| return x | |||
| @non_graph_engine | |||
| def test_for_loop_with_continue(): | |||
| run_test(for_loop_with_continue, 10) | |||
| class for_loop_with_cont_break(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| for i in range(8): | |||
| if i < 3: | |||
| i *= 2 | |||
| continue | |||
| if i > 5: | |||
| x *= 3 | |||
| break | |||
| x *= 2 | |||
| x = x * 2 | |||
| pass | |||
| return x | |||
| @non_graph_engine | |||
| def test_for_loop_with_cont_break(): | |||
| run_test(for_loop_with_cont_break, 10) | |||
| class for_nested_loop_with_break(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| for i in range(3): | |||
| for j in range(5): | |||
| if j > 3: | |||
| x *= 2 | |||
| break | |||
| x = x * 1.5 | |||
| return x | |||
| @non_graph_engine | |||
| def test_for_nested_loop_with_break(): | |||
| run_test(for_nested_loop_with_break, 10) | |||
| class while_with_break(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| i = 0 | |||
| while i < 5: | |||
| if i > 3: | |||
| x *= 2 | |||
| break | |||
| x = x * 1.5 | |||
| i += 1 | |||
| return x | |||
| @non_graph_engine | |||
| def test_while_with_break(): | |||
| run_test(while_with_break, 10) | |||
| class while_with_continue(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| i = 0 | |||
| while i < 5: | |||
| if i > 3: | |||
| x *= 2 | |||
| i += 1 | |||
| continue | |||
| x = x * 1.5 | |||
| i += 1 | |||
| return x | |||
| @non_graph_engine | |||
| def test_while_with_continue(): | |||
| run_test(while_with_continue, 10) | |||
| class while_for_nested(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| i = 0 | |||
| while i < 5: | |||
| if i > 3: | |||
| for j in range(3): | |||
| if j > 1: | |||
| break | |||
| x *= 2 | |||
| i += 1 | |||
| continue | |||
| x = x * 1.5 | |||
| i += 1 | |||
| return x | |||
| @non_graph_engine | |||
| def test_while_for_nested(): | |||
| run_test(while_for_nested, 10) | |||
| class pass_branch(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| i = 0 | |||
| while i < 5: | |||
| if i > 3: | |||
| pass | |||
| else: | |||
| x = x * 1.5 | |||
| i += 1 | |||
| return x | |||
| @non_graph_engine | |||
| def test_pass_branch(): | |||
| run_test(pass_branch, 10) | |||