/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tools/optimizer/common/gllo_utils.h" #include #include #include #include "src/ops/primitive_c.h" #include "src/common/common.h" #include "frontend/operator/ops.h" #include "backend/optimizer/common/helper.h" namespace mindspore { namespace opt { namespace { constexpr auto kAnfPrimitiveIndex = 0; bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { if (node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } if (!node->isa()) { return false; } auto cnode = node->cast(); if (cnode == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); } bool IsRealKernel(const AnfNodePtr &node) { if (node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } // parameter and value node is not a real kernel too if (!node->isa()) { return true; } auto cnode = node->cast(); if (cnode == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } if (cnode->inputs().empty()) { MS_LOG(ERROR) << "Illegal null input of cnode(%s)" << node->DebugString(); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INPUT_TENSOR_ERROR); return false; } auto input = cnode->inputs()[0]; bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || IsPrimitive(input, prim::kPrimTensorSummary) || IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); return !is_virtual_node; } ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } if (utils::isa(sexp)) { return NewValueNode(utils::cast(sexp)); } return nullptr; } CNodePtr CreateCNodeWithGraph(const std::vector &input_nodes, const BaseRef &graph) { if (utils::isa(graph)) { return std::make_shared(input_nodes, utils::cast(graph)); } if (utils::isa(graph)) { return std::make_shared(input_nodes, utils::cast(graph)); } return nullptr; } VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { if (utils::isa(graph)) { MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); return std::make_shared(utils::cast(sexp), nullptr); } if (utils::isa(graph)) { MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); return std::make_shared(utils::cast(sexp), utils::cast(graph)); } MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); return nullptr; } AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); std::vector input_nodes; const auto &tuple = utils::cast(sexp); if (multigraph && utils::isa(graph)) { for (auto &x : tuple) { AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); input_nodes.push_back(node); } VarPtr var_ptr = utils::cast(graph); return std::make_shared(input_nodes, var_ptr); } for (auto &x : tuple) { AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); input_nodes.push_back(node); } return CreateCNodeWithGraph(input_nodes, graph); } } // namespace bool AnfEqual(const BaseRef &a, const BaseRef &b) { if (utils::isa(a) && utils::isa(b)) { auto a_node = utils::cast(a); auto b_node = utils::cast(b); if (a_node == nullptr || b_node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } if (IsValueNode(a_node) && IsValueNode(b_node)) { auto a_value_node = a_node->cast(); auto b_value_node = b_node->cast(); if (a_value_node == nullptr || b_value_node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } auto a_value = a_value_node->value(); auto b_value = b_value_node->value(); if (a_value == nullptr || b_value == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } auto a_prim = a_value->cast(); auto b_prim = b_value->cast(); if (a_prim == nullptr || b_prim == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } return a_prim->cast()->Type() == b_prim->cast()->Type(); } else if (a_node->isa() && b_node->isa()) { auto a_value_node_ptr = a_node->cast(); auto b_value_node_ptr = b_node->cast(); if (a_value_node_ptr == nullptr || b_value_node_ptr == nullptr) { MS_LOG(ERROR) << "cast value node ptr fail"; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } auto a_value_ptr = a_value_node_ptr->value(); auto b_value_ptr = b_value_node_ptr->value(); if (a_value_ptr == nullptr || b_value_ptr == nullptr) { MS_LOG(ERROR) << "value ptr is nullptr"; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } if (utils::isa(a_value_ptr) && utils::isa(b_value_ptr)) { auto a_obj = (lite::PrimitiveC *)(a_value_ptr.get()); auto b_obj = (lite::PrimitiveC *)(b_value_ptr.get()); return (*a_obj) == (*b_obj); } else { return (*a_value_ptr) == (*b_value_ptr); } } } if (a.m_ptr->isa() && b.m_ptr->isa()) { auto a_value_node_ptr = a.m_ptr->cast(); auto b_value_node_ptr = b.m_ptr->cast(); return a_value_node_ptr->Type() == b_value_node_ptr->Type(); } return a == b; } bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { // To matchCNode and Kernel's type if (utils::isa(a) && utils::isa(b)) { return true; } return a.type() == b.type(); } AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); if (primitive_vars == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return nullptr; } if (utils::isa(sexp)) { return HandleSexpVector(sexp, graph, primitive_vars, multigraph); } if (utils::isa(sexp)) { auto var_ptr = utils::cast(sexp); if (var_ptr == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return nullptr; } if (var_ptr->primitive()) { (*primitive_vars)[var_ptr->primitive()] = var_ptr; return NewValueNode(var_ptr->primitive()); } return CreateVarNodeWithSexp(sexp, graph); } if (utils::isa(sexp)) { return utils::cast(sexp); } auto value_node = CreateValueNodeWithSexp(sexp); if (value_node == nullptr) { MS_LOG(ERROR) << "sexp cannot converted. sexp: " << sexp.ToString(); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return nullptr; } return value_node; } bool IsRealCNodeKernel(const AnfNodePtr &node) { if (node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } // parameter and value node is not a real cnode kernel if (!node->isa()) { return false; } // return considered as a real node if (CheckPrimitiveType(node, prim::kPrimReturn)) { return true; } return IsRealKernel(node); } bool IsGraphKernel(const AnfNodePtr &node) { if (node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } // graph kernel should be a real cnode kernel. if (!IsRealCNodeKernel(node)) { return false; } auto cnode = node->cast(); if (cnode == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } auto input = cnode->input(kAnfPrimitiveIndex); // graph kernel should has func_graph as first input. if (!IsValueNode(input)) { return false; } auto func_graph = GetValueNode(input); if (func_graph == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); } int CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) { if (graph == nullptr) { MS_LOG(ERROR) << "The graph is null."; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return lite::RET_NULL_PTR; } return lite::RET_OK; } int CheckIfAnfNodeIsNull(const AnfNodePtr &node) { if (node == nullptr) { MS_LOG(ERROR) << "The AnfNode is null."; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return lite::RET_NULL_PTR; } return lite::RET_OK; } int CheckIfCNodeIsNull(const CNodePtr &node) { if (node == nullptr) { MS_LOG(ERROR) << "The CNode is null."; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return lite::RET_NULL_PTR; } return lite::RET_OK; } int CheckIfVarIsNull(const VarPtr &var) { if (var == nullptr) { MS_LOG(ERROR) << "The Var is null."; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return lite::RET_NULL_PTR; } return lite::RET_OK; } int CheckIfNodeIsParam(const AnfNodePtr &node) { if (node != nullptr && !utils::isa(node)) { MS_LOG(ERROR) << "The Node is not param."; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR); return lite::RET_INVALID_OP_ATTR; } return lite::RET_OK; } int CheckInputSize(const CNodePtr &node, const int size) { if (static_cast(node->inputs().size()) != size) { MS_LOG(ERROR) << "The input size of node must be " << size << ", but it is" << node->inputs().size(); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR); return lite::RET_INVALID_OP_ATTR; } return lite::RET_OK; } int CheckLeastInputSize(const CNodePtr &node, const int size) { if (static_cast(node->inputs().size()) < size) { MS_LOG(ERROR) << "The input size of node must be " << size << ", but it is" << node->inputs().size(); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR); return lite::RET_INVALID_OP_ATTR; } return lite::RET_OK; } ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, const ParamValueLitePtr &weight_tensor) { auto bias_parameter = func_graph->add_parameter(); MS_ASSERT(bias_parameter != nullptr); std::vector shape = {kernel_num}; auto abstract_tensor = std::make_shared(TypeIdToType(weight_tensor->tensor_type()), shape); bias_parameter->set_abstract(abstract_tensor); ParamValueLitePtr param_value = std::make_shared(); MS_ASSERT(param_value != nullptr); param_value->set_tensor_addr(bias_data); param_value->set_tensor_size(kernel_num * sizeof(float) / sizeof(uint8_t)); param_value->set_format(weight_tensor->format()); param_value->set_tensor_type(weight_tensor->tensor_type()); param_value->set_tensor_shape(shape); bias_parameter->set_default_param(param_value); return bias_parameter; } schema::PrimitiveType GetCNodeType(const BaseRef &n) { ValueNodePtr value_node; if (utils::isa(n)) { auto in = utils::cast(n); value_node = in->input(0)->cast(); } else if (utils::isa(n)) { value_node = utils::cast(n); } else { MS_LOG(ERROR) << "only value node or cnode has type"; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR); return schema::PrimitiveType_NONE; } if (value_node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return schema::PrimitiveType_NONE; } auto value = value_node->value(); MS_ASSERT(value != nullptr); if (utils::isa(value)) { auto primitive = value->cast(); MS_ASSERT(primitive != nullptr); return (schema::PrimitiveType)primitive->Type(); } else if (utils::isa(value)) { auto primitive = value->cast(); MS_ASSERT(primitive != nullptr); MS_LOG(INFO) << "anf primitive node type:" << primitive->name(); return schema::PrimitiveType_NONE; } return schema::PrimitiveType_NONE; } ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node) { MS_ASSERT(node != nullptr); if (!utils::isa(node)) { MS_LOG(ERROR) << "get lite param value node must paramter"; return nullptr; } auto param = node->cast(); MS_ASSERT(param != nullptr); auto param_value = std::dynamic_pointer_cast(param->default_param()); return param_value; } bool IsParamNode(const BaseRef &n) { if (!utils::isa(n)) { return false; } auto param = utils::cast(n)->default_param(); auto tensor = std::dynamic_pointer_cast(param); if (tensor == nullptr) { return false; } return tensor->tensor_addr() != nullptr; } bool IsConvNode(const BaseRef &n) { if (utils::isa(n) || utils::isa(n)) { auto type = opt::GetCNodeType(n); return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D; } return false; } bool IsPoolingNode(const BaseRef &n) { if (utils::isa(n) || utils::isa(n)) { auto type = opt::GetCNodeType(n); return type == schema::PrimitiveType_Pooling; } return false; } bool IsQuantNode(const BaseRef &n) { if (utils::isa(n) || utils::isa(n)) { auto type = opt::GetCNodeType(n); return type == schema::PrimitiveType_QuantDTypeCast; } return false; } bool CheckIsAllInputsParam(const AnfNodePtr &node) { if (utils::isa(node)) { auto cnode = node->cast(); for (size_t i = 1; i < cnode->inputs().size(); i++) { if (!utils::isa(cnode->input(i))) { return false; } } return true; } return false; } size_t GetOutputTensorNum(const AnfNodePtr &node) { if (node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return 0; } auto type = node->Type(); if (type == nullptr) { return 1; } if (type->isa()) { auto tuple_type = type->cast(); if (tuple_type == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return 0; } return tuple_type->size(); } else if (type->isa() || type->isa()) { return 1; } else if (type->isa()) { return 0; } else { return 1; } } bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) { auto output_node_list = GetRealNodeUsedList(graph, node); if (output_node_list->size() != 1) { MS_LOG(DEBUG) << "fusion node has multi output nodes"; return true; } return false; } std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, const AnfNodePtr &node) { auto output_node_list = std::make_shared>>(); if (graph == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return nullptr; } auto manager = graph->manager(); if (manager == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return nullptr; } auto iter = manager->node_users().find(node); if (iter == manager->node_users().end()) { MS_LOG(ERROR) << "node has no output in manager"; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_FIND_OP); return nullptr; } auto output_info_list = iter->second; std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list)); return output_node_list; } size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) { MS_ASSERT(tuple_get_item != nullptr); if (tuple_get_item->size() != kTupleGetItemInputSize) { MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!"; return -1; } auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem); MS_ASSERT(output_index_value_node != nullptr); auto value_node = output_index_value_node->cast(); MS_ASSERT(value_node != nullptr); return IntToSize(GetValue(value_node->value())); } std::shared_ptr>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t output_index) { MS_ASSERT(graph != nullptr); MS_ASSERT(node != nullptr); auto output_node_list = std::make_shared>>(); auto manager = graph->manager(); MS_ASSERT(manager != nullptr); auto iter = manager->node_users().find(node); if (iter == manager->node_users().end()) { MS_LOG(ERROR) << "node has no output in manager"; return output_node_list; } auto output_info_list = iter->second; for (const auto &output_info : output_info_list) { size_t used_output_index; if (GetCNodeType(output_info.first) == schema::PrimitiveType_TupleGetItem) { used_output_index = GetTupleGetItemOutIndex(utils::cast(output_info.first)); } else if (GetCNodeType(node) == schema::PrimitiveType_TupleGetItem) { used_output_index = output_index; } else { if (output_index != 0) { MS_LOG(ERROR) << "node has no output in manager"; return output_node_list; } return output_node_list; } if (used_output_index == output_index) { output_node_list->push_back(output_info); } } return output_node_list; } STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, int32_t *filterH, int32_t *filterW) { MS_ASSERT(oriDims.size() == 4); if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) { *filterK = oriDims.at(lite::KCHW_K); *filterC = oriDims.at(lite::KCHW_C); *filterH = oriDims.at(lite::KCHW_H); *filterW = oriDims.at(lite::KCHW_W); } else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) { *filterC = oriDims.at(lite::CKHW_C); *filterK = oriDims.at(lite::CKHW_K); *filterH = oriDims.at(lite::CKHW_H); *filterW = oriDims.at(lite::CKHW_W); } else if (type == kHWCK2KCHW || type == kHWCK2CKHW) { *filterH = oriDims.at(lite::HWCK_H); *filterW = oriDims.at(lite::HWCK_W); *filterC = oriDims.at(lite::HWCK_C); *filterK = oriDims.at(lite::HWCK_K); } else if (type == kHWKC2KCHW || type == kHWKC2CKHW) { *filterH = oriDims.at(lite::HWKC_H); *filterW = oriDims.at(lite::HWKC_W); *filterK = oriDims.at(lite::HWKC_K); *filterC = oriDims.at(lite::HWKC_C); } else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) { *filterK = oriDims.at(lite::NHWC_N); *filterH = oriDims.at(lite::NHWC_H); *filterW = oriDims.at(lite::NHWC_W); *filterC = oriDims.at(lite::NHWC_C); } else if (type == kCHWK2HWCK || type == kCHWK2KHWC) { *filterC = oriDims.at(lite::CHWK_C); *filterH = oriDims.at(lite::CHWK_H); *filterW = oriDims.at(lite::CHWK_W); *filterK = oriDims.at(lite::CHWK_K); } else if (type == kKHWC2HWCK || type == kKHWC2CHWK) { *filterK = oriDims.at(lite::KHWC_K); *filterH = oriDims.at(lite::KHWC_H); *filterW = oriDims.at(lite::KHWC_W); *filterC = oriDims.at(lite::KHWC_C); } else { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; } return RET_OK; } STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW) { MS_ASSERT(tensor != nullptr); if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) { tensor->set_tensor_shape({filterH, filterW, filterC, filterK}); } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) { tensor->set_tensor_shape({filterH, filterW, filterK, filterC}); } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) { tensor->set_tensor_shape({filterK, filterC, filterH, filterW}); } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) { tensor->set_tensor_shape({filterC, filterK, filterH, filterW}); } else if (type == kKHWC2CHWK) { tensor->set_tensor_shape({filterC, filterH, filterW, filterK}); } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) { tensor->set_tensor_shape({filterK, filterH, filterW, filterC}); } else { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; } return RET_OK; } template static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW) { MS_ASSERT(tensor != nullptr); int count = filterH * filterW * filterC * filterK; if (count <= 0) { MS_LOG(ERROR) << "Dim size invalid"; return RET_ERROR; } std::unique_ptr buf(new (std::nothrow) T[count]); if (buf == nullptr) { MS_LOG(ERROR) << "new buf failed"; return RET_ERROR; } void *originWeightData = tensor->tensor_addr(); T *weightData = static_cast(originWeightData); if (weightData == nullptr) { MS_LOG(ERROR) << "weightData is nullptr"; return RET_ERROR; } T *p1Buff = nullptr; T *p2Buff = nullptr; switch (type) { case kCHWK2HWCK: case kCHWK2KHWC: { for (int c = 0; c < filterC; ++c) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int k = 0; k < filterK; ++k) { p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); if (type == kCHWK2HWCK) { p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kCHWK2KHWC) { p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } *p2Buff = *p1Buff; } } } } } break; case kKHWC2HWCK: { for (int k = 0; k < filterK; ++k) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); *p2Buff = *p1Buff; } } } } } break; case kKCHW2HWCK: case kKCHW2CKHW: case kKCHW2KHWC: case kKCHW2HWKC: { for (int k = 0; k < filterK; ++k) { for (int c = 0; c < filterC; ++c) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); if (type == kKCHW2HWCK) { p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kKCHW2KHWC) { p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } else if (type == kKCHW2CKHW) { p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); } *p2Buff = *p1Buff; } } } } } break; case kCKHW2HWCK: case kCKHW2KHWC: case kCKHW2HWKC: { for (int c = 0; c < filterC; ++c) { for (int k = 0; k < filterK; ++k) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); if (type == kCKHW2HWCK) { p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kCKHW2KHWC) { p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } else { p2Buff = buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); } *p2Buff = *p1Buff; } } } } } break; case kHWCK2KCHW: case kHWCK2CKHW: { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { for (int k = 0; k < filterK; ++k) { p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); if (type == kHWCK2KCHW) { p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } *p2Buff = *p1Buff; } } } } } break; case kHWKC2KCHW: case kHWKC2CKHW: { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { for (int k = 0; k < filterK; ++k) { p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); if (type == kHWKC2KCHW) { p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } *p2Buff = *p1Buff; } } } } } break; case kNHWC2HWCK: case kNHWC2KCHW: case kNHWC2CKHW: { for (int k = 0; k < filterK; ++k) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); if (type == kNHWC2HWCK) { p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kNHWC2CKHW) { p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); } *p2Buff = *p1Buff; } } } } } break; case kKHWC2CHWK: { for (int k = 0; k < filterK; ++k) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); *p2Buff = *p1Buff; } } } } } break; default: { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; } } auto ret = ::memcpy_s(tensor->tensor_addr(), count * sizeof(T), buf.get(), count * sizeof(T)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s failed: " << ret; return RET_ERROR; } return RET_OK; } template static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type) { MS_ASSERT(tensor != nullptr); auto oriDims = tensor->tensor_shape(); if (oriDims.size() != (size_t)lite::DIM_DEFAULT_SIZE) { MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size(); return lite::RET_ERROR; } int32_t filterH; int32_t filterW; int32_t filterC; int32_t filterK; auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW); if (status != lite::RET_OK) { MS_LOG(ERROR) << "GetFilterDim failed: " << status; return status; } status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW); if (status != lite::RET_OK) { MS_LOG(ERROR) << "SetFilterDim failed: " << status; return status; } status = TransFilterData(tensor, type, filterK, filterC, filterH, filterW); if (status != lite::RET_OK) { MS_LOG(ERROR) << "TransFilterData failed: " << status; return status; } return lite::RET_OK; } STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format) { if (tensor == nullptr) { return lite::RET_NULL_PTR; } auto ori_dims = tensor->tensor_shape(); if (ori_dims.size() != (size_t)lite::DIM_DEFAULT_SIZE) { MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << ori_dims.size(); return lite::RET_ERROR; } auto src_format = tensor->format(); auto data_type = tensor->tensor_type(); lite::STATUS status; switch (dst_format) { case schema::Format::Format_KHWC: { switch (src_format) { case schema::Format::Format_KCHW: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kKCHW2KHWC); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kKCHW2KHWC); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKCHW2KHWC); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kKCHW2KHWC); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; case schema::Format::Format_CKHW: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kCKHW2KHWC); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kCKHW2KHWC); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCKHW2KHWC); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kCKHW2KHWC); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; case schema::Format::Format_CHWK: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kCHWK2KHWC); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kCHWK2KHWC); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCHWK2KHWC); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kCHWK2KHWC); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; case schema::Format::Format_KHWC: return RET_OK; default: MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); return RET_ERROR; } } break; case schema::Format::Format_HWCK: { switch (src_format) { case schema::Format::Format_KCHW: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kKCHW2HWCK); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kKCHW2HWCK); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKCHW2HWCK); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kKCHW2HWCK); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; case schema::Format::Format_KHWC: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kKHWC2HWCK); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kKHWC2HWCK); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKHWC2HWCK); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kKHWC2HWCK); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; case schema::Format::Format_CKHW: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kCKHW2HWCK); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kCKHW2HWCK); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCKHW2HWCK); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kCKHW2HWCK); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; case schema::Format::Format_CHWK: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kCHWK2HWCK); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kCHWK2HWCK); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCHWK2HWCK); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kCHWK2HWCK); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return lite::RET_ERROR; } break; case schema::Format::Format_HWCK: return RET_OK; default: MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); return RET_ERROR; } } break; case schema::Format::Format_KCHW: { switch (src_format) { case schema::Format::Format_KCHW: return RET_OK; case schema::Format::Format_HWCK: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kHWCK2KCHW); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kHWCK2KCHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kHWCK2KCHW); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kHWCK2KCHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; case schema::Format::Format_HWKC: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kHWKC2KCHW); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kHWKC2KCHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kHWKC2KCHW); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kHWCK2KCHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; case schema::Format::Format_KHWC: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kKHWC2KCHW); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kKHWC2KCHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKHWC2KCHW); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kKHWC2KCHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; case schema::Format::Format_CKHW: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kCKHW2KCHW); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kCKHW2KCHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCKHW2KCHW); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kCKHW2KCHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; case schema::Format::Format_CHWK: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kCHWK2KCHW); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kCHWK2KCHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCHWK2KCHW); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kCKHW2KCHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; default: MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); return RET_ERROR; } } break; case schema::Format::Format_CKHW: { switch (src_format) { case schema::Format::Format_HWCK: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kHWCK2CKHW); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kHWCK2CKHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kHWCK2CKHW); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kHWCK2CKHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; case schema::Format::Format_HWKC: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kHWKC2CKHW); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kHWKC2CKHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kHWKC2CKHW); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kHWKC2CKHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; case schema::Format::Format_KCHW: if (data_type == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kKCHW2CKHW); } else if (data_type == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kKCHW2CKHW); } else if (data_type == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKCHW2CKHW); } else if (data_type == kNumberTypeFloat16) { status = TransFilterFormat(tensor, kKCHW2CKHW); } else { MS_LOG(ERROR) << "Unsupported data_type: " << data_type; return RET_ERROR; } break; case schema::Format::Format_CKHW: return RET_OK; default: MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); return RET_ERROR; } } break; default: MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); return RET_ERROR; } if (status != RET_OK) { MS_LOG(ERROR) << "TransFilterData failed: " << status; return status; } return RET_OK; } } // namespace opt } // namespace mindspore