/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "backend/session/kernel_graph.h" #include #include #include #include #include "frontend/operator/ops.h" #include "ir/param_value.h" #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_info.h" #include "backend/kernel_compiler/kernel_build_info.h" #include "runtime/device/kernel_runtime_manager.h" #include "backend/kernel_compiler/common_utils.h" namespace mindspore { namespace session { namespace { constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; void PushNoVisitedNode(const AnfNodePtr &node, std::queue *que, std::unordered_set *visited_nodes) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(que); MS_EXCEPTION_IF_NULL(visited_nodes); if (visited_nodes->find(node) == visited_nodes->end()) { que->push(node); (void)visited_nodes->insert(node); MS_LOG(DEBUG) << "Push que:" << node->DebugString(); } } std::vector GetCallRealOutputs(const AnfNodePtr &call_node) { auto item_with_index = AnfAlgo::VisitKernelWithReturnType(call_node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple}); AnfNodePtr node = item_with_index.first; MS_EXCEPTION_IF_NULL(node); if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { auto outputs = AnfAlgo::GetAllOutput(node); std::set memo; std::vector new_output; for (auto &output : outputs) { if (memo.find(output) != memo.end()) { continue; } memo.insert(output); new_output.push_back(output); } if (new_output.size() == 1 && AnfAlgo::CheckPrimitiveType(new_output[0], prim::kPrimCall)) { node = new_output[0]; } } if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { return {node}; } std::vector real_inputs; auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node->cast()); for (const auto &child_graph : child_graphs) { if (child_graph->get_output_null()) { continue; } auto real_input = child_graph->output(); auto child_real_inputs = GetCallRealOutputs(real_input); std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs)); } return real_inputs; } bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) { if (left == right) { return true; } if (left == nullptr || right == nullptr) { return false; } if (!IsPrimitiveCNode(left, GetCNodePrimitive(right))) { return false; } if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) { return AnfAlgo::GetNodeAttr(left, kAttrLabelIndex) == AnfAlgo::GetNodeAttr(right, kAttrLabelIndex); } return false; } } // namespace AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) { auto value_node = node->cast(); if (value_node == nullptr) { return nullptr; } ValueNodePtr new_value_node = std::make_shared(value_node->value()); new_value_node->set_abstract(value_node->abstract()); this->SetKernelInfoForNode(new_value_node); return new_value_node; } std::vector KernelGraph::outputs() const { auto graph_output = output(); if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) { auto make_tuple = output()->cast(); MS_EXCEPTION_IF_NULL(make_tuple); auto &inputs = make_tuple->inputs(); return std::vector(inputs.begin() + 1, inputs.end()); } return std::vector(1, graph_output); } void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, std::unordered_set *visited_nodes) { MS_EXCEPTION_IF_NULL(visit_queue); MS_EXCEPTION_IF_NULL(visited_nodes); auto it = node_output_edges_.find(node); if (it == node_output_edges_.end()) { // value node and parameter has no input,no need to print log if (node->isa()) { MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]"; } return; } // visit all reduce node first, then other nodes std::vector active_nodes; for (const auto &output_edge : it->second) { auto next_node = output_edge.first; MS_EXCEPTION_IF_NULL(next_node); if (node_input_num_.find(next_node) == node_input_num_.end()) { MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]"; } MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString() << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second; if (node_input_num_[next_node] < output_edge.second) { MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node] << ",depend edge:" << output_edge.second; } node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second; // allreduce first if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) { (void)visited_nodes->insert(next_node); if (AnfAlgo::IsCommunicationOp(next_node)) { MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString(); visit_queue->push(next_node); } else { active_nodes.emplace_back(next_node); } } } for (auto &node : active_nodes) { MS_EXCEPTION_IF_NULL(node); MS_LOG(DEBUG) << "Visit node:" << node->DebugString(); visit_queue->push(node); } } void KernelGraph::SetExecOrderByDefault() { std::queue seed_nodes; UpdateNodeEdgeList(&seed_nodes); execution_order_.clear(); std::unordered_set visited_nodes; std::queue zero_input_nodes; AnfNodePtr last_communication_node = nullptr; std::queue communication_descendants; while (!seed_nodes.empty() || last_communication_node != nullptr) { // seed nodes first, then visit last all reduce node descendant if (seed_nodes.empty()) { VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); last_communication_node = nullptr; } else { zero_input_nodes.push(seed_nodes.front()); seed_nodes.pop(); } // all reduce node descendant first, then common queue while (!zero_input_nodes.empty() || !communication_descendants.empty()) { AnfNodePtr node = nullptr; bool is_communication_descendant = false; if (communication_descendants.empty()) { node = zero_input_nodes.front(); zero_input_nodes.pop(); } else { node = communication_descendants.front(); communication_descendants.pop(); is_communication_descendant = true; } // add execute node MS_EXCEPTION_IF_NULL(node); if (node->isa() && AnfAlgo::IsRealKernel(node)) { execution_order_.push_back(node->cast()); } // for all reduce node, visit last all reduce node descendant if (AnfAlgo::IsCommunicationOp(node)) { if (last_communication_node != nullptr) { VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); } last_communication_node = node; } else if (is_communication_descendant) { VisitNodeDescendants(node, &communication_descendants, &visited_nodes); } else { VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes); } } } CheckLoop(); // resort start label / end goto std::vector re_order; if (start_label_ != nullptr) { re_order.push_back(start_label_); } for (auto &node : execution_order_) { if (node == start_label_ || node == end_goto_) { continue; } if (IsSameLabel(node, end_goto_)) { end_goto_ = node; MS_LOG(INFO) << "Replace end_goto_ in kernel graph:" << graph_id(); continue; } if (IsSameLabel(node, start_label_)) { start_label_ = node; MS_LOG(INFO) << "Replace start_label_ in kernel graph:" << graph_id(); continue; } re_order.push_back(node); } if (end_goto_ != nullptr) { re_order.push_back(end_goto_); } execution_order_ = re_order; } void KernelGraph::CheckLoop() { std::map none_zero_nodes; if (node_input_edges_.size() != node_input_num_.size()) { MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size() << "not equal to node_input_num_ size:" << node_input_num_.size(); } for (auto &it : node_input_num_) { MS_EXCEPTION_IF_NULL(it.first); string str; auto node_input_it = node_input_edges_.find(it.first); if (node_input_it == node_input_edges_.end()) { MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]"; } for (const auto &input_edge : node_input_edges_[it.first]) { MS_EXCEPTION_IF_NULL(input_edge.first); str = str.append(input_edge.first->DebugString()).append("|"); } if (it.second != 0) { MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second; none_zero_nodes[it.first] = it.second; } } // if don't consider control depend and loop exit,a exception will be throw if (!none_zero_nodes.empty()) { MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); } } CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { auto cnode = FuncGraph::NewCNode(inputs); MS_EXCEPTION_IF_NULL(cnode); cnode->set_abstract(std::make_shared()); CreateKernelInfoFromNewParameter(cnode); if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) { AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); } SetKernelInfoForNode(cnode); AnfAlgo::SetGraphId(graph_id_, cnode.get()); return cnode; } void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) { if (!AnfAlgo::IsGraphKernel(cnode)) { return; } auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); MS_EXCEPTION_IF_NULL(func_graph); std::vector node_list; std::vector input_list; std::vector output_list; kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); for (auto &anf_node : node_list) { MS_EXCEPTION_IF_NULL(anf_node); auto kernel_info = std::make_shared(); anf_node->set_kernel_info(kernel_info); auto anf_cnode = anf_node->cast(); MS_EXCEPTION_IF_NULL(anf_cnode); for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_cnode); ++i) { auto input_node = anf_cnode->input(i + 1); MS_EXCEPTION_IF_NULL(input_node); if (IsValueNode(input_node)) { auto new_input_node = MakeValueNode(input_node); if (new_input_node != nullptr) { anf_cnode->set_input(i + 1, new_input_node); } } } } for (auto &anf_node : input_list) { MS_EXCEPTION_IF_NULL(anf_node); auto kernel_info = std::make_shared(); anf_node->set_kernel_info(kernel_info); } } void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { MS_EXCEPTION_IF_NULL(node); auto kernel_info = std::make_shared(); node->set_kernel_info(kernel_info); if (node->isa()) { std::vector feature_map_input_indexs; kernel_info->SetFeatureMapFlag(false); for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) { if (AnfAlgo::IsFeatureMapInput(node, index)) { kernel_info->SetFeatureMapFlag(true); feature_map_input_indexs.push_back(index); } } if (AnfAlgo::GetInputTensorNum(node) == 0) { kernel_info->SetFeatureMapFlag(true); } if (AnfAlgo::IsRealKernel(node)) { // if the node only has the primitive(such as getNext) or the node's input has a feature map input // then the node's output is a feature map output AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), node); AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), node); } return; } auto kernel_build_info_builder = std::make_shared(); // set the format of value_node to DEFAULT_FORMAT std::vector types; kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); if (node->isa()) { kernel_info->SetFeatureMapFlag(false); types.emplace_back(kTypeUnknown); } if (node->isa()) { auto parameter = node->cast(); MS_EXCEPTION_IF_NULL(parameter); bool is_weight = AnfAlgo ::IsParameterWeight(parameter); kernel_info->SetFeatureMapFlag(!is_weight); types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0)); } // set parameter initaial device data type kernel_build_info_builder->SetOutputsDeviceType(types); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get()); } CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); auto new_cnode = std::make_shared(*cnode); // if a cnode is created not from front,this cnode won't be in map,so when replace it,we shouldn't update map if (BackendNodeExistInFrontBackendMap(cnode)) { FrontBackendlMapUpdate(cnode, new_cnode); } AnfAlgo::SetGraphId(graph_id_, cnode.get()); if (IsInternalOutput(cnode)) { ReplaceInternalOutput(cnode, new_cnode); } return new_cnode; } ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { auto abstract = parameter == nullptr ? std::make_shared() : parameter->abstract(); auto new_parameter = NewParameter(abstract); MS_EXCEPTION_IF_NULL(new_parameter); // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter if (parameter != nullptr) { new_parameter->set_name(parameter->name()); if (AnfAlgo::IsParameterWeight(parameter)) { new_parameter->set_default_param(parameter->default_param()); } } // create kernel_info form new parameter SetKernelInfoForNode(new_parameter); AnfAlgo::SetGraphId(graph_id_, new_parameter.get()); return new_parameter; } ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract) { ParameterPtr new_parameter = add_parameter(); new_parameter->set_abstract(abstract); MS_EXCEPTION_IF_NULL(new_parameter); // create kernel_info form new parameter SetKernelInfoForNode(new_parameter); AnfAlgo::SetGraphId(graph_id_, new_parameter.get()); return new_parameter; } std::vector KernelGraph::SplitTupleParameterToNodeList(const ParameterPtr ¶meter) { MS_EXCEPTION_IF_NULL(parameter); std::vector convert_nodes_list; auto abstract = parameter->abstract(); MS_EXCEPTION_IF_NULL(abstract); if (!abstract->isa()) { MS_LOG(EXCEPTION) << "Multiple output Parameter's output must be a tuple abstract but got " << abstract->ToString(); } auto tuple_abstract = abstract->cast(); MS_EXCEPTION_IF_NULL(tuple_abstract); for (size_t index = 0; index < tuple_abstract->size(); ++index) { auto new_parameter = this->NewParameter((*tuple_abstract)[index]); SetKernelInfoForNode(new_parameter); convert_nodes_list.emplace_back(new_parameter); } auto new_inputs = std::make_shared>(); auto old_inputs = inputs(); for (const auto &input_node : old_inputs) { if (input_node != parameter) { new_inputs->emplace_back(input_node); continue; } std::copy(convert_nodes_list.begin(), convert_nodes_list.end(), std::back_inserter(*new_inputs)); } inputs_ = new_inputs; return convert_nodes_list; } std::vector KernelGraph::SplitTupleOutputNodeToNodeList(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << node->DebugString(); } if (node->isa()) { return SplitTupleParameterToNodeList(node->cast()); } return SplitTupleValueNodeToNodeList(node->cast()); } std::vector KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) { MS_EXCEPTION_IF_NULL(value_node); auto node_value = value_node->value(); std::vector convert_inputs; if (!node_value->isa()) { MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString(); } auto value_tuple = node_value->cast(); MS_EXCEPTION_IF_NULL(value_tuple); auto abstract = value_node->abstract(); if (!abstract->isa()) { MS_LOG(EXCEPTION) << "Spilted node's output abstract is not type tuple"; } auto tuple_abstract = abstract->cast(); MS_EXCEPTION_IF_NULL(tuple_abstract); if (tuple_abstract->size() != value_tuple->size()) { MS_LOG(EXCEPTION) << "The node output index [" << value_tuple->size() << "]is outof range " << tuple_abstract->size(); } for (size_t index = 0; index < value_tuple->value().size(); ++index) { auto new_value_node = std::make_shared(value_tuple->value()[index]); new_value_node->set_abstract((*tuple_abstract)[index]); AddValueNodeToGraph(new_value_node); SetKernelInfoForNode(new_value_node); AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); convert_inputs.emplace_back(new_value_node); } if (!RemoveValueNodeFromGraph(value_node)) { MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString(); } return convert_inputs; } ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { MS_EXCEPTION_IF_NULL(value_node); auto new_value_node = MakeValueNode(value_node)->cast(); AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); return new_value_node; } const std::vector &KernelGraph::inputs() const { MS_EXCEPTION_IF_NULL(inputs_); return *inputs_; } void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf) { MS_EXCEPTION_IF_NULL(front_anf); MS_EXCEPTION_IF_NULL(backend_anf); if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) { MS_LOG(EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_"; } if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) { MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_"; } front_backend_anf_map_[front_anf] = backend_anf; backend_front_anf_map_[backend_anf] = front_anf; } void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) { MS_EXCEPTION_IF_NULL(old_backend_anf); MS_EXCEPTION_IF_NULL(new_backend_anf); if (old_backend_anf == new_backend_anf) { MS_LOG(DEBUG) << "Old same with new:" << old_backend_anf->DebugString(); return; } if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) { MS_LOG(DEBUG) << "Old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map"; return; } if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) { MS_LOG(EXCEPTION) << "Anf is not exist in the map ,old " << old_backend_anf->DebugString(); } front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf; backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf]; // delete old kernel (void)backend_front_anf_map_.erase(old_backend_anf); } // get kernel by anf AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) { if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) { return nullptr; } return front_backend_anf_map_[front_anf]; } bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) { return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end(); } ValueNodePtr KernelGraph::GetValueNodeByTensor(const mindspore::tensor::TensorPtr &tensor) { if (tensor_to_value_node_map_.find(tensor) == tensor_to_value_node_map_.end()) { return nullptr; } return tensor_to_value_node_map_[tensor]; } void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node) { MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(value_node); tensor_to_value_node_map_[tensor] = value_node; } void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(input); MS_LOG(DEBUG) << "Input:" << input->DebugString() << ", node:" << node->DebugString() << ",num:" << depend_edge_num; auto output_depend_edge = std::pair(node, depend_edge_num); // add output depend edge of input auto output_it = node_output_edges_.find(input); if (output_it == node_output_edges_.end()) { node_output_edges_[input] = std::vector>{output_depend_edge}; } else { output_it->second.push_back(output_depend_edge); } // add input depend edge of output auto input_depend_edge = std::pair(input, depend_edge_num); auto input_it = node_input_edges_.find(node); if (input_it == node_input_edges_.end()) { node_input_edges_[node] = std::vector>{input_depend_edge}; } else { input_it->second.push_back(input_depend_edge); } // add node input depend num auto depend_it = node_input_num_.find(node); if (depend_it == node_input_num_.end()) { node_input_num_[node] = depend_edge_num; } else { depend_it->second += depend_edge_num; } } std::vector KernelGraph::GetOutputNodes(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto it = node_output_edges_.find(node); if (it == node_output_edges_.end()) { MS_LOG(EXCEPTION) << "Can't find node[" << node->DebugString() << "]"; } std::vector output_nodes; auto trans = [](const std::pair &pair) -> AnfNodePtr { return pair.first; }; (void)std::transform(it->second.begin(), it->second.end(), std::back_inserter(output_nodes), trans); return output_nodes; } // Find control_depend real input nodes. void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector *result, std::set *visited) { MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(result); MS_EXCEPTION_IF_NULL(visited); if (visited->find(anf_node) != visited->end()) { MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited"; return; } visited->insert(anf_node); if (AnfAlgo::IsRealKernel(anf_node)) { result->emplace_back(anf_node); return; } if (!anf_node->isa()) { return; } auto cnode = anf_node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().empty()) { MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString(); } auto input0 = cnode->input(0); if (IsPrimitive(input0, prim::kPrimMakeTuple)) { for (size_t i = 1; i < cnode->inputs().size(); ++i) { GetAllFatherRealNode(cnode->input(i), result, visited); } } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { if (cnode->inputs().size() != kTupleGetItemInputSize) { MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; } GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited); } else if (IsPrimitive(input0, prim::kPrimDepend)) { if (cnode->inputs().size() != kDependInputSize) { MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!"; } GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited); GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited); } } // update the depend relations of control depend void KernelGraph::UpdateControlDependRelations(const std::vector &depends) { for (const auto &node : depends) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { return; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend"; } auto prior_node = cnode->input(kControlDependPriorIndex); auto depend_node = cnode->input(kControlDependBehindIndex); MS_EXCEPTION_IF_NULL(prior_node); MS_EXCEPTION_IF_NULL(depend_node); std::vector prior_nodes = {prior_node}; std::vector depend_nodes = {depend_node}; int depend_mode = 0; if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { depend_mode = AnfAlgo::GetNodeAttr(cnode, kControlDependMode); } MS_LOG(DEBUG) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString() << "], depend_mode :" << depend_mode << "."; if (prior_node->isa() && depend_mode == 1) { prior_nodes = GetOutputNodes(prior_node); } if (depend_node->isa()) { depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector{}; } std::vector real_prior_nodes; std::set prior_visited; for (const auto &tmp : prior_nodes) { GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); } std::vector real_depend_nodes; std::set depend_visited; for (const auto &tmp : depend_nodes) { GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); } for (auto &first_node : real_prior_nodes) { if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { continue; } for (auto &second_node : real_depend_nodes) { if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { continue; } MS_EXCEPTION_IF_NULL(first_node); MS_EXCEPTION_IF_NULL(second_node); MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); AddDependEdge(second_node, first_node, 1); } } } } bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue *que, std::unordered_set *visited_nodes) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(que); MS_EXCEPTION_IF_NULL(visited_nodes); if (!node->isa()) { return false; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { return false; } // set the control depend visited but don't push it into the que if (visited_nodes->find(node) != visited_nodes->end()) { return true; } (void)visited_nodes->insert(cnode); // add a 0 depend num to keep the link relations to prepare for finding zero output nodes auto prior_node = cnode->input(kControlDependPriorIndex); auto depend_node = cnode->input(kControlDependBehindIndex); for (const auto &input : cnode->inputs()) { AddDependEdge(node, input, 0); } PushNoVisitedNode(depend_node, que, visited_nodes); PushNoVisitedNode(prior_node, que, visited_nodes); return true; } void KernelGraph::UpdateNodeEdgeList(std::queue *seed_nodes) { MS_EXCEPTION_IF_NULL(seed_nodes); node_output_edges_.clear(); node_input_num_.clear(); node_input_edges_.clear(); std::vector control_depends; std::unordered_set visited_nodes; std::queue que; que.push(get_return()); while (!que.empty()) { auto node = que.front(); que.pop(); MS_EXCEPTION_IF_NULL(node); if (node->isa() || node->isa()) { seed_nodes->push(node); continue; } if (!node->isa()) { continue; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); // handle data links for (const auto &input : cnode->inputs()) { size_t depend_edge_num = 1; // handle control depend,all inputs of control depend has no depend edge if (HandleControlDependNode(input, &que, &visited_nodes)) { control_depends.push_back(input); depend_edge_num = 0; } PushNoVisitedNode(input, &que, &visited_nodes); AddDependEdge(node, input, depend_edge_num); } } UpdateControlDependRelations(control_depends); } void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) { (void)graph_value_nodes_.insert(value_node); } bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return ref_out_in_map_.count(pair) != 0; } AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const { if (!IsInRefOutputMap(out_pair)) { MS_LOG(EXCEPTION) << "Out_pair is not in RefOutputMap"; } return ref_out_in_map_.at(out_pair); } void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) { if (IsInRefOutputMap(final_pair)) { MS_LOG(EXCEPTION) << "Out_pair is already in RefOutputMap"; } (void)ref_out_in_map_.insert(std::make_pair(final_pair, origin_pair)); } bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) { if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end()) { (void)graph_value_nodes_.erase(value_node); return true; } return false; } void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNull new_anf_node) { MS_EXCEPTION_IF_NULL(inputs_); { std::queue seed_nodes; UpdateNodeEdgeList(&seed_nodes); } auto it = node_output_edges_.find(old_anf_node); if (it != node_output_edges_.end()) { const auto &outputs = it->second; for (auto &output_node : outputs) { MS_EXCEPTION_IF_NULL(output_node.first); auto output_cnode = output_node.first->cast(); MS_EXCEPTION_IF_NULL(output_cnode); auto &output_node_inputs = output_cnode->inputs(); // don't replace node if it is a control edge => output_node.second == 0 if (output_node.second == 0) { continue; } for (size_t i = 1; i < output_node_inputs.size(); i++) { if (output_node_inputs[i] == old_anf_node.get()) { output_cnode->set_input(i, new_anf_node); } } // update graph inputs for (size_t i = 0; i < inputs_->size(); i++) { if ((*inputs_)[i] == old_anf_node.get()) { MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString() << ",new graph input:" << new_anf_node->DebugString(); (*inputs_)[i] = new_anf_node.get(); break; } } } // update front to backend map FrontBackendlMapUpdate(old_anf_node, new_anf_node); } { std::queue seed_nodes; UpdateNodeEdgeList(&seed_nodes); } // update graph inputs in child graph auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(), [&old_anf_node](const std::pair> &n) -> bool { return n.first == old_anf_node.get(); }); if (it_real_inputs != real_inputs_.end()) { // erase old parameter in map auto old_args = it_real_inputs->second; real_inputs_.erase(it_real_inputs); // insert new parameter to map auto iter = std::find_if(real_inputs_.begin(), real_inputs_.end(), [&new_anf_node](const std::pair> &n) -> bool { return n.first == new_anf_node.get(); }); if (iter != real_inputs_.end()) { MS_LOG(WARNING) << new_anf_node->DebugString() << " Already exist in real inputs, will be rewrited."; iter->second = old_args; } else { real_inputs_.emplace_back(new_anf_node, old_args); } } } void KernelGraph::UpdateExecuteKernelStreamLabel() { for (auto &kernel : execution_order_) { AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get()); } } std::vector> KernelGraph::GetLeafGraphOrder() { std::vector> leaf_graph_order; if (IsLeafGraph()) { leaf_graph_order.push_back(shared_from_this()->cast()); } else { for (const auto &child_graph : child_graph_order_) { MS_EXCEPTION_IF_NULL(child_graph); auto child_leaf_graph_order = child_graph->GetLeafGraphOrder(); std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order)); } } return leaf_graph_order; } bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); } std::vector KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const { std::vector result; for (const auto &anf : execution_order_) { if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) { result.push_back(anf->cast()); } } return result; } void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) { MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(arg); MS_LOG(INFO) << "Parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString(); MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(arg); auto iter = std::find_if( real_inputs_.begin(), real_inputs_.end(), [¶meter](const std::pair> &n) -> bool { return n.first == parameter; }); if (iter != real_inputs_.end()) { auto &args = iter->second; args.push_back(arg); } else { real_inputs_.emplace_back(parameter, std::vector(1, arg)); } } void KernelGraph::AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr &from_graph) { unreuse_args_[arg] = from_graph; } void KernelGraph::UpdateCallRealInput() { MS_LOG(INFO) << "Update graph id: " << graph_id_; std::vector>> real_inputs_map; for (auto &it : real_inputs_) { auto parameter = it.first; MS_EXCEPTION_IF_NULL(parameter); auto real_inputs = it.second; std::vector new_real_inputs; for (auto &real_input : real_inputs) { // if real input is a call node ,find the child graph output act as the new real input auto tmp_real_input = GetCallRealOutputs(real_input); std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs)); // replace the call in unreuse_args_ auto unreuse_arg_it = unreuse_args_.find(real_input); if (unreuse_arg_it != unreuse_args_.end()) { auto old_graph = unreuse_arg_it->second; for (auto new_real_input : new_real_inputs) { // if call reference graph output is parameter, it will be allowed to reuse if (!new_real_input->isa()) { unreuse_args_[new_real_input] = old_graph; } } } } real_inputs_map.emplace_back(parameter, new_real_inputs); } real_inputs_ = real_inputs_map; } void KernelGraph::PrintGraphExecuteOrder() const { MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order"; for (size_t i = 0; i < execution_order_.size(); i++) { CNodePtr cur_cnode_ptr = execution_order_[i]; MS_EXCEPTION_IF_NULL(cur_cnode_ptr); std::string event_str; std::string label_str; if (AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) { event_str = ", event_id[" + std::to_string(AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrEventId)) + "]"; } if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) { label_str = ", label_id[" + std::to_string(AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrLabelIndex)) + "]"; } if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) { auto label_list = AnfAlgo::GetNodeAttr>(cur_cnode_ptr, kAttrLabelSwitchList); label_str = ", label_id["; for (size_t j = 0; j < label_list.size(); ++j) { label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]"); } } MS_LOG(INFO) << "Index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id[" << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]" << event_str << label_str; } } void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node) { if (front_node == nullptr || node == nullptr) { MS_LOG(INFO) << "Front node or node is nullptr"; return; } MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString(); front_to_internal_outputs_map_[front_node] = node; int output_idx = 0; if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) { output_idx = AnfAlgo::GetTupleGetItemOutIndex(front_node->cast()); } internal_outputs_to_front_map_[node][output_idx] = front_node; } void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx, int dst_output_idx) { if (new_node == nullptr || node == nullptr) { MS_LOG(INFO) << "New node or node is nullptr"; return; } if (node == new_node) { MS_LOG(INFO) << "New node and node is the same"; return; } auto iter = internal_outputs_to_front_map_.find(node); if (iter == internal_outputs_to_front_map_.end()) { MS_LOG(INFO) << "Node is not internal output"; return; } MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString(); auto &front_nodes = iter->second; // Move all front nodes to new node mapping if (src_output_idx == -1) { internal_outputs_to_front_map_[new_node] = front_nodes; for (const auto &front_node_iter : front_nodes) { front_to_internal_outputs_map_[front_node_iter.second] = new_node; } internal_outputs_to_front_map_.erase(iter); return; } // Move specified front node to new node mapping int index = SizeToInt(src_output_idx); auto front_node_iter = front_nodes.find(index); if (front_node_iter == front_nodes.end()) { MS_LOG(INFO) << "The output " << src_output_idx << " of node " << node->DebugString() << " is not an internal node"; return; } auto front_node = front_node_iter->second; internal_outputs_to_front_map_[new_node][dst_output_idx] = front_node; front_to_internal_outputs_map_[front_node] = new_node; front_nodes.erase(index); if (front_nodes.empty()) { internal_outputs_to_front_map_.erase(iter); } } AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const { auto iter = front_to_internal_outputs_map_.find(front_node); if (iter != front_to_internal_outputs_map_.end()) { return iter->second; } return nullptr; } bool KernelGraph::IsInternalOutput(const AnfNodePtr &node, int output_idx) const { auto front_nodes_iter = internal_outputs_to_front_map_.find(node); if (front_nodes_iter != internal_outputs_to_front_map_.end()) { if (output_idx == -1) { return true; } auto &front_nodes = front_nodes_iter->second; if (front_nodes.find(output_idx) != front_nodes.end()) { return true; } } return false; } void KernelGraph::UpdateChildGraphOrder() { MS_LOG(INFO) << "Update " << ToString() << " child graph order."; SetExecOrderByDefault(); auto call_nodes = FindNodeByPrimitive(std::make_shared(prim::kPrimCall->name())); std::vector child_graph_order; for (auto &call_node : call_nodes) { MS_EXCEPTION_IF_NULL(call_node); auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); for (const auto &child_graph : call_child_graphs) { MS_EXCEPTION_IF_NULL(child_graph); if (child_graph != parent_graph_) { auto shared_this = std::dynamic_pointer_cast(shared_from_this()); MS_EXCEPTION_IF_NULL(shared_this); child_graph->set_parent_graph(shared_this); } child_graph_order.push_back(child_graph); } } for (size_t i = 0; i < child_graph_order.size(); ++i) { MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]"; } child_graph_order_ = child_graph_order; } std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); } } // namespace session } // namespace mindspore