| @@ -880,7 +880,7 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap | |||||
| } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) { | } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) { | ||||
| // To find J user. | // To find J user. | ||||
| auto j_user = GetJUser(node_user_map, cnode, index); | auto j_user = GetJUser(node_user_map, cnode, index); | ||||
| primal_j_pair.emplace_back(std::pair<CNodePtr, CNodePtr>(nullptr, j_user)); | |||||
| (void)primal_j_pair.emplace_back(std::pair<CNodePtr, CNodePtr>(nullptr, j_user)); | |||||
| } | } | ||||
| } | } | ||||
| @@ -85,6 +85,11 @@ std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &topos | |||||
| if (IsPrimitiveCNode(node, prim::kPrimLoad)) { | if (IsPrimitiveCNode(node, prim::kPrimLoad)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| // if Call/Switch/SwitchLayer, do not replace load. | |||||
| if (IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch) || | |||||
| IsPrimitiveCNode(node, prim::kPrimSwitchLayer)) { | |||||
| return true; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| auto &inputs = cnode->inputs(); | auto &inputs = cnode->inputs(); | ||||
| return std::any_of(inputs.begin(), inputs.end(), | return std::any_of(inputs.begin(), inputs.end(), | ||||
| @@ -28,6 +28,7 @@ namespace { | |||||
| // data = Load(input, attach) | // data = Load(input, attach) | ||||
| // data = Depend(input, attach) | // data = Depend(input, attach) | ||||
| // monad = UpdateState(input, attach) | // monad = UpdateState(input, attach) | ||||
| constexpr size_t kFirstInputIndex = 0; | |||||
| constexpr size_t kInputIndex = 1; | constexpr size_t kInputIndex = 1; | ||||
| constexpr size_t kAttachIndex = 2; | constexpr size_t kAttachIndex = 2; | ||||
| constexpr size_t kMakeTupleSize = 3; | constexpr size_t kMakeTupleSize = 3; | ||||
| @@ -120,6 +121,13 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| } | } | ||||
| // Skip Call/Switch/SwitchLayer. | |||||
| auto first_input_node = cnode->input(kFirstInputIndex); | |||||
| if (IsPrimitiveCNode(first_input_node, prim::kPrimCall) || IsPrimitiveCNode(first_input_node, prim::kPrimSwitch) || | |||||
| IsPrimitiveCNode(first_input_node, prim::kPrimSwitchLayer)) { | |||||
| return nullptr; | |||||
| } | |||||
| // Remove UpdateState by replace it with its input monad. | // Remove UpdateState by replace it with its input monad. | ||||
| return update_state->input(kInputIndex); | return update_state->input(kInputIndex); | ||||
| } | } | ||||
| @@ -246,7 +246,7 @@ void SubstitutionList::DisplayStatusOfSubstitution(const std::unordered_map<std: | |||||
| << std::endl; | << std::endl; | ||||
| for (size_t i = 0; i < list_.size(); i++) { | for (size_t i = 0; i < list_.size(); i++) { | ||||
| auto name = list_[i]->name_; | auto name = list_[i]->name_; | ||||
| ss << std::left << std::setw(space + 4) << name << "\t"; | |||||
| ss << std::left << std::setw(SizeToInt(space) + 4) << name << "\t"; | |||||
| for (auto change : status.at(name + std::to_string(i))) { | for (auto change : status.at(name + std::to_string(i))) { | ||||
| ss << change << " "; | ss << change << " "; | ||||
| } | } | ||||
| @@ -393,7 +393,7 @@ class TensorDataImpl : public TensorData { | |||||
| pos++; | pos++; | ||||
| } | } | ||||
| size_t len = pos - index; | size_t len = pos - index; | ||||
| std::string space(max_width - len, ' '); | |||||
| std::string space(max_width - SizeToInt(len), ' '); | |||||
| str = str.replace(index, len, space); | str = str.replace(index, len, space); | ||||
| index = str.find('#', index); | index = str.find('#', index); | ||||
| } | } | ||||
| @@ -170,3 +170,52 @@ def test_for_in_if_03(): | |||||
| assert graph_forward_res == pynative_forward_res | assert graph_forward_res == pynative_forward_res | ||||
| assert graph_backward_res == pynative_backward_res | assert graph_backward_res == pynative_backward_res | ||||
| def test_for_in_if_04(): | |||||
| class ForInIfNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||||
| def construct(self, x): | |||||
| out = self.param_a | |||||
| x = self.func(x) | |||||
| out *= x | |||||
| return out | |||||
| def func(self, x): | |||||
| if self.param_a > self.param_b: | |||||
| for _ in range(0, 4): | |||||
| self.param_a += 1 | |||||
| self.param_b -= 3 | |||||
| self.param_b += 10 | |||||
| return x | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return grad_all(self.net)(*inputs) | |||||
| x = Tensor(5, mstype.int32) | |||||
| # graph mode | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| for_in_if_net = ForInIfNet() | |||||
| net = GradNet(for_in_if_net) | |||||
| graph_forward_res = for_in_if_net(x) | |||||
| graph_backward_res = net(x) | |||||
| # pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| for_in_if_net = ForInIfNet() | |||||
| net = GradNet(for_in_if_net) | |||||
| pynative_forward_res = for_in_if_net(x) | |||||
| pynative_backward_res = net(x) | |||||
| assert graph_forward_res == pynative_forward_res | |||||
| assert graph_backward_res == pynative_backward_res | |||||
| @@ -74,6 +74,25 @@ class IfAfterIfNet2(nn.Cell): | |||||
| return y | return y | ||||
| class IfAfterIfNet3(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.param_a = Parameter(Tensor(5, mstype.int32), name='a') | |||||
| self.param_b = Parameter(Tensor(4, mstype.int32), name='b') | |||||
| def construct(self, x, y): | |||||
| out = x * y + self.func(self.param_b) | |||||
| if self.param_a > self.param_b: | |||||
| out += 5 | |||||
| return out | |||||
| def func(self, x): | |||||
| if self.param_a > self.param_b: | |||||
| x += 5 | |||||
| self.param_b += 4 | |||||
| return x | |||||
| class GradNet(nn.Cell): | class GradNet(nn.Cell): | ||||
| def __init__(self, net): | def __init__(self, net): | ||||
| super(GradNet, self).__init__() | super(GradNet, self).__init__() | ||||
| @@ -118,3 +137,9 @@ def test_if_after_if_02(): | |||||
| x = Tensor(2, mstype.int32) | x = Tensor(2, mstype.int32) | ||||
| y = Tensor(5, mstype.int32) | y = Tensor(5, mstype.int32) | ||||
| control_flow_if_after_if(IfAfterIfNet2, x, y) | control_flow_if_after_if(IfAfterIfNet2, x, y) | ||||
| def test_if_after_if_03(): | |||||
| x = Tensor(2, mstype.int32) | |||||
| y = Tensor(5, mstype.int32) | |||||
| control_flow_if_after_if(IfAfterIfNet3, x, y) | |||||