diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc deleted file mode 100644 index 0fc701f8b4..0000000000 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ /dev/null @@ -1,908 +0,0 @@ -/** - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/session/ascend_control_parser.h" -#include -#include -#include -#include -#include "backend/session/anf_runtime_algorithm.h" -#include "utils/union_find_set.h" -#include "runtime/device/ascend/ascend_label_assign.h" -#include "utils/ms_context.h" -#include "debug/anf_ir_dump.h" - -static constexpr size_t kCNodePrim = 0; -static constexpr size_t kCNodeCallArg = 1; -static constexpr size_t kCNodeSwitchCond = 1; -static constexpr size_t kCNodeSwitchTrue = 2; -static constexpr size_t kCNodeSwitchFalse = 3; -static constexpr size_t kCNodeSwitchLength = 4; -static constexpr size_t kCNodePartialLength = 2; -static constexpr size_t kCNodePartialFunc = 1; -static constexpr size_t kCNodeSwitchLayerBranch = 2; -static constexpr size_t kCNodeSwitchLayerLength = 3; -static constexpr size_t kCNodeAssignTarget = 1; -static constexpr size_t kCNodeAssignSource = 2; -static constexpr size_t kCNodeAssignDestination = 1; - -namespace mindspore { -namespace session { -static void RecursiveReplaceNode(NotNull kg, NotNull main_parameter, - const std::set ¶meter_reuse_set, - const NotNull *> memo) { - if (parameter_reuse_set.empty()) { - MS_LOG(EXCEPTION) << "Parameter_reuse_set is empty."; - } - if (memo->find(kg.get()) != memo->end()) { - return; - } - memo->insert(kg.get()); - - for (auto ¶ : parameter_reuse_set) { - if (para == main_parameter.get()) { - continue; - } - MS_EXCEPTION_IF_NULL(para); - MS_LOG(INFO) << "In " << kg->ToString() << " replace " << para->DebugString() << " of graph " - << AnfAlgo::GetGraphId(para.get()) << " to " << main_parameter->DebugString() << " of graph " - << AnfAlgo::GetGraphId(main_parameter.get().get()); - kg->ReplaceNode(NOT_NULL(para), main_parameter); - } - - for (auto &child : kg->child_graph_order()) { - RecursiveReplaceNode(NOT_NULL(child.lock()), main_parameter, parameter_reuse_set, memo); - } -} - -static AnfNodePtr GetMainParameter(NotNull root_kg, const AnfNodePtr &key, - const std::set ¶meter_reuse_set) { - AnfNodePtr main_parameter = key; - std::set root_inputs_set; - const auto &root_inputs_vector = root_kg->inputs(); - root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end()); - for (auto &node : parameter_reuse_set) { - if (root_inputs_set.find(node) != root_inputs_set.end()) { - main_parameter = node; - break; - } - } - return main_parameter; -} - -static void ReuseParameter(NotNull root_kg, - const std::vector> &link_list) { - // make union find set - UnionFindSet union_find_set; - for (auto &[param, arg] : link_list) { - union_find_set.Add(param); - union_find_set.Add(arg); - } - for (auto &[param, arg] : link_list) { - union_find_set.Union(param, arg); - } - auto parameter_reuse_sets = union_find_set.GetSets(); - - for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) { - if (parameter_reuse_set.size() <= 1) { - continue; - } - auto main_parameter = GetMainParameter(root_kg, key, parameter_reuse_set); - std::set memo; - RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo)); - } -} - -static CNodePtr GetNextRealKernel(const std::vector &list, size_t start) { - for (size_t i = start; i < list.size() - 1; ++i) { - if (AnfAlgo::IsRealKernel(list[i])) { - return list[i]; - } - } - return nullptr; -} - -static void UpdateLabelIdToLabelSetMap(const std::vector &exec_order, - const NotNull *> label_id_to_label_set) { - for (auto &node : exec_order) { - MS_EXCEPTION_IF_NULL(node); - if (!IsPrimitiveCNode(node, prim::kPrimLabelSet)) { - continue; - } - if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { - MS_LOG(EXCEPTION) << node->DebugString() << " has no attr kAttrLabelIndex"; - } - uint32_t label_id = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); - if (auto iter = label_id_to_label_set->find(label_id); iter != label_id_to_label_set->end()) { - MS_LOG(EXCEPTION) << "There are more than one node has same label id " << label_id - << ", node: " << iter->second->DebugString() << " and " << node->DebugString(); - } - (*label_id_to_label_set)[label_id] = node; - } -} - -static std::vector GetTargetLabelSetNodes(NotNull jump_node, - const std::map &label_id_to_label_set) { - std::vector target_label_list; - std::vector target_labelset_nodes; - if (IsPrimitiveCNode(jump_node.get(), prim::kPrimLabelGoto)) { - if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, jump_node)) { - MS_LOG(EXCEPTION) << jump_node->DebugString() << " has no attr kAttrLabelIndex"; - } - uint32_t label_id = AnfAlgo::GetNodeAttr(jump_node.get(), kAttrLabelIndex); - target_label_list.push_back(label_id); - } else if (IsPrimitiveCNode(jump_node.get(), prim::kPrimLabelSwitch)) { - if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, jump_node)) { - MS_LOG(EXCEPTION) << jump_node->DebugString() << " has no attr kPrimLabelSwitch"; - } - target_label_list = AnfAlgo::GetNodeAttr>(jump_node.get(), kAttrLabelSwitchList); - } else { - MS_LOG(EXCEPTION) << "Unknown type jump node " << jump_node->DebugString(); - } - - for (auto label_id : target_label_list) { - auto iter = label_id_to_label_set.find(label_id); - if (iter == label_id_to_label_set.end()) { - MS_LOG(EXCEPTION) << "Cannot find LabelSet node has label id " << label_id; - } - target_labelset_nodes.push_back(iter->second); - } - return target_labelset_nodes; -} - -static void EraseNodeFromExecOrder(const AnfNodePtr &node, const NotNull *> exec_order) { - MS_EXCEPTION_IF_NULL(node); - auto exec_iter = std::find(exec_order->begin(), exec_order->end(), node); - if (exec_iter == exec_order->end()) { - MS_LOG(EXCEPTION) << "Cannot find " << node->DebugString() << " in exec order."; - } - exec_order->erase(exec_iter); -} - -void AscendControlParser::AttachChildGraphToReturnNode(NotNull graph, - const NotNull *> memo) { - if (memo->find(graph) != memo->end()) { - return; - } - memo->insert(graph.get()); - const std::vector> &child_graph_order = graph->child_graph_order(); - if (child_graph_order.empty()) { - return; - } - - std::vector depend_inputs = {NewValueNode(std::make_shared(prim::kPrimPartial->name()))}; - for (auto &kg : child_graph_order) { - std::shared_ptr cg = kg.lock(); - MS_EXCEPTION_IF_NULL(cg); - auto fg = cg->cast(); - MS_EXCEPTION_IF_NULL(fg); - depend_inputs.emplace_back(NewValueNode(fg)); - AttachChildGraphToReturnNode(NOT_NULL(cg), memo); - } - auto child_graphs = graph->NewCNode(depend_inputs); - InsertDependToGraph(graph, NOT_NULL(child_graphs)); -} - -void AscendControlParser::LinkGraph(NotNull kg) { - std::set memo; - std::vector> link_list; - // Insert Assign - ChildGraphDataAssign(kg, NOT_NULL(&link_list), NOT_NULL(&memo)); - memo.clear(); - // Reuse Parameter - ReuseParameter(kg, link_list); - // replace call by label goto / label switch - (void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); - memo.clear(); - // assign label resource - device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg); -} - -void AscendControlParser::EraseParameter(NotNull root_graph, - const std::set &graph_list) { - std::vector exec_order = root_graph->execution_order(); - std::set search_list(exec_order.begin(), exec_order.end()); - std::set root_inputs(root_graph->inputs().begin(), root_graph->inputs().end()); - auto ref_map = root_graph->GetRefMap(); - ReferenceCounter parameter_count([](int64_t read, int64_t write) -> bool { return write == 1; }); - std::multimap> ref_multimap; - std::transform(ref_map.begin(), ref_map.end(), std::inserter(ref_multimap, ref_multimap.end()), - [](const std::pair, std::pair> &p) - -> std::pair> { - return {p.first.first, {p.first.second, p.second.first, p.second.second}}; - }); - std::set all_nodes; - std::map para_to_written_node; - for (auto &graph : graph_list) { - auto out = graph->get_return(); - MS_EXCEPTION_IF_NULL(out); - search_list.insert(out->cast()); - auto nodes = TopoSort(out); - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - if (cnode != nullptr) { - all_nodes.insert(cnode); - } - } - } - // parameter->transdata->assign<-5d node, ref parameter would get from transdata input - auto validate_ref_parameter = [](AnfNodePtr node) -> AnfNodePtr { - if (node->isa() && AnfAlgo::CheckPrimitiveType(node, prim::KPrimTransData)) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto first_input = cnode->input(kFirstDataInputIndex); - MS_EXCEPTION_IF_NULL(first_input); - return first_input; - } - return node; - }; - // prepare referance count - for (auto &node : search_list) { - MS_EXCEPTION_IF_NULL(node); - // if assign node - std::set refed_parameters; - for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) { - refed_parameters.insert(validate_ref_parameter(std::get<1>(iter->second))); - } - - for (auto &in : node->inputs()) { - auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first; - visit_node = validate_ref_parameter(visit_node); - if (!visit_node->isa() || root_inputs.find(visit_node) != root_inputs.end()) { - continue; - } - if (refed_parameters.find(visit_node) != refed_parameters.end()) { - parameter_count.AddWriteCount(visit_node, 1); - para_to_written_node[visit_node] = node; - } else { - parameter_count.AddReadCount(visit_node, 1); - } - } - } - - EraseAssign(std::make_shared(parameter_count), all_nodes, para_to_written_node, root_graph, - graph_list); -} - -void AscendControlParser::EraseAssign(std::shared_ptr parameter_count, - const std::set &all_nodes, - const std::map ¶_to_written_node, - NotNull root_graph, const std::set &graph_list) { - std::vector exec_order = root_graph->execution_order(); - while (parameter_count->HasValidElem()) { - auto [para, read, written] = parameter_count->GetOneValidElem(); - MS_LOG(INFO) << para->DebugString() << " was read " << read << " times, written " << written << " times."; - auto assign_iter = para_to_written_node.find(para); - if (assign_iter == para_to_written_node.end()) { - MS_LOG(EXCEPTION) << "Cannot find assign node that write " << para->DebugString(); - } - auto &assign_node = assign_iter->second; - MS_EXCEPTION_IF_NULL(assign_node); - auto source = assign_node->input(kCNodeAssignSource); - auto destination = assign_node->input(kCNodeAssignDestination); - // not assign node or assign destination is transdata which for ref parameter(write 2 times) -> continue - if (!IsPrimitiveCNode(assign_node, prim::kPrimAssign) || IsPrimitiveCNode(destination, prim::KPrimTransData)) { - parameter_count->EraseElem(para); - continue; - } - MS_LOG(INFO) << "Erase " << assign_node->DebugString(5); - EraseNodeFromExecOrder(assign_node, NOT_NULL(&exec_order)); - MS_EXCEPTION_IF_NULL(source); - auto visit_source = AnfAlgo::VisitKernelWithReturnType(source, 0).first; - parameter_count->AddWriteCount(para, -1); - parameter_count->AddReadCount(para, -1); - if (visit_source->isa()) { - parameter_count->AddReadCount(visit_source, read - 1); - } - - // replace parameter in node - for (auto &node : all_nodes) { - for (size_t i = 0; i < node->size(); ++i) { - if (node->input(i) == para) { - MS_LOG_INFO << "Replace " << node->DebugString() << " input " << i << " by " << source->DebugString(); - node->set_input(i, source); - } - } - } - - // replace parameter in graph input - for (auto &g : graph_list) { - auto child_graph_inputs = g->MutableInputs(); - std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), para, source); - MS_LOG_INFO << "Replace parameter " << para->DebugString() << " by " << source->DebugString() << " in graph " - << g->graph_id() << " inputs"; - } - } - root_graph->set_execution_order(exec_order); -} - -void AscendControlParser::EraseLabel(NotNull root_graph) { - std::vector exec_order = root_graph->execution_order(); - ReferenceCounter label_count([](int32_t read, int32_t write) -> bool { return read <= 1; }); - std::map label_to_written_node; - std::map label_id_to_label_set; - UpdateLabelIdToLabelSetMap(exec_order, NOT_NULL(&label_id_to_label_set)); - CNodePtr last_node = nullptr; - for (auto &cur_node : exec_order) { - MS_EXCEPTION_IF_NULL(cur_node); - if (AnfAlgo::IsCondControlKernel(cur_node)) { - std::vector target_labelset_nodes = GetTargetLabelSetNodes(NOT_NULL(cur_node), label_id_to_label_set); - for (auto &label_set : target_labelset_nodes) { - label_count.AddReadCount(label_set, 1); - label_to_written_node[label_set] = cur_node; - } - } else if (IsPrimitiveCNode(cur_node, prim::kPrimLabelSet)) { - label_count.AddWriteCount(cur_node, 1); - if (last_node != nullptr && !AnfAlgo::IsCondControlKernel(last_node)) { - label_count.AddReadCount(cur_node, 1); - label_to_written_node[cur_node] = last_node; - } - } - last_node = cur_node; - } - - while (label_count.HasValidElem()) { - auto [label_set, read, written] = label_count.GetOneValidElem(); - MS_LOG(INFO) << label_set->DebugString() << " was read " << read << " times, written " << written << " times."; - auto iter = label_to_written_node.find(label_set); - if (read > 0 && iter == label_to_written_node.end()) { - MS_LOG(EXCEPTION) << "Cannot find node jump to " << label_set->DebugString(); - } - CNodePtr jump_node = read > 0 ? iter->second : nullptr; - if (jump_node == nullptr || IsPrimitiveCNode(jump_node, prim::kPrimLabelGoto)) { - MS_LOG(INFO) << "Erase node " << label_set->DebugString(); - EraseNodeFromExecOrder(label_set, NOT_NULL(&exec_order)); - } - if (jump_node != nullptr && IsPrimitiveCNode(jump_node, prim::kPrimLabelGoto)) { - MS_LOG(INFO) << "Erase node " << jump_node->DebugString(); - EraseNodeFromExecOrder(jump_node, NOT_NULL(&exec_order)); - } - label_count.EraseElem(label_set); - } - - root_graph->set_execution_order(exec_order); -} - -void AscendControlParser::ExecutorValidate(NotNull root_graph) { - std::set memo; - (void)RecurseGraph(root_graph, NOT_NULL(&memo)); - EraseParameter(root_graph, memo); - EraseLabel(root_graph); - - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - std::string file_name = "after_erase_label_and_parameter.ir"; - DumpIR(file_name, root_graph.get()); - } -} - -std::vector>> AscendControlParser::ParseCallSwitchNode( - NotNull cnode) { - std::vector>> ret; - - if (IsPrimitiveCNode(cnode.get(), prim::kPrimCall)) { - if (cnode->size() <= kCNodeCallArg) { - MS_LOG(EXCEPTION) << "Call node " << cnode->DebugString() << " has invalid inputs size " << cnode->size(); - } - auto call_arg = cnode->input(kCNodeCallArg); - MS_EXCEPTION_IF_NULL(call_arg); - ret.emplace_back(GetValueNode(call_arg), - std::vector(cnode->inputs().begin() + kCNodeCallArg + 1, cnode->inputs().end())); - } else if (IsPrimitiveCNode(cnode.get(), prim::kPrimSwitch)) { - const std::vector &switch_inputs = cnode->inputs(); - if (switch_inputs.size() < kCNodeSwitchLength) { - MS_LOG(EXCEPTION) << "Switch node " << cnode->DebugString() << " has invalid inputs size " - << switch_inputs.size(); - } - for (auto iter = switch_inputs.begin() + kCNodeSwitchCond + 1; iter != switch_inputs.end(); ++iter) { - const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter)); - ret.emplace_back(target_graph, args); - } - } else if (IsPrimitiveCNode(cnode.get(), prim::kPrimSwitchLayer)) { - const std::vector &switch_layer_inputs = cnode->inputs(); - if (switch_layer_inputs.size() <= kCNodeSwitchLayerBranch) { - MS_LOG(EXCEPTION) << "Switch layer node " << cnode->DebugString() << " has invalid inputs size " - << switch_layer_inputs.size(); - } - for (auto iter = switch_layer_inputs.begin() + kCNodeSwitchLayerBranch; iter != switch_layer_inputs.end(); ++iter) { - const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter)); - ret.emplace_back(target_graph, args); - } - } else { - MS_LOG(EXCEPTION) << "Unsupported call node: " << cnode->DebugString(5); - } - return ret; -} - -void AscendControlParser::ChildGraphDataAssign( - NotNull kg, const NotNull> *> link_list, - const NotNull *> memo) { - if (memo->find(kg) != memo->end()) { - return; - } - memo->insert(kg.get()); - - MS_LOG(INFO) << "Start link data for " << kg->ToString(); - const std::vector &nodes = kg->execution_order(); - - for (auto &node : nodes) { - if (!(IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch) || - IsPrimitiveCNode(node, prim::kPrimSwitchLayer))) { - continue; - } - - auto child_graph_list = ParseCallSwitchNode(NOT_NULL(node)); - for (auto &[child_graph, args] : child_graph_list) { - MS_EXCEPTION_IF_NULL(child_graph); - const std::vector ¶ms = child_graph->inputs(); - if (args.size() != params.size()) { - MS_LOG(EXCEPTION) << child_graph->ToString() << " needs " << params.size() << " inputs but call node " - << node->DebugString(5) << " gives " << args.size(); - } - for (size_t i = 0; i < args.size(); ++i) { - InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i])); - } - } - } - kg->SetExecOrderByDefault(); - for (auto &child_graph : kg->child_graph_order()) { - ChildGraphDataAssign(NOT_NULL(child_graph.lock()), link_list, memo); - } -} - -NotNull AscendControlParser::GetStartLabel(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label) { - CNodePtr start_label; - if (last_node != nullptr && last_label != nullptr) { - start_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); - MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString(); - kg->set_start_label(start_label); - } else { - // no goto node will jump to start label of root graph, so return a fake label - start_label = std::make_shared(std::vector(), FuncGraphPtr(nullptr)); - } - return NOT_NULL(start_label); -} - -NotNull AscendControlParser::ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label, - const NotNull *> memo) { - MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); - - // 1. recursive condition - if (memo->find(kg) != memo->end()) { - MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString(); - return NOT_NULL(kg->get_start_label()); - } - memo->insert(kg.get()); - - // 2. args replace placeholder - LinkParentGraph(kg, last_node, last_label); - - // 3. topological sort - kg->SetExecOrderByDefault(); - const std::vector &nodes = kg->execution_order(); - // 4. insert first_label - CNodePtr start_label = GetStartLabel(kg, last_node, last_label); - - // 5. traverse - for (size_t i = 0; i < nodes.size(); ++i) { - auto &cnode = nodes[i]; - MS_EXCEPTION_IF_NULL(cnode); - if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || - AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) || - AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer))) { - continue; - } - - if (IsPrimitiveCNode(cnode, prim::kPrimCall)) { - RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); - } else if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) { - RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); - } else if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) { - RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); - } else { - MS_LOG(EXCEPTION) << "Unexpected node: " << cnode->DebugString(); - } - } - kg->SetExecOrderByDefault(); - MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString(); - return NOT_NULL(start_label); -} - -void AscendControlParser::InsertDependToGraph(NotNull kg, NotNull attch_node) { - auto return_node = kg->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), - return_node->input(kFirstDataInputIndex), attch_node.get()}; - auto depend_node = kg->NewCNode(inputs); - return_node->set_input(kFirstDataInputIndex, depend_node); -} - -void AscendControlParser::InsertControlDependToGraph(NotNull kg, NotNull prior_node, - NotNull behind_node) { - MS_LOG(INFO) << "Insert control dependence at the end of graph, the prior node is " << prior_node->DebugString() - << ", the behind node is " << behind_node->DebugString(); - auto manager = kg->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodePtrList inputs = {NewValueNode(prim::kPrimDepend), behind_node, prior_node}; - auto depend_cnode = kg->NewCNode(inputs); - if (!manager->Replace(behind_node, depend_cnode)) { - MS_LOG(EXCEPTION) << behind_node->DebugString() << ", replace node failed."; - } -} - -void AscendControlParser::LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, - const CNodePtr &last_label) { - // if not entry graph, replace return with label_goto - if (from_graph_call_node != nullptr && last_label != nullptr) { - auto label_goto = - kg->NewCNode({std::make_shared(std::make_shared(kLabelGotoOpName)), last_label}); - MS_EXCEPTION_IF_NULL(label_goto); - MS_LOG(INFO) << "Insert end goto " << label_goto->DebugString() << " to " << kg->ToString(); - kg->set_end_goto(label_goto); - } -} - -void AscendControlParser::AttachOriginalInputsToGraph(NotNull graph, - const std::vector orig_inputs) { - std::vector make_tuple_inputs = { - mindspore::NewValueNode(std::make_shared(prim::kPrimMakeTuple->name()))}; - std::copy(orig_inputs.begin(), orig_inputs.end(), std::back_inserter(make_tuple_inputs)); - auto make_tuple = graph->NewCNode(make_tuple_inputs); - - InsertDependToGraph(graph, NOT_NULL(make_tuple)); -} - -void AscendControlParser::RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - const NotNull *> memo) { - MS_LOG(INFO) << "Process call func " << cur_node->DebugString(); - - // 1 get kernel graph - std::vector origin_inputs = cur_node->inputs(); - if (kCNodeCallArg >= origin_inputs.size()) { - MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size(); - } - std::vector new_inputs = {std::make_shared(std::make_shared(kLabelGotoOpName))}; - if (!IsValueNode(origin_inputs[kCNodeCallArg])) { - MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; - return; - } - // 2 return label - auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); - MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " call node " - << cur_node->DebugString(); - // 3 add depend relationship - InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); - if (next_node != nullptr && next_node != kg->get_return()) { - InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); - } - auto call_kg = GetValueNode(origin_inputs[kCNodeCallArg]); - // 4 modify call op to goto op - cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]); - // 5 recurse sub graph - CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, memo); - new_inputs.push_back(sub_label); - cur_node->set_inputs(new_inputs); - cur_node->set_abstract(nullptr); - AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue>({call_kg}), cur_node.get()); - kg->RemoveNodeFromGraph(origin_inputs[kCNodeCallArg]); - origin_inputs.assign(origin_inputs.begin() + kCNodeCallArg + 1, origin_inputs.end()); - AttachOriginalInputsToGraph(kg, origin_inputs); - MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString(); -} - -void AscendControlParser::RecurseSwitch(NotNull kg, NotNull cur_node, - const CNodePtr &next_node, const NotNull *> memo) { - MS_LOG(INFO) << "Process switch node " << cur_node->DebugString(); - - if (cur_node->size() < kCNodeSwitchLength) { - MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength; - } - // 1 return label - auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); - MS_EXCEPTION_IF_NULL(back_label); - MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " switch node " - << cur_node->DebugString(); - // 2 add depend relationship - InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); - if (next_node != nullptr && next_node != kg->get_return()) { - InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); - } - // 3 recurse sub graph - const std::vector &origin_switch_inputs = cur_node->inputs(); - if (kCNodeSwitchCond >= origin_switch_inputs.size()) { - MS_LOG(EXCEPTION) << "The size of origin_switch_inputs is not more than " << kCNodeSwitchCond; - } - std::vector new_switch_inputs = { - std::make_shared(std::make_shared(kLabelSwitchOpName)), - origin_switch_inputs[kCNodeSwitchCond]}; - std::vector child_graphs; - for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { - // 3.1 branch kernel graph and args - KernelGraphPtr branch_fg; - std::vector origin_inputs; - std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); - child_graphs.push_back(branch_fg); - // 3.2 recurse sub graph - CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); - new_switch_inputs.push_back(branch_label); - AttachOriginalInputsToGraph(kg, origin_inputs); - } - std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); - - cur_node->set_inputs(new_switch_inputs); - cur_node->set_abstract(nullptr); - AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue>(child_graphs), cur_node.get()); - MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString(); -} - -void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull cur_node, - const CNodePtr &next_node, - const NotNull *> memo) { - MS_LOG(INFO) << "Process switch node " << cur_node->DebugString(); - - if (cur_node->size() < kCNodeSwitchLayerLength) { - MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; - } - - std::vector branch_partial; - for (size_t idx = kCNodeSwitchLayerBranch; idx < cur_node->inputs().size(); idx++) { - branch_partial.emplace_back(cur_node->input(idx)); - } - // 1 return label - auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); - // 2 add depend relationship - InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); - if (next_node != nullptr && next_node != kg->get_return()) { - InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); - } - // 3 recurse sub graph - const std::vector &origin_switch_inputs = cur_node->inputs(); - if (kCNodeSwitchCond >= origin_switch_inputs.size()) { - MS_LOG(EXCEPTION) << "Index out of range:" << origin_switch_inputs.size() << "."; - } - std::vector new_switch_inputs = { - std::make_shared(std::make_shared(kLabelSwitchOpName)), - origin_switch_inputs[kCNodeSwitchCond]}; - std::vector child_graphs; - for (size_t i = 0; i < branch_partial.size(); ++i) { - // 3.1 branch kernel graph and args - KernelGraphPtr branch_fg; - std::vector origin_inputs; - std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i + kCNodeSwitchLayerBranch])); - child_graphs.push_back(branch_fg); - // 3.2 recurse sub graph - CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); - new_switch_inputs.push_back(branch_label); - AttachOriginalInputsToGraph(kg, origin_inputs); - } - cur_node->set_inputs(new_switch_inputs); - cur_node->set_abstract(std::make_shared()); - // To adapt to the true and false branches of the switch, the sequence of the branches is reversed. - std::reverse(child_graphs.begin(), child_graphs.end()); - AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue>(child_graphs), cur_node.get()); - MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); -} - -std::tuple> AscendControlParser::ParsePartial(NotNull node) { - if (!node.get()->isa()) { - if (IsValueNode(node)) { - return {GetValueNode(node), {}}; - } - MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); - } - // 2.1 branch kernel graph and args - auto partial_cnode = utils::cast(node.get()); - MS_EXCEPTION_IF_NULL(partial_cnode); - if (partial_cnode->size() < kCNodePartialLength) { - MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength; - } - - const auto &partial_inputs = partial_cnode->inputs(); - if (kCNodePartialFunc >= partial_inputs.size()) { - MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << "."; - } - auto branch_kg = GetValueNode(partial_inputs[kCNodePartialFunc]); - return {branch_kg, std::vector(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end())}; -} - -void AscendControlParser::InsertMultipleAssignToGraph(NotNull from_graph, const AnfNodePtr &jump_node, - NotNull from, NotNull to) { - std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); - std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); - MS_LOG(INFO) << "Insert multi-assign from [" << from->DebugString() << "] to [" << to->DebugString() << "]"; - if (from_outputs.size() != to_outputs.size()) { - MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size[" - << to_outputs.size() << "]"; - } - for (size_t i = 0; i < from_outputs.size(); i++) { - auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); - if (assign_node == nullptr) { - continue; - } - const auto &from_graph_exe_order = from_graph->execution_order(); - if (jump_node == nullptr) { - if (!from_graph_exe_order.empty()) { - InsertControlDependToGraph(from_graph, NOT_NULL(*(from_graph_exe_order.rbegin())), NOT_NULL(assign_node)); - } else { - InsertDependToGraph(from_graph, NOT_NULL(assign_node)); - } - continue; - } - - auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node); - if (jump_node_iter == from_graph_exe_order.end()) { - MS_LOG(EXCEPTION) << "Cannot find jump node " << jump_node->DebugString() << " in graph " - << from_graph->ToString(); - } - // insert assign between jump_node -1 and jump_node - while (jump_node_iter != from_graph_exe_order.begin()) { - CNodePtr node = *(jump_node_iter - 1); - if (AnfAlgo::GetGraphId(node.get()) == from_graph->graph_id()) { - InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); - break; - } else { - jump_node_iter--; - } - } - InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); - } -} - -AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, - NotNull to) { - if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && - AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { - return nullptr; - } - if (from.get() == to.get()) { - return nullptr; - } - MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " - << to->DebugString(); - // config inputs of assign node - std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimAssign->name())), to, from}; - // generate a new cnode - auto assign_node = kg->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(assign_node); - assign_node->set_abstract(to->abstract()); - return assign_node; -} - -std::vector AscendControlParser::RecurseGraph(NotNull graph, - const NotNull *> memo) { - MS_LOG(INFO) << "Graph:" << graph->graph_id() << " start"; - if (memo->find(graph) != memo->end()) { - return {}; - } - memo->insert(graph.get()); - graph->SetExecOrderByDefault(); - std::vector cnodes = graph->execution_order(); - - auto end_label_goto = graph->get_end_goto(); - if (cnodes.rbegin() != cnodes.rend() && *cnodes.rbegin() == end_label_goto) { - cnodes.pop_back(); - } - AnfAlgo::ReorderOptimizerExecList(NOT_NULL(&cnodes)); - if (end_label_goto != nullptr) { - cnodes.push_back(end_label_goto); - } - - std::vector execution_order; - auto recurse_child_graph = [&](uint32_t index, uint32_t label_index, const CNodePtr &node) { - KernelGraphPtr cur_child_graph; - if (!CheckLabelIndex(index, label_index, node, &cur_child_graph)) { - MS_LOG(EXCEPTION) << "Check label index fail"; - } - MS_EXCEPTION_IF_NULL(cur_child_graph); - auto child_execution_order = RecurseGraph(NOT_NULL(cur_child_graph), memo); - execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); - }; - - for (auto &node : cnodes) { - uint32_t child_graph_index = 0; - execution_order.push_back(node); - if (node == graph->get_end_goto()) { - continue; - } - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { - std::vector label_switch_list = AnfAlgo::GetNodeAttr>(node, kAttrLabelSwitchList); - for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { - recurse_child_graph(child_graph_index++, *iter, node); - } - } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { - uint32_t label_index = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); - recurse_child_graph(child_graph_index, label_index, node); - } - // erase kAttrChildGraph after finish using - if (AnfAlgo::HasNodeAttr(kAttrChildGraph, node)) { - AnfAlgo::EraseNodeAttr(kAttrChildGraph, node); - } - } - graph->set_execution_order(execution_order); - return execution_order; -} - -bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label, - KernelGraphPtr *cur_child_graph) { - auto child_graphs = AnfAlgo::GetNodeAttr>(cur_label, kAttrChildGraph); - // check index and child order size - if (child_graphs.size() <= IntToSize(index)) { - MS_LOG(EXCEPTION) << "Child graph index is wrong, current node " << cur_label->ToString() << " child graph size " - << child_graphs.size() << " goto index " << index; - } - *cur_child_graph = child_graphs[index]; - MS_EXCEPTION_IF_NULL(*cur_child_graph); - - // get start_label_set_index of child graph - auto start_label_set = (*cur_child_graph)->get_start_label(); - uint32_t start_label_set_index = AnfAlgo::GetNodeAttr(start_label_set, kAttrLabelIndex); - if (label_index != start_label_set_index) { - MS_EXCEPTION_IF_NULL(cur_label); - MS_EXCEPTION_IF_NULL(start_label_set); - MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() - << " index " << start_label_set_index; - return false; - } else { - return true; - } -} - -void AscendControlParser::ReferenceCounter::AddReadCount(const AnfNodePtr &key, int64_t num) { - auto iter = count_.find(key); - if (iter != count_.end()) { - iter->second.first += num; - } else { - count_[key] = {num, 0}; - } -} - -void AscendControlParser::ReferenceCounter::AddWriteCount(const AnfNodePtr &key, int64_t num) { - auto iter = count_.find(key); - if (iter != count_.end()) { - iter->second.second += num; - } else { - count_[key] = {0, num}; - } -} - -void AscendControlParser::ReferenceCounter::EraseElem(const AnfNodePtr &key) { count_.erase(key); } - -bool AscendControlParser::ReferenceCounter::HasValidElem() const { - auto it = std::find_if(count_.begin(), count_.end(), - [this](const std::pair> &p) -> bool { - auto &[read, written] = p.second; - return predicate_(read, written); - }); - return it != count_.end(); -} - -std::tuple AscendControlParser::ReferenceCounter::GetOneValidElem() const { - auto it = std::find_if(count_.begin(), count_.end(), - [this](const std::pair> &p) -> bool { - auto &[read, written] = p.second; - return predicate_(read, written); - }); - if (it == count_.end()) { - MS_LOG(EXCEPTION) << "No valid parameter."; - } - return {it->first, it->second.first, it->second.second}; -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.h b/mindspore/ccsrc/backend/session/ascend_control_parser.h deleted file mode 100644 index edc15b4a1d..0000000000 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.h +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_CONTROL_PARSER_H -#define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_CONTROL_PARSER_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include "backend/session/kernel_graph.h" -#include "base/base_ref.h" -#include "utils/contract.h" -#include "utils/union_find_set.h" - -namespace mindspore { -namespace session { -class AscendControlParser { - public: - static void LinkGraph(NotNull kg); - - static void InsertDependToGraph(NotNull kg, NotNull attch_node); - static void InsertControlDependToGraph(NotNull kg, NotNull first_node, - NotNull second_node); - static void ExecutorValidate(NotNull root_graph); - static void InsertMultipleAssignToGraph(NotNull from_graph, const AnfNodePtr &jump_node, - NotNull from, NotNull to); - - private: - class ReferenceCounter; - - static void EraseParameter(NotNull root_graph, const std::set &graph_list); - static void EraseAssign(std::shared_ptr parameter_count, const std::set &all_nodes, - const std::map ¶_to_written_node, - NotNull root_graph, const std::set &graph_list); - static void EraseLabel(NotNull root_graph); - static void ChildGraphDataAssign(NotNull kg, - const NotNull> *> link_list, - const NotNull *> memo); - static NotNull GetStartLabel(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label); - static NotNull ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label, - const NotNull *> memo); - static void RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - const NotNull *> memo); - static void RecurseSwitch(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - const NotNull *> memo); - static void RecurseSwitchLayer(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - const NotNull *> memo); - - static void LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, - const CNodePtr &last_label); - - static AnfNodePtr InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); - static std::vector>> ParseCallSwitchNode( - NotNull call_node); - static std::tuple> ParsePartial(NotNull node); - static void AttachChildGraphToReturnNode(NotNull graph, - const NotNull *> memo); - // root graph order - static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode, - KernelGraphPtr *cur_child_graph); - static std::vector RecurseGraph(NotNull graph, - const NotNull *> memo); - static void AttachOriginalInputsToGraph(NotNull graph, const std::vector orig_inputs); -}; -class AscendControlParser::ReferenceCounter { - public: - explicit ReferenceCounter(std::function func) : predicate_(func), count_() {} - ~ReferenceCounter() = default; - void AddReadCount(const AnfNodePtr &key, int64_t num); - void AddWriteCount(const AnfNodePtr &key, int64_t num); - void EraseElem(const AnfNodePtr &key); - bool HasValidElem() const; - std::tuple GetOneValidElem() const; - - private: - std::function predicate_; - std::map> count_; -}; -} // namespace session -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_CONTROL_PARSER_H diff --git a/mindspore/ccsrc/backend/session/ascend_inference_session.h b/mindspore/ccsrc/backend/session/ascend_inference_session.h index 671d2e09c7..10cbbc74c6 100644 --- a/mindspore/ccsrc/backend/session/ascend_inference_session.h +++ b/mindspore/ccsrc/backend/session/ascend_inference_session.h @@ -28,7 +28,6 @@ #include "backend/session/kernel_graph.h" #include "backend/kernel_compiler/kernel.h" #include "backend/session/session_factory.h" -#include "backend/session/ascend_control_parser.h" namespace mindspore { namespace session { diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 6b62316c8e..26a052b48c 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -132,17 +132,6 @@ void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool } } -std::vector GetCNodes(const std::vector &anf_nodes) { - std::vector cnodes = {}; - for (const auto &anf : anf_nodes) { - MS_EXCEPTION_IF_NULL(anf); - if (anf->isa()) { - cnodes.push_back(anf->cast()); - } - } - return cnodes; -} - TensorPtr GetCNodeOutputStubTensor(const KernelWithIndex &kernel_with_index, const std::map &node_output_info, bool *output_is_weight) { @@ -408,10 +397,9 @@ void AscendSession::LoadInputData(const std::shared_ptr &kernel_gra input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); } auto &input_nodes = kernel_graph->input_nodes(); - auto extra_param_size = kernel_graph->GetExtraParamAndTensor().size(); - if ((inputs.size() + input_ctrl_size) - 3 != input_nodes.size() - extra_param_size) { + if ((inputs.size() + input_ctrl_size) - 3 != input_nodes.size()) { MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() - << ", input_ctrl_size:" << input_ctrl_size << ", extra_param_size:" << extra_param_size; + << ", input_ctrl_size:" << input_ctrl_size; } auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); @@ -663,10 +651,6 @@ bool AscendSession::IsSupportSummary() { return !device::KernelAdjust::NeedInser void AscendSession::PreExecuteGraph(const std::shared_ptr &kernel_graph, const std::vector &inputs, VectorRef *const outputs) { - // load data to extra params - std::set memo; - SyncDataToExtraParams(NOT_NULL(kernel_graph), NOT_NULL(&memo)); - memo.clear(); if (debugger_) { debugger_->PreExecute(kernel_graph, graph_sum_); } @@ -1419,168 +1403,11 @@ void AscendSession::BackendOptimization(const std::vector &all_g MS_LOG(INFO) << "End."; } -void AscendSession::LinkChildGraphs(NotNull graph) { AscendControlParser::LinkGraph(graph); } - -bool AscendSession::IsMultiCallGraph(NotNull graph, std::vector parent_graphs) { - std::stack post_graph; - std::set memo; - post_graph.push(graph->graph_id()); - while (!post_graph.empty()) { - auto graph_id = post_graph.top(); - post_graph.pop(); - memo.insert(graph_id); - for (auto child_graph : graphs_[graph_id]->child_graph_order()) { - std::shared_ptr child_graph_ptr = child_graph.lock(); - MS_EXCEPTION_IF_NULL(child_graph_ptr); - if (std::find(parent_graphs.begin(), parent_graphs.end(), child_graph_ptr->graph_id()) != parent_graphs.end()) { - MS_LOG(DEBUG) << "graph:" << graph->graph_id() << " will call its parent graph:" << child_graph_ptr->graph_id(); - return false; - } else if (memo.find(child_graph_ptr->graph_id()) == memo.end()) { - MS_LOG(DEBUG) << "child graph:" << child_graph_ptr->graph_id() << " into deque, wait for check."; - post_graph.push(child_graph_ptr->graph_id()); - } - } - } - return true; -} - -void AscendSession::MultiCallGraphOptimize(NotNull root_graph) { - for (auto current : parent_graphs_) { - if (current.second.size() < 2) { - continue; - } - auto graph = graphs_[current.first]; - auto parent_kernel_graphs = current.second; - if (!IsMultiCallGraph(NOT_NULL(graph), parent_kernel_graphs)) { - MS_LOG(DEBUG) << "graph:" << graph->graph_id() << " with it's parent graphs make up a cycle"; - continue; - } - MS_LOG(INFO) << "graph: " << graph->graph_id() << " has been called by more than two graphs"; - int32_t index = 0; - std::vector child_graphs; - auto start_label_id = AnfAlgo::GetNodeAttr(graph->get_start_label(), kAttrLabelIndex); - auto end_node = graph->get_end_goto(); - ParameterPtr post_label_param = graph->AddExtraParamAndTensor("label_param", 0); - std::vector new_inputs = {std::make_shared(std::make_shared(kLabelSwitchOpName)), - post_label_param}; - for (auto graph_id : parent_kernel_graphs) { - auto kg = graphs_[graph_id]; - auto nodes = kg->execution_order(); - for (uint32_t i = 0; i < nodes.size(); i++) { - if (AnfAlgo::IsLabelIndexInNode(nodes[i], start_label_id)) { - if (i < (nodes.size() - 1)) { - new_inputs.push_back(nodes[i + 1]); - } else { - MS_LOG(EXCEPTION) << "No labelset after labelgoto"; - } - ParameterPtr pre_label_param = kg->AddExtraParamAndTensor("label_param", index++); - AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(kg), nodes[i], NOT_NULL(pre_label_param), - NOT_NULL(post_label_param)); - } - } - kg->SetExecOrderByDefault(); - child_graphs.push_back(kg); - } - end_node->set_inputs(new_inputs); - AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue>(child_graphs), end_node); - std::vector label_list; - for (size_t i = kLabelSwitchLabelId; i < end_node->size(); ++i) { - auto input = end_node->input(i); - MS_EXCEPTION_IF_NULL(input); - if (!input->isa() || AnfAlgo::GetCNodeName(input) != kLabelSetOpName) { - break; - } - uint32_t goto_label_id = AnfAlgo::GetNodeAttr(input, kAttrLabelIndex); - label_list.push_back(goto_label_id); - MS_LOG(INFO) << "Switch " << end_node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " - << goto_label_id; - } - AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue>(label_list), end_node); - end_node->set_inputs({end_node->input(kAnfPrimitiveIndex), end_node->input(kFirstDataInputIndex)}); - graph->SetExecOrderByDefault(); - } -} - -void AscendSession::SyncDataToExtraParams(NotNull graph, NotNull *> memo) { - if (memo->find(graph.get()) != memo->end()) { - return; - } - memo->insert(graph.get()); - auto extra_param_tensor = graph->GetExtraParamAndTensor(); - for (uint32_t i = 0; i < extra_param_tensor.size(); i++) { - auto param = extra_param_tensor[i].first; - auto tensor = extra_param_tensor[i].second; - auto device_address = AnfAlgo::GetMutableOutputAddr(param, 0); - MS_EXCEPTION_IF_NULL(device_address); - tensor->set_device_address(device_address); - if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(param, 0), LongToSize(tensor->data().nbytes()), - tensor->data_type(), tensor->data_c())) { - MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; - } - } - for (auto &child_graph : graph->child_graph_order()) { - SyncDataToExtraParams(NOT_NULL(child_graph.lock()), memo); - } -} - void AscendSession::RootGraphExecutorValidate(NotNull graph) { AscendAutoMonad auto_monad(graph); auto_monad.GenerateExecuteOrder(); } -void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNull *> memo) { - if (memo->find(graph.get()) != memo->end()) { - return; - } - memo->insert(graph.get()); - graph->UpdateChildGraphOrder(); - for (auto &child_graph : graph->child_graph_order()) { - CreateMultiBranchOutput(NOT_NULL(child_graph.lock()), memo); - } - std::map need_replace_list; - auto node_list = GetCNodes(TopoSort(graph->get_return())); - for (auto &node : node_list) { - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) { - // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output - auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); - MS_EXCEPTION_IF_NULL(graph->MutableInputs()); - graph->AddChildGraphResult(output_param); - - std::vector depend_inputs = { - graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimDepend->name()))), output_param, node}; - auto depend = graph->NewCNode(depend_inputs); - depend->set_abstract(output_param->abstract()); - need_replace_list.emplace(node, depend); - MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString() - << ", depend node is " << depend->DebugString(); - // insert assign in order to transfer child graph output to parameter - auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node); - for (auto &child_graph : child_graphs) { - MS_EXCEPTION_IF_NULL(child_graph); - // If graph has no output, the graph is the true graph of while and will call condition graph, no need insert - // assign from condition to true graph - if (memo->find(child_graph) != memo->end()) { - continue; - } - AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr, - NOT_NULL(child_graph->output()), NOT_NULL(output_param)); - } - } - } - // searching for nodes' input to replace call by depend(parameter, call) - for (auto &node : node_list) { - for (size_t i = 0; i < node->size(); ++i) { - auto input = node->input(i); - auto iter = need_replace_list.find(input); - if (iter != need_replace_list.end()) { - node->set_input(i, iter->second); - } - } - } - memo->erase(graph.get()); -} - void AscendSession::IrFusionPass(const NotNull graph, NotNull *> memo) { if (memo->find(graph) != memo->end()) { return; diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h index 0d23e58fe3..08e65898f6 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.h +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -30,7 +30,6 @@ #include "backend/session/kernel_graph.h" #include "backend/kernel_compiler/kernel.h" #include "backend/session/session_factory.h" -#include "backend/session/ascend_control_parser.h" namespace mindspore { namespace session { @@ -95,11 +94,6 @@ class AscendSession : public SessionBasic { void RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) const; static void BackendOptimization(const std::vector &all_graphs); - static void LinkChildGraphs(NotNull graph); - // replace labelgoto with labelswitch in subgraph called multiple times - void MultiCallGraphOptimize(NotNull root_graph); - bool IsMultiCallGraph(NotNull graph, std::vector parent_graphs); - void SyncDataToExtraParams(NotNull graph, NotNull *> memo); void RootGraphExecutorValidate(NotNull graph); // merge execution order list of child graphs void MergeGraphExecOrder(); diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 30612d7fbe..3a7af7036d 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -1186,33 +1186,6 @@ void KernelGraph::RemoveNodeFromGraph(const AnfNodePtr &node) { } } -ParameterPtr KernelGraph::AddExtraParamAndTensor(std::string param_name, int32_t value) { - ParameterPtr param; - ShapeVector shp = {1}; - tensor::TensorPtr tensor_ptr = std::make_shared(kInt32->type_id(), shp); - MS_EXCEPTION_IF_NULL(tensor_ptr); - mindspore::abstract::AbstractBasePtr paremeter_abstract_ptr = tensor_ptr->ToAbstract(); - ParameterPtr new_param = std::make_shared(shared_from_this()->cast()); - MS_EXCEPTION_IF_NULL(new_param); - new_param->set_name(param_name); - new_param->set_abstract(paremeter_abstract_ptr); - param = NewParameter(new_param); - // ensure alloc mem for this param - std::vector *mute_inputs = MutableInputs(); - MS_EXCEPTION_IF_NULL(mute_inputs); - mute_inputs->push_back(param); - - tensor::TensorPtr data_tensor_ptr = std::make_shared(kInt32->type_id(), shp); - MS_EXCEPTION_IF_NULL(data_tensor_ptr); - int32_t *val = nullptr; - val = static_cast(data_tensor_ptr->data_c()); - *val = value; - - extra_param_tensor_.push_back(std::make_pair(param, data_tensor_ptr)); - MS_LOG(INFO) << "Create new param: " << param->DebugString(); - return param; -} - void KernelGraph::UpdateGraphDynamicAttr() { for (const auto &cnode : execution_order_) { if (AnfAlgo::IsDynamicShape(cnode)) { diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index b299758208..294969af32 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -45,7 +45,6 @@ class KernelGraph : public FuncGraph { executable_ = true; summary_node_exist_ = false; stream_distinction_label_ = kInvalidDistincLabel; - extra_param_tensor_ = {}; } KernelGraph(const KernelGraph &graph) : FuncGraph(graph) { @@ -90,7 +89,6 @@ class KernelGraph : public FuncGraph { first_step_ = graph.first_step_; has_optimizer_ = graph.has_optimizer_; is_dynamic_shape_ = graph.is_dynamic_shape_; - extra_param_tensor_ = graph.extra_param_tensor_; } ~KernelGraph() override; @@ -230,9 +228,6 @@ class KernelGraph : public FuncGraph { } } void RemoveNodeFromGraph(const AnfNodePtr &node); - // Add Param which pass callback point - ParameterPtr AddExtraParamAndTensor(std::string param_name, int32_t value); - const std::vector> GetExtraParamAndTensor() { return extra_param_tensor_; } void UpdateGraphDynamicAttr(); bool is_dynamic_shape() const { return is_dynamic_shape_; } void SetOptimizerFlag(); @@ -324,8 +319,6 @@ class KernelGraph : public FuncGraph { std::vector child_graph_result_; std::vector execution_order_; std::vector mem_reuse_exec_order_; - // extra params and tensors for control flow - std::vector> extra_param_tensor_; uint32_t graph_id_; uint32_t stream_distinction_label_; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 0ec4ae7d54..9dd63fa7e7 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1448,12 +1448,6 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP (void)ConstructKernelGraph(child_graph, all_out_graph); } (void)CreateValueNodeKernelGraph(node, graph.get()); - auto &parent_graph = parent_graphs_[front_backend_graph_map_[child_graph.get()]->graph_id()]; - auto parent_graph_it = - std::find(parent_graph.begin(), parent_graph.end(), front_backend_graph_map_[func_graph.get()]->graph_id()); - if (parent_graph_it == parent_graph.end()) { - parent_graph.push_back(front_backend_graph_map_[func_graph.get()]->graph_id()); - } continue; } // Create cnode diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 4ddbd5df21..dfbfe89f88 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -255,7 +255,6 @@ class SessionBasic : public std::enable_shared_from_this { std::unordered_map> graphs_; std::unordered_map> run_op_graphs_; std::unordered_map front_backend_graph_map_; - std::unordered_map> parent_graphs_; std::shared_ptr context_; CallBackFunc summary_callback_; static GraphId graph_sum_;