| @@ -49,11 +49,11 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { | |||||
| if (vars_.count(var)) { | if (vars_.count(var)) { | ||||
| AnfNodePtr node = vars_[var]; | AnfNodePtr node = vars_[var]; | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (node->isa<ValueNode>()) { | |||||
| return NewValueNode(GetValueNode(node)); | |||||
| } else { | |||||
| return node; | |||||
| auto iter = resolve_to_removable_phis_.find(node); | |||||
| if (iter != resolve_to_removable_phis_.end()) { | |||||
| return iter->second; | |||||
| } | } | ||||
| return node; | |||||
| } | } | ||||
| // get var from predecessor block ,if can't get the make a resolve node to it | // get var from predecessor block ,if can't get the make a resolve node to it | ||||
| if (matured_) { | if (matured_) { | ||||
| @@ -64,7 +64,13 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { | |||||
| return block->ReadVariable(var); | return block->ReadVariable(var); | ||||
| } else if (prev_blocks_.empty()) { | } else if (prev_blocks_.empty()) { | ||||
| // get namespace and make Reslove | // get namespace and make Reslove | ||||
| return MakeResolveSymbol(var); | |||||
| auto it = var_to_resolve_.find(var); | |||||
| if (it != var_to_resolve_.end()) { | |||||
| return it->second; | |||||
| } | |||||
| auto tmp_node = MakeResolveSymbol(var); | |||||
| var_to_resolve_[var] = tmp_node; | |||||
| return tmp_node; | |||||
| } | } | ||||
| } | } | ||||
| // If have more than one predecessor blocks then build a phi node. | // If have more than one predecessor blocks then build a phi node. | ||||
| @@ -217,6 +223,7 @@ bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { | |||||
| // replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1." | // replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1." | ||||
| WriteVariable(var, arg_node); | WriteVariable(var, arg_node); | ||||
| removable_phis_[phi] = arg_node; | removable_phis_[phi] = arg_node; | ||||
| resolve_to_removable_phis_[arg_node] = phi; | |||||
| // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized | // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized | ||||
| // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node. | // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node. | ||||
| for (auto &prev : prev_blocks_) { | for (auto &prev : prev_blocks_) { | ||||
| @@ -101,12 +101,21 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> { | |||||
| // keeps all removable phis which will be removed in one pass. | // keeps all removable phis which will be removed in one pass. | ||||
| std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_; | std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_; | ||||
| // Keeps the map for the resolve node to the removable phi node. | |||||
| // For the case that ReadVariable returns a phi node although this phi node | |||||
| // generated in the prev block is identified as removable. The other blocks | |||||
| // should find this phi node. | |||||
| std::unordered_map<AnfNodePtr, ParameterPtr> resolve_to_removable_phis_; | |||||
| // hold declared global variables in function | // hold declared global variables in function | ||||
| std::set<std::string> global_vars_; | std::set<std::string> global_vars_; | ||||
| // other depend need to insert before function return nodes. | // other depend need to insert before function return nodes. | ||||
| // summary or some other node | // summary or some other node | ||||
| std::vector<AnfNodePtr> auto_depends_; | std::vector<AnfNodePtr> auto_depends_; | ||||
| // keeps the new made resolve symbol for the variable not found in vars_. | |||||
| std::unordered_map<std::string, AnfNodePtr> var_to_resolve_; | |||||
| }; | }; | ||||
| } // namespace parse | } // namespace parse | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ test control ops """ | """ test control ops """ | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| from mindspore import dtype as ms | from mindspore import dtype as ms | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| @@ -1150,3 +1151,147 @@ def test_if_by_if_forward_all_const_branch(): | |||||
| end = Tensor(np.array(3), dtype=ms.float32) | end = Tensor(np.array(3), dtype=ms.float32) | ||||
| x = Tensor(np.array(0), dtype=ms.float32) | x = Tensor(np.array(0), dtype=ms.float32) | ||||
| net(idx, end, x) | net(idx, end, x) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_if_const_grad(): | |||||
| class MyNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.add = P.TensorAdd() | |||||
| def construct(self, *inputs): | |||||
| out = self.add(*inputs) | |||||
| return out | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| self.weights = ParameterTuple(net.trainable_params()) | |||||
| def construct(self, *inputs): | |||||
| a = 1 | |||||
| b = 2 | |||||
| if a > 0: | |||||
| b = 1 | |||||
| a += b | |||||
| return grad_by_list(self.net, self.weights)(*inputs) | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| my_net = MyNet() | |||||
| net = GradNet(my_net) | |||||
| a = Tensor(np.array(0), dtype=ms.int32) | |||||
| b = Tensor(np.array(1), dtype=ms.int32) | |||||
| net(a, b) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_if_by_if_const_grad(): | |||||
| class MyNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.add = P.TensorAdd() | |||||
| def construct(self, *inputs): | |||||
| out = self.add(*inputs) | |||||
| return out | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| self.weights = ParameterTuple(net.trainable_params()) | |||||
| def construct(self, *inputs): | |||||
| a = 1 | |||||
| b = 2 | |||||
| if a > 0: | |||||
| b = 1 | |||||
| if a < 0: | |||||
| b = 0 | |||||
| if a == 0: | |||||
| b = 3 | |||||
| a += b | |||||
| return grad_by_list(self.net, self.weights)(*inputs) | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| my_net = MyNet() | |||||
| net = GradNet(my_net) | |||||
| a = Tensor(np.array(0), dtype=ms.int32) | |||||
| b = Tensor(np.array(1), dtype=ms.int32) | |||||
| net(a, b) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_while_const_grad(): | |||||
| class MyNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.add = P.TensorAdd() | |||||
| def construct(self, *inputs): | |||||
| out = self.add(*inputs) | |||||
| return out | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| self.weights = ParameterTuple(net.trainable_params()) | |||||
| def construct(self, *inputs): | |||||
| a = 1 | |||||
| while a > 1: | |||||
| a = a - 1 | |||||
| return grad_by_list(self.net, self.weights)(*inputs) | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| my_net = MyNet() | |||||
| net = GradNet(my_net) | |||||
| a = Tensor(np.array(0), dtype=ms.int32) | |||||
| b = Tensor(np.array(1), dtype=ms.int32) | |||||
| net(a, b) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_if_by_while_const_grad(): | |||||
| class MyNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.add = P.TensorAdd() | |||||
| def construct(self, *inputs): | |||||
| out = self.add(*inputs) | |||||
| return out | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| self.weights = ParameterTuple(net.trainable_params()) | |||||
| def construct(self, *inputs): | |||||
| a = 1 | |||||
| b = 2 | |||||
| if a > 0: | |||||
| b = 0 | |||||
| while a > 1: | |||||
| a = a - 1 | |||||
| a += b | |||||
| return grad_by_list(self.net, self.weights)(*inputs) | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| my_net = MyNet() | |||||
| net = GradNet(my_net) | |||||
| a = Tensor(np.array(0), dtype=ms.int32) | |||||
| b = Tensor(np.array(1), dtype=ms.int32) | |||||
| net(a, b) | |||||
| @@ -187,7 +187,7 @@ TEST_F(TestOptOpt, CSE) { | |||||
| FuncGraphManagerPtr manager1 = Manage(test_graph1); | FuncGraphManagerPtr manager1 = Manage(test_graph1); | ||||
| draw::Draw("opt_cse_before_1.dot", test_graph1); | draw::Draw("opt_cse_before_1.dot", test_graph1); | ||||
| ASSERT_EQ(manager1->all_nodes().size(), 10); | |||||
| ASSERT_EQ(manager1->all_nodes().size(), 9); | |||||
| auto cse = std::make_shared<CSE>(); | auto cse = std::make_shared<CSE>(); | ||||
| ASSERT_TRUE(cse != nullptr); | ASSERT_TRUE(cse != nullptr); | ||||
| @@ -205,7 +205,7 @@ TEST_F(TestOptOpt, CSE) { | |||||
| FuncGraphManagerPtr manager2 = Manage(test_graph2); | FuncGraphManagerPtr manager2 = Manage(test_graph2); | ||||
| draw::Draw("opt_cse_before_2.dot", test_graph2); | draw::Draw("opt_cse_before_2.dot", test_graph2); | ||||
| ASSERT_EQ(manager2->all_nodes().size(), 22); | |||||
| ASSERT_EQ(manager2->all_nodes().size(), 16); | |||||
| is_changed = cse->Cse(test_graph2, manager2); | is_changed = cse->Cse(test_graph2, manager2); | ||||
| ASSERT_TRUE(is_changed); | ASSERT_TRUE(is_changed); | ||||
| ASSERT_EQ(manager2->all_nodes().size(), 12); | ASSERT_EQ(manager2->all_nodes().size(), 12); | ||||