/** * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "backend/optimizer/common/helper.h" #include #include #include #include #include #include #include "utils/hash_set.h" #include "utils/utils.h" #include "base/base_ref.h" #include "backend/session/anf_runtime_algorithm.h" #include "base/core_ops.h" #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" #include "frontend/operator/ops.h" #include "utils/ms_utils.h" #include "utils/convert_utils.h" #include "runtime/device/kernel_info.h" #include "utils/ms_context.h" #include "utils/trace_base.h" #include "backend/optimizer/common/const_input_to_attr_registry.h" #include "abstract/primitive_infer_map.h" namespace mindspore { namespace opt { namespace { constexpr size_t kType32Len = 4; constexpr size_t kType64Len = 8; void UpdateDumpFlagAndDebugInfo(const CNodePtr &node, const std::vector &orig_nodes) { std::vector orig_real_cnodes; for (auto &orig_node : orig_nodes) { if (AnfUtils::IsRealCNodeKernel(orig_node)) { auto orig_cnode = orig_node->cast(); if (AnfAlgo::HasNodeAttr(kAttrDump, orig_cnode)) { AnfAlgo::CopyNodeAttr(kAttrDump, orig_cnode, node); } orig_real_cnodes.push_back(orig_node); } } node->AddFusedDebugInfoList(orig_real_cnodes); } } // namespace std::vector Convert2Int(const std::vector &v) { std::vector result; (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt); return result; } std::vector Convert2Long(const std::vector &v) { std::vector result; (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToLong); return result; } bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector &nodes) { MS_EXCEPTION_IF_NULL(node); FuncGraphManagerPtr manager = graph.manager(); MS_EXCEPTION_IF_NULL(manager); mindspore::HashSet seen_node; std::deque todo{node}; while (!todo.empty()) { AnfNodePtr nd = todo.front(); todo.pop_front(); if (seen_node.count(nd) > 0 || !manager->all_nodes().contains(nd)) { continue; } (void)seen_node.insert(nd); if (std::any_of(nodes.begin(), nodes.end(), [&nd](const AnfNodePtr &item) { return nd == item; })) { return true; } if (nd->isa()) { auto cnode = nd->cast(); MS_EXCEPTION_IF_NULL(cnode); auto inputs = cnode->inputs(); (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); } } return false; } bool UnVisited(const BaseRef &n) { if (utils::isa(n)) { AnfNodePtr in = utils::cast(n); MS_EXCEPTION_IF_NULL(in); if (IsValueNode(in)) { auto value_node = in->cast(); MS_EXCEPTION_IF_NULL(value_node); auto value = value_node->value(); MS_EXCEPTION_IF_NULL(value); auto prim_py = value->cast(); MS_EXCEPTION_IF_NULL(prim_py); return !prim_py->HasAttr(kAttrVisited); } else if (IsValueNode(in)) { auto func_graph = GetValueNode(in); MS_EXCEPTION_IF_NULL(func_graph); return !func_graph->has_flag(kAttrVisited); } return false; } return false; } CNodePtr NewCNode(const std::vector &inputs, const FuncGraphPtr &fg, const std::vector &orig_nodes) { MS_EXCEPTION_IF_NULL(fg); auto node = fg->NewCNode(inputs); MS_EXCEPTION_IF_NULL(node); UpdateDumpFlagAndDebugInfo(node, orig_nodes); return node; } CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector &orig_nodes) { MS_EXCEPTION_IF_NULL(fg); auto node = fg->NewCNode(cnode); MS_EXCEPTION_IF_NULL(node); UpdateDumpFlagAndDebugInfo(node, orig_nodes); return node; } CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { MS_LOG(EXCEPTION) << "The node is expected to be a cnode"; } auto cnode = node->cast(); CheckCNodeInputSize(cnode, input_size); return cnode; } void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_size) { MS_EXCEPTION_IF_NULL(cnode); auto real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode); if (real_input_tensor_num != input_tensor_size) { MS_LOG(EXCEPTION) << "The input tensor size[" << real_input_tensor_num << "] of node [" + cnode->DebugString() + "] is not equal to " << input_tensor_size << trace::DumpSourceLines(cnode); } } bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y) { MS_EXCEPTION_IF_NULL(node_x); MS_EXCEPTION_IF_NULL(node_y); return (AnfAlgo::GetInputDeviceDataType(node_x, 0) == AnfAlgo::GetOutputDeviceDataType(node_y, 0) && AnfAlgo::GetOutputDeviceDataType(node_x, 0) == AnfAlgo::GetInputDeviceDataType(node_y, 0)); } const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(func_graph); auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum); MS_EXCEPTION_IF_NULL(transop_cnode); auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(1), kDependInputTensorNum); auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputTensorNum); auto transed_node = prev_transop_cnode->input(1); MS_EXCEPTION_IF_NULL(transed_node); std::vector replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node, depend_cnode->input(kDependAttachNodeIndex)}; AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs); MS_EXCEPTION_IF_NULL(replace_depend); auto transed_abstract = transed_node->abstract(); replace_depend->set_abstract(transed_abstract); return replace_depend; } bool Visited(const BaseRef &n) { if (utils::isa(n)) { AnfNodePtr in = utils::cast(n); MS_EXCEPTION_IF_NULL(in); if (IsValueNode(in)) { auto value_node = in->cast(); MS_EXCEPTION_IF_NULL(value_node); auto value = value_node->value(); MS_EXCEPTION_IF_NULL(value); auto prim_py = value->cast(); MS_EXCEPTION_IF_NULL(prim_py); return prim_py->HasAttr(kAttrVisited); } else if (IsValueNode(in)) { auto func_graph = GetValueNode(in); MS_EXCEPTION_IF_NULL(func_graph); return func_graph->has_flag(kAttrVisited); } return false; } return false; } void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num, std::vector *outputs) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(outputs); auto type_ptr = node->Type(); auto shape_ptr = node->Shape(); for (size_t i = 0; i < output_num; i++) { int64_t temp = SizeToLong(i); auto idx = NewValueNode(temp); MS_EXCEPTION_IF_NULL(idx); auto imm = std::make_shared(temp); auto abstract_scalar = std::make_shared(imm); idx->set_abstract(abstract_scalar); auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); MS_EXCEPTION_IF_NULL(tuple_getitem); AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(type_ptr, i)}, {AnfAlgo::GetOutputInferShape(node, shape_ptr, i)}, tuple_getitem.get()); (*outputs).push_back(tuple_getitem); } } template tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr, size_t data_length) { MS_EXCEPTION_IF_NULL(value_tuple_ptr); MS_EXCEPTION_IF_NULL(type_ptr); std::vector values; for (const auto &v : value_tuple_ptr->value()) { MS_EXCEPTION_IF_NULL(v); if (v->isa()) { ScalarPtr scalar = v->cast(); values.push_back(GetValue(scalar)); } else { MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; return nullptr; } } std::vector tensor_shape = {SizeToLong(values.size())}; tensor::TensorPtr tensor = std::make_shared(type_ptr->type_id(), tensor_shape); MS_EXCEPTION_IF_NULL(tensor); tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr}; tensor->set_device_info(device_info); auto data_ptr = tensor->data_c(); MS_EXCEPTION_IF_NULL(data_ptr); auto elem_num = values.size() * data_length; auto ret_code = memcpy_s(data_ptr, static_cast(tensor->data().nbytes()), values.data(), elem_num); if (ret_code != 0) { MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code; } return tensor; } tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) { MS_EXCEPTION_IF_NULL(value_tuple); tensor::TensorPtr tensor = nullptr; if (value_tuple->value().empty()) { MS_LOG(WARNING) << "The value tuple is empty."; return nullptr; } ValuePtr v = *(value_tuple->value().begin()); MS_EXCEPTION_IF_NULL(v); // Currently we only deal with the scalar tuple if (!v->isa()) { MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; return nullptr; } ScalarPtr scalar = v->cast(); MS_EXCEPTION_IF_NULL(scalar); if (scalar->isa()) { tensor = CreateTensorWithValueTuple(value_tuple, kInt32, sizeof(int32_t)); } else if (scalar->isa()) { tensor = CreateTensorWithValueTuple(value_tuple, kInt64, sizeof(int64_t)); } else if (scalar->isa()) { tensor = CreateTensorWithValueTuple(value_tuple, kFloat32, sizeof(float)); } else { auto type = scalar->type(); auto type_str = (type == nullptr) ? "nullptr" : type->ToString(); MS_LOG(ERROR) << "Invalid scalar type: " << type_str; return nullptr; } return tensor; } bool IsNopNode(const AnfNodePtr &node) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); auto target = GetCNodeTarget(node); if (target == kCPUDevice) { return false; } if (context_ptr->get_param(MS_CTX_DEVICE_TARGET) != kAscendDevice && context_ptr->get_param(MS_CTX_DEVICE_TARGET) != kGPUDevice) { return false; } static mindspore::HashSet nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName, prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(), kFlattenGradOpName, prim::kPrimReformat->name()}; if (node == nullptr || !node->isa()) { return false; } CNodePtr cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().empty()) { return false; } auto input0 = cnode->input(0); MS_EXCEPTION_IF_NULL(input0); if (!input0->isa()) { return false; } bool is_nop_node = false; if (AnfAlgo::HasNodeAttr("nop_op", cnode)) { is_nop_node = AnfAlgo::GetNodeAttr(cnode, "nop_op"); } if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end() && !is_nop_node) { return false; } return true; } bool IsAllNopNode(const session::KernelGraph *const graph) { MS_EXCEPTION_IF_NULL(graph); auto execution_order = graph->execution_order(); for (auto &cnode : execution_order) { MS_EXCEPTION_IF_NULL(cnode); if (!IsNopNode(cnode)) { return false; } } return true; } bool NeedHideNode(const std::vector &outputs, const AnfNodePtr &node, bool is_dynamic_graph) { MS_EXCEPTION_IF_NULL(node); // if node is not a nop node, keep it in execution order if (!IsNopNode(node)) { return false; } // if node is nop node and the graph is dynamic graph, check if the nop node is graph's output. if (is_dynamic_graph) { auto iter = find(outputs.begin(), outputs.end(), node); if (iter != outputs.end()) { return false; } } return true; } void HideNopNode(session::KernelGraph *const graph) { MS_EXCEPTION_IF_NULL(graph); if (IsAllNopNode(graph) == true) { return; } auto execution_order = graph->execution_order(); auto outputs = graph->outputs(); bool is_dynamic_graph = graph->is_dynamic_shape(); MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size(); std::vector new_nodes; for (auto &cnode : execution_order) { MS_EXCEPTION_IF_NULL(cnode); if (NeedHideNode(outputs, cnode, is_dynamic_graph)) { AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(true), cnode); } else { new_nodes.push_back(cnode); } } graph->set_execution_order(new_nodes); MS_LOG(INFO) << "nop node info (After Remove) size: " << graph->execution_order().size(); } void RemoveNopNode(session::KernelGraph *const graph) { MS_EXCEPTION_IF_NULL(graph); if (IsAllNopNode(graph) == true) { return; } bool changed = true; while (changed) { changed = false; std::vector new_nodes; auto outputs = graph->outputs(); bool is_dynamic_graph = graph->is_dynamic_shape(); for (auto &cnode : graph->execution_order()) { MS_EXCEPTION_IF_NULL(cnode); // ignore nop node itself if (NeedHideNode(outputs, cnode, is_dynamic_graph)) { AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(true), cnode); continue; } // Replace the input which is nop node std::vector new_inputs; new_inputs.push_back(cnode->input(0)); bool need_update = false; for (size_t i = 1; i < cnode->inputs().size(); ++i) { auto input = cnode->input(i); MS_EXCEPTION_IF_NULL(input); auto cinput = input->cast(); if (cinput == nullptr || !IsNopNode(cinput)) { new_inputs.push_back(input); continue; } constexpr auto kInputSize = 2; if (cinput->inputs().size() == kInputSize) { new_inputs.push_back(cinput->input(1)); need_update = true; changed = true; } else { new_inputs.push_back(input); } } if (need_update) { cnode->set_inputs(new_inputs); } // push into new execution list new_nodes.push_back(cnode); } graph->set_execution_order(new_nodes); } } size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) { auto out_list = GetRealNodeUsedList(graph, node); MS_EXCEPTION_IF_NULL(out_list); return out_list->size(); } std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, const AnfNodePtr &node) { auto output_node_list = std::make_shared>>(); MS_EXCEPTION_IF_NULL(graph); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); auto iter = manager->node_users().find(node); if (iter == manager->node_users().end()) { return output_node_list; } auto output_info_list = iter->second; for (const auto &output_info : output_info_list) { auto cnode_name = AnfAlgo::GetCNodeName(output_info.first); if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) || (cnode_name == prim::kPrimUpdateState->name())) { continue; } output_node_list->push_back(output_info); } return output_node_list; } std::shared_ptr>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t output_index) { auto output_node_list = std::make_shared>>(); MS_EXCEPTION_IF_NULL(graph); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); auto iter = manager->node_users().find(node); if (iter == manager->node_users().end()) { MS_LOG(EXCEPTION) << "node has no output in manager"; } auto output_info_list = iter->second; for (const auto &output_info : output_info_list) { auto cnode_name = AnfAlgo::GetCNodeName(output_info.first); if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) || (cnode_name == prim::kPrimUpdateState->name())) { continue; } size_t used_output_index; if (cnode_name == prim::kPrimTupleGetItem->name()) { used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast(output_info.first)); } else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) { used_output_index = output_index; } else { auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(output_info.first, IntToSize(output_info.second - 1)); if (kernel_with_index.first.get() != node.get()) { MS_LOG(EXCEPTION) << "Get used node failed for op[" << AnfAlgo::GetCNodeName(node) << "]"; } used_output_index = kernel_with_index.second; } if (used_output_index == output_index) { output_node_list->push_back(output_info); } } return output_node_list; } bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); auto output_node_list = GetRealNodeUsedList(graph, node); MS_EXCEPTION_IF_NULL(output_node_list); return output_node_list->size() > 1; } bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); auto output_node_list = GetRealNodeUsedList(graph, node); MS_EXCEPTION_IF_NULL(output_node_list); if (output_node_list->empty()) { return true; } for (const auto &output : *output_node_list) { auto out_node = output.first; auto name = AnfAlgo::GetCNodeName(out_node); if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() || name == prim::kPrimTupleGetItem->name() || name == prim::kPrimLoad->name()) { auto result = IsNotRealUsedByOthers(graph, out_node); if (!result) { return result; } continue; } return false; } return true; } CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(func_graph); auto idx = NewValueNode(SizeToLong(output_idx)); MS_EXCEPTION_IF_NULL(idx); auto imm = std::make_shared(SizeToLong(output_idx)); auto abstract_scalar = std::make_shared(imm); idx->set_abstract(abstract_scalar); CNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); MS_EXCEPTION_IF_NULL(tuple_getitem); tuple_getitem->set_scope(node->scope()); auto abs = node->abstract()->cast(); MS_EXCEPTION_IF_NULL(abs); auto abs_i = abs->elements()[output_idx]; MS_EXCEPTION_IF_NULL(abs_i); tuple_getitem->set_abstract(abs_i); return tuple_getitem; } ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector &shape, bool to_tensor) { MS_EXCEPTION_IF_NULL(func_graph); auto kernel_graph = func_graph->cast(); MS_EXCEPTION_IF_NULL(kernel_graph); ValuePtr shape_value = nullptr; AbstractBasePtr abstract = nullptr; if (to_tensor) { // create Tensor int64_t shape_dim = SizeToLong(shape.size()); std::vector shape_vec_shape = {shape_dim}; auto shape_tensor = std::make_shared(kNumberTypeInt64, shape_vec_shape); MS_EXCEPTION_IF_NULL(shape_tensor); auto data_ptr = shape_tensor->data_c(); MS_EXCEPTION_IF_NULL(data_ptr); auto elem_num = shape.size() * kType64Len; auto ret_code = memcpy_s(data_ptr, static_cast(shape_tensor->data().nbytes()), &shape[0], elem_num); if (ret_code != 0) { MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code; return nullptr; } shape_value = shape_tensor; abstract = std::make_shared(kInt64, shape_vec_shape); } else { // create ValueTuple std::vector dim_values{}; abstract::AbstractBasePtrList abs{}; for (const auto &dim : shape) { dim_values.push_back(MakeValue(dim)); abs.push_back(std::make_shared(dim)); } shape_value = std::make_shared(dim_values); abstract = std::make_shared(abs); } MS_EXCEPTION_IF_NULL(shape_value); MS_EXCEPTION_IF_NULL(abstract); auto shape_value_node = kernel_graph->NewValueNode(abstract, shape_value); MS_EXCEPTION_IF_NULL(shape_value_node); kernel_graph->AddValueNodeToGraph(shape_value_node); return shape_value_node; } void ConstInputToAttr(const CNodePtr &cnode, const mindspore::HashSet &input_attrs) { MS_EXCEPTION_IF_NULL(cnode); std::vector new_inputs; auto primitive = AnfAlgo::GetCNodePrimitive(cnode); MS_EXCEPTION_IF_NULL(primitive); primitive = primitive->Clone(); auto input_names = primitive->GetAttr(kAttrInputNames); if (input_names == nullptr) { MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]"; return; } auto input_names_vec = GetValue>(input_names); auto inputs = cnode->inputs(); new_inputs.push_back(inputs[0]); bool need_update = false; for (size_t i = 0; i < inputs.size() - 1; ++i) { auto input_node = inputs[i + 1]; if (AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimDepend)) { input_node = AnfAlgo::VisitKernel(input_node, 0).first; } MS_EXCEPTION_IF_NULL(input_node); if (input_attrs.find(i) != input_attrs.end() && input_node->isa() && !HasAbstractMonad(input_node)) { auto value_node = input_node->cast(); MS_EXCEPTION_IF_NULL(value_node); MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]"; if (i >= input_names_vec.size()) { MS_LOG(EXCEPTION) << "Index " << i << " is larger than input names size [" << input_names_vec.size() << "]"; } auto value = value_node->value(); if (value->isa()) { auto tensor = value->cast(); if (tensor->data().const_data() == nullptr) { need_update = false; break; } } primitive->set_attr(input_names_vec[i], value); need_update = true; } else { new_inputs.push_back(inputs[i + 1]); } } if (need_update) { // Update cnode's inputs new_inputs[0] = NewValueNode(primitive); cnode->set_inputs(new_inputs); } } bool AnfEqual(const BaseRef &a, const BaseRef &b) { if (utils::isa(a) && utils::isa(b)) { auto a_node = utils::cast(a); auto b_node = utils::cast(b); MS_EXCEPTION_IF_NULL(a_node); MS_EXCEPTION_IF_NULL(b_node); if (IsValueNode(a_node) && IsValueNode(b_node)) { auto a_value_node = a_node->cast(); MS_EXCEPTION_IF_NULL(a_value_node); auto a_value = a_value_node->value(); MS_EXCEPTION_IF_NULL(a_value); auto a_prim = a_value->cast(); MS_EXCEPTION_IF_NULL(a_prim); auto b_value_node = b_node->cast(); MS_EXCEPTION_IF_NULL(b_value_node); auto b_value = b_value_node->value(); MS_EXCEPTION_IF_NULL(b_value); auto b_prim = b_value->cast(); MS_EXCEPTION_IF_NULL(b_prim); return a_prim->name() == b_prim->name(); } else if (a_node->isa() && b_node->isa()) { auto a_value_node_ptr = a_node->cast(); if (a_value_node_ptr == nullptr) { MS_LOG(EXCEPTION) << "Cast value node ptr fail."; } auto a_value_ptr = a_value_node_ptr->value(); if (a_value_ptr == nullptr) { MS_LOG(EXCEPTION) << "Value ptr is nullptr."; } auto b_value_node_ptr = b_node->cast(); if (b_value_node_ptr == nullptr) { MS_LOG(EXCEPTION) << "Cast value node ptr fail."; } auto b_value_ptr = b_value_node_ptr->value(); if (b_value_ptr == nullptr) { MS_LOG(EXCEPTION) << "Value ptr is nullptr."; } return (*a_value_ptr) == (*b_value_ptr); } MS_LOG(DEBUG) << "check AnfNodePtr equal"; } if (utils::isa(a) && utils::isa(b)) { MS_LOG(DEBUG) << "check GraphPtr equal"; } return a == b; } bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { // To matchCNode and Kernel's type if (utils::isa(a) && utils::isa(b)) { return true; } return a.type() == b.type(); } namespace { ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp, PrimitiveVarMap *primitive_vars) { if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } if (utils::isa(sexp)) { auto value = utils::cast(sexp); if (utils::isa(sexp)) { auto prim = utils::cast(sexp); if (primitive_vars->find(prim) != primitive_vars->end()) { prim = std::make_shared(prim->name()); value = prim; } (*primitive_vars)[prim] = std::make_shared(prim); } return NewValueNode(value); } return nullptr; } CNodePtr CreateCNodeWithGraph(const std::vector &input_nodes, const BaseRef &graph) { if (utils::isa(graph)) { return std::make_shared(input_nodes, utils::cast(graph)); } if (utils::isa(graph)) { return std::make_shared(input_nodes, utils::cast(graph)); } return nullptr; } VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { if (utils::isa(graph)) { MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); return std::make_shared(utils::cast(sexp), nullptr); } if (utils::isa(graph)) { MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); return std::make_shared(utils::cast(sexp), utils::cast(graph)); } MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); return nullptr; } AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); std::vector input_nodes; const auto &tuple = utils::cast(sexp); if (multigraph && utils::isa(graph)) { for (auto &x : tuple) { AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); input_nodes.push_back(node); } VarPtr var_ptr = utils::cast(graph); return std::make_shared(input_nodes, var_ptr); } for (auto &x : tuple) { AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); input_nodes.push_back(node); } return CreateCNodeWithGraph(input_nodes, graph); } // rectify absttract if the input has been converted to the attr AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) { MS_EXCEPTION_IF_NULL(primitive); opt::ConstInputToAttrInfoRegister reg; if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(primitive->name(), ®)) { return input_abstract; } if (AnfAlgo::HasDynamicShapeFlag(primitive)) { return input_abstract; } auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); auto device = ms_context->get_param(MS_CTX_DEVICE_TARGET); if (device == kGPUDevice) { if (DynamicShapeConstInputToAttrGPU.find(primitive->name()) != DynamicShapeConstInputToAttrGPU.end()) { return input_abstract; } } else if (DynamicShapeConstInputToAttr.find(primitive->name()) != DynamicShapeConstInputToAttr.end()) { return input_abstract; } auto convert_input_list = reg.GetConstInputAttrInfo(); auto input_names = primitive->GetAttr(kAttrInputNames); if (input_names == nullptr) { return input_abstract; } auto input_names_vec = GetValue>(input_names); AbstractBasePtrList rectify_abs_list; size_t ori_index = 0; rectify_abs_list.resize(input_names_vec.size()); for (size_t index = 0; index < rectify_abs_list.size(); ++index) { // if convert input list find the index it means the input has been converted to the attr if (convert_input_list.find(index) != convert_input_list.end()) { AbstractBasePtr rectify_abs = nullptr; auto input_name = input_names_vec[index]; auto attr = primitive->GetAttr(input_name); if (attr != nullptr) { rectify_abs = attr->ToAbstract(); } else { MS_LOG(DEBUG) << "the node prim name :" << primitive->name() << "input index :" << index << " input name :" << input_name << "has not been converted to the attr"; rectify_abs = input_abstract[ori_index++]; } rectify_abs_list[index] = rectify_abs; continue; } if (ori_index > input_abstract.size()) { MS_LOG(EXCEPTION) << "Index " << ori_index << " is out of range in input abstract size " << input_abstract.size(); } rectify_abs_list[index] = input_abstract[ori_index++]; } return rectify_abs_list; } AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) { auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes); if (dynamic_inputs_list == nullptr) { return input_abstract; } AbstractBasePtrList rectifyed_abs_list; const int kNotDynamicFlag = -1; auto dynamic_inputs_index = GetValue>(dynamic_inputs_list); size_t input_index = 0; for (auto item : dynamic_inputs_index) { if (item == kNotDynamicFlag) { if (input_index >= input_abstract.size()) { MS_LOG(EXCEPTION) << "Index " << input_index << " is out of range in input abstract " << input_abstract.size(); } (void)rectifyed_abs_list.emplace_back(input_abstract[input_index++]); } else { if (item < 0) { MS_LOG(EXCEPTION) << "The dynamic input size check error the index should be -1 or positive number but got " << item; } AbstractBasePtrList dynamic_inputs_abs; for (auto index = item; index > 0; --index) { if (input_index >= input_abstract.size()) { MS_LOG(EXCEPTION) << "Index " << input_index << " is out of range in input abstract " << input_abstract.size(); } (void)dynamic_inputs_abs.emplace_back(input_abstract[input_index++]); } (void)rectifyed_abs_list.emplace_back(std::make_shared(dynamic_inputs_abs)); } } return rectifyed_abs_list; } AbstractBasePtrList RectifyAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) { auto rectify_abs_list = RectifyAbstractFromRegAttr(primitive, input_abstract); return RectifyAbstractFromDynamicInput(primitive, rectify_abs_list); } } // namespace AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); MS_EXCEPTION_IF_NULL(primitive_vars); if (utils::isa(sexp)) { return HandleSexpVector(sexp, graph, primitive_vars, multigraph); } if (utils::isa(sexp)) { auto var_ptr = utils::cast(sexp); MS_EXCEPTION_IF_NULL(var_ptr); if (var_ptr->primitive()) { (*primitive_vars)[var_ptr->primitive()] = var_ptr; return NewValueNode(var_ptr->primitive()); } return CreateVarNodeWithSexp(sexp, graph); } if (utils::isa(sexp)) { return utils::cast(sexp); } auto value_node = CreateValueNodeWithSexp(sexp, primitive_vars); if (value_node == nullptr) { MS_LOG(EXCEPTION) << "Sexp cannot converted, sexp: " + sexp.ToString(); } return value_node; } bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) { MS_EXCEPTION_IF_NULL(equiv1); MS_EXCEPTION_IF_NULL(equiv2); MS_EXCEPTION_IF_NULL(var_node); auto equiv1_node = GetAnfNodeByVar(equiv1, var_node); MS_EXCEPTION_IF_NULL(equiv1_node); auto equiv2_node = GetAnfNodeByVar(equiv2, var_node); MS_EXCEPTION_IF_NULL(equiv2_node); return *equiv1_node == *equiv2_node; } AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) { MS_EXCEPTION_IF_NULL(equiv); MS_EXCEPTION_IF_NULL(var_node); auto iter = (*equiv).find(var_node); if (iter == (*equiv).end()) { MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched."; return nullptr; } auto res = utils::cast(iter->second); if (res == nullptr) { MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node"; } return res; } bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) { MS_EXCEPTION_IF_NULL(n1); MS_EXCEPTION_IF_NULL(n2); auto n1_cnode = n1->cast(); auto n2_cnode = n2->cast(); MS_EXCEPTION_IF_NULL(n1_cnode); MS_EXCEPTION_IF_NULL(n2_cnode); auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem); MS_EXCEPTION_IF_NULL(index_input1); auto value_node1 = index_input1->cast(); MS_EXCEPTION_IF_NULL(value_node1); auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem); MS_EXCEPTION_IF_NULL(index_input2); auto value_node2 = index_input2->cast(); MS_EXCEPTION_IF_NULL(value_node2); return GetValue(value_node1->value()) < GetValue(value_node2->value()); } bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { MS_LOG(INFO) << "node is not a cnode"; return false; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr(node, attr_name); } bool CheckSupportDataType(const AnfNodePtr &node, const std::set &supported_data_type_set) { MS_EXCEPTION_IF_NULL(node); TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0); if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) { return true; } MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString(); return false; } ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) { MS_EXCEPTION_IF_NULL(value_node); ValueNodePtr new_value_node = std::make_shared(value_node->value()); MS_EXCEPTION_IF_NULL(new_value_node); new_value_node->set_abstract(value_node->abstract()); // create kernel_info fo new value node auto kernel_info = std::make_shared(); new_value_node->set_kernel_info(kernel_info); // create kernel_build_info for new value node auto kernel_build_info_builder = std::make_shared(); MS_EXCEPTION_IF_NULL(kernel_build_info_builder); // set the format of value_node to DEFAULT_FORMAT kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); // set value node initial device data type = infer data type std::vector types; size_t output_num = AnfAlgo::GetOutputTensorNum(value_node); for (size_t index = 0; index < output_num; ++index) { types.push_back(kTypeUnknown); } kernel_build_info_builder->SetOutputsDeviceType(types); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); return new_value_node; } void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) { MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(graph); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); // Find BatchNorm's output which is a Depend or UpdateState. auto node_users = manager->node_users()[old_node]; for (const auto &node_index : node_users) { AnfNodePtr output = node_index.first; MS_EXCEPTION_IF_NULL(output); if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) { auto depend = output->cast(); MS_EXCEPTION_IF_NULL(depend); manager->SetEdge(depend, node_index.second, new_node); } } } AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) { MS_EXCEPTION_IF_NULL(prim); auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap(); auto ret = prim_eval_implement_map.find(prim); if (ret != prim_eval_implement_map.end()) { // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr MS_EXCEPTION_IF_NULL(ret->second.infer_shape_impl_); auto infer_spec_list = RectifyAbstract(prim, args_spec_list); return ret->second.infer_shape_impl_(nullptr, prim, infer_spec_list); } else { // if the infer function has been not founded in the front infer map find it in the backend infer map instead auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap(); auto ret_backend = prim_backend_eval_impl_map.find(prim); if (ret_backend != prim_backend_eval_impl_map.end()) { MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_); auto infer_spec_list = args_spec_list; if (!ret_backend->second.in_white_list_) { infer_spec_list = RectifyAbstract(prim, args_spec_list); } return ret_backend->second.infer_shape_impl_(nullptr, prim, infer_spec_list); } } MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() << " primitive type:" << prim->type_name(); } kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector &node_list) { std::vector inputs_device_format; std::vector outputs_device_format; std::vector inputs_device_type; std::vector outputs_device_type; std::vector> outputs_shape; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; for (size_t idx = 0; idx < node_list.size(); ++idx) { auto cnode = utils::cast(node_list[idx]); MS_EXCEPTION_IF_NULL(cnode); size_t input_num = AnfAlgo::GetInputTensorNum(cnode); for (size_t input_index = 0; input_index < input_num; ++input_index) { (void)inputs_device_format.emplace_back(kOpFormat_DEFAULT); (void)inputs_device_type.emplace_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index)); } size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); for (size_t output_index = 0; output_index < output_num; ++output_index) { (void)outputs_device_format.emplace_back(kOpFormat_DEFAULT); (void)outputs_device_type.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, output_index)); (void)outputs_shape.emplace_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); } } builder.SetInputsFormat(inputs_device_format); builder.SetOutputsFormat(outputs_device_format); builder.SetInputsDeviceType(inputs_device_type); builder.SetOutputsDeviceType(outputs_device_type); return builder.Build(); } std::vector GetNodeOutputUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto manager = kernel_graph.manager(); MS_EXCEPTION_IF_NULL(manager); auto output_num = AnfAlgo::GetOutputTensorNum(node); std::vector output_used_num(output_num, 0); if (output_num == 1) { output_used_num[0] = SizeToLong(manager->node_users()[node].size()); } else { for (auto out_getitem : manager->node_users()[node]) { MS_EXCEPTION_IF_NULL(out_getitem.first); if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) { continue; } auto out_getitem_ptr = out_getitem.first->cast(); MS_EXCEPTION_IF_NULL(out_getitem_ptr); auto getitem_input2 = out_getitem_ptr->input(kInputNodeOutputIndexInTupleGetItem); auto output_idx = LongToSize(GetValue(GetValueNode(getitem_input2))); output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size()); } } return output_used_num; } int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) { auto output_used_num = GetNodeOutputUsedNum(kernel_graph, node); return std::accumulate(output_used_num.begin(), output_used_num.end(), int64_t(0)); } void GetCustomOpAttrIndex(const PrimitivePtr &primitive, mindspore::HashSet *indexes) { if (primitive == nullptr || primitive->name() != prim::kPrimCustom->name()) { return; } MS_EXCEPTION_IF_NULL(indexes); auto input_names = primitive->GetAttr(kAttrInputNames); auto attr_names = primitive->GetAttr(kAttrAttrNames); if (input_names == nullptr || attr_names == nullptr) { return; } auto input_names_vec = GetValue>(input_names); auto attr_names_vec = GetValue>(attr_names); if (input_names_vec.size() >= attr_names_vec.size()) { size_t offset = input_names_vec.size() - attr_names_vec.size(); for (size_t i = offset; i < input_names_vec.size(); ++i) { if (input_names_vec[i] != attr_names_vec[i - offset]) { MS_LOG(EXCEPTION) << primitive->name() << " found mismatching attr name " << input_names_vec[i] << "in input_names and " << attr_names_vec[i - offset] << " in attr_names"; } indexes->insert(i); } } } } // namespace opt } // namespace mindspore