/** * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tools/optimizer/common/gllo_utils.h" #include #include #include #include #include #include #include "src/ops/primitive_c.h" #include "src/common/common.h" #include "frontend/operator/ops.h" #include "backend/optimizer/common/helper.h" namespace mindspore { namespace opt { namespace { constexpr auto kAnfPrimitiveIndex = 0; bool IsRealKernel(const AnfNodePtr &node) { if (node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } // parameter and value node is not a real kernel too if (!node->isa()) { return true; } auto cnode = node->cast(); if (cnode == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } if (cnode->inputs().empty()) { MS_LOG(ERROR) << "Illegal null input of cnode(%s)" << node->DebugString(); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INPUT_TENSOR_ERROR); return false; } auto input = cnode->inputs()[0]; bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || IsPrimitive(input, prim::kPrimTensorSummary) || IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); return !is_virtual_node; } ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } return nullptr; } CNodePtr CreateCNodeWithGraph(const std::vector &input_nodes, const BaseRef &graph) { if (utils::isa(graph)) { return std::make_shared(input_nodes, utils::cast(graph)); } if (utils::isa(graph)) { return std::make_shared(input_nodes, utils::cast(graph)); } return nullptr; } VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { if (utils::isa(graph)) { MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); return std::make_shared(utils::cast(sexp), nullptr); } if (utils::isa(graph)) { MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); return std::make_shared(utils::cast(sexp), utils::cast(graph)); } MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); return nullptr; } AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { if (primitive_vars == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return nullptr; } MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); std::vector input_nodes; const auto &tuple = utils::cast(sexp); if (multigraph && utils::isa(graph)) { for (auto &x : tuple) { AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); input_nodes.push_back(node); } VarPtr var_ptr = utils::cast(graph); return std::make_shared(input_nodes, var_ptr); } for (auto &x : tuple) { AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); input_nodes.push_back(node); } return CreateCNodeWithGraph(input_nodes, graph); } } // namespace bool CheckInputs(const CNodePtr &cnode) { if (cnode == nullptr) { MS_LOG(ERROR) << "cnode is nullptr."; return false; } if (std::any_of(cnode->inputs().begin(), cnode->inputs().end(), [](const AnfNodePtr &anf_node) { return anf_node == nullptr; })) { MS_LOG(ERROR) << "input is nullptr."; return false; } return true; } bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { if (node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } if (!node->isa()) { return false; } auto cnode = node->cast(); if (cnode == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); } bool AnfEqualPrimitive(AnfNodePtr a_node, AnfNodePtr b_node) { auto a_value_node = a_node->cast(); auto b_value_node = b_node->cast(); if (a_value_node == nullptr || b_value_node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } auto a_value = a_value_node->value(); auto b_value = b_value_node->value(); if (a_value == nullptr || b_value == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } auto a_prim = a_value->cast(); auto b_prim = b_value->cast(); if (a_prim == nullptr || b_prim == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } return a_prim->cast()->Type() == b_prim->cast()->Type(); } bool AnfEqualValueNode(AnfNodePtr a_node, AnfNodePtr b_node) { auto a_value_node_ptr = a_node->cast(); auto b_value_node_ptr = b_node->cast(); if (a_value_node_ptr == nullptr || b_value_node_ptr == nullptr) { MS_LOG(ERROR) << "cast value node ptr fail"; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } auto a_value_ptr = a_value_node_ptr->value(); auto b_value_ptr = b_value_node_ptr->value(); if (a_value_ptr == nullptr || b_value_ptr == nullptr) { MS_LOG(ERROR) << "value ptr is nullptr"; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } if (utils::isa(a_value_ptr) && utils::isa(b_value_ptr)) { auto a_obj = (lite::PrimitiveC *)(a_value_ptr.get()); auto b_obj = (lite::PrimitiveC *)(b_value_ptr.get()); return (*a_obj) == (*b_obj); } else { return (*a_value_ptr) == (*b_value_ptr); } } bool AnfEqual(const BaseRef &a, const BaseRef &b) { if (utils::isa(a) && utils::isa(b)) { auto a_node = utils::cast(a); auto b_node = utils::cast(b); if (a_node == nullptr || b_node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } if (IsValueNode(a_node) && IsValueNode(b_node)) { return AnfEqualPrimitive(a_node, b_node); } if (a_node->isa() && b_node->isa()) { return AnfEqualValueNode(a_node, b_node); } } if (a.m_ptr->isa() && b.m_ptr->isa()) { auto a_value_node_ptr = a.m_ptr->cast(); auto b_value_node_ptr = b.m_ptr->cast(); return a_value_node_ptr->Type() == b_value_node_ptr->Type(); } return a == b; } bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { // To matchCNode and Kernel's type if (utils::isa(a) && utils::isa(b)) { return true; } return a.type() == b.type(); } AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); if (primitive_vars == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return nullptr; } if (utils::isa(sexp)) { return HandleSexpVector(sexp, graph, primitive_vars, multigraph); } if (utils::isa(sexp)) { auto var_ptr = utils::cast(sexp); if (var_ptr == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return nullptr; } if (var_ptr->primitive()) { (*primitive_vars)[var_ptr->primitive()] = var_ptr; return NewValueNode(var_ptr->primitive()); } return CreateVarNodeWithSexp(sexp, graph); } if (utils::isa(sexp)) { return utils::cast(sexp); } auto value_node = CreateValueNodeWithSexp(sexp); if (value_node == nullptr) { MS_LOG(ERROR) << "sexp cannot converted. sexp: " << sexp.ToString(); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return nullptr; } return value_node; } bool IsRealCNodeKernel(const AnfNodePtr &node) { if (node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } // parameter and value node is not a real cnode kernel if (!node->isa()) { return false; } // return considered as a real node if (CheckPrimitiveType(node, prim::kPrimReturn)) { return true; } return IsRealKernel(node); } bool IsGraphKernel(const AnfNodePtr &node) { if (node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } // graph kernel should be a real cnode kernel. if (!IsRealCNodeKernel(node)) { return false; } auto cnode = node->cast(); if (cnode == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } auto input = cnode->input(kAnfPrimitiveIndex); // graph kernel should has func_graph as first input. if (!IsValueNode(input)) { return false; } auto func_graph = GetValueNode(input); if (func_graph == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); } int CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) { if (graph == nullptr) { MS_LOG(ERROR) << "The graph is null."; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return lite::RET_NULL_PTR; } return lite::RET_OK; } int CheckIfAnfNodeIsNull(const AnfNodePtr &node) { if (node == nullptr) { MS_LOG(ERROR) << "The AnfNode is null."; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return lite::RET_NULL_PTR; } return lite::RET_OK; } int CheckIfCNodeIsNull(const CNodePtr &node) { if (node == nullptr) { MS_LOG(ERROR) << "The CNode is null."; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return lite::RET_NULL_PTR; } return lite::RET_OK; } int CheckIfVarIsNull(const VarPtr &var) { if (var == nullptr) { MS_LOG(ERROR) << "The Var is null."; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return lite::RET_NULL_PTR; } return lite::RET_OK; } int CheckIfNodeIsParam(const AnfNodePtr &node) { if (node != nullptr && !utils::isa(node)) { MS_LOG(DEBUG) << "The Node is not param."; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR); return lite::RET_INVALID_OP_ATTR; } return lite::RET_OK; } int CheckInputSize(const CNodePtr &node, const int size) { if (static_cast(node->inputs().size()) != size) { MS_LOG(ERROR) << "The input size of node must be " << size << ", but it is" << node->inputs().size(); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR); return lite::RET_INVALID_OP_ATTR; } return lite::RET_OK; } int CheckLeastInputSize(const CNodePtr &node, const int size) { if (static_cast(node->inputs().size()) < size) { MS_LOG(ERROR) << "The input size of node must be " << size << ", but it is" << node->inputs().size(); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR); return lite::RET_INVALID_OP_ATTR; } return lite::RET_OK; } ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, const ParamValueLitePtr &weight_tensor) { auto bias_parameter = func_graph->add_parameter(); MS_ASSERT(bias_parameter != nullptr); std::vector shape = {kernel_num}; std::vector shape_vector; (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), [](const int32_t &value) { return static_cast(value); }); auto abstract_tensor = std::make_shared(TypeIdToType(weight_tensor->tensor_type()), shape_vector); bias_parameter->set_abstract(abstract_tensor); ParamValueLitePtr param_value = std::make_shared(); MS_ASSERT(param_value != nullptr); param_value->SetTensorData(bias_data, kernel_num * sizeof(float) / sizeof(uint8_t)); param_value->set_format(weight_tensor->format()); param_value->set_tensor_type(weight_tensor->tensor_type()); param_value->set_tensor_shape(shape); bias_parameter->set_default_param(param_value); return bias_parameter; } schema::PrimitiveType GetCNodeType(const BaseRef &n) { ValueNodePtr value_node; if (utils::isa(n)) { auto in = utils::cast(n); value_node = in->input(0)->cast(); } else if (utils::isa(n)) { value_node = utils::cast(n); } else { MS_LOG(INFO) << "only value node or cnode has type"; return schema::PrimitiveType_NONE; } if (value_node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return schema::PrimitiveType_NONE; } auto value = value_node->value(); MS_ASSERT(value != nullptr); if (utils::isa(value)) { auto primitive = value->cast(); MS_ASSERT(primitive != nullptr); return (schema::PrimitiveType)primitive->Type(); } else if (utils::isa(value)) { auto primitive = value->cast(); MS_ASSERT(primitive != nullptr); MS_LOG(INFO) << "anf primitive node type:" << primitive->name(); return schema::PrimitiveType_NONE; } return schema::PrimitiveType_NONE; } ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node) { MS_ASSERT(node != nullptr); if (!utils::isa(node)) { if (utils::isa(node)) { auto valueNode = node->cast(); auto value = std::dynamic_pointer_cast(valueNode->value()); if (value != nullptr) { return value; } } MS_LOG(DEBUG) << "get lite param value node neither parameternode or valuenode"; return nullptr; } auto param = node->cast(); MS_ASSERT(param != nullptr); auto param_value = std::dynamic_pointer_cast(param->default_param()); return param_value; } AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index) { if (cnode == nullptr) { MS_LOG(ERROR) << "CNodePtr is nullptr"; return nullptr; } auto inputs = cnode->inputs(); if (!(0 < index && index < inputs.size())) { return nullptr; } auto input = inputs[index]; if (input == nullptr) { MS_LOG(ERROR) << "CNode input is nullptr"; return nullptr; } AbstractBasePtr abstract = nullptr; if (utils::isa(input)) { auto parameter = input->cast(); abstract = parameter->abstract(); } else if (utils::isa(input)) { auto input_cnode = input->cast(); if (GetCNodeType(input_cnode) == schema::PrimitiveType_TupleGetItem) { auto tuple_inputs = input_cnode->inputs(); MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize); auto get_item_input_cnode = tuple_inputs.at(1); MS_ASSERT(get_item_input_cnode != nullptr); auto idx = GetTupleGetItemOutIndex(input_cnode); if (!utils::isa(get_item_input_cnode->abstract())) { MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple"; return nullptr; } auto abstract_tuple = utils::cast(get_item_input_cnode->abstract()); auto abstract_list = abstract_tuple->elements(); if (abstract_list.size() <= idx) { MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect"; return nullptr; } abstract = abstract_list[idx]; } else { abstract = input_cnode->abstract(); } } else { MS_LOG(ERROR) << "unsupported input node type"; return nullptr; } return abstract; } bool IsParamNode(const BaseRef &n) { if (!utils::isa(n)) { return false; } auto param = utils::cast(n)->default_param(); auto tensor = std::dynamic_pointer_cast(param); if (tensor == nullptr) { return false; } return tensor->tensor_addr() != nullptr; } bool IsConvNode(const BaseRef &n) { if (utils::isa(n) || utils::isa(n)) { auto type = opt::GetCNodeType(n); return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D || type == schema::PrimitiveType_DeConv2D; } return false; } bool IsPoolingNode(const BaseRef &n) { if (utils::isa(n) || utils::isa(n)) { auto type = opt::GetCNodeType(n); return type == schema::PrimitiveType_Pooling; } return false; } bool IsActivationNode(const BaseRef &n) { if (utils::isa(n) || utils::isa(n)) { auto type = opt::GetCNodeType(n); return type == schema::PrimitiveType_Activation; } return false; } bool IsQuantNode(const BaseRef &n) { if (utils::isa(n) || utils::isa(n)) { auto type = opt::GetCNodeType(n); return type == schema::PrimitiveType_QuantDTypeCast; } return false; } bool CheckIsAllInputsParam(const AnfNodePtr &node) { if (node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return 0; } if (utils::isa(node)) { auto cnode = node->cast(); for (size_t i = 1; i < cnode->inputs().size(); i++) { if (!utils::isa(cnode->input(i)) && !utils::isa(cnode->input(i))) { return false; } } return true; } return false; } size_t GetOutputTensorNum(const AnfNodePtr &node) { if (node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return 0; } auto type = node->Type(); if (type == nullptr) { return 1; } if (type->isa()) { auto tuple_type = type->cast(); if (tuple_type == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return 0; } return tuple_type->size(); } else if (type->isa() || type->isa()) { return 1; } else if (type->isa()) { return 0; } else { return 1; } } bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) { if (node == nullptr || graph == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return 0; } auto output_node_list = GetRealNodeUsedList(graph, node); if (output_node_list == nullptr) { MS_LOG(ERROR) << "output node list is nullptr"; return false; } if (output_node_list->size() != 1) { MS_LOG(DEBUG) << "fusion node has multi output nodes"; return true; } return false; } std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, const AnfNodePtr &node) { auto output_node_list = std::make_shared>>(); if (graph == nullptr || node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return nullptr; } auto manager = graph->manager(); if (manager == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return nullptr; } auto iter = manager->node_users().find(node); if (iter == manager->node_users().end()) { MS_LOG(ERROR) << "node has no output in manager"; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR); return nullptr; } auto output_info_list = iter->second; std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list)); return output_node_list; } size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) { MS_ASSERT(tuple_get_item != nullptr); if (tuple_get_item->size() != kTupleGetItemInputSize) { MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!"; return -1; } auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem); MS_ASSERT(output_index_value_node != nullptr); auto value_node = output_index_value_node->cast(); MS_ASSERT(value_node != nullptr); return IntToSize(lite::CastToInt(value_node->value()).front()); } std::shared_ptr>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t output_index) { MS_ASSERT(graph != nullptr); MS_ASSERT(node != nullptr); auto output_node_list = std::make_shared>>(); auto manager = graph->manager(); MS_ASSERT(manager != nullptr); auto iter = manager->node_users().find(node); if (iter == manager->node_users().end()) { MS_LOG(ERROR) << "node has no output in manager"; return output_node_list; } auto output_info_list = iter->second; for (const auto &output_info : output_info_list) { size_t used_output_index; if (GetCNodeType(output_info.first) == schema::PrimitiveType_TupleGetItem) { used_output_index = GetTupleGetItemOutIndex(utils::cast(output_info.first)); } else if (GetCNodeType(node) == schema::PrimitiveType_TupleGetItem) { used_output_index = output_index; } else { if (output_index != 0) { MS_LOG(ERROR) << "node has no output in manager"; return output_node_list; } return output_node_list; } if (used_output_index == output_index) { output_node_list->push_back(output_info); } } return output_node_list; } STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, int32_t *filterH, int32_t *filterW) { MS_ASSERT(oriDims.size() == 4); std::unordered_map maps = { {kKCHW2HWCK, 1}, {kKCHW2HWKC, 1}, {kKCHW2KHWC, 1}, {kKCHW2CKHW, 1}, {kCKHW2HWCK, 2}, {kCKHW2HWKC, 2}, {kCKHW2KHWC, 2}, {kHWCK2KCHW, 3}, {kHWCK2CKHW, 3}, {kHWCK2KHWC, 3}, {kHWKC2KCHW, 4}, {kHWKC2CKHW, 4}, {kHWKC2KHWC, 4}, {kNHWC2KCHW, 5}, {kNHWC2HWCK, 5}, {kNHWC2CKHW, 5}, {kCHWK2HWCK, 6}, {kCHWK2KHWC, 6}, {kKHWC2HWCK, 7}, {kKHWC2CHWK, 7}, }; if (maps.find(type) == maps.end()) { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; } switch (maps.find(type)->second) { case 1: *filterK = oriDims.at(lite::KCHW_K); *filterC = oriDims.at(lite::KCHW_C); *filterH = oriDims.at(lite::KCHW_H); *filterW = oriDims.at(lite::KCHW_W); break; case 2: *filterC = oriDims.at(lite::CKHW_C); *filterK = oriDims.at(lite::CKHW_K); *filterH = oriDims.at(lite::CKHW_H); *filterW = oriDims.at(lite::CKHW_W); break; case 3: *filterH = oriDims.at(lite::HWCK_H); *filterW = oriDims.at(lite::HWCK_W); *filterC = oriDims.at(lite::HWCK_C); *filterK = oriDims.at(lite::HWCK_K); break; case 4: *filterH = oriDims.at(lite::HWKC_H); *filterW = oriDims.at(lite::HWKC_W); *filterK = oriDims.at(lite::HWKC_K); *filterC = oriDims.at(lite::HWKC_C); break; case 5: *filterK = oriDims.at(lite::NHWC_N); *filterH = oriDims.at(lite::NHWC_H); *filterW = oriDims.at(lite::NHWC_W); *filterC = oriDims.at(lite::NHWC_C); break; case 6: *filterC = oriDims.at(lite::CHWK_C); *filterH = oriDims.at(lite::CHWK_H); *filterW = oriDims.at(lite::CHWK_W); *filterK = oriDims.at(lite::CHWK_K); break; case 7: *filterK = oriDims.at(lite::KHWC_K); *filterH = oriDims.at(lite::KHWC_H); *filterW = oriDims.at(lite::KHWC_W); *filterC = oriDims.at(lite::KHWC_C); break; default: MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; } return RET_OK; } STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW) { MS_ASSERT(tensor != nullptr); std::unordered_map maps = { {kKCHW2HWCK, 1}, {kCKHW2HWCK, 1}, {kNHWC2HWCK, 1}, {kKHWC2HWCK, 1}, {kCHWK2HWCK, 1}, {kKCHW2HWKC, 2}, {kCKHW2HWKC, 2}, {kHWCK2KCHW, 3}, {kHWKC2KCHW, 3}, {kNHWC2KCHW, 3}, {kHWCK2CKHW, 4}, {kHWKC2CKHW, 4}, {kNHWC2CKHW, 4}, {kKCHW2CKHW, 4}, {kKHWC2CHWK, 5}, {kKCHW2KHWC, 6}, {kCKHW2KHWC, 6}, {kCHWK2KHWC, 6}, {kHWCK2KHWC, 6}, {kHWKC2KHWC, 6}, }; if (maps.find(type) == maps.end()) { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; } switch (maps.find(type)->second) { case 1: tensor->set_tensor_shape({filterH, filterW, filterC, filterK}); break; case 2: tensor->set_tensor_shape({filterH, filterW, filterK, filterC}); break; case 3: tensor->set_tensor_shape({filterK, filterC, filterH, filterW}); break; case 4: tensor->set_tensor_shape({filterC, filterK, filterH, filterW}); break; case 5: tensor->set_tensor_shape({filterC, filterH, filterW, filterK}); break; case 6: tensor->set_tensor_shape({filterK, filterH, filterW, filterC}); break; default: MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; } return RET_OK; } template void TransFilterDataCHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { for (int c = 0; c < filterC; ++c) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int k = 0; k < filterK; ++k) { p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); if (type == kCHWK2HWCK) { p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kCHWK2KHWC) { p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } *p2Buff = *p1Buff; } } } } } template void TransFilterDataKHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { for (int k = 0; k < filterK; ++k) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); *p2Buff = *p1Buff; } } } } } template void TransFilterDataKCHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { for (int k = 0; k < filterK; ++k) { for (int c = 0; c < filterC; ++c) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); if (type == kKCHW2HWCK) { p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kKCHW2KHWC) { p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } else if (type == kKCHW2CKHW) { p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); } *p2Buff = *p1Buff; } } } } } template void TransFilterDataCKHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { for (int c = 0; c < filterC; ++c) { for (int k = 0; k < filterK; ++k) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); if (type == kCKHW2HWCK) { p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kCKHW2KHWC) { p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } else { p2Buff = buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); } *p2Buff = *p1Buff; } } } } } template void TransFilterDataHWCK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { for (int k = 0; k < filterK; ++k) { p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); if (type == kHWCK2KCHW) { p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); } else if (type == kHWCK2CKHW) { p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } *p2Buff = *p1Buff; } } } } } template void TransFilterDataHWKC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { for (int k = 0; k < filterK; ++k) { p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); if (type == kHWKC2KCHW) { p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } *p2Buff = *p1Buff; } } } } } template void TransFilterDataNHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { for (int k = 0; k < filterK; ++k) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); if (type == kNHWC2HWCK) { p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kNHWC2CKHW) { p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); } *p2Buff = *p1Buff; } } } } } template void TransFilterDataKHWC2CHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { for (int k = 0; k < filterK; ++k) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); *p2Buff = *p1Buff; } } } } } template static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW) { MS_ASSERT(tensor != nullptr); int count = filterH * filterW * filterC * filterK; if (count <= 0) { MS_LOG(ERROR) << "Dim size invalid"; return RET_ERROR; } std::unique_ptr buf(new (std::nothrow) T[count]); if (buf == nullptr) { MS_LOG(ERROR) << "new buf failed"; return RET_ERROR; } void *originWeightData = tensor->tensor_addr(); T *weightData = static_cast(originWeightData); if (weightData == nullptr) { MS_LOG(ERROR) << "weightData is nullptr"; return RET_ERROR; } T *p1Buff = nullptr; T *p2Buff = nullptr; std::unordered_map maps = { {kCHWK2HWCK, 1}, {kCHWK2KHWC, 1}, {kKHWC2HWCK, 2}, {kKCHW2HWCK, 3}, {kKCHW2CKHW, 3}, {kKCHW2KHWC, 3}, {kKCHW2HWKC, 3}, {kCKHW2HWCK, 4}, {kCKHW2KHWC, 4}, {kCKHW2HWKC, 4}, {kHWCK2KCHW, 5}, {kHWCK2CKHW, 5}, {kHWCK2KHWC, 5}, {kHWKC2KCHW, 6}, {kHWKC2KHWC, 6}, {kHWKC2CKHW, 6}, {kNHWC2HWCK, 7}, {kNHWC2KCHW, 7}, {kNHWC2CKHW, 7}, {kKHWC2CHWK, 8}, }; if (maps.find(type) == maps.end()) { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; } switch (maps.find(type)->second) { case 1: { TransFilterDataCHWK(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; case 2: { TransFilterDataKHWC(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; case 3: { TransFilterDataKCHW(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; case 4: { TransFilterDataCKHW(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; case 5: { TransFilterDataHWCK(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; case 6: { TransFilterDataHWKC(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; case 7: { TransFilterDataNHWC(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; case 8: { TransFilterDataKHWC2CHWK(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; default: { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; } } auto ret = ::memcpy_s(tensor->tensor_addr(), count * sizeof(T), buf.get(), count * sizeof(T)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s failed: " << ret; return RET_ERROR; } return RET_OK; } template static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type) { MS_ASSERT(tensor != nullptr); auto oriDims = tensor->tensor_shape(); if (oriDims.size() != (size_t)lite::DIM_DEFAULT_SIZE) { MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size(); return lite::RET_ERROR; } int32_t filterH; int32_t filterW; int32_t filterC; int32_t filterK; auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW); if (status != lite::RET_OK) { MS_LOG(ERROR) << "GetFilterDim failed: " << status; return status; } status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW); if (status != lite::RET_OK) { MS_LOG(ERROR) << "SetFilterDim failed: " << status; return status; } status = TransFilterData(tensor, type, filterK, filterC, filterH, filterW); if (status != lite::RET_OK) { MS_LOG(ERROR) << "TransFilterData failed: " << status; return status; } return lite::RET_OK; } STATUS TransFilterFormatWithType(const ParamValueLitePtr &tensor, TypeId data_type, kTransFilterType trans_filter_type) { if (data_type == kNumberTypeFloat32) { return TransFilterFormat(tensor, trans_filter_type); } else if (data_type == kNumberTypeUInt8) { return TransFilterFormat(tensor, trans_filter_type); } else if (data_type == kNumberTypeInt8) { return TransFilterFormat(tensor, trans_filter_type); } else if (data_type == kNumberTypeFloat16) { return TransFilterFormat(tensor, trans_filter_type); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } } STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format) { if (tensor == nullptr) { return lite::RET_NULL_PTR; } auto ori_dims = tensor->tensor_shape(); if (ori_dims.size() != (size_t)lite::DIM_DEFAULT_SIZE) { MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << ori_dims.size(); return lite::RET_ERROR; } auto src_format = tensor->format(); auto data_type = tensor->tensor_type(); lite::STATUS status; std::unordered_map khwc_trans_maps = { {schema::Format::Format_KCHW, kKCHW2KHWC}, {schema::Format::Format_CKHW, kCKHW2KHWC}, {schema::Format::Format_CHWK, kCHWK2KHWC}, {schema::Format::Format_HWCK, kHWCK2KHWC}, {schema::Format::Format_HWKC, kHWKC2KHWC}, }; std::unordered_map hwck_trans_maps = { {schema::Format::Format_KCHW, kKCHW2HWCK}, {schema::Format::Format_KHWC, kKHWC2HWCK}, {schema::Format::Format_CKHW, kCKHW2HWCK}, {schema::Format::Format_CHWK, kCHWK2HWCK}, }; std::unordered_map kchw_trans_maps = { {schema::Format::Format_HWCK, kHWCK2KCHW}, {schema::Format::Format_HWKC, kHWKC2KCHW}, {schema::Format::Format_KHWC, kKHWC2KCHW}, {schema::Format::Format_CKHW, kCKHW2KCHW}, {schema::Format::Format_CHWK, kCHWK2KCHW}, }; std::unordered_map ckhw_trans_maps = {{schema::Format::Format_HWCK, kHWCK2CKHW}, {schema::Format::Format_HWKC, kHWKC2CKHW}, {schema::Format::Format_KCHW, kKCHW2CKHW}}; std::unordered_map chwk_trans_maps = {{schema::Format::Format_KHWC, kKHWC2CHWK}}; if (src_format == dst_format) { return RET_OK; } switch (dst_format) { case schema::Format::Format_KHWC: { if (khwc_trans_maps.find(static_cast(src_format)) == khwc_trans_maps.end()) { MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast(src_format)) << " to " << EnumNameFormat(dst_format); return RET_ERROR; } else { status = TransFilterFormatWithType(tensor, data_type, khwc_trans_maps.find(static_cast(src_format))->second); } } break; case schema::Format::Format_HWCK: { if (hwck_trans_maps.find(static_cast(src_format)) == hwck_trans_maps.end()) { MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast(src_format)) << " to " << EnumNameFormat(dst_format); return RET_ERROR; } else { status = TransFilterFormatWithType(tensor, data_type, hwck_trans_maps.find(static_cast(src_format))->second); } } break; case schema::Format::Format_KCHW: { if (kchw_trans_maps.find(static_cast(src_format)) == kchw_trans_maps.end()) { MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast(src_format)) << " to " << EnumNameFormat(dst_format); return RET_ERROR; } else { status = TransFilterFormatWithType(tensor, data_type, kchw_trans_maps.find(static_cast(src_format))->second); } } break; case schema::Format::Format_CKHW: { if (ckhw_trans_maps.find(static_cast(src_format)) == ckhw_trans_maps.end()) { MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast(src_format)) << " to " << EnumNameFormat(dst_format); return RET_ERROR; } else { status = TransFilterFormatWithType(tensor, data_type, ckhw_trans_maps.find(static_cast(src_format))->second); } } break; case schema::Format::Format_CHWK: { if (chwk_trans_maps.find(static_cast(src_format)) == chwk_trans_maps.end()) { MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast(src_format)) << " to " << EnumNameFormat(dst_format); return RET_ERROR; } else { status = TransFilterFormatWithType(tensor, data_type, chwk_trans_maps.find(static_cast(src_format))->second); } } break; default: MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); return RET_ERROR; } if (status != RET_OK) { MS_LOG(ERROR) << "TransFilterData failed: " << status; return status; } return RET_OK; } ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const int32_t &data, const std::string &node_name) { MS_ASSERT(func_graph != nullptr); MS_ASSERT(data.size() != 0); auto param_node = func_graph->add_parameter(); auto type_ptr = TypeIdToType(kNumberTypeInt32); auto abstract_tensor = std::make_shared(type_ptr); param_node->set_abstract(abstract_tensor); param_node->set_name(node_name); ParamValueLitePtr param_value = std::make_shared(); MS_ASSERT(param_value != nullptr); param_value->set_tensor_shape({1}); param_value->set_tensor_type(kNumberTypeInt32); char *default_data = new (std::nothrow) char[sizeof(int32_t)]; *(reinterpret_cast(default_data)) = data; param_value->SetTensorData(default_data, sizeof(int32_t)); param_node->set_default_param(param_value); return param_node; } ParameterPtr BuildIntVecParameterNode(const FuncGraphPtr &func_graph, const std::vector &data, const std::string &node_name) { MS_ASSERT(func_graph != nullptr); MS_ASSERT(data.size() != 0); auto param_node = func_graph->add_parameter(); auto type_ptr = TypeIdToType(kNumberTypeInt32); std::vector shape_vector{static_cast(data.size())}; auto abstract_tensor = std::make_shared(type_ptr, shape_vector); param_node->set_abstract(abstract_tensor); param_node->set_name(node_name); ParamValueLitePtr param_value = std::make_shared(); MS_ASSERT(param_value != nullptr); std::vector shape{static_cast(data.size())}; param_value->set_tensor_shape(shape); param_value->set_tensor_type(kNumberTypeInt32); char *default_data = new (std::nothrow) char[data.size() * sizeof(int32_t)]; if (memcpy_s(default_data, data.size() * sizeof(int32_t), data.data(), data.size() * sizeof(int32_t)) != EOK) { MS_LOG(ERROR) << "memcpy data failed."; delete[] default_data; return nullptr; } param_value->SetTensorData(default_data, data.size() * sizeof(int32_t)); param_node->set_default_param(param_value); return param_node; } ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const std::vector> &data, const std::string &node_name) { MS_ASSERT(func_graph != nullptr); MS_ASSERT(data.size() != 0); auto param_node = func_graph->add_parameter(); auto type_ptr = TypeIdToType(kNumberTypeInt32); std::vector shape_vector; shape_vector.push_back(data.size()); shape_vector.push_back(2); auto abstract_tensor = std::make_shared(type_ptr, shape_vector); param_node->set_abstract(abstract_tensor); param_node->set_name(node_name); ParamValueLitePtr param_value = std::make_shared(); MS_ASSERT(param_value != nullptr); std::vector shape; shape.push_back(data.size()); shape.push_back(2); param_value->set_tensor_shape(shape); param_value->set_tensor_type(kNumberTypeInt32); std::vector data_1d; for (auto pair : data) { data_1d.insert(data_1d.end(), pair.begin(), pair.end()); } auto size = data_1d.size() * sizeof(int32_t); char *default_data = new (std::nothrow) char[size]; if (memcpy_s(default_data, size, data_1d.data(), size) != EOK) { MS_LOG(ERROR) << "memcpy data failed."; delete[] default_data; return nullptr; } param_value->SetTensorData(default_data, size); param_node->set_default_param(param_value); return param_node; } ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data, const std::string &node_name) { MS_ASSERT(func_graph != nullptr); MS_ASSERT(data.size() != 0); auto param_node = func_graph->add_parameter(); auto type_ptr = TypeIdToType(kNumberTypeFloat32); std::vector shape_vector = {1}; auto abstract_tensor = std::make_shared(type_ptr, shape_vector); param_node->set_abstract(abstract_tensor); param_node->set_name(node_name); ParamValueLitePtr param_value = std::make_shared(); MS_ASSERT(param_value != nullptr); param_value->set_tensor_shape({1}); param_value->set_tensor_type(kNumberTypeFloat32); char *default_data = new (std::nothrow) char[sizeof(float)]; if (memcpy_s(default_data, sizeof(float), &data, sizeof(float)) != EOK) { MS_LOG(ERROR) << "memcpy data failed."; delete[] default_data; return nullptr; } param_value->SetTensorData(default_data, sizeof(float)); param_node->set_default_param(param_value); return param_node; } } // namespace opt } // namespace mindspore