From: @huangbingjian Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qhpull/14888/MERGE
| @@ -103,8 +103,8 @@ static bool isTraversable(const AnfNodePtr &node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, | |||||
| const SubstitutionPtr &substitution) { | |||||
| static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, | |||||
| const SubstitutionPtr &substitution) { | |||||
| auto manager = optimizer->manager(); | auto manager = optimizer->manager(); | ||||
| bool is_match = substitution->predicate_(node); | bool is_match = substitution->predicate_(node); | ||||
| if (is_match) { | if (is_match) { | ||||
| @@ -126,8 +126,8 @@ static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNod | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| static inline void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, | |||||
| std::deque<AnfNodePtr> *todo, bool change, size_t seen) { | |||||
| static void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, | |||||
| bool change, size_t seen) { | |||||
| if (IsValueNode<FuncGraph>(node)) { | if (IsValueNode<FuncGraph>(node)) { | ||||
| (*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output()); | (*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output()); | ||||
| } | } | ||||
| @@ -238,6 +238,23 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons | |||||
| return changes; | return changes; | ||||
| } | } | ||||
| void SubstitutionList::DisplayStatusOfSubstitution(const std::unordered_map<std::string, std::vector<bool>> &status, | |||||
| const OptimizerPtr &optimizer, size_t space) const { | |||||
| std::stringstream ss; | |||||
| ss << std::endl | |||||
| << "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name | |||||
| << std::endl; | |||||
| for (size_t i = 0; i < list_.size(); i++) { | |||||
| auto name = list_[i]->name_; | |||||
| ss << std::left << std::setw(space + 4) << name << "\t"; | |||||
| for (auto change : status.at(name + std::to_string(i))) { | |||||
| ss << change << " "; | |||||
| } | |||||
| ss << std::endl; | |||||
| } | |||||
| MS_LOG(DEBUG) << ss.str(); | |||||
| } | |||||
| bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const { | bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const { | ||||
| // Add for substitution status counting | // Add for substitution status counting | ||||
| size_t space = 0; | size_t space = 0; | ||||
| @@ -282,19 +299,7 @@ bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, con | |||||
| // Display the status of each substitution | // Display the status of each substitution | ||||
| if (optimizer->is_on_debug_) { | if (optimizer->is_on_debug_) { | ||||
| std::stringstream ss; | |||||
| ss << std::endl | |||||
| << "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name | |||||
| << std::endl; | |||||
| for (size_t i = 0; i < list_.size(); i++) { | |||||
| auto name = list_[i]->name_; | |||||
| ss << std::left << std::setw(space + 4) << name << "\t"; | |||||
| for (auto change : status[name + std::to_string(i)]) { | |||||
| ss << change << " "; | |||||
| } | |||||
| ss << std::endl; | |||||
| } | |||||
| MS_LOG(DEBUG) << ss.str(); | |||||
| DisplayStatusOfSubstitution(status, optimizer, space); | |||||
| } | } | ||||
| return changes; | return changes; | ||||
| } | } | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <unordered_map> | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| @@ -74,6 +75,8 @@ class SubstitutionList { | |||||
| bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; | bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; | ||||
| bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const; | bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const; | ||||
| bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; | bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; | ||||
| void DisplayStatusOfSubstitution(const std::unordered_map<std::string, std::vector<bool>> &status, | |||||
| const OptimizerPtr &optimizer, size_t space) const; | |||||
| std::vector<SubstitutionPtr> list_; | std::vector<SubstitutionPtr> list_; | ||||
| // a flag to mark this list of Substitution can only be executed only once | // a flag to mark this list of Substitution can only be executed only once | ||||
| @@ -0,0 +1,60 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import context | |||||
| from mindspore import Tensor, nn | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| context.set_context(device_target="Ascend") | |||||
| def test_signle_if(): | |||||
| class SignleIfNet(nn.Cell): | |||||
| def construct(self, x, y): | |||||
| x += 1 | |||||
| if x < y: | |||||
| y += x | |||||
| else: | |||||
| y -= x | |||||
| y += 5 | |||||
| return y | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return grad_all(self.net)(*inputs) | |||||
| x = Tensor(2, mstype.int32) | |||||
| y = Tensor(5, mstype.int32) | |||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| if_net = SignleIfNet() | |||||
| net = GradNet(if_net) | |||||
| graph_forward_res = if_net(x, y) | |||||
| graph_backward_res = net(x, y) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| if_net = SignleIfNet() | |||||
| net = GradNet(if_net) | |||||
| pynative_forward_res = if_net(x, y) | |||||
| pynative_backward_res = net(x, y) | |||||
| assert graph_forward_res == pynative_forward_res | |||||
| assert graph_backward_res == pynative_backward_res | |||||
| @@ -0,0 +1,64 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import context | |||||
| from mindspore import Tensor, nn | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.parameter import Parameter | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| context.set_context(device_target="Ascend") | |||||
| def test_if_in_if(): | |||||
| class IfInIfNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||||
| def construct(self, x): | |||||
| if self.param_a > self.param_b: | |||||
| x += 10 | |||||
| if x > self.param_a: | |||||
| self.param_b += 1 | |||||
| x += self.param_a | |||||
| return x | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return grad_all(self.net)(*inputs) | |||||
| x = Tensor(2, mstype.int32) | |||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| if_in_if_net = IfInIfNet() | |||||
| net = GradNet(if_in_if_net) | |||||
| graph_forward_res = if_in_if_net(x) | |||||
| graph_backward_res = net(x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| if_in_if_net = IfInIfNet() | |||||
| net = GradNet(if_in_if_net) | |||||
| pynative_forward_res = if_in_if_net(x) | |||||
| pynative_backward_res = net(x) | |||||
| assert graph_forward_res == pynative_forward_res | |||||
| assert graph_backward_res == pynative_backward_res | |||||
| @@ -0,0 +1,65 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import context | |||||
| from mindspore import Tensor, nn | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.parameter import Parameter | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| context.set_context(device_target="Ascend") | |||||
| def test_if_after_if(): | |||||
| class IfAfterIfNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||||
| def construct(self, x): | |||||
| out = x + self.param_b | |||||
| if self.param_a > self.param_b: | |||||
| x += 5 | |||||
| self.param_b += 4 | |||||
| if x < self.param_b: | |||||
| out += self.param_b | |||||
| return out | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return grad_all(self.net)(*inputs) | |||||
| x = Tensor(2, mstype.int32) | |||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| if_after_if_net = IfAfterIfNet() | |||||
| net = GradNet(if_after_if_net) | |||||
| graph_forward_res = if_after_if_net(x) | |||||
| graph_backward_res = net(x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| if_after_if_net = IfAfterIfNet() | |||||
| net = GradNet(if_after_if_net) | |||||
| pynative_forward_res = if_after_if_net(x) | |||||
| pynative_backward_res = net(x) | |||||
| assert graph_forward_res == pynative_forward_res | |||||
| assert graph_backward_res == pynative_backward_res | |||||
| @@ -0,0 +1,67 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import context | |||||
| from mindspore import Tensor, nn | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.parameter import Parameter | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| context.set_context(device_target="Ascend") | |||||
| def test_if_after_if_in_if(): | |||||
| class IfAfterIfInIfNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||||
| def construct(self, x): | |||||
| out = x + self.param_b | |||||
| if self.param_a > self.param_b: | |||||
| x += 5 | |||||
| if x > self.param_a: | |||||
| self.param_b += 1 | |||||
| self.param_b += 3 | |||||
| if x < self.param_b: | |||||
| out += self.param_b | |||||
| return out | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return grad_all(self.net)(*inputs) | |||||
| x = Tensor(2, mstype.int32) | |||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| if_after_if_in_if_net = IfAfterIfInIfNet() | |||||
| net = GradNet(if_after_if_in_if_net) | |||||
| graph_forward_res = if_after_if_in_if_net(x) | |||||
| graph_backward_res = net(x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| if_after_if_in_if_net = IfAfterIfInIfNet() | |||||
| net = GradNet(if_after_if_in_if_net) | |||||
| pynative_forward_res = if_after_if_in_if_net(x) | |||||
| pynative_backward_res = net(x) | |||||
| assert graph_forward_res == pynative_forward_res | |||||
| assert graph_backward_res == pynative_backward_res | |||||
| @@ -0,0 +1,66 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import context | |||||
| from mindspore import Tensor, nn | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.parameter import Parameter | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| context.set_context(device_target="Ascend") | |||||
| def test_if_after_if_in_for(): | |||||
| class IfAfterIfInForNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||||
| def construct(self, x): | |||||
| out = x + self.param_b | |||||
| for _ in range(4): | |||||
| if out <= 20: | |||||
| out += self.param_a | |||||
| self.param_b += 3 | |||||
| if x < self.param_b: | |||||
| out -= self.param_b | |||||
| return out | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return grad_all(self.net)(*inputs) | |||||
| x = Tensor(2, mstype.int32) | |||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| if_after_if_in_for_net = IfAfterIfInForNet() | |||||
| net = GradNet(if_after_if_in_for_net) | |||||
| graph_forward_res = if_after_if_in_for_net(x) | |||||
| graph_backward_res = net(x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| if_after_if_in_for_net = IfAfterIfInForNet() | |||||
| net = GradNet(if_after_if_in_for_net) | |||||
| pynative_forward_res = if_after_if_in_for_net(x) | |||||
| pynative_backward_res = net(x) | |||||
| assert graph_forward_res == pynative_forward_res | |||||
| assert graph_backward_res == pynative_backward_res | |||||
| @@ -0,0 +1,67 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import context | |||||
| from mindspore import Tensor, nn | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.parameter import Parameter | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| context.set_context(device_target="Ascend") | |||||
| def test_if_after_for_in_if(): | |||||
| class IfAfterForInIfNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||||
| def construct(self, x): | |||||
| out = x + self.param_a | |||||
| if self.param_a > self.param_b: | |||||
| for _ in range(4): | |||||
| self.param_a += 1 | |||||
| self.param_b -= 3 | |||||
| self.param_b += 15 | |||||
| if x < self.param_b: | |||||
| out -= self.param_b | |||||
| return out | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return grad_all(self.net)(*inputs) | |||||
| x = Tensor(2, mstype.int32) | |||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| if_after_for_in_if_net = IfAfterForInIfNet() | |||||
| net = GradNet(if_after_for_in_if_net) | |||||
| graph_forward_res = if_after_for_in_if_net(x) | |||||
| graph_backward_res = net(x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| if_after_for_in_if_net = IfAfterForInIfNet() | |||||
| net = GradNet(if_after_for_in_if_net) | |||||
| pynative_forward_res = if_after_for_in_if_net(x) | |||||
| pynative_backward_res = net(x) | |||||
| assert graph_forward_res == pynative_forward_res | |||||
| assert graph_backward_res == pynative_backward_res | |||||
| @@ -0,0 +1,67 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import context | |||||
| from mindspore import Tensor, nn | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.parameter import Parameter | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| context.set_context(device_target="Ascend") | |||||
| def test_if_after_for_in_while(): | |||||
| class IfAfterForInWhileNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||||
| self.param_b = Parameter(Tensor(2, mstype.int32), name='b') | |||||
| def construct(self, x): | |||||
| out = x + self.param_a | |||||
| while self.param_a > self.param_b: | |||||
| self.param_b += 1 | |||||
| for _ in range(4): | |||||
| self.param_a += 3 | |||||
| self.param_a -= 40 | |||||
| if x > self.param_a: | |||||
| out += self.param_a * 10 | |||||
| return out | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return grad_all(self.net)(*inputs) | |||||
| x = Tensor(2, mstype.int32) | |||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| if_after_for_in_while_net = IfAfterForInWhileNet() | |||||
| net = GradNet(if_after_for_in_while_net) | |||||
| graph_forward_res = if_after_for_in_while_net(x) | |||||
| graph_backward_res = net(x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| if_after_for_in_while_net = IfAfterForInWhileNet() | |||||
| net = GradNet(if_after_for_in_while_net) | |||||
| pynative_forward_res = if_after_for_in_while_net(x) | |||||
| pynative_backward_res = net(x) | |||||
| assert graph_forward_res == pynative_forward_res | |||||
| assert graph_backward_res == pynative_backward_res | |||||
| @@ -0,0 +1,67 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import context | |||||
| from mindspore import Tensor, nn | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.parameter import Parameter | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| context.set_context(device_target="Ascend") | |||||
| def test_if_after_for_in_for(): | |||||
| class IfAfterForInForNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||||
| self.param_b = Parameter(Tensor(2, mstype.int32), name='b') | |||||
| def construct(self, x): | |||||
| out = x + self.param_a | |||||
| for _ in range(0, 10): | |||||
| x *= 2 | |||||
| for _ in range(0, 5): | |||||
| self.param_a += 1 | |||||
| x += self.param_b | |||||
| if self.param_a > self.param_b: | |||||
| out += x | |||||
| return out | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return grad_all(self.net)(*inputs) | |||||
| x = Tensor(2, mstype.int32) | |||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| if_after_for_in_for_net = IfAfterForInForNet() | |||||
| net = GradNet(if_after_for_in_for_net) | |||||
| graph_forward_res = if_after_for_in_for_net(x) | |||||
| graph_backward_res = net(x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| if_after_for_in_for_net = IfAfterForInForNet() | |||||
| net = GradNet(if_after_for_in_for_net) | |||||
| pynative_forward_res = if_after_for_in_for_net(x) | |||||
| pynative_backward_res = net(x) | |||||
| assert graph_forward_res == pynative_forward_res | |||||
| assert graph_backward_res == pynative_backward_res | |||||
| @@ -0,0 +1,66 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import context | |||||
| from mindspore import Tensor, nn | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.parameter import Parameter | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| context.set_context(device_target="Ascend") | |||||
| def test_for_after_if(): | |||||
| class ForAfterIfNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||||
| def construct(self, x): | |||||
| out = self.param_a | |||||
| if self.param_a > self.param_b: | |||||
| x += 3 | |||||
| self.param_b += 1 | |||||
| for _ in range(0, 5): | |||||
| x += self.param_b | |||||
| out *= x | |||||
| return out | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return grad_all(self.net)(*inputs) | |||||
| x = Tensor(2, mstype.int32) | |||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| for_after_if_net = ForAfterIfNet() | |||||
| net = GradNet(for_after_if_net) | |||||
| graph_forward_res = for_after_if_net(x) | |||||
| graph_backward_res = net(x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| for_after_if_net = ForAfterIfNet() | |||||
| net = GradNet(for_after_if_net) | |||||
| pynative_forward_res = for_after_if_net(x) | |||||
| pynative_backward_res = net(x) | |||||
| assert graph_forward_res == pynative_forward_res | |||||
| assert graph_backward_res == pynative_backward_res | |||||
| @@ -0,0 +1,69 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import context | |||||
| from mindspore import Tensor, nn | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.parameter import Parameter | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| context.set_context(device_target="Ascend") | |||||
| def test_for_after_if_in_if(): | |||||
| class ForAfterIfInIfNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||||
| def construct(self, x): | |||||
| out = self.param_a | |||||
| if self.param_a > self.param_b: | |||||
| x += 3 | |||||
| if x > self.param_a: | |||||
| self.param_b += 4 | |||||
| x += self.param_a | |||||
| self.param_b += 2 | |||||
| for _ in range(0, 5): | |||||
| x += self.param_b | |||||
| out *= x | |||||
| return out | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return grad_all(self.net)(*inputs) | |||||
| x = Tensor(5, mstype.int32) | |||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| for_after_if_in_if_net = ForAfterIfInIfNet() | |||||
| net = GradNet(for_after_if_in_if_net) | |||||
| graph_forward_res = for_after_if_in_if_net(x) | |||||
| graph_backward_res = net(x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| for_after_if_in_if_net = ForAfterIfInIfNet() | |||||
| net = GradNet(for_after_if_in_if_net) | |||||
| pynative_forward_res = for_after_if_in_if_net(x) | |||||
| pynative_backward_res = net(x) | |||||
| assert graph_forward_res == pynative_forward_res | |||||
| assert graph_backward_res == pynative_backward_res | |||||
| @@ -0,0 +1,68 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import context | |||||
| from mindspore import Tensor, nn | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.parameter import Parameter | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| context.set_context(device_target="Ascend") | |||||
| def test_for_after_for_in_if(): | |||||
| class ForAfterForInIfNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||||
| def construct(self, x): | |||||
| out = self.param_a | |||||
| if self.param_a > self.param_b: | |||||
| for _ in range(0, 4): | |||||
| self.param_a += 1 | |||||
| self.param_b -= 3 | |||||
| self.param_b += 10 | |||||
| for _ in range(0, 5): | |||||
| x += self.param_b | |||||
| out *= x | |||||
| return out | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return grad_all(self.net)(*inputs) | |||||
| x = Tensor(5, mstype.int32) | |||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| for_after_for_in_if_net = ForAfterForInIfNet() | |||||
| net = GradNet(for_after_for_in_if_net) | |||||
| graph_forward_res = for_after_for_in_if_net(x) | |||||
| graph_backward_res = net(x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| for_after_for_in_if_net = ForAfterForInIfNet() | |||||
| net = GradNet(for_after_for_in_if_net) | |||||
| pynative_forward_res = for_after_for_in_if_net(x) | |||||
| pynative_backward_res = net(x) | |||||
| assert graph_forward_res == pynative_forward_res | |||||
| assert graph_backward_res == pynative_backward_res | |||||