Merge pull request !2967 from huanghui/heterogeneous-backend-control-depend-optimizetags/v0.6.0-beta
| @@ -451,10 +451,14 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K | |||||
| } | } | ||||
| auto origin_inputs = cnode->inputs(); | auto origin_inputs = cnode->inputs(); | ||||
| bool optimize_depend = false; | bool optimize_depend = false; | ||||
| bool optimize_control_depend = false; | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && | if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && | ||||
| origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) { | origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) { | ||||
| optimize_depend = true; | optimize_depend = true; | ||||
| } | } | ||||
| if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3) { | |||||
| optimize_control_depend = true; | |||||
| } | |||||
| // if has multiple depends,only select first depend as parameter | // if has multiple depends,only select first depend as parameter | ||||
| for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { | for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { | ||||
| auto anf = origin_inputs[input_idx]; | auto anf = origin_inputs[input_idx]; | ||||
| @@ -485,6 +489,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K | |||||
| } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { | } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { | ||||
| cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); | cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); | ||||
| continue; | continue; | ||||
| } else if (optimize_control_depend) { | |||||
| cnode_inputs.push_back(NewValueNode(MakeValue(input_idx))); | |||||
| } else { | } else { | ||||
| *from_other_graph = true; | *from_other_graph = true; | ||||
| // the input node is a cnode from other graph | // the input node is a cnode from other graph | ||||
| @@ -117,6 +117,14 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||||
| eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { | eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { | ||||
| args.emplace_back(inps[kRealInputIndexInDepend]); | args.emplace_back(inps[kRealInputIndexInDepend]); | ||||
| args.emplace_back(inps[kRealInputIndexInDepend]); | args.emplace_back(inps[kRealInputIndexInDepend]); | ||||
| } else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) { | |||||
| for (size_t i = 1; i < inps.size(); ++i) { | |||||
| if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) { | |||||
| args.emplace_back(NewValueNode(MakeValue(i))); | |||||
| } else { | |||||
| args.emplace_back(ref(inps[i])); | |||||
| } | |||||
| } | |||||
| } else { | } else { | ||||
| (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref); | (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref); | ||||
| } | } | ||||
| @@ -69,7 +69,91 @@ bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref) { | |||||
| bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node, | |||||
| std::vector<AnfNodePtr> *prior_nodes, std::vector<AnfNodePtr> *depend_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(prior_node); | |||||
| MS_EXCEPTION_IF_NULL(behind_node); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto &node_users = manager->node_users(); | |||||
| if (prior_node->isa<Parameter>()) { | |||||
| for (auto &user : node_users[prior_node]) { | |||||
| auto cnode = user.first->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||||
| prior_nodes->emplace_back(cnode); | |||||
| } | |||||
| } | |||||
| } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) { | |||||
| prior_nodes->emplace_back(prior_node); | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| if (behind_node->isa<Parameter>()) { | |||||
| for (auto &user : node_users[behind_node]) { | |||||
| auto cnode = user.first->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||||
| depend_nodes->emplace_back(cnode); | |||||
| } | |||||
| } | |||||
| } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) { | |||||
| depend_nodes->emplace_back(behind_node); | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges, | |||||
| std::map<AnfNodePtr, size_t> *nodes_ref) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto input_cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(input_cnode); | |||||
| auto prior_node = input_cnode->input(kControlDependPriorIndex); | |||||
| auto depend_node = input_cnode->input(kControlDependBehindIndex); | |||||
| MS_EXCEPTION_IF_NULL(prior_node); | |||||
| MS_EXCEPTION_IF_NULL(depend_node); | |||||
| PrimitivePtr prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0)); | |||||
| MS_EXCEPTION_IF_NULL(prim_ptr); | |||||
| ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); | |||||
| int depend_mode = 0; | |||||
| if (mode_ptr != nullptr) { | |||||
| depend_mode = GetValue<int>(mode_ptr); | |||||
| } | |||||
| if ((prior_node->isa<Parameter>() || depend_node->isa<Parameter>()) && depend_mode == 0) { | |||||
| return; | |||||
| } | |||||
| std::vector<AnfNodePtr> prior_nodes; | |||||
| std::vector<AnfNodePtr> behind_nodes; | |||||
| if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) { | |||||
| return; | |||||
| } | |||||
| for (auto &first_node : prior_nodes) { | |||||
| for (auto &second_node : behind_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(first_node); | |||||
| MS_EXCEPTION_IF_NULL(second_node); | |||||
| auto iter = control_edges->find(second_node); | |||||
| if (iter == control_edges->end()) { | |||||
| (void)control_edges->insert( | |||||
| std::pair<AnfNodePtr, std::vector<AnfNodePtr>>(second_node, std::vector<AnfNodePtr>{first_node})); | |||||
| } else { | |||||
| iter->second.emplace_back(first_node); | |||||
| } | |||||
| auto ref_iter = nodes_ref->find(first_node); | |||||
| if (ref_iter != nodes_ref->end()) { | |||||
| ref_iter->second++; | |||||
| } else { | |||||
| (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(first_node, 1)); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref, | |||||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) { | |||||
| std::queue<AnfNodePtr> queue; | std::queue<AnfNodePtr> queue; | ||||
| queue.push(graph->get_return()); | queue.push(graph->get_return()); | ||||
| std::set<AnfNodePtr> visited; | std::set<AnfNodePtr> visited; | ||||
| @@ -83,6 +167,9 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| for (auto &input : cnode->inputs()) { | for (auto &input : cnode->inputs()) { | ||||
| if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) { | |||||
| AddControlEdge(graph, input, control_edges, nodes_ref); | |||||
| } | |||||
| auto iter = nodes_ref->find(input); | auto iter = nodes_ref->find(input); | ||||
| if (iter != nodes_ref->end()) { | if (iter != nodes_ref->end()) { | ||||
| iter->second++; | iter->second++; | ||||
| @@ -142,7 +229,8 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string & | |||||
| std::stack<AnfNodePtr> to_visit; | std::stack<AnfNodePtr> to_visit; | ||||
| std::stack<AnfNodePtr> next_to_visit; | std::stack<AnfNodePtr> next_to_visit; | ||||
| std::map<AnfNodePtr, size_t> nodes_ref; | std::map<AnfNodePtr, size_t> nodes_ref; | ||||
| CalcNodeRefCount(graph, &nodes_ref); | |||||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges; | |||||
| CalcNodeRefCount(graph, &nodes_ref, &control_edges); | |||||
| std::string handle_target = default_target; | std::string handle_target = default_target; | ||||
| std::string next_target = ""; | std::string next_target = ""; | ||||
| to_visit.push(graph->get_return()); | to_visit.push(graph->get_return()); | ||||
| @@ -162,6 +250,10 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string & | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto node_inputs = cnode->inputs(); | auto node_inputs = cnode->inputs(); | ||||
| std::reverse(node_inputs.begin(), node_inputs.end()); | std::reverse(node_inputs.begin(), node_inputs.end()); | ||||
| auto ctrl_inputs = control_edges.find(node); | |||||
| if (ctrl_inputs != control_edges.end()) { | |||||
| node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); | |||||
| } | |||||
| for (auto &input : node_inputs) { | for (auto &input : node_inputs) { | ||||
| auto iter = nodes_ref.find(input); | auto iter = nodes_ref.find(input); | ||||
| if (iter != nodes_ref.end()) { | if (iter != nodes_ref.end()) { | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "base/core_ops.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // namespace to support intermediate representation definition | // namespace to support intermediate representation definition | ||||
| @@ -217,6 +218,15 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { | |||||
| auto primitive = value->cast<PrimitivePtr>(); | auto primitive = value->cast<PrimitivePtr>(); | ||||
| auto att_target = primitive->GetAttr("primitive_target"); | auto att_target = primitive->GetAttr("primitive_target"); | ||||
| if (att_target != nullptr) { | if (att_target != nullptr) { | ||||
| if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || | |||||
| IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || | |||||
| IsPrimitive(attr_input, prim::kPrimMakeTuple) || IsPrimitive(attr_input, prim::kPrimStateSetItem) || | |||||
| IsPrimitive(attr_input, prim::kPrimDepend) || IsPrimitive(attr_input, prim::kPrimTupleGetItem) || | |||||
| IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || | |||||
| IsPrimitive(attr_input, prim::kPrimPartial)) { | |||||
| primitive->EraseAttr("primitive_target"); | |||||
| return default_target; | |||||
| } | |||||
| if (!att_target->isa<StringImm>()) { | if (!att_target->isa<StringImm>()) { | ||||
| MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; | MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; | ||||
| } | } | ||||
| @@ -0,0 +1,71 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net1(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net1, self).__init__() | |||||
| self.relu1 = P.ReLU() | |||||
| self.relu2 = P.ReLU() | |||||
| self.mul = P.Mul() | |||||
| self.control = P.ControlDepend() | |||||
| def construct(self, x, y): | |||||
| a = self.relu1(x) | |||||
| b = self.relu2(y) | |||||
| c = self.mul(a, b) | |||||
| e = self.control(a, b) | |||||
| return c, e | |||||
| class Net2(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net2, self).__init__() | |||||
| self.relu1 = P.ReLU() | |||||
| self.relu2 = P.ReLU().add_prim_attr("primitive_target", "CPU") | |||||
| self.mul = P.Mul() | |||||
| self.control = P.ControlDepend() | |||||
| def construct(self, x, y): | |||||
| a = self.relu1(x) | |||||
| b = self.relu2(y) | |||||
| c = self.mul(a, b) | |||||
| e = self.control(a, b) | |||||
| return c, e | |||||
| def test_net(): | |||||
| x = np.random.randn(2, 3, 3, 4).astype(np.float32) | |||||
| y = np.random.randn(2, 3, 3, 4).astype(np.float32) | |||||
| net1 = Net1() | |||||
| output1 = net1(Tensor(x), Tensor(y)) | |||||
| context.set_context(save_graphs=True) | |||||
| net2 = Net2() | |||||
| output2 = net2(Tensor(x), Tensor(y)) | |||||
| assert np.allclose(output1[0].asnumpy(), output2[0].asnumpy()) | |||||
| print("##success##") | |||||
| if __name__ == "__main__": | |||||
| test_net() | |||||