From: @huangbingjian Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qhpull/14888/MERGE
| @@ -103,8 +103,8 @@ static bool isTraversable(const AnfNodePtr &node) { | |||
| return false; | |||
| } | |||
| static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, | |||
| const SubstitutionPtr &substitution) { | |||
| static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, | |||
| const SubstitutionPtr &substitution) { | |||
| auto manager = optimizer->manager(); | |||
| bool is_match = substitution->predicate_(node); | |||
| if (is_match) { | |||
| @@ -126,8 +126,8 @@ static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNod | |||
| return nullptr; | |||
| } | |||
| static inline void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, | |||
| std::deque<AnfNodePtr> *todo, bool change, size_t seen) { | |||
| static void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, | |||
| bool change, size_t seen) { | |||
| if (IsValueNode<FuncGraph>(node)) { | |||
| (*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output()); | |||
| } | |||
| @@ -238,6 +238,23 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons | |||
| return changes; | |||
| } | |||
| void SubstitutionList::DisplayStatusOfSubstitution(const std::unordered_map<std::string, std::vector<bool>> &status, | |||
| const OptimizerPtr &optimizer, size_t space) const { | |||
| std::stringstream ss; | |||
| ss << std::endl | |||
| << "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name | |||
| << std::endl; | |||
| for (size_t i = 0; i < list_.size(); i++) { | |||
| auto name = list_[i]->name_; | |||
| ss << std::left << std::setw(space + 4) << name << "\t"; | |||
| for (auto change : status.at(name + std::to_string(i))) { | |||
| ss << change << " "; | |||
| } | |||
| ss << std::endl; | |||
| } | |||
| MS_LOG(DEBUG) << ss.str(); | |||
| } | |||
| bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const { | |||
| // Add for substitution status counting | |||
| size_t space = 0; | |||
| @@ -282,19 +299,7 @@ bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, con | |||
| // Display the status of each substitution | |||
| if (optimizer->is_on_debug_) { | |||
| std::stringstream ss; | |||
| ss << std::endl | |||
| << "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name | |||
| << std::endl; | |||
| for (size_t i = 0; i < list_.size(); i++) { | |||
| auto name = list_[i]->name_; | |||
| ss << std::left << std::setw(space + 4) << name << "\t"; | |||
| for (auto change : status[name + std::to_string(i)]) { | |||
| ss << change << " "; | |||
| } | |||
| ss << std::endl; | |||
| } | |||
| MS_LOG(DEBUG) << ss.str(); | |||
| DisplayStatusOfSubstitution(status, optimizer, space); | |||
| } | |||
| return changes; | |||
| } | |||
| @@ -20,6 +20,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| @@ -74,6 +75,8 @@ class SubstitutionList { | |||
| bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; | |||
| bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const; | |||
| bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; | |||
| void DisplayStatusOfSubstitution(const std::unordered_map<std::string, std::vector<bool>> &status, | |||
| const OptimizerPtr &optimizer, size_t space) const; | |||
| std::vector<SubstitutionPtr> list_; | |||
| // a flag to mark this list of Substitution can only be executed only once | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore import context | |||
| from mindspore import Tensor, nn | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| grad_all = C.GradOperation(get_all=True) | |||
| context.set_context(device_target="Ascend") | |||
| def test_signle_if(): | |||
| class SignleIfNet(nn.Cell): | |||
| def construct(self, x, y): | |||
| x += 1 | |||
| if x < y: | |||
| y += x | |||
| else: | |||
| y -= x | |||
| y += 5 | |||
| return y | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return grad_all(self.net)(*inputs) | |||
| x = Tensor(2, mstype.int32) | |||
| y = Tensor(5, mstype.int32) | |||
| # graph mode | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| if_net = SignleIfNet() | |||
| net = GradNet(if_net) | |||
| graph_forward_res = if_net(x, y) | |||
| graph_backward_res = net(x, y) | |||
| # pynative mode | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| if_net = SignleIfNet() | |||
| net = GradNet(if_net) | |||
| pynative_forward_res = if_net(x, y) | |||
| pynative_backward_res = net(x, y) | |||
| assert graph_forward_res == pynative_forward_res | |||
| assert graph_backward_res == pynative_backward_res | |||
| @@ -0,0 +1,64 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore import context | |||
| from mindspore import Tensor, nn | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.parameter import Parameter | |||
| grad_all = C.GradOperation(get_all=True) | |||
| context.set_context(device_target="Ascend") | |||
| def test_if_in_if(): | |||
| class IfInIfNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||
| def construct(self, x): | |||
| if self.param_a > self.param_b: | |||
| x += 10 | |||
| if x > self.param_a: | |||
| self.param_b += 1 | |||
| x += self.param_a | |||
| return x | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return grad_all(self.net)(*inputs) | |||
| x = Tensor(2, mstype.int32) | |||
| # graph mode | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| if_in_if_net = IfInIfNet() | |||
| net = GradNet(if_in_if_net) | |||
| graph_forward_res = if_in_if_net(x) | |||
| graph_backward_res = net(x) | |||
| # pynative mode | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| if_in_if_net = IfInIfNet() | |||
| net = GradNet(if_in_if_net) | |||
| pynative_forward_res = if_in_if_net(x) | |||
| pynative_backward_res = net(x) | |||
| assert graph_forward_res == pynative_forward_res | |||
| assert graph_backward_res == pynative_backward_res | |||
| @@ -0,0 +1,65 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore import context | |||
| from mindspore import Tensor, nn | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.parameter import Parameter | |||
| grad_all = C.GradOperation(get_all=True) | |||
| context.set_context(device_target="Ascend") | |||
| def test_if_after_if(): | |||
| class IfAfterIfNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||
| def construct(self, x): | |||
| out = x + self.param_b | |||
| if self.param_a > self.param_b: | |||
| x += 5 | |||
| self.param_b += 4 | |||
| if x < self.param_b: | |||
| out += self.param_b | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return grad_all(self.net)(*inputs) | |||
| x = Tensor(2, mstype.int32) | |||
| # graph mode | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| if_after_if_net = IfAfterIfNet() | |||
| net = GradNet(if_after_if_net) | |||
| graph_forward_res = if_after_if_net(x) | |||
| graph_backward_res = net(x) | |||
| # pynative mode | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| if_after_if_net = IfAfterIfNet() | |||
| net = GradNet(if_after_if_net) | |||
| pynative_forward_res = if_after_if_net(x) | |||
| pynative_backward_res = net(x) | |||
| assert graph_forward_res == pynative_forward_res | |||
| assert graph_backward_res == pynative_backward_res | |||
| @@ -0,0 +1,67 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore import context | |||
| from mindspore import Tensor, nn | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.parameter import Parameter | |||
| grad_all = C.GradOperation(get_all=True) | |||
| context.set_context(device_target="Ascend") | |||
| def test_if_after_if_in_if(): | |||
| class IfAfterIfInIfNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||
| def construct(self, x): | |||
| out = x + self.param_b | |||
| if self.param_a > self.param_b: | |||
| x += 5 | |||
| if x > self.param_a: | |||
| self.param_b += 1 | |||
| self.param_b += 3 | |||
| if x < self.param_b: | |||
| out += self.param_b | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return grad_all(self.net)(*inputs) | |||
| x = Tensor(2, mstype.int32) | |||
| # graph mode | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| if_after_if_in_if_net = IfAfterIfInIfNet() | |||
| net = GradNet(if_after_if_in_if_net) | |||
| graph_forward_res = if_after_if_in_if_net(x) | |||
| graph_backward_res = net(x) | |||
| # pynative mode | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| if_after_if_in_if_net = IfAfterIfInIfNet() | |||
| net = GradNet(if_after_if_in_if_net) | |||
| pynative_forward_res = if_after_if_in_if_net(x) | |||
| pynative_backward_res = net(x) | |||
| assert graph_forward_res == pynative_forward_res | |||
| assert graph_backward_res == pynative_backward_res | |||
| @@ -0,0 +1,66 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore import context | |||
| from mindspore import Tensor, nn | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.parameter import Parameter | |||
| grad_all = C.GradOperation(get_all=True) | |||
| context.set_context(device_target="Ascend") | |||
| def test_if_after_if_in_for(): | |||
| class IfAfterIfInForNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||
| def construct(self, x): | |||
| out = x + self.param_b | |||
| for _ in range(4): | |||
| if out <= 20: | |||
| out += self.param_a | |||
| self.param_b += 3 | |||
| if x < self.param_b: | |||
| out -= self.param_b | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return grad_all(self.net)(*inputs) | |||
| x = Tensor(2, mstype.int32) | |||
| # graph mode | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| if_after_if_in_for_net = IfAfterIfInForNet() | |||
| net = GradNet(if_after_if_in_for_net) | |||
| graph_forward_res = if_after_if_in_for_net(x) | |||
| graph_backward_res = net(x) | |||
| # pynative mode | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| if_after_if_in_for_net = IfAfterIfInForNet() | |||
| net = GradNet(if_after_if_in_for_net) | |||
| pynative_forward_res = if_after_if_in_for_net(x) | |||
| pynative_backward_res = net(x) | |||
| assert graph_forward_res == pynative_forward_res | |||
| assert graph_backward_res == pynative_backward_res | |||
| @@ -0,0 +1,67 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore import context | |||
| from mindspore import Tensor, nn | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.parameter import Parameter | |||
| grad_all = C.GradOperation(get_all=True) | |||
| context.set_context(device_target="Ascend") | |||
| def test_if_after_for_in_if(): | |||
| class IfAfterForInIfNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||
| def construct(self, x): | |||
| out = x + self.param_a | |||
| if self.param_a > self.param_b: | |||
| for _ in range(4): | |||
| self.param_a += 1 | |||
| self.param_b -= 3 | |||
| self.param_b += 15 | |||
| if x < self.param_b: | |||
| out -= self.param_b | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return grad_all(self.net)(*inputs) | |||
| x = Tensor(2, mstype.int32) | |||
| # graph mode | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| if_after_for_in_if_net = IfAfterForInIfNet() | |||
| net = GradNet(if_after_for_in_if_net) | |||
| graph_forward_res = if_after_for_in_if_net(x) | |||
| graph_backward_res = net(x) | |||
| # pynative mode | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| if_after_for_in_if_net = IfAfterForInIfNet() | |||
| net = GradNet(if_after_for_in_if_net) | |||
| pynative_forward_res = if_after_for_in_if_net(x) | |||
| pynative_backward_res = net(x) | |||
| assert graph_forward_res == pynative_forward_res | |||
| assert graph_backward_res == pynative_backward_res | |||
| @@ -0,0 +1,67 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore import context | |||
| from mindspore import Tensor, nn | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.parameter import Parameter | |||
| grad_all = C.GradOperation(get_all=True) | |||
| context.set_context(device_target="Ascend") | |||
| def test_if_after_for_in_while(): | |||
| class IfAfterForInWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||
| self.param_b = Parameter(Tensor(2, mstype.int32), name='b') | |||
| def construct(self, x): | |||
| out = x + self.param_a | |||
| while self.param_a > self.param_b: | |||
| self.param_b += 1 | |||
| for _ in range(4): | |||
| self.param_a += 3 | |||
| self.param_a -= 40 | |||
| if x > self.param_a: | |||
| out += self.param_a * 10 | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return grad_all(self.net)(*inputs) | |||
| x = Tensor(2, mstype.int32) | |||
| # graph mode | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| if_after_for_in_while_net = IfAfterForInWhileNet() | |||
| net = GradNet(if_after_for_in_while_net) | |||
| graph_forward_res = if_after_for_in_while_net(x) | |||
| graph_backward_res = net(x) | |||
| # pynative mode | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| if_after_for_in_while_net = IfAfterForInWhileNet() | |||
| net = GradNet(if_after_for_in_while_net) | |||
| pynative_forward_res = if_after_for_in_while_net(x) | |||
| pynative_backward_res = net(x) | |||
| assert graph_forward_res == pynative_forward_res | |||
| assert graph_backward_res == pynative_backward_res | |||
| @@ -0,0 +1,67 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore import context | |||
| from mindspore import Tensor, nn | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.parameter import Parameter | |||
| grad_all = C.GradOperation(get_all=True) | |||
| context.set_context(device_target="Ascend") | |||
| def test_if_after_for_in_for(): | |||
| class IfAfterForInForNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||
| self.param_b = Parameter(Tensor(2, mstype.int32), name='b') | |||
| def construct(self, x): | |||
| out = x + self.param_a | |||
| for _ in range(0, 10): | |||
| x *= 2 | |||
| for _ in range(0, 5): | |||
| self.param_a += 1 | |||
| x += self.param_b | |||
| if self.param_a > self.param_b: | |||
| out += x | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return grad_all(self.net)(*inputs) | |||
| x = Tensor(2, mstype.int32) | |||
| # graph mode | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| if_after_for_in_for_net = IfAfterForInForNet() | |||
| net = GradNet(if_after_for_in_for_net) | |||
| graph_forward_res = if_after_for_in_for_net(x) | |||
| graph_backward_res = net(x) | |||
| # pynative mode | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| if_after_for_in_for_net = IfAfterForInForNet() | |||
| net = GradNet(if_after_for_in_for_net) | |||
| pynative_forward_res = if_after_for_in_for_net(x) | |||
| pynative_backward_res = net(x) | |||
| assert graph_forward_res == pynative_forward_res | |||
| assert graph_backward_res == pynative_backward_res | |||
| @@ -0,0 +1,66 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore import context | |||
| from mindspore import Tensor, nn | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.parameter import Parameter | |||
| grad_all = C.GradOperation(get_all=True) | |||
| context.set_context(device_target="Ascend") | |||
| def test_for_after_if(): | |||
| class ForAfterIfNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||
| def construct(self, x): | |||
| out = self.param_a | |||
| if self.param_a > self.param_b: | |||
| x += 3 | |||
| self.param_b += 1 | |||
| for _ in range(0, 5): | |||
| x += self.param_b | |||
| out *= x | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return grad_all(self.net)(*inputs) | |||
| x = Tensor(2, mstype.int32) | |||
| # graph mode | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| for_after_if_net = ForAfterIfNet() | |||
| net = GradNet(for_after_if_net) | |||
| graph_forward_res = for_after_if_net(x) | |||
| graph_backward_res = net(x) | |||
| # pynative mode | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| for_after_if_net = ForAfterIfNet() | |||
| net = GradNet(for_after_if_net) | |||
| pynative_forward_res = for_after_if_net(x) | |||
| pynative_backward_res = net(x) | |||
| assert graph_forward_res == pynative_forward_res | |||
| assert graph_backward_res == pynative_backward_res | |||
| @@ -0,0 +1,69 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore import context | |||
| from mindspore import Tensor, nn | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.parameter import Parameter | |||
| grad_all = C.GradOperation(get_all=True) | |||
| context.set_context(device_target="Ascend") | |||
| def test_for_after_if_in_if(): | |||
| class ForAfterIfInIfNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||
| def construct(self, x): | |||
| out = self.param_a | |||
| if self.param_a > self.param_b: | |||
| x += 3 | |||
| if x > self.param_a: | |||
| self.param_b += 4 | |||
| x += self.param_a | |||
| self.param_b += 2 | |||
| for _ in range(0, 5): | |||
| x += self.param_b | |||
| out *= x | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return grad_all(self.net)(*inputs) | |||
| x = Tensor(5, mstype.int32) | |||
| # graph mode | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| for_after_if_in_if_net = ForAfterIfInIfNet() | |||
| net = GradNet(for_after_if_in_if_net) | |||
| graph_forward_res = for_after_if_in_if_net(x) | |||
| graph_backward_res = net(x) | |||
| # pynative mode | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| for_after_if_in_if_net = ForAfterIfInIfNet() | |||
| net = GradNet(for_after_if_in_if_net) | |||
| pynative_forward_res = for_after_if_in_if_net(x) | |||
| pynative_backward_res = net(x) | |||
| assert graph_forward_res == pynative_forward_res | |||
| assert graph_backward_res == pynative_backward_res | |||
| @@ -0,0 +1,68 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore import context | |||
| from mindspore import Tensor, nn | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.parameter import Parameter | |||
| grad_all = C.GradOperation(get_all=True) | |||
| context.set_context(device_target="Ascend") | |||
| def test_for_after_for_in_if(): | |||
| class ForAfterForInIfNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||
| def construct(self, x): | |||
| out = self.param_a | |||
| if self.param_a > self.param_b: | |||
| for _ in range(0, 4): | |||
| self.param_a += 1 | |||
| self.param_b -= 3 | |||
| self.param_b += 10 | |||
| for _ in range(0, 5): | |||
| x += self.param_b | |||
| out *= x | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return grad_all(self.net)(*inputs) | |||
| x = Tensor(5, mstype.int32) | |||
| # graph mode | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| for_after_for_in_if_net = ForAfterForInIfNet() | |||
| net = GradNet(for_after_for_in_if_net) | |||
| graph_forward_res = for_after_for_in_if_net(x) | |||
| graph_backward_res = net(x) | |||
| # pynative mode | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| for_after_for_in_if_net = ForAfterForInIfNet() | |||
| net = GradNet(for_after_for_in_if_net) | |||
| pynative_forward_res = for_after_for_in_if_net(x) | |||
| pynative_backward_res = net(x) | |||
| assert graph_forward_res == pynative_forward_res | |||
| assert graph_backward_res == pynative_backward_res | |||