Merge pull request !2967 from huanghui/heterogeneous-backend-control-depend-optimizetags/v0.6.0-beta
| @@ -451,10 +451,14 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K | |||
| } | |||
| auto origin_inputs = cnode->inputs(); | |||
| bool optimize_depend = false; | |||
| bool optimize_control_depend = false; | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && | |||
| origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) { | |||
| optimize_depend = true; | |||
| } | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3) { | |||
| optimize_control_depend = true; | |||
| } | |||
| // if has multiple depends,only select first depend as parameter | |||
| for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { | |||
| auto anf = origin_inputs[input_idx]; | |||
| @@ -485,6 +489,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K | |||
| } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { | |||
| cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); | |||
| continue; | |||
| } else if (optimize_control_depend) { | |||
| cnode_inputs.push_back(NewValueNode(MakeValue(input_idx))); | |||
| } else { | |||
| *from_other_graph = true; | |||
| // the input node is a cnode from other graph | |||
| @@ -117,6 +117,14 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||
| eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { | |||
| args.emplace_back(inps[kRealInputIndexInDepend]); | |||
| args.emplace_back(inps[kRealInputIndexInDepend]); | |||
| } else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) { | |||
| for (size_t i = 1; i < inps.size(); ++i) { | |||
| if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) { | |||
| args.emplace_back(NewValueNode(MakeValue(i))); | |||
| } else { | |||
| args.emplace_back(ref(inps[i])); | |||
| } | |||
| } | |||
| } else { | |||
| (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref); | |||
| } | |||
| @@ -69,7 +69,91 @@ bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) { | |||
| return false; | |||
| } | |||
| void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref) { | |||
| bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node, | |||
| std::vector<AnfNodePtr> *prior_nodes, std::vector<AnfNodePtr> *depend_nodes) { | |||
| MS_EXCEPTION_IF_NULL(prior_node); | |||
| MS_EXCEPTION_IF_NULL(behind_node); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto &node_users = manager->node_users(); | |||
| if (prior_node->isa<Parameter>()) { | |||
| for (auto &user : node_users[prior_node]) { | |||
| auto cnode = user.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||
| prior_nodes->emplace_back(cnode); | |||
| } | |||
| } | |||
| } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) { | |||
| prior_nodes->emplace_back(prior_node); | |||
| } else { | |||
| return false; | |||
| } | |||
| if (behind_node->isa<Parameter>()) { | |||
| for (auto &user : node_users[behind_node]) { | |||
| auto cnode = user.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||
| depend_nodes->emplace_back(cnode); | |||
| } | |||
| } | |||
| } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) { | |||
| depend_nodes->emplace_back(behind_node); | |||
| } else { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges, | |||
| std::map<AnfNodePtr, size_t> *nodes_ref) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto input_cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(input_cnode); | |||
| auto prior_node = input_cnode->input(kControlDependPriorIndex); | |||
| auto depend_node = input_cnode->input(kControlDependBehindIndex); | |||
| MS_EXCEPTION_IF_NULL(prior_node); | |||
| MS_EXCEPTION_IF_NULL(depend_node); | |||
| PrimitivePtr prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim_ptr); | |||
| ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); | |||
| int depend_mode = 0; | |||
| if (mode_ptr != nullptr) { | |||
| depend_mode = GetValue<int>(mode_ptr); | |||
| } | |||
| if ((prior_node->isa<Parameter>() || depend_node->isa<Parameter>()) && depend_mode == 0) { | |||
| return; | |||
| } | |||
| std::vector<AnfNodePtr> prior_nodes; | |||
| std::vector<AnfNodePtr> behind_nodes; | |||
| if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) { | |||
| return; | |||
| } | |||
| for (auto &first_node : prior_nodes) { | |||
| for (auto &second_node : behind_nodes) { | |||
| MS_EXCEPTION_IF_NULL(first_node); | |||
| MS_EXCEPTION_IF_NULL(second_node); | |||
| auto iter = control_edges->find(second_node); | |||
| if (iter == control_edges->end()) { | |||
| (void)control_edges->insert( | |||
| std::pair<AnfNodePtr, std::vector<AnfNodePtr>>(second_node, std::vector<AnfNodePtr>{first_node})); | |||
| } else { | |||
| iter->second.emplace_back(first_node); | |||
| } | |||
| auto ref_iter = nodes_ref->find(first_node); | |||
| if (ref_iter != nodes_ref->end()) { | |||
| ref_iter->second++; | |||
| } else { | |||
| (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(first_node, 1)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref, | |||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) { | |||
| std::queue<AnfNodePtr> queue; | |||
| queue.push(graph->get_return()); | |||
| std::set<AnfNodePtr> visited; | |||
| @@ -83,6 +167,9 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (auto &input : cnode->inputs()) { | |||
| if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) { | |||
| AddControlEdge(graph, input, control_edges, nodes_ref); | |||
| } | |||
| auto iter = nodes_ref->find(input); | |||
| if (iter != nodes_ref->end()) { | |||
| iter->second++; | |||
| @@ -142,7 +229,8 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string & | |||
| std::stack<AnfNodePtr> to_visit; | |||
| std::stack<AnfNodePtr> next_to_visit; | |||
| std::map<AnfNodePtr, size_t> nodes_ref; | |||
| CalcNodeRefCount(graph, &nodes_ref); | |||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges; | |||
| CalcNodeRefCount(graph, &nodes_ref, &control_edges); | |||
| std::string handle_target = default_target; | |||
| std::string next_target = ""; | |||
| to_visit.push(graph->get_return()); | |||
| @@ -162,6 +250,10 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string & | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto node_inputs = cnode->inputs(); | |||
| std::reverse(node_inputs.begin(), node_inputs.end()); | |||
| auto ctrl_inputs = control_edges.find(node); | |||
| if (ctrl_inputs != control_edges.end()) { | |||
| node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); | |||
| } | |||
| for (auto &input : node_inputs) { | |||
| auto iter = nodes_ref.find(input); | |||
| if (iter != nodes_ref.end()) { | |||
| @@ -26,6 +26,7 @@ | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| // namespace to support intermediate representation definition | |||
| @@ -217,6 +218,15 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { | |||
| auto primitive = value->cast<PrimitivePtr>(); | |||
| auto att_target = primitive->GetAttr("primitive_target"); | |||
| if (att_target != nullptr) { | |||
| if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || | |||
| IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || | |||
| IsPrimitive(attr_input, prim::kPrimMakeTuple) || IsPrimitive(attr_input, prim::kPrimStateSetItem) || | |||
| IsPrimitive(attr_input, prim::kPrimDepend) || IsPrimitive(attr_input, prim::kPrimTupleGetItem) || | |||
| IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || | |||
| IsPrimitive(attr_input, prim::kPrimPartial)) { | |||
| primitive->EraseAttr("primitive_target"); | |||
| return default_target; | |||
| } | |||
| if (!att_target->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; | |||
| } | |||
| @@ -0,0 +1,71 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Net1(nn.Cell): | |||
| def __init__(self): | |||
| super(Net1, self).__init__() | |||
| self.relu1 = P.ReLU() | |||
| self.relu2 = P.ReLU() | |||
| self.mul = P.Mul() | |||
| self.control = P.ControlDepend() | |||
| def construct(self, x, y): | |||
| a = self.relu1(x) | |||
| b = self.relu2(y) | |||
| c = self.mul(a, b) | |||
| e = self.control(a, b) | |||
| return c, e | |||
| class Net2(nn.Cell): | |||
| def __init__(self): | |||
| super(Net2, self).__init__() | |||
| self.relu1 = P.ReLU() | |||
| self.relu2 = P.ReLU().add_prim_attr("primitive_target", "CPU") | |||
| self.mul = P.Mul() | |||
| self.control = P.ControlDepend() | |||
| def construct(self, x, y): | |||
| a = self.relu1(x) | |||
| b = self.relu2(y) | |||
| c = self.mul(a, b) | |||
| e = self.control(a, b) | |||
| return c, e | |||
| def test_net(): | |||
| x = np.random.randn(2, 3, 3, 4).astype(np.float32) | |||
| y = np.random.randn(2, 3, 3, 4).astype(np.float32) | |||
| net1 = Net1() | |||
| output1 = net1(Tensor(x), Tensor(y)) | |||
| context.set_context(save_graphs=True) | |||
| net2 = Net2() | |||
| output2 = net2(Tensor(x), Tensor(y)) | |||
| assert np.allclose(output1[0].asnumpy(), output2[0].asnumpy()) | |||
| print("##success##") | |||
| if __name__ == "__main__": | |||
| test_net() | |||