/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include "session/ascend_control_parser.h" #include "session/anf_runtime_algorithm.h" namespace mindspore { namespace session { static VectorRef GetCallArgs(std::vector::iterator iter_begin, std::vector::iterator iter_end) { VectorRef call_args; for (auto iter = iter_begin; iter != iter_end; ++iter) { if (utils::isa(*iter)) { call_args.push_back(GetValueNode(*iter)); } else { call_args.push_back(*iter); } } return call_args; } void AscendControlParser::LinkGraph(NotNull kg) { std::set memo; ProcessKernelGraph(kg, nullptr, nullptr, {}, NOT_NULL(&memo)); } NotNull AscendControlParser::ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, const CNodePtr &last_label, const VectorRef &args, NotNull *> memo) { MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); // 0. recursive condition if (memo->find(kg) != memo->end()) { MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString(); return NOT_NULL(kg->get_start_label()); } // 2. args replace placeholder LinkParentGraph(kg, last_node, last_label, args); // 3. topological sort std::vector nodes = GetCNodes(TopoSort(kg->get_return())); if (nodes.empty()) { MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!"; } // 4. insert first_label auto start_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); for (auto node : nodes) { if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { InsertControlDependToGraph(kg, NOT_NULL(start_label), NOT_NULL(node)); break; } } kg->set_start_label(start_label); // 5. traverse for (size_t i = 0; i < nodes.size(); ++i) { auto &cnode = nodes[i]; if (cnode->size() < kCNodePrim + 1) { MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; } AnfNodePtr fn = cnode->input(kCNodePrim); if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) { MS_LOG(DEBUG) << "continue node " << cnode->DebugString(); continue; } AnfNodePtr arg = cnode->input(kCNodeCallArg); if (IsValueNode(arg)) { RecurseCall(kg, NOT_NULL(cnode), (i + 1 < nodes.size() ? nodes[i + 1] : nullptr), memo); } else if (!arg->isa()) { MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString(); } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitch)) { auto arg_cnode = arg->cast(); cnode->set_inputs(cnode->inputs()); RecurseSwitch(kg, NOT_NULL(cnode), memo); } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitchLayer)) { auto arg_cnode = arg->cast(); cnode->set_inputs(cnode->inputs()); RecurseSwitchLayer(kg, NOT_NULL(cnode), memo); } } MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString(); return NOT_NULL(start_label); } std::vector AscendControlParser::GetCNodes(const std::vector &in) { std::vector out; for (auto &node : in) { if (node->isa()) { out.push_back(node->cast()); } } return out; } void AscendControlParser::InsertDependToGraph(NotNull kg, NotNull attch_node) { std::vector inputs = {NewValueNode(std::make_shared("depend"))}; auto return_node = kg->get_return(); MS_EXCEPTION_IF_NULL(return_node); inputs.push_back(return_node->input(1)); inputs.push_back(attch_node.get()); auto depend_node = kg->NewCNode(inputs); return_node->set_input(1, depend_node); } void AscendControlParser::InsertControlDependToGraph(NotNull kg, NotNull first_node, NotNull second_node) { MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() << ", the second node is " << second_node->DebugString(); std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimControlDepend->name())), first_node, second_node}; auto control_depend = kg->NewCNode(inputs); InsertDependToGraph(kg, NOT_NULL(control_depend)); } void AscendControlParser::LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, const CNodePtr &last_label, const VectorRef &args) { if (from_graph_call_node != nullptr) { SetSubGraphInput(kg, NOT_NULL(from_graph_call_node), args); } auto origin_return = kg->get_return(); std::vector origin_return_inputs = origin_return->inputs(); // if entry graph, replace return with make_tuple if (from_graph_call_node == nullptr || last_label == nullptr) { MS_LOG(INFO) << kg->ToString() << " is entry graph."; std::vector make_tuple_inputs = {std::make_shared(prim::kPrimMakeTuple)}; make_tuple_inputs.insert(make_tuple_inputs.end(), origin_return_inputs.begin() + 1, origin_return_inputs.end()); auto make_tuple = kg->NewCNode(make_tuple_inputs); origin_return->set_inputs({origin_return->input(kCNodePrim), make_tuple}); } else { // else replace return with label_goto auto label_goto = kg->NewCNode({std::make_shared(std::make_shared(kLabelGotoOpName)), last_label}); InsertDependToGraph(kg, NOT_NULL(label_goto)); } } void AscendControlParser::RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, NotNull *> memo) { MS_LOG(INFO) << "process call func " << cur_node->DebugString(); // 1 get kernel graph auto origin_inputs = cur_node->inputs(); std::vector new_inputs = {std::make_shared(std::make_shared(kLabelGotoOpName))}; auto call_args = GetCallArgs(origin_inputs.begin() + 1, origin_inputs.end()); if (!IsValueNode(origin_inputs[kCNodeCallArg])) { MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; return; } // 2 return label auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); // 3 add depend relationship InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); if (next_node != nullptr && next_node != kg->get_return()) { InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); } auto call_kg = GetValueNode(origin_inputs[kCNodeCallArg]); // 4 modify call op to goto op cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]); // 5 recurse sub graph CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, call_args, memo); new_inputs.push_back(sub_label); new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end()); cur_node->set_inputs(new_inputs); cur_node->set_abstract(nullptr); MS_LOG(INFO) << "success process call func " << cur_node->DebugString(); } void AscendControlParser::RecurseSwitch(NotNull kg, NotNull cur_node, NotNull *> memo) { MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); if (cur_node->size() < kCNodeSwitchLength) { MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength; } // 1 return label auto back_label = kg->NewCNode({std::make_shared(prim::kPrimLabelSet)}); // 2 recurse sub graph auto origin_switch_inputs = cur_node->inputs(); std::vector new_switch_inputs = { std::make_shared(std::make_shared(kLabelSwitchOpName)), origin_switch_inputs[kCNodeSwitchCond]}; for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { // 2.1 branch kernel graph and args CNodePtr partial; KernelGraphPtr branch_fg; VectorRef call_args; std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); // 2.2 add depend relationship InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); // 2.3 recurse sub graph CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo); new_switch_inputs.push_back(branch_label); } std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end()); cur_node->set_inputs(new_switch_inputs); cur_node->set_abstract(nullptr); MS_LOG(INFO) << "success process switch func " << cur_node->DebugString(); } void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull cur_node, NotNull *> memo) { MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); if (cur_node->size() < kCNodeSwitchLayerLength) { MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; } auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch); MS_EXCEPTION_IF_NULL(branch_tuple); if (!branch_tuple->isa()) { MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; } auto branch_partial = utils::cast(branch_tuple)->inputs(); // 1 return label auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSwitchOpName))}); // 2 recurse sub graph auto origin_switch_inputs = cur_node->inputs(); std::vector new_switch_inputs = {std::make_shared(prim::kPrimLabelSwitch), origin_switch_inputs[kCNodeSwitchCond]}; for (size_t i = 0; i < branch_partial.size(); ++i) { // 2.1 branch kernel graph and args CNodePtr partial; KernelGraphPtr branch_fg; VectorRef call_args; std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); // 2.2 add depend relationship InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); // 2.3 recurse sub graph CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo); new_switch_inputs.push_back(branch_label); } new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); cur_node->set_inputs(new_switch_inputs); cur_node->set_abstract(nullptr); MS_LOG(INFO) << "success process switch layer " << cur_node->DebugString(); } std::tuple AscendControlParser::ParsePartial(NotNull node) { if (!node.get()->isa()) { MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); } // 2.1 branch kernel graph and args auto partial_cnode = utils::cast(node.get()); if (partial_cnode->size() < kCNodePartialLength) { MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength; } auto partial_inputs = partial_cnode->inputs(); auto branch_kg = GetValueNode(partial_inputs[kCNodePartialFunc]); auto call_args = GetCallArgs(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end()); return {partial_cnode, branch_kg, call_args}; } void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, NotNull to) { if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { return; } if (from.get() == to.get()) { return; } MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " << to->DebugString(); // config inputs of assign node std::vector inputs = {NewValueNode(std::make_shared("Assign")), to, from}; // generate a new cnode auto assign_node = kg->NewCNode(inputs); MS_EXCEPTION_IF_NULL(assign_node); assign_node->set_abstract(to->abstract()); // append the assign at the end of from graph InsertDependToGraph(kg, NOT_NULL(assign_node)); } size_t AscendControlParser::SetChildGraphInput(NotNull kg, NotNull node, size_t input_index) { auto output_num = AnfAlgo::GetOutputTensorNum(node); if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { return input_index + output_num; } auto &graph_inputs = kg->inputs(); if (input_index >= graph_inputs.size()) { MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); } auto backend_parameter = graph_inputs[input_index]; if (node.get()->isa()) { MS_EXCEPTION_IF_NULL(backend_parameter); MS_LOG(INFO) << "Reuse node [" << node->DebugString() << "], old node[" << backend_parameter->DebugString() << "] will be replaced."; kg->ReplaceNode(backend_parameter, node); return input_index; } InsertAssignToGraph(kg, node, NOT_NULL(backend_parameter)); return input_index + 1; } void AscendControlParser::SetSubGraphInput(NotNull kg, NotNull from_graph_call_node, const VectorRef &args) {} } // namespace session } // namespace mindspore