Merge pull request !1385 from amongo/SupportMultiSwitchtags/v0.5.0-beta
| @@ -52,13 +52,17 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { | |||
| // Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be | |||
| // converted to switch guarded. | |||
| std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list( | |||
| {{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}}, {prim::kPrimStateSetItem, {1}}, | |||
| {prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}}, {prim::kPrimReduceSum, {2}}, | |||
| {prim::kPrimReduceMean, {2}}, {prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}}, | |||
| {prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}}, {prim::kPrimGatherV2, {3}}, | |||
| {prim::kPrimReshape, {2}}, {prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}}, | |||
| {prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}}, {prim::kPrimImageSummary, {1}}, | |||
| {prim::kPrimScalarSummary, {1}}, {prim::kPrimHistogramSummary, {1}}}); | |||
| {{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}}, | |||
| {prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}}, | |||
| {prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}}, | |||
| {prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}}, | |||
| {prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}}, | |||
| {prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}}, | |||
| {prim::kPrimGatherV2, {3}}, {prim::kPrimReshape, {2}}, | |||
| {prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}}, | |||
| {prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}}, | |||
| {prim::kPrimImageSummary, {1}}, {prim::kPrimScalarSummary, {1}}, | |||
| {prim::kPrimHistogramSummary, {1}}}); | |||
| for (auto &item : white_list) { | |||
| auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { | |||
| return IsPrimitiveCNode(node, item.first) && idx == index; | |||
| @@ -80,7 +84,8 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { | |||
| using NodeInputReplMap = std::unordered_map<std::pair<AnfNodePtr, size_t>, AnfNodePtr, PairHasher>; | |||
| // replace the nodes which should be changed | |||
| void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector<std::pair<CNodePtr, CNodePtr>> nodes_changed, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node, NodeInputReplMap repl_node_inputs) { | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node, NodeInputReplMap repl_node_inputs, | |||
| const FuncGraphPtr &func_graph) { | |||
| for (auto &node_pair : nodes_changed) { | |||
| CNodePtr old_node = node_pair.first; | |||
| CNodePtr new_node = node_pair.second; | |||
| @@ -99,9 +104,11 @@ void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector<std::p | |||
| } | |||
| for (auto &item : repl_node) { | |||
| if (!manager->Replace(item.first, item.second)) { | |||
| MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString() | |||
| << " to new: " << item.second->DebugString(); | |||
| if (IsPrimitiveCNode(item.second, prim::kPrimReturn)) { | |||
| func_graph->set_output(item.second->cast<CNodePtr>()->input(1)); | |||
| } else if (!manager->Replace(item.first, item.second)) { | |||
| MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString(2) | |||
| << " to new: " << item.second->DebugString(2); | |||
| } | |||
| } | |||
| } | |||
| @@ -154,7 +161,7 @@ FuncGraphPtr TransformGraphCondBranchNodes( | |||
| nodes_changed.emplace_back(node->cast<CNodePtr>(), new_node); | |||
| } | |||
| } | |||
| RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs); | |||
| RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs, graph); | |||
| return graph; | |||
| } | |||
| @@ -508,11 +515,12 @@ bool GraphOutputCompatible(const AbstractBasePtr &true_branch_abs, const Abstrac | |||
| AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, | |||
| const AbstractBasePtr &true_graph_output_abs, | |||
| const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond) { | |||
| const AbstractBasePtr &false_graph_output_abs, const FuncGraphPtr &switch_graph, | |||
| const AnfNodePtr &cond) { | |||
| MS_EXCEPTION_IF_NULL(true_graph_output_abs); | |||
| MS_EXCEPTION_IF_NULL(false_graph_output_abs); | |||
| MS_EXCEPTION_IF_NULL(cond); | |||
| MS_EXCEPTION_IF_NULL(cond->func_graph()); | |||
| MS_EXCEPTION_IF_NULL(switch_graph); | |||
| auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast<PrimitivePtr>(); | |||
| MS_EXCEPTION_IF_NULL(PrimMerge); | |||
| @@ -520,10 +528,10 @@ AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodeP | |||
| std::vector<AnfNodePtr> merge_nodes; | |||
| merge_nodes.push_back(NewValueNode(PrimMerge)); | |||
| std::vector<AnfNodePtr> make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), true_output_node, false_output_node}; | |||
| merge_nodes.push_back(cond->func_graph()->NewCNode(make_tuple_nodes)); | |||
| merge_nodes.push_back(switch_graph->NewCNode(make_tuple_nodes)); | |||
| std::vector<AnfNodePtr> tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), | |||
| cond->func_graph()->NewCNode(merge_nodes), NewValueNode(MakeValue(0))}; | |||
| return cond->func_graph()->NewCNode(tuple_getitem_nodes); | |||
| switch_graph->NewCNode(merge_nodes), NewValueNode(MakeValue(0))}; | |||
| return switch_graph->NewCNode(tuple_getitem_nodes); | |||
| } else { | |||
| abstract::AbstractTuplePtr true_branch_tuple = true_graph_output_abs->cast<abstract::AbstractTuplePtr>(); | |||
| abstract::AbstractTuplePtr false_branch_tuple = false_graph_output_abs->cast<abstract::AbstractTuplePtr>(); | |||
| @@ -533,27 +541,29 @@ AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodeP | |||
| for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) { | |||
| std::vector<AnfNodePtr> true_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), true_output_node, | |||
| NewValueNode(MakeValue(SizeToInt(i)))}; | |||
| auto true_node = cond->func_graph()->NewCNode(true_getitem_nodes); | |||
| auto true_node = switch_graph->NewCNode(true_getitem_nodes); | |||
| std::vector<AnfNodePtr> false_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), false_output_node, | |||
| NewValueNode(MakeValue(SizeToInt(i)))}; | |||
| auto false_node = cond->func_graph()->NewCNode(false_getitem_nodes); | |||
| auto false_node = switch_graph->NewCNode(false_getitem_nodes); | |||
| auto merge_node = GenerateMergeNodes(true_node, false_node, true_branch_tuple->elements()[i], | |||
| false_branch_tuple->elements()[i], cond); | |||
| false_branch_tuple->elements()[i], switch_graph, cond); | |||
| make_tuple_nodes.push_back(merge_node); | |||
| } | |||
| return cond->func_graph()->NewCNode(make_tuple_nodes); | |||
| return switch_graph->NewCNode(make_tuple_nodes); | |||
| } | |||
| } | |||
| AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, | |||
| const AbstractBasePtr &true_graph_output_abs, | |||
| const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond) { | |||
| const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond, | |||
| const FuncGraphPtr &switch_graph) { | |||
| if (!GraphOutputCompatible(true_graph_output_abs, false_graph_output_abs)) { | |||
| MS_LOG(EXCEPTION) << "Switch output branch not compatible, true:" << true_graph_output_abs->ToString() | |||
| << ", false:" << false_graph_output_abs->ToString(); | |||
| } | |||
| return GenerateMergeNodes(true_output_node, false_output_node, true_graph_output_abs, false_graph_output_abs, cond); | |||
| return GenerateMergeNodes(true_output_node, false_output_node, true_graph_output_abs, false_graph_output_abs, | |||
| switch_graph, cond); | |||
| } | |||
| } // namespace internal | |||
| } // namespace irpass | |||
| @@ -168,7 +168,8 @@ FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const | |||
| FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); | |||
| AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, | |||
| const AbstractBasePtr &true_graph_output_abs, | |||
| const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond); | |||
| const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond, | |||
| const FuncGraphPtr &func_graph); | |||
| } // namespace internal | |||
| // {{prim::kPrimSwitch, X, G1, G2}, Xs} | |||
| @@ -190,6 +191,20 @@ class ConvertSwitchReplacement : public AnfVisitor { | |||
| if (g2_ == nullptr || g1_->output() == nullptr || g2_->output() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| // for switch replace method, only graphs without graph inside can be replaced | |||
| for (auto &item : g1_->value_nodes()) { | |||
| auto value_node = item.first; | |||
| if (IsValueNode<FuncGraph>(value_node)) { | |||
| return nullptr; | |||
| } | |||
| } | |||
| for (auto &item : g2_->value_nodes()) { | |||
| auto value_node = item.first; | |||
| if (IsValueNode<FuncGraph>(value_node)) { | |||
| return nullptr; | |||
| } | |||
| } | |||
| auto true_output = g1_->output()->abstract(); | |||
| auto false_output = g2_->output()->abstract(); | |||
| @@ -200,8 +215,8 @@ class ConvertSwitchReplacement : public AnfVisitor { | |||
| auto fg = node->func_graph(); | |||
| auto cloned_g1 = InlineClone(trans_g1, fg, params); | |||
| auto cloned_g2 = InlineClone(trans_g2, fg, params); | |||
| return internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_); | |||
| auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); | |||
| return nnode; | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| @@ -162,7 +162,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| } | |||
| OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}); | |||
| opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}, true); | |||
| OptPassGroupMap map({ | |||
| {"control_group", control_group}, | |||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | |||
| @@ -346,7 +346,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||
| if ((*value == *kAnyValue)) { | |||
| auto value_desc = abs_base->value_desc(); | |||
| MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc) | |||
| << " for python primitive."; | |||
| << " for python primitive." << abs_base->ToString(); | |||
| } | |||
| MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is " | |||
| << value->ToString(); | |||
| @@ -24,6 +24,8 @@ from mindspore.common.parameter import Parameter, ParameterTuple | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||
| from mindspore.common import ms_function | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| @@ -371,7 +373,8 @@ def test_switch_layer(): | |||
| class Layer1(nn.Cell): | |||
| def __init__(self): | |||
| super(Layer1, self).__init__() | |||
| self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') | |||
| self.z1 = Parameter( | |||
| Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') | |||
| def construct(self, x): | |||
| return x * self.z1 | |||
| @@ -379,7 +382,8 @@ def test_switch_layer(): | |||
| class Layer2(nn.Cell): | |||
| def __init__(self): | |||
| super(Layer2, self).__init__() | |||
| self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') | |||
| self.z2 = Parameter( | |||
| Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') | |||
| def construct(self, x): | |||
| return x * self.z2 | |||
| @@ -388,7 +392,8 @@ def test_switch_layer(): | |||
| def __init__(self): | |||
| super(SwitchLayerCell, self).__init__() | |||
| self.layers = (Layer1(), Layer2()) | |||
| self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') | |||
| self.z3 = Parameter( | |||
| Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') | |||
| def construct(self, index, x): | |||
| ret = F.switch_layer(index, self.layers)(x) * self.z3 | |||
| @@ -406,7 +411,8 @@ def test_index_to_switch_layer(): | |||
| class Layer1(nn.Cell): | |||
| def __init__(self): | |||
| super(Layer1, self).__init__() | |||
| self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') | |||
| self.z1 = Parameter( | |||
| Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') | |||
| def construct(self, x): | |||
| return x * self.z1 | |||
| @@ -414,7 +420,8 @@ def test_index_to_switch_layer(): | |||
| class Layer2(nn.Cell): | |||
| def __init__(self): | |||
| super(Layer2, self).__init__() | |||
| self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') | |||
| self.z2 = Parameter( | |||
| Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') | |||
| def construct(self, x): | |||
| return x * self.z2 | |||
| @@ -423,7 +430,8 @@ def test_index_to_switch_layer(): | |||
| def __init__(self): | |||
| super(SwitchLayerCell, self).__init__() | |||
| self.layers = (Layer1(), Layer2()) | |||
| self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') | |||
| self.z3 = Parameter( | |||
| Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') | |||
| def construct(self, index, x): | |||
| ret = self.layers[index](x) * self.z3 | |||
| @@ -444,3 +452,69 @@ def test_control_depend_check(): | |||
| depend = P.ControlDepend(2) | |||
| with pytest.raises(TypeError) as e: | |||
| depend = P.ControlDepend((2,)) | |||
| def test_if_nested_compile(): | |||
| class Net(nn.Cell): | |||
| def __init__(self, auto_prefix=True): | |||
| super().__init__(auto_prefix=auto_prefix) | |||
| self.squre = P.Square() | |||
| self.value = Tensor(3, dtype=ms.float32) | |||
| def construct(self, x, y): | |||
| res = self.value | |||
| if x <= y: | |||
| res = x + res | |||
| res = y + res | |||
| else: | |||
| if x == y: | |||
| res = self.squre(self.value * y) | |||
| else: | |||
| res = self.squre(self.value) | |||
| return res | |||
| x = Tensor(1.0, dtype=ms.float32) | |||
| y = Tensor(2.0, dtype=ms.float32) | |||
| net = Net() | |||
| net(x, y) | |||
| def test_if_inside_for(): | |||
| class Net(nn.Cell): | |||
| def __init__(self, auto_prefix=True): | |||
| super().__init__(auto_prefix=auto_prefix) | |||
| self.squre = P.Square() | |||
| self.value = Tensor(3, dtype=ms.float32) | |||
| self.count = 4 | |||
| def construct(self, x, y): | |||
| res = 0 | |||
| for i in range(self.count): | |||
| if i == x: | |||
| res = res + x | |||
| else: | |||
| res = res - y | |||
| return res | |||
| c1 = Tensor(1, dtype=ms.int32) | |||
| c2 = Tensor(1, dtype=ms.int32) | |||
| net = Net() | |||
| out = net(c1, c2) | |||
| def test_while_in_while(): | |||
| c1 = Tensor(1, dtype=ms.int32) | |||
| c2 = Tensor(2, dtype=ms.int32) | |||
| c3 = Tensor(3, dtype=ms.int32) | |||
| c4 = Tensor(4, dtype=ms.int32) | |||
| @ms_function | |||
| def while_in_while(x, y, z, u): | |||
| out = c4 | |||
| while x < y: | |||
| z = c4 + c4 | |||
| while z < y: | |||
| z = z + 1 | |||
| out = out + 1 | |||
| x = x + 1 | |||
| out = out + 3 | |||
| return out | |||
| while_in_while(c1, c2, c3, c4) | |||
| @@ -0,0 +1,81 @@ | |||
| import numpy as np | |||
| import mindspore | |||
| from mindspore import nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| class Layer1(nn.Cell): | |||
| def __init__(self): | |||
| super(Layer1, self).__init__() | |||
| self.net = nn.Conv2d(3, 1, 3, pad_mode='same') | |||
| self.pad = nn.Pad( | |||
| paddings=((0, 0), (0, 2), (0, 0), (0, 0)), mode="CONSTANT") | |||
| def construct(self, x): | |||
| y = self.net(x) | |||
| return self.pad(y) | |||
| class Layer2(nn.Cell): | |||
| def __init__(self): | |||
| super(Layer2, self).__init__() | |||
| self.net = nn.Conv2d(3, 1, 7, pad_mode='same') | |||
| self.pad = nn.Pad( | |||
| paddings=((0, 0), (0, 2), (0, 0), (0, 0)), mode="CONSTANT") | |||
| def construct(self, x): | |||
| y = self.net(x) | |||
| return self.pad(y) | |||
| class Layer3(nn.Cell): | |||
| def __init__(self): | |||
| super(Layer3, self).__init__() | |||
| self.net = nn.Conv2d(3, 3, 3, pad_mode='same') | |||
| def construct(self, x): | |||
| return self.net(x) | |||
| class SwitchNet(nn.Cell): | |||
| def __init__(self): | |||
| super(SwitchNet, self).__init__() | |||
| self.layer1 = Layer1() | |||
| self.layer2 = Layer2() | |||
| self.layer3 = Layer3() | |||
| self.layers = (self.layer1, self.layer2, self.layer3) | |||
| self.fill = P.Fill() | |||
| def construct(self, x, index): | |||
| y = self.layers[index](x) | |||
| return y | |||
| class MySwitchNet(nn.Cell): | |||
| def __init__(self): | |||
| super(MySwitchNet, self).__init__() | |||
| self.layer1 = Layer1() | |||
| self.layer2 = Layer2() | |||
| self.layer3 = Layer3() | |||
| self.layers = (self.layer1, self.layer2, self.layer3) | |||
| self.fill = P.Fill() | |||
| def construct(self, x, index): | |||
| y = self.layers[0](x) | |||
| for i in range(len(self.layers)): | |||
| if i == index: | |||
| y = self.layers[i](x) | |||
| return y | |||
| def test_layer_switch(): | |||
| net = MySwitchNet() | |||
| x = Tensor(np.ones((3, 3, 24, 24)), mindspore.float32) | |||
| index = Tensor(0, dtype=mindspore.int32) | |||
| y = net(x, index) | |||