/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tools/common/node_util.h" #include #include #include "src/common/common.h" #include "src/common/log_adapter.h" #include "tools/common/graph_util.h" #include "tools/common/tensor_util.h" namespace mindspore { namespace lite { static const std::vector nhwcOpList = { #ifdef SUPPORT_TRAIN schema::PrimitiveType_Conv2DGradFilter, schema::PrimitiveType_Conv2DGradInput, schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_BiasGrad, schema::PrimitiveType_BNGrad, schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_ApplyMomentum, schema::PrimitiveType_Sgd, #endif schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_Pooling, schema::PrimitiveType_Resize, schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm, schema::PrimitiveType_PReLU, schema::PrimitiveType_BiasAdd}; static const std::vector nhwcOpDualInputList = { #ifdef SUPPORT_TRAIN schema::PrimitiveType_Conv2DGradFilter, schema::PrimitiveType_BNGrad #endif }; static const std::vector nhwcOpAllInputList = { #ifdef SUPPORT_TRAIN schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_ActivationGrad #endif }; static const std::vector fp32FullOpList = { schema::PrimitiveType_Concat, schema::PrimitiveType_Add, schema::PrimitiveType_Floor}; // fp32 ops support C4 and nhwc in fp32 static const std::vector int8NeedNhwcOpList = {}; static const std::vector int8OpList = { schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation, schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection, schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin, schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm, schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div, schema::PrimitiveType_Mul, schema::PrimitiveType_Slice, schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split, schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_TopK, schema::PrimitiveType_Unsqueeze, schema::PrimitiveType_MatMul, schema::PrimitiveType_Pad, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_Scale}; static const std::vector needInsertOpList = { schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add, schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop}; static const std::unordered_map nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}}; std::unordered_map GetNc2NhAxisMap() { return nc2NhAxisMap; } std::vector GetInsertOpList() { return needInsertOpList; } std::vector Getfp32FullOpList() { return fp32FullOpList; } std::vector GetNhwcOpList() { return nhwcOpList; } std::vector GetNhwcDualInputOpList() { return nhwcOpDualInputList; } std::vector GetNhwcAllInputOpList() { return nhwcOpAllInputList; } std::vector GetUint8NhwcOpList() { return int8NeedNhwcOpList; } std::vector GetInt8OpList() { return int8OpList; } STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector &src_dims, mindspore::schema::Format dst_format, std::vector *dst_dims) { if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) { MS_LOG(ERROR) << "Convert format , src size " << src_dims.size() << " <3 or src format is equal to dst format,not need convert"; *dst_dims = src_dims; return RET_PARAM_INVALID; } std::vector nchw_dim; switch (src_format) { case schema::Format::Format_NCHW: nchw_dim = src_dims; break; case schema::Format::Format_NHWC: if (src_dims.size() == DIM_DEFAULT_SIZE) { nchw_dim.push_back(src_dims[NHWC_N]); nchw_dim.push_back(src_dims[NHWC_C]); nchw_dim.push_back(src_dims[NHWC_H]); nchw_dim.push_back(src_dims[NHWC_W]); } else { nchw_dim.push_back(src_dims[HWC_C]); nchw_dim.push_back(src_dims[HWC_H]); nchw_dim.push_back(src_dims[HWC_W]); } break; default: MS_LOG(ERROR) << "Not support src format: " << EnumNameFormat(src_format); return RET_ERROR; } if (nchw_dim.size() == 0) { MS_LOG(ERROR) << "Param nchw_dim is empty!"; return RET_ERROR; } switch (dst_format) { case schema::Format::Format_NCHW: *dst_dims = nchw_dim; break; case schema::Format::Format_NHWC: if (src_dims.size() == DIM_DEFAULT_SIZE) { dst_dims->push_back(nchw_dim[NCHW_N]); dst_dims->push_back(nchw_dim[NCHW_H]); dst_dims->push_back(nchw_dim[NCHW_W]); dst_dims->push_back(nchw_dim[NCHW_C]); } break; default: MS_LOG(ERROR) << "Not support dst format: " << dst_format; return RET_ERROR; } return RET_OK; } STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, int32_t *filterH, int32_t *filterW) { MS_ASSERT(oriDims.size() == 4); if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) { *filterK = oriDims.at(KCHW_K); *filterC = oriDims.at(KCHW_C); *filterH = oriDims.at(KCHW_H); *filterW = oriDims.at(KCHW_W); } else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) { *filterC = oriDims.at(CKHW_C); *filterK = oriDims.at(CKHW_K); *filterH = oriDims.at(CKHW_H); *filterW = oriDims.at(CKHW_W); } else if (type == kHWCK2KCHW || type == kHWCK2CKHW) { *filterH = oriDims.at(HWCK_H); *filterW = oriDims.at(HWCK_W); *filterC = oriDims.at(HWCK_C); *filterK = oriDims.at(HWCK_K); } else if (type == kHWKC2KCHW || type == kHWKC2CKHW) { *filterH = oriDims.at(HWKC_H); *filterW = oriDims.at(HWKC_W); *filterK = oriDims.at(HWKC_K); *filterC = oriDims.at(HWKC_C); } else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) { *filterK = oriDims.at(NHWC_N); *filterH = oriDims.at(NHWC_H); *filterW = oriDims.at(NHWC_W); *filterC = oriDims.at(NHWC_C); } else if (type == kCHWK2HWCK || type == kCHWK2KHWC) { *filterC = oriDims.at(CHWK_C); *filterH = oriDims.at(CHWK_H); *filterW = oriDims.at(CHWK_W); *filterK = oriDims.at(CHWK_K); } else if (type == kKHWC2HWCK || type == kKHWC2CHWK) { *filterK = oriDims.at(KHWC_K); *filterH = oriDims.at(KHWC_H); *filterW = oriDims.at(KHWC_W); *filterC = oriDims.at(KHWC_C); } else { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; } return RET_OK; } STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW) { MS_ASSERT(tensor != nullptr); if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) { tensor->dims = {filterH, filterW, filterC, filterK}; } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) { tensor->dims = {filterH, filterW, filterK, filterC}; } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) { tensor->dims = {filterK, filterC, filterH, filterW}; } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) { tensor->dims = {filterC, filterK, filterH, filterW}; } else if (type == kKHWC2CHWK) { tensor->dims = {filterC, filterH, filterW, filterK}; } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) { tensor->dims = {filterK, filterH, filterW, filterC}; } else { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; } return RET_OK; } STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { if (tensor == nullptr) { return RET_NULL_PTR; } std::vector oriDims = tensor->dims; if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) { MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size(); return RET_ERROR; } auto srcFormat = tensor->format; auto dataType = tensor->dataType; STATUS status; switch (dstFormat) { case schema::Format::Format_KHWC: { switch (srcFormat) { case schema::Format::Format_KCHW: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kKCHW2KHWC); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kKCHW2KHWC); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKCHW2KHWC); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; case schema::Format::Format_CKHW: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kCKHW2KHWC); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kCKHW2KHWC); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCKHW2KHWC); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; case schema::Format::Format_CHWK: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kCHWK2KHWC); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kCHWK2KHWC); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCHWK2KHWC); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; case schema::Format::Format_KHWC: return RET_OK; default: MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " << EnumNameFormat(dstFormat); return RET_ERROR; } } break; case schema::Format::Format_HWCK: { switch (srcFormat) { case schema::Format::Format_KCHW: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kKCHW2HWCK); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kKCHW2HWCK); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKCHW2HWCK); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; case schema::Format::Format_KHWC: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kKHWC2HWCK); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kKHWC2HWCK); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKHWC2HWCK); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; case schema::Format::Format_CKHW: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kCKHW2HWCK); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kCKHW2HWCK); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCKHW2HWCK); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; case schema::Format::Format_CHWK: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kCHWK2HWCK); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kCHWK2HWCK); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCHWK2HWCK); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; case schema::Format::Format_HWCK: return RET_OK; default: MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " << EnumNameFormat(dstFormat); return RET_ERROR; } } break; case schema::Format::Format_KCHW: { switch (srcFormat) { case schema::Format::Format_KCHW: return RET_OK; case schema::Format::Format_HWCK: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kHWCK2KCHW); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kHWCK2KCHW); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kHWCK2KCHW); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; case schema::Format::Format_HWKC: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kHWKC2KCHW); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kHWKC2KCHW); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kHWKC2KCHW); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; case schema::Format::Format_KHWC: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kKHWC2KCHW); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kKHWC2KCHW); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKHWC2KCHW); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; case schema::Format::Format_CKHW: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kCKHW2KCHW); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kCKHW2KCHW); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCKHW2KCHW); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; case schema::Format::Format_CHWK: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kCHWK2KCHW); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kCHWK2KCHW); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kCHWK2KCHW); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; default: MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " << EnumNameFormat(dstFormat); return RET_ERROR; } } break; case schema::Format::Format_CKHW: { switch (srcFormat) { case schema::Format::Format_HWCK: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kHWCK2CKHW); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kHWCK2CKHW); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kHWCK2CKHW); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; case schema::Format::Format_HWKC: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kHWKC2CKHW); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kHWKC2CKHW); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kHWKC2CKHW); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; case schema::Format::Format_KCHW: if (dataType == kNumberTypeFloat32) { status = TransFilterFormat(tensor, kKCHW2CKHW); } else if (dataType == kNumberTypeUInt8) { status = TransFilterFormat(tensor, kKCHW2CKHW); } else if (dataType == kNumberTypeInt8) { status = TransFilterFormat(tensor, kKCHW2CKHW); } else { MS_LOG(ERROR) << "Unsupported dataType: " << dataType; return RET_ERROR; } break; case schema::Format::Format_CKHW: return RET_OK; default: MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " << EnumNameFormat(dstFormat); return RET_ERROR; } } break; default: MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " << EnumNameFormat(dstFormat); return RET_ERROR; } if (status != RET_OK) { MS_LOG(ERROR) << "TransFilterData failed: " << status; return status; } return RET_OK; } } // namespace lite } // namespace mindspore