Merge pull request !4548 from hangq/mastertags/v0.7.0-beta
| @@ -353,7 +353,6 @@ if (BUILD_CONVERTER) | |||
| tflite_parser_mid | |||
| caffe_parser_mid | |||
| onnx_parser_mid | |||
| node_mid | |||
| graph_pass_mid | |||
| fusion_mid | |||
| quantizer_mid | |||
| @@ -24,74 +24,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<CNodeT> &node) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| // set quantParam to preNode | |||
| for (size_t i = 0; i < node->inputIndex.size(); i++) { | |||
| auto preNodeIdexes = GetInputNodeIdx(*graphT, *(node.get()), i); | |||
| for (auto preNodeIdx : preNodeIdexes) { | |||
| MS_ASSERT(graphT->nodes.size() > preNodeIdx); | |||
| auto &preNode = graphT->nodes.at(preNodeIdx); | |||
| MS_ASSERT(preNode != nullptr); | |||
| // if preNode is not init, it maybe not a quantNode, so skip | |||
| // if (preNode->inputIndex.size() + preNode->outputIndex.size() != preNode->quantParam.size()) { | |||
| // continue; | |||
| // } | |||
| auto preNodeOutputIndexes = preNode->outputIndex; | |||
| int32_t currentNodeIndexInPre = -1; | |||
| for (auto index : preNodeOutputIndexes) { | |||
| currentNodeIndexInPre++; | |||
| if (index == node->inputIndex.at(i)) { | |||
| break; | |||
| } | |||
| } | |||
| MS_ASSERT(currentNodeIndexInPre != -1); | |||
| MS_ASSERT(node->quantParam.size() > i); | |||
| MS_ASSERT(node->quantParam.at(i) != nullptr); | |||
| // auto quantParamArrayCopy = CopyQuantParamArrayT(node->quantParam.at(i)); | |||
| // if (quantParamArrayCopy == nullptr) { | |||
| // //MS_LOG(ERROR)("CopyQuantParamArray return nullptr, node: %s", node->name.c_str()); | |||
| // return RET_ERROR; | |||
| // } | |||
| // preNode->quantParam.at(preNode->inputIndex.size() + currentNodeIndexInPre) = | |||
| // std::move(CopyQuantParamArrayT(quantParamArrayCopy)); | |||
| } | |||
| } | |||
| // set quantParam to postNode | |||
| for (size_t i = 0; i < node->outputIndex.size(); i++) { | |||
| auto postNodeIdexes = GetOutputNodeIdx(*graphT, *(node.get()), i); | |||
| for (auto postNodeIdx : postNodeIdexes) { | |||
| MS_ASSERT(graphT->nodes.size() > postNodeIdx); | |||
| auto &postNode = graphT->nodes.at(postNodeIdx); | |||
| MS_ASSERT(postNode != nullptr); | |||
| // if postNode is not init, it maybe not a quantNode, so skip | |||
| // if (postNode->inputIndex.size() + postNode->outputIndex.size() != postNode->quantParam.size()) { | |||
| // continue; | |||
| // } | |||
| auto postNodeInputIndexes = postNode->inputIndex; | |||
| int32_t currentNodeIndexInPost = -1; | |||
| for (auto index : postNodeInputIndexes) { | |||
| currentNodeIndexInPost++; | |||
| if (index == node->outputIndex.at(i)) { | |||
| break; | |||
| } | |||
| } | |||
| MS_ASSERT(currentNodeIndexInPost != -1); | |||
| MS_ASSERT(node->quantParam.size() > node->inputIndex.size() + i); | |||
| MS_ASSERT(node->quantParam.at(node->inputIndex.size() + i) != nullptr); | |||
| // auto quantParamArrayCopy = CopyQuantParamArrayT(node->quantParam.at(node->inputIndex.size() + i)); | |||
| // if (quantParamArrayCopy == nullptr) { | |||
| // //MS_LOG(ERROR)("CopyQuantParamArray return nullptr, node: %s", node->name.c_str()); | |||
| // return RET_ERROR; | |||
| // } | |||
| // postNode->quantParam.at(currentNodeIndexInPost) = std::move(CopyQuantParamArrayT(quantParamArrayCopy)); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| static const std::vector<schema::PrimitiveType> nhwcOpList = { | |||
| schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, | |||
| schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, | |||
| @@ -121,8 +53,8 @@ std::vector<schema::PrimitiveType> GetUint8OpList() { return uint8OpList; } | |||
| STATUS NodeUtils::ConvertDims(mindspore::lite::Format src_format, const std::vector<int32_t> &src_dims, | |||
| mindspore::lite::Format dst_format, std::vector<int32_t> *dst_dims) { | |||
| if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) { | |||
| // MS_LOG(ERROR)("Convert format , src size %lu <3 or src format is equal to dst format,not need convert", | |||
| // src_dims.size()); | |||
| MS_LOG(ERROR) << "Convert format , src size " << src_dims.size() | |||
| << " <3 or src format is equal to dst format,not need convert"; | |||
| *dst_dims = src_dims; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| @@ -145,12 +77,12 @@ STATUS NodeUtils::ConvertDims(mindspore::lite::Format src_format, const std::vec | |||
| } | |||
| break; | |||
| default: | |||
| // MS_LOG(ERROR)("Not support src format: %d", src_format); | |||
| MS_LOG(ERROR) << "Not support src format: " << schema::EnumNameFormat(src_format); | |||
| return RET_ERROR; | |||
| } | |||
| if (nchw_dim.size() == 0) { | |||
| // MS_LOG(ERROR)("Param nchw_dim is empty!"); | |||
| MS_LOG(ERROR) << "Param nchw_dim is empty!"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -172,7 +104,250 @@ STATUS NodeUtils::ConvertDims(mindspore::lite::Format src_format, const std::vec | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { | |||
| if (tensor == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<int32_t> oriDims = tensor->dims; | |||
| if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) { | |||
| MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size(); | |||
| return RET_ERROR; | |||
| } | |||
| auto srcFormat = tensor->format; | |||
| auto dataType = tensor->dataType; | |||
| STATUS status; | |||
| switch (dstFormat) { | |||
| case schema::Format_KHWC: { | |||
| switch (srcFormat) { | |||
| case schema::Format_KCHW: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kKCHW2KHWC); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format_CKHW: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kCKHW2KHWC); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format_CHWK: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kCHWK2KHWC); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format_KHWC: | |||
| return RET_OK; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to " | |||
| << schema::EnumNameFormat(dstFormat); | |||
| return RET_ERROR; | |||
| } | |||
| } break; | |||
| case schema::Format_HWCK: { | |||
| switch (srcFormat) { | |||
| case schema::Format_KCHW: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kKCHW2HWCK); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format_KHWC: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kKHWC2HWCK); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format_CKHW: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kCKHW2HWCK); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format_CHWK: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kCHWK2HWCK); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format_HWCK: | |||
| return RET_OK; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to " | |||
| << schema::EnumNameFormat(dstFormat); | |||
| return RET_ERROR; | |||
| } | |||
| } break; | |||
| case schema::Format_KCHW: { | |||
| switch (srcFormat) { | |||
| case schema::Format_KCHW: | |||
| return RET_OK; | |||
| case schema::Format_HWCK: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kHWCK2KCHW); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format_HWKC: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kHWKC2KCHW); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format_KHWC: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kKHWC2KCHW); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format_CKHW: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kCKHW2KCHW); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format_CHWK: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kCHWK2KCHW); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to " | |||
| << schema::EnumNameFormat(dstFormat); | |||
| return RET_ERROR; | |||
| } | |||
| } break; | |||
| case schema::Format_CKHW: { | |||
| switch (srcFormat) { | |||
| case schema::Format_HWCK: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kHWCK2CKHW); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format_HWKC: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kHWKC2CKHW); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format_KCHW: | |||
| if (dataType == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kKCHW2CKHW); | |||
| } else if (dataType == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW); | |||
| } else if (dataType == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format_CKHW: | |||
| return RET_OK; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to " | |||
| << schema::EnumNameFormat(dstFormat); | |||
| return RET_ERROR; | |||
| } | |||
| } break; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to " | |||
| << schema::EnumNameFormat(dstFormat); | |||
| return RET_ERROR; | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "TransFilterData failed: " << status; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -53,7 +53,7 @@ class NodeUtils { | |||
| // todo check this | |||
| enum kTransFilterType { | |||
| kKCHW2HWCK, | |||
| kKCHW2HWCK, // 0 | |||
| kKCHW2KHWC, | |||
| kCKHW2KHWC, | |||
| kCKHW2HWCK, | |||
| @@ -63,19 +63,23 @@ enum kTransFilterType { | |||
| kHWCK2CKHW, | |||
| kHWKC2KCHW, | |||
| kHWKC2CKHW, | |||
| kNHWC2KCHW, | |||
| kNHWC2KCHW, // 10 | |||
| kNHWC2CKHW, | |||
| kNHWC2HWCK, | |||
| kKHWC2HWCK, | |||
| kCHWK2HWCK, | |||
| kKHWC2CHWK, | |||
| kCHWK2KHWC | |||
| kCHWK2KHWC, | |||
| kKHWC2KCHW, | |||
| kCKHW2KCHW, | |||
| kCHWK2KCHW, | |||
| kKCHW2CKHW // 20 | |||
| }; | |||
| static STATUS GetFilterDim(std::vector<int32_t> &oriDims, kTransFilterType type, int32_t &filterK, int32_t &filterC, | |||
| int32_t &filterH, int32_t &filterW) { | |||
| MS_ASSERT(oriDims.size() == 4); | |||
| if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC) { | |||
| if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) { | |||
| filterK = oriDims.at(KCHW_K); | |||
| filterC = oriDims.at(KCHW_C); | |||
| filterH = oriDims.at(KCHW_H); | |||
| @@ -126,7 +130,7 @@ static STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32 | |||
| tensor->dims = {filterH, filterW, filterK, filterC}; | |||
| } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) { | |||
| tensor->dims = {filterK, filterC, filterH, filterW}; | |||
| } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW) { | |||
| } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) { | |||
| tensor->dims = {filterC, filterK, filterH, filterW}; | |||
| } else if (type == kKHWC2CHWK) { | |||
| tensor->dims = {filterC, filterH, filterW, filterK}; | |||
| @@ -198,6 +202,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in | |||
| } | |||
| } break; | |||
| case kKCHW2HWCK: | |||
| case kKCHW2CKHW: | |||
| case kKCHW2KHWC: | |||
| case kKCHW2HWKC: { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| @@ -211,6 +216,9 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in | |||
| } else if (type == kKCHW2KHWC) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } else if (type == kKCHW2CKHW) { | |||
| p2Buff = | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||
| @@ -367,7 +375,8 @@ static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) | |||
| return RET_OK; | |||
| } | |||
| STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_NODE_UTIL_H | |||
| @@ -103,7 +103,6 @@ target_link_libraries(converter_lite PRIVATE | |||
| onnx_parser_mid | |||
| anf_importer_mid | |||
| anf_exporter_mid | |||
| node_mid | |||
| graph_pass_mid | |||
| fusion_mid | |||
| quantizer_mid | |||
| @@ -23,34 +23,12 @@ | |||
| #include "src/common/op_utils.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h" | |||
| // | |||
| // #include "tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/const_fold/cast_const_fold_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/const_fold/concat_v2_const_fold_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/const_fold/expand_dims_const_fold_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/const_fold/mul_const_fold_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/const_fold/range_const_fold_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/const_fold/reshape_const_fold_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/const_fold/rsqrt_const_fold_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/const_fold/shape_const_fold_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/const_fold/slice_const_fold_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/const_fold/stack_const_fold_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/const_fold/strided_slice_const_fold_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/const_fold/sub_const_fold_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/const_fold/tile_const_fold_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/const_fold/transpose_const_fold_pass.h" | |||
| // | |||
| #include "tools/converter/legacy_optimizer/node/weight_format_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" | |||
| @@ -76,19 +54,6 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { | |||
| std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean); | |||
| break; | |||
| } | |||
| // case QuantType::QuantType_WeightQuant: { | |||
| // MS_LOGI("create WeightQuantizer!"); | |||
| // mQuantizer.reset(new WeightQuantizer(graphDefT, flags->quantSize)); | |||
| // break; | |||
| // } | |||
| // case QuantType_PostTraining: { | |||
| // MS_LOGI("create PostTrainningQuantizer!"); | |||
| // mQuantizer.reset(new PostTrainingQuantizer(graphDefT, flags->configFile)); | |||
| // break; | |||
| // } | |||
| // case QuantType::QuantType_QUANT_NONE: | |||
| // MS_LOGD("Not do quantization for model!"); | |||
| // break; | |||
| default: | |||
| // MS_LOGI("will support quantizer type %s in the future!", flags->quantTypeIn.c_str()); | |||
| break; | |||
| @@ -97,16 +62,16 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { | |||
| int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| STATUS status; | |||
| // weight format trans | |||
| if (ctx.formatTrans) { | |||
| { | |||
| Optimizer weightFormatOptimizer; | |||
| auto weightFormatPass = new (std::nothrow) WeightFormatPass(); | |||
| if (weightFormatPass == nullptr) { | |||
| MS_LOG(ERROR) << "new weightFormatPass failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto weightHardCodePass = new WeightFormatHardCodePass(); | |||
| auto weightFormatPass = new WeightFormatTransformPass(); | |||
| weightHardCodePass->SetQuantType(ctx.quantType); | |||
| weightHardCodePass->SetFmkType(ctx.fmk); | |||
| weightFormatPass->SetQuantType(ctx.quantType); | |||
| weightFormatPass->SetFmkType(ctx.fmk); | |||
| // weightFormatPass->SetDstFormat(Format_KHWC); | |||
| weightFormatOptimizer.AddPass(weightHardCodePass); | |||
| weightFormatOptimizer.AddPass(weightFormatPass); | |||
| status = weightFormatOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| @@ -1,6 +1,4 @@ | |||
| include_directories(${CMAKE_CURRENT_SOURCE_DIR}) | |||
| add_subdirectory(fusion) | |||
| #add_subdirectory(const_fold) | |||
| add_subdirectory(node) | |||
| add_subdirectory(graph) | |||
| @@ -1,13 +1,6 @@ | |||
| add_library(fusion_mid OBJECT | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/fusion_pattern.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/conv_scale_bias_fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/conv_bn_fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/conv_scale_fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/conv_activation_fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/conv_relu_fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/conv_relu6_fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/conv_biasadd_fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/matmul_biasadd_fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast_fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_fold_fusion_pass.cc | |||
| @@ -1,101 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "utils/log_adapter.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "src/common/op_utils.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #define CONV_ACTIVATION_MATCH_PATH_LEN 2 | |||
| STATUS ConvActivationFusionPass::DefinePattern() { | |||
| auto convOp = std::make_shared<PatternOp>(); | |||
| convOp->id = kConvName; | |||
| convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; | |||
| auto actOp = std::make_shared<PatternOp>(); | |||
| actOp->id = ACTIVATION_NAME; | |||
| actOp->types = {schema::PrimitiveType_Activation}; | |||
| actOp->left = convOp; | |||
| std::unique_ptr<FusionPattern> fusionPattern(new (std::nothrow) FusionPattern("ConvActivationFusion")); | |||
| if (fusionPattern == nullptr) { | |||
| MS_LOG(ERROR) << "new fusionPattern failed"; | |||
| return RET_ERROR; | |||
| } | |||
| fusionPattern->AddPatternOp(convOp); | |||
| fusionPattern->AddPatternOp(actOp); | |||
| fusionPattern->Finish(); | |||
| this->patterns.emplace_back(fusionPattern.release()); | |||
| return RET_OK; | |||
| } | |||
| // 1. change attr of conv | |||
| // 2. delete Activation node | |||
| STATUS ConvActivationFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (matchedPath.size() != CONV_ACTIVATION_MATCH_PATH_LEN) { | |||
| MS_LOG(ERROR) << "Conv-Activation-Fusion should have two NodeIndex in matchedPair"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto convPath = matchedPath[kConvName]; | |||
| auto actPath = matchedPath[ACTIVATION_NAME]; | |||
| auto &convNode = graph->nodes.at(convPath->nodeIdx); | |||
| auto &actNode = graph->nodes.at(actPath->nodeIdx); | |||
| // todo if combine conv_relu_fusion and conv_relu6_fusion to conv_activation_fusion | |||
| if (actNode->primitive->value.AsActivation()->type != this->activationType) { | |||
| return RET_NO_CHANGE; | |||
| } | |||
| if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { | |||
| convNode->primitive->value.AsConv2D()->activationType = this->activationType; | |||
| } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| convNode->primitive->value.AsDepthwiseConv2D()->activationType = this->activationType; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| // remove activation node | |||
| MergeNodeAttrFromPost(convNode, actNode); | |||
| auto status = IsolateOneWayNode(graph, actPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, subGraph: " << actPath->subGraphIdx << ", node: " << actPath->nodeIdx | |||
| << ", error: " << status; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvActivationFusionPass::Run(schema::MetaGraphT *graph) { | |||
| SetActivationType(); | |||
| return FusionPass::Run(graph); | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,50 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H | |||
| #define MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class ConvActivationFusionPass : public FusionPass { | |||
| public: | |||
| ConvActivationFusionPass() = default; | |||
| ~ConvActivationFusionPass() override = default; | |||
| STATUS DefinePattern() override; | |||
| virtual STATUS SetActivationType() = 0; | |||
| // 1. change attr of conv | |||
| // 2. delete Activation node | |||
| STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| protected: | |||
| schema::ActivationType activationType = schema::ActivationType_RELU; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H | |||
| @@ -1,295 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h" | |||
| #include <cfloat> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "utils/log_adapter.h" | |||
| #include "securec/include/securec.h" | |||
| // #include "utils/log_adapter.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/common/op_utils.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #define CONV_BIASADD_MATCH_PATH_LEN 2 | |||
| #define BIASADD_OP_BIAS_INDEX_IN_WEIGHT 0 | |||
| #define BIASADD_OP_INPUT_NUM 2 | |||
| #define BIASADD_OP_CONST_TENSOR_INDEX 1 | |||
| STATUS ConvBiasAddFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } | |||
| STATUS ConvBiasAddFusionPass::DefinePattern() { | |||
| auto convOp = std::make_shared<PatternOp>(); | |||
| convOp->id = kConvName; | |||
| convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeConv2D}; | |||
| auto baOp = std::make_shared<PatternOp>(); | |||
| baOp->id = BIASADD_NAME; | |||
| baOp->types = {schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Add}; | |||
| baOp->left = convOp; | |||
| std::unique_ptr<FusionPattern> fusionPattern(new (std::nothrow) FusionPattern("ConvBiasAddFusion")); | |||
| if (fusionPattern == nullptr) { | |||
| MS_LOG(ERROR) << "new fusionPattern failed"; | |||
| return RET_ERROR; | |||
| } | |||
| fusionPattern->AddPatternOp(convOp); | |||
| fusionPattern->AddPatternOp(baOp); | |||
| fusionPattern->Finish(); | |||
| this->patterns.emplace_back(fusionPattern.release()); | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (matchedPath.size() != CONV_BIASADD_MATCH_PATH_LEN) { | |||
| MS_LOG(ERROR) << "Conv-BiasAdd-Fusion should have two NodeIndex in matchedPair"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto convPath = matchedPath[kConvName]; | |||
| auto baPath = matchedPath[BIASADD_NAME]; | |||
| auto &convNode = graph->nodes.at(convPath->nodeIdx); | |||
| auto &baNode = graph->nodes.at(baPath->nodeIdx); | |||
| // add/biasadd node the second tensor is not constant tensor, don't fusion | |||
| auto baNodeInputIndex = baNode->inputIndex; | |||
| if (baNodeInputIndex.size() != BIASADD_OP_INPUT_NUM) { | |||
| MS_LOG(ERROR) << baNode->name.c_str() << " node tensors number is invalid! "; | |||
| return RET_ERROR; | |||
| } | |||
| auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get(); | |||
| MS_ASSERT(baNodeBiasTensor != nullptr); | |||
| if (baNodeBiasTensor->nodeType != schema::NodeType_ValueNode) { | |||
| // dont fusion, return | |||
| return RET_OK; | |||
| } | |||
| // 1. generate newBiasTensor for conv | |||
| auto status = GenConvBiasTensor(convPath, baPath, graph); | |||
| if (RET_OK != status) { | |||
| MS_LOG(ERROR) << "GenConvBiasTensor failed, " << status; | |||
| return status; | |||
| } | |||
| if (this->newBiasTensor != nullptr) { | |||
| status = AddTensor2Node(graph, convPath->nodeIdx, std::move(this->newBiasTensor)); | |||
| this->newBiasTensor = nullptr; | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "AddTensor2Node failed, node: " << convPath->nodeIdx << ", error: " << status; | |||
| return status; | |||
| } | |||
| // add bias quantParam | |||
| // todo add quantParam for tensors | |||
| // if (convNode->quantParam.size() == convNode->inputIndex.size() + convNode->outputIndex.size() - 1) { | |||
| // std::unique_ptr<QuantParamArrayT> quantParamArray(new QuantParamArrayT()); | |||
| // if (quantParamArray == nullptr) { | |||
| // MS_LOG(ERROR) << "new QuantParamArrayT failed"); | |||
| // return RET_ERROR; | |||
| // } | |||
| // std::unique_ptr<QuantParamT> quantParam(new QuantParamT()); | |||
| // if (quantParam == nullptr) { | |||
| // MS_LOG(ERROR) << "new QuantParamT failed"); | |||
| // return RET_ERROR; | |||
| // } | |||
| // quantParam->numBits = -1; | |||
| // quantParam->scale = FLT_MAX; | |||
| // quantParam->zeroPoint = 0; | |||
| // quantParam->narrowRange = true; | |||
| // quantParam->min = FLT_MAX; | |||
| // quantParam->max = FLT_MAX; | |||
| // quantParamArray->param.emplace_back(quantParam.release()); | |||
| // convNode->quantParam.emplace_back(quantParamArray.release()); | |||
| // } | |||
| } | |||
| // 2. change attr of conv | |||
| if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { | |||
| convNode->primitive->value.AsConv2D()->hasBias = true; | |||
| } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true; | |||
| } else if (convNode->primitive->value.type == schema::PrimitiveType_DeConv2D) { | |||
| convNode->primitive->value.AsDeConv2D()->hasBias = true; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| // 5. delete BiasAdd node | |||
| MergeNodeAttrFromPost(convNode, baNode); | |||
| status = IsolateOneWayNode(graph, baPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, graph: %zu, node: %zu, error: %d"; | |||
| //, baPath->subGraphIdx, baPath->nodeIdx, status); | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #define BIASADD_WEIGHT_SHAPE_SIZE 1 | |||
| #define BIASADD_BIAS_DIM_INDEX 0 | |||
| STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr<Path> convPath, std::shared_ptr<Path> baPath, | |||
| MetaGraphT *graph) { | |||
| MS_ASSERT(convPath != nullptr); | |||
| MS_ASSERT(baPath != nullptr); | |||
| MS_ASSERT(graph != nullptr); | |||
| auto convNode = graph->nodes.at(convPath->nodeIdx).get(); | |||
| MS_ASSERT(convNode != nullptr); | |||
| auto baNode = graph->nodes.at(baPath->nodeIdx).get(); | |||
| MS_ASSERT(baNode != nullptr); | |||
| int32_t kernelNum = 0; | |||
| if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { | |||
| kernelNum = convNode->primitive->value.AsConv2D()->channelOut; | |||
| } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| kernelNum = convNode->primitive->value.AsDepthwiseConv2D()->channelIn * | |||
| convNode->primitive->value.AsDepthwiseConv2D()->channelMultiplier; | |||
| } else if (convNode->primitive->value.type == schema::PrimitiveType_DeConv2D) { | |||
| kernelNum = convNode->primitive->value.AsDeConv2D()->channelOut; | |||
| } | |||
| auto convWeightTensorIdxes = convNode->inputIndex; | |||
| if (convWeightTensorIdxes.size() < CONV_OP_NO_BIAS_INPUT_NUM) { | |||
| MS_LOG(ERROR) << convNode->name.c_str() << " node tensors number is invalid! "; | |||
| return RET_ERROR; | |||
| } | |||
| convWeightTensorIdxes.erase(convWeightTensorIdxes.begin()); | |||
| auto baWeightTensorIdxes = baNode->inputIndex; | |||
| if (baWeightTensorIdxes.size() != BIASADD_OP_INPUT_NUM) { | |||
| MS_LOG(ERROR) << baNode->name.c_str() << " node tensors number is invalid! "; | |||
| return RET_ERROR; | |||
| } | |||
| baWeightTensorIdxes.erase(baWeightTensorIdxes.begin()); | |||
| if (convWeightTensorIdxes.empty()) { | |||
| MS_LOG(ERROR) << "Conv2D should has one weight tensors at least, current number of weight tensors " | |||
| << convWeightTensorIdxes.size(); | |||
| return RET_ERROR; | |||
| } | |||
| if (baWeightTensorIdxes.empty()) { | |||
| MS_LOG(ERROR) << "BiasAdd should has one weight tensors at least, current number of weight tensors " | |||
| << baWeightTensorIdxes.size(); | |||
| return RET_ERROR; | |||
| } | |||
| TensorT *oldBiasTensor = nullptr; | |||
| TensorT *biasTensor = nullptr; | |||
| if (convWeightTensorIdxes.size() == CONV_OP_HAS_BIAS_WEIGHT_NUM) { | |||
| oldBiasTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_BIAS_INDEX_IN_WEIGHT]).get(); | |||
| MS_ASSERT(oldBiasTensor != nullptr); | |||
| } | |||
| biasTensor = graph->allTensors.at(baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX_IN_WEIGHT)).get(); | |||
| MS_ASSERT(biasTensor != nullptr); | |||
| auto biasDims = biasTensor->dims; | |||
| // if biasTensor is a scaler | |||
| if (biasDims.empty() && biasTensor->data.data() == nullptr) { | |||
| MS_LOG(ERROR) << "BiasAdd node %s bias tensor is invalid" << baNode->name.c_str(); | |||
| return RET_ERROR; | |||
| } | |||
| if (!biasDims.empty() && biasDims.size() != BIASADD_WEIGHT_SHAPE_SIZE) { | |||
| MS_LOG(ERROR) << "BiasAdd bias tensor should has one dimension, current number of dimension " << biasDims.size() | |||
| << ". or bias tensor is a scaler"; | |||
| return RET_ERROR; | |||
| } | |||
| bool bias_const = !biasDims.empty() && biasDims.size() == 1 && biasDims[0] == 1; | |||
| if (!biasDims.empty() && !bias_const && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) { | |||
| MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)" | |||
| << biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum; | |||
| return RET_ERROR; | |||
| } | |||
| // cal new biasData | |||
| this->newBiasData = new (std::nothrow) float[kernelNum]; | |||
| if (newBiasData == nullptr) { | |||
| MS_LOG(ERROR) << "new newBiasData failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (biasDims.empty() && biasTensor->data.data() != nullptr) { | |||
| auto *biasData = reinterpret_cast<float *>(biasTensor->data.data()); | |||
| if (0 != memset_s(newBiasData, kernelNum * sizeof(float), *biasData, kernelNum * sizeof(float))) { | |||
| MS_LOG(ERROR) << "memset_s newBiasData failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (bias_const) { | |||
| auto *biasData = reinterpret_cast<float *>(biasTensor->data.data()); | |||
| for (size_t i = 0; i < kernelNum; i++) { | |||
| newBiasData[i] = *biasData; | |||
| } | |||
| } else { | |||
| if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) { | |||
| MS_LOG(ERROR) << "memcpy_s newBiasData failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (oldBiasTensor != nullptr) { | |||
| auto oldBiasDims = oldBiasTensor->dims; | |||
| if (oldBiasDims.size() != 1) { | |||
| MS_LOG(ERROR) | |||
| << "Conv bias tensor should has one dimension, current number of dimension %zu"; // oldBiasDims.size()); | |||
| return RET_ERROR; | |||
| } | |||
| if (oldBiasDims.at(0) != kernelNum) { | |||
| MS_LOG(ERROR) | |||
| << "Size(%zu) of Conv bias tensor should be equal to kernelNum(%d), current number of dimension %zu"; | |||
| // oldBiasDims.size(), kernelNum); | |||
| return RET_ERROR; | |||
| } | |||
| auto *oldBiasData = reinterpret_cast<float *>(oldBiasTensor->data.data()); | |||
| for (size_t i = 0; i < kernelNum; i++) { | |||
| oldBiasData[i] += newBiasData[i]; | |||
| } | |||
| } else { | |||
| auto *newCharBiasData = reinterpret_cast<uint8_t *>(newBiasData); | |||
| std::vector<uint8_t> tmpBiasVec(newCharBiasData, newCharBiasData + kernelNum * sizeof(float) / sizeof(uint8_t)); | |||
| auto weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get(); | |||
| this->newBiasTensor = std::unique_ptr<TensorT>(new (std::nothrow) TensorT); | |||
| // todo biasShape | |||
| this->newBiasTensor->dims = {kernelNum}; | |||
| this->newBiasTensor->dataType = weightTensor->dataType; | |||
| this->newBiasTensor->format = weightTensor->format; | |||
| this->newBiasTensor->refCount = weightTensor->refCount; | |||
| this->newBiasTensor->data.swap(tmpBiasVec); | |||
| newCharBiasData = nullptr; | |||
| } | |||
| delete (this->newBiasData); | |||
| newBiasData = nullptr; | |||
| return RET_OK; | |||
| } | |||
| ConvBiasAddFusionPass::~ConvBiasAddFusionPass() { | |||
| if (this->newBiasData != nullptr) { | |||
| delete (this->newBiasData); | |||
| } | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,51 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H | |||
| #define MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class ConvBiasAddFusionPass : public FusionPass { | |||
| public: | |||
| ConvBiasAddFusionPass() = default; | |||
| ~ConvBiasAddFusionPass() override; | |||
| STATUS DefinePattern() override; | |||
| STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| protected: | |||
| // gen this->newBiasTensor if conv has no bias before | |||
| STATUS GenConvBiasTensor(std::shared_ptr<Path> convPath, std::shared_ptr<Path> dstPath, schema::MetaGraphT *graph); | |||
| protected: | |||
| float *newBiasData = nullptr; | |||
| std::unique_ptr<TensorT> newBiasTensor = nullptr; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H | |||
| @@ -1,224 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <cmath> | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h" | |||
| #include "securec/include/securec.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #define CAFFE_BATCHNORM_OP_WEIGHT_NUM 2 | |||
| #define TF_BATCHNORM_OP_WEIGHT_NUM 4 | |||
| #define CAFFE_BATCHNORM_MEAN_INDEX 0 | |||
| #define CAFFE_BATCHNORM_VARIANCE_INDEX 1 | |||
| #define TF_BATCHNORM_SCALE_INDEX 0 | |||
| #define TF_BATCHNORM_BIAS_INDEX 1 | |||
| #define TF_BATCHNORM_MEAN_INDEX 2 | |||
| #define TF_BATCHNORM_VARIANCE_INDEX 3 | |||
| constexpr const float EPS = 1e-8; | |||
| constexpr const float EPS_DEFAULT_FLOAT = 1e-5; | |||
| constexpr const float POW_NUM = 0.5; | |||
| STATUS ConvBNFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||
| return ConvScaleBiasFusionPass::DoFusion(graph, patternName, matchedPath); | |||
| } | |||
| STATUS ConvBNFusionPass::DefinePattern() { | |||
| auto convOp = std::make_shared<PatternOp>(); | |||
| convOp->id = kConvName; | |||
| convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; | |||
| auto bnOp = std::make_shared<PatternOp>(); | |||
| bnOp->id = DST_NAME; | |||
| bnOp->types = {schema::PrimitiveType_FusedBatchNorm, schema::PrimitiveType_BatchNorm}; | |||
| bnOp->left = convOp; | |||
| std::unique_ptr<FusionPattern> fusionPattern(new (std::nothrow) FusionPattern("ConvBatchNormFusion")); | |||
| if (fusionPattern == nullptr) { | |||
| MS_LOG(ERROR) << "new fusionPattern failed"; | |||
| return RET_ERROR; | |||
| } | |||
| fusionPattern->AddPatternOp(convOp); | |||
| fusionPattern->AddPatternOp(bnOp); | |||
| fusionPattern->Finish(); | |||
| this->patterns.emplace_back(fusionPattern.release()); | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvBNFusionPass::Run(schema::MetaGraphT *graph) { return ConvScaleBiasFusionPass::Run(graph); } | |||
| STATUS ConvBNFusionPass::GetTransParam(schema::MetaGraphT *graph, std::shared_ptr<Path> bnPath, int32_t kernelNum) { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(bnPath != nullptr); | |||
| BNWeightTensors bnWeightTensors; | |||
| auto status = GetBnWeightTensors(graph, bnPath, kernelNum, bnWeightTensors); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "GetBnWeightTensors error " << status; | |||
| return status; | |||
| } | |||
| schema::TensorT *meanTensor = bnWeightTensors.meanTensor; | |||
| schema::TensorT *varianceTensor = bnWeightTensors.varianceTensor; | |||
| schema::TensorT *scaleTensor = bnWeightTensors.scaleTensor; | |||
| schema::TensorT *biasTensor = bnWeightTensors.biasTensor; | |||
| auto *meanData = reinterpret_cast<float *>(meanTensor->data.data()); | |||
| auto *varianceData = reinterpret_cast<float *>(varianceTensor->data.data()); | |||
| float eps = EPS_DEFAULT_FLOAT; | |||
| status = GetBnEpsilon(graph, bnPath, eps); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "GetBnEpsilon failed " << status; | |||
| return status; | |||
| } | |||
| // cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps) | |||
| if (memcpy_s(transScale, kernelNum * sizeof(float), varianceData, kernelNum * sizeof(float)) != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s transScale error"; | |||
| return RET_ERROR; | |||
| } | |||
| // 1/sqrt(variance + eps) | |||
| for (int32_t i = 0; i < kernelNum; i++) { | |||
| float tmp = transScale[i] + eps; | |||
| tmp = pow(tmp, POW_NUM); | |||
| transScale[i] = 1 / tmp; | |||
| } | |||
| if (scaleTensor != nullptr) { | |||
| auto *scaleData = reinterpret_cast<float *>(scaleTensor->data.data()); | |||
| // scale/sqrt(variance + eps) | |||
| for (int32_t i = 0; i < kernelNum; i++) { | |||
| transScale[i] *= scaleData[i]; | |||
| } | |||
| } | |||
| // cal transBias, tf : -scale*mean/sqrt(variance + eps) + bias; caffe : -mean/sqrt(variance + eps) | |||
| // -mean/sqrt(variance + eps) | |||
| for (int32_t i = 0; i < kernelNum; i++) { | |||
| transBias[i] = -meanData[i] * transScale[i]; | |||
| } | |||
| if (biasTensor != nullptr) { | |||
| auto *biasData = reinterpret_cast<float *>(biasTensor->data.data()); | |||
| // -scale*mean/sqrt(variance + eps) + bias | |||
| for (int32_t i = 0; i < kernelNum; i++) { | |||
| transBias[i] += biasData[i]; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| // BatchNorm weight Tensor definition: | |||
| // caffe | |||
| // estimated_mean --0 | |||
| // estimated_variance --1 | |||
| // tensorflow | |||
| // scale -- 0 | |||
| // bias --1 | |||
| // estimated_mean --2 | |||
| // estimated_variance --3 | |||
| STATUS ConvBNFusionPass::GetBnWeightTensors(schema::MetaGraphT *graph, std::shared_ptr<Path> bnPath, int32_t kernelNum, | |||
| BNWeightTensors &bnWeightTensors) { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(bnPath != nullptr); | |||
| auto bnNode = graph->nodes.at(bnPath->nodeIdx).get(); | |||
| auto bnWeightTensorIdxes = bnNode->inputIndex; | |||
| bnWeightTensorIdxes.erase(bnWeightTensorIdxes.begin()); | |||
| if (bnWeightTensorIdxes.size() == CAFFE_BATCHNORM_OP_WEIGHT_NUM) { | |||
| bnWeightTensors.meanTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_MEAN_INDEX]).get(); | |||
| bnWeightTensors.varianceTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_VARIANCE_INDEX]).get(); | |||
| } else if (bnWeightTensorIdxes.size() == TF_BATCHNORM_OP_WEIGHT_NUM) { | |||
| bnWeightTensors.scaleTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_SCALE_INDEX]).get(); | |||
| bnWeightTensors.biasTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_BIAS_INDEX]).get(); | |||
| bnWeightTensors.meanTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_MEAN_INDEX]).get(); | |||
| bnWeightTensors.varianceTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_VARIANCE_INDEX]).get(); | |||
| } else { | |||
| MS_LOG(ERROR) << "BatchNorm should has " << CAFFE_BATCHNORM_OP_WEIGHT_NUM << " or " << TF_BATCHNORM_OP_WEIGHT_NUM | |||
| << " weight tensors, current number of weight tensors " << bnWeightTensorIdxes.size(); | |||
| return RET_ERROR; | |||
| } | |||
| if (bnWeightTensors.meanTensor == nullptr) { | |||
| MS_LOG(ERROR) << "BatchNorm's mean tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (bnWeightTensors.varianceTensor == nullptr) { | |||
| MS_LOG(ERROR) << "BatchNorm's variance tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (kernelNum != bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { | |||
| MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" | |||
| << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; | |||
| return RET_ERROR; | |||
| } | |||
| if (kernelNum != bnWeightTensors.varianceTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { | |||
| MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" | |||
| << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; | |||
| return RET_ERROR; | |||
| } | |||
| if (bnWeightTensors.scaleTensor != nullptr) { | |||
| if (kernelNum != bnWeightTensors.scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { | |||
| MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" | |||
| << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (bnWeightTensors.biasTensor != nullptr) { | |||
| if (kernelNum != bnWeightTensors.biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { | |||
| MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" | |||
| << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvBNFusionPass::GetBnEpsilon(schema::MetaGraphT *graph, std::shared_ptr<Path> bnPath, float &eps) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto bnNode = graph->nodes.at(bnPath->nodeIdx).get(); | |||
| MS_ASSERT(bnNode != nullptr); | |||
| if (bnNode->primitive->value.type == schema::PrimitiveType_FusedBatchNorm) { | |||
| eps = bnNode->primitive->value.AsFusedBatchNorm()->epsilon; | |||
| } else if (bnNode->primitive->value.type == schema::PrimitiveType_BatchNorm) { | |||
| eps = bnNode->primitive->value.AsBatchNorm()->epsilon; | |||
| } else { | |||
| MS_LOG(ERROR) << "match pattern has error, " << bnNode->name.c_str() << " not BatchNorm node"; | |||
| return RET_ERROR; | |||
| } | |||
| if (eps < EPS) { | |||
| eps = EPS_DEFAULT_FLOAT; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,54 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #ifndef MINDSPORE_CONV_BN_FUSION_PASS_H | |||
| #define MINDSPORE_CONV_BN_FUSION_PASS_H | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class ConvBNFusionPass : public ConvScaleBiasFusionPass { | |||
| public: | |||
| ConvBNFusionPass() = default; | |||
| ~ConvBNFusionPass() override = default; | |||
| STATUS DefinePattern() override; | |||
| STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| protected: | |||
| STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr<Path> bnPath, int32_t kernelNum) override; | |||
| // Get and check BNNode weight tensor | |||
| STATUS GetBnWeightTensors(schema::MetaGraphT *graph, std::shared_ptr<Path> bnPath, int32_t kernelNum, | |||
| BNWeightTensors &bnWeightTensors); | |||
| STATUS GetBnEpsilon(schema::MetaGraphT *graph, std::shared_ptr<Path> bnPath, float &eps); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CONV_BN_FUSION_PASS_H | |||
| @@ -1,41 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS ConvRelu6FusionPass::DefinePattern() { return ConvActivationFusionPass::DefinePattern(); } | |||
| STATUS ConvRelu6FusionPass::SetActivationType() { | |||
| this->activationType = ActivationType_RELU6; | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvRelu6FusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||
| return ConvActivationFusionPass::DoFusion(graph, patternName, matchedPath); | |||
| } | |||
| STATUS ConvRelu6FusionPass::Run(MetaGraphT *graph) { return ConvActivationFusionPass::Run(graph); } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,46 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H | |||
| #define MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class ConvRelu6FusionPass : public ConvActivationFusionPass { | |||
| public: | |||
| ConvRelu6FusionPass() = default; | |||
| ~ConvRelu6FusionPass() override = default; | |||
| STATUS DefinePattern() override; | |||
| STATUS SetActivationType() override; | |||
| STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H | |||
| @@ -1,40 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS ConvReluFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||
| return ConvActivationFusionPass::DoFusion(graph, patternName, matchedPath); | |||
| } | |||
| STATUS ConvReluFusionPass::Run(schema::MetaGraphT *graph) { return ConvActivationFusionPass::Run(graph); } | |||
| STATUS ConvReluFusionPass::SetActivationType() { | |||
| this->activationType = schema::ActivationType_RELU; | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvReluFusionPass::DefinePattern() { return ConvActivationFusionPass::DefinePattern(); } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,45 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H | |||
| #define MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class ConvReluFusionPass : public ConvActivationFusionPass { | |||
| public: | |||
| ConvReluFusionPass() = default; | |||
| ~ConvReluFusionPass() override = default; | |||
| STATUS DefinePattern() override; | |||
| STATUS SetActivationType() override; | |||
| STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H | |||
| @@ -1,361 +0,0 @@ | |||
| /* | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2018-2019. All rights reserved. | |||
| * Description: mslite | |||
| * Author: mslite | |||
| * Create: 2019-12-13 | |||
| */ | |||
| #include <cfloat> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h" | |||
| #include "securec/include/securec.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/common/op_utils.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/tensor_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #define CONV_SCALE_BIAS_MATCH_PATH_LEN 2 | |||
| // 1. generate biasTensor according to BN weightTensor | |||
| // 2. change attr of conv | |||
| // 3. delete BN node | |||
| STATUS ConvScaleBiasFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (matchedPath.size() != CONV_SCALE_BIAS_MATCH_PATH_LEN) { | |||
| MS_LOG(ERROR) << "Conv-Scale-Bias-Fusion should have two NodeIndex in matchedPair"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto convPath = matchedPath[kConvName]; | |||
| MS_ASSERT(convPath != nullptr); | |||
| auto dstPath = matchedPath[DST_NAME]; | |||
| MS_ASSERT(dstPath != nullptr); | |||
| MS_ASSERT(subGraph != nullptr); | |||
| auto &convNode = graph->nodes.at(convPath->nodeIdx); | |||
| MS_ASSERT(convNode != nullptr); | |||
| auto &dstNode = graph->nodes.at(dstPath->nodeIdx); | |||
| MS_ASSERT(dstNode != nullptr); | |||
| // 1. generate new weightTensor and biasTensor for conv | |||
| auto status = GenConvWeightTensors(graph, convPath, dstPath); | |||
| if (RET_OK != status) { | |||
| MS_LOG(ERROR) << "GenConvWeightTensors failed, " << status; | |||
| return status; | |||
| } | |||
| if (convNode->inputIndex.size() == CONV_OP_HAS_BIAS_INPUT_NUM) { | |||
| status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT), | |||
| std::move(this->newWeightTensor)); | |||
| this->newWeightTensor = nullptr; | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx | |||
| << ", node: " << convPath->nodeIdx << ", tensor " | |||
| << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; | |||
| return status; | |||
| } | |||
| status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_BIAS_INDEX_IN_INPUT), | |||
| std::move(this->newBiasTensor)); | |||
| this->newBiasTensor = nullptr; | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx | |||
| << ", node: " << convPath->nodeIdx << ", tensor " | |||
| << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; | |||
| return status; | |||
| } | |||
| } else if (convNode->inputIndex.size() == CONV_OP_NO_BIAS_INPUT_NUM) { | |||
| status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT), | |||
| std::move(this->newWeightTensor)); | |||
| this->newWeightTensor = nullptr; | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx | |||
| << ", node: " << convPath->nodeIdx << ", tensor " | |||
| << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; | |||
| return status; | |||
| } | |||
| status = AddTensor2Node(graph, convPath->nodeIdx, std::move(this->newBiasTensor)); | |||
| this->newBiasTensor = nullptr; | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx | |||
| << ", node: " << convPath->nodeIdx << ", tensor " | |||
| << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; | |||
| return status; | |||
| } | |||
| // if (convNode->name == "Conv_461") { | |||
| // } | |||
| // add bias quantParam | |||
| // todo use tensor quant param | |||
| // if (convNode->quantParam.size() == convNode->inputIndex.size() + convNode->outputIndex.size() - 1) { | |||
| // std::unique_ptr<QuantParamArrayT> quantParamArray(new QuantParamArrayT()); | |||
| // if (quantParamArray == nullptr) { | |||
| // MS_LOG(ERROR) << "new QuantParamArrayT failed"; | |||
| // return RET_ERROR; | |||
| // } | |||
| // std::unique_ptr<QuantParamT> quantParam(new QuantParamT()); | |||
| // if (quantParam == nullptr) { | |||
| // MS_LOG(ERROR) << "new QuantParamT failed"; | |||
| // return RET_ERROR; | |||
| // } | |||
| // quantParam->numBits = -1; | |||
| // quantParam->scale = FLT_MAX; | |||
| // quantParam->zeroPoint = 0; | |||
| // quantParam->narrowRange = true; | |||
| // quantParam->min = FLT_MAX; | |||
| // quantParam->max = FLT_MAX; | |||
| // quantParamArray->param.emplace_back(quantParam.release()); | |||
| // convNode->quantParam.emplace_back(quantParamArray.release()); | |||
| // } | |||
| } else { | |||
| MS_LOG(ERROR) << "Conv node should has 2 or 3 weight tensors rather than " << convNode->inputIndex.size(); | |||
| return RET_ERROR; | |||
| } | |||
| // 2. change attr of conv | |||
| if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { | |||
| convNode->primitive->value.AsConv2D()->hasBias = true; | |||
| } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| // 3. delete DST node | |||
| MergeNodeAttrFromPost(convNode, dstNode); | |||
| status = IsolateOneWayNode(graph, dstPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstPath->nodeIdx << ", error: " << status; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvScaleBiasFusionPass::GenConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr<Path> &convPath, | |||
| std::shared_ptr<Path> dstPath) { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(convPath != nullptr); | |||
| MS_ASSERT(dstPath != nullptr); | |||
| MS_ASSERT(subGraph != nullptr); | |||
| auto &convNode = graph->nodes.at(convPath->nodeIdx); | |||
| MS_ASSERT(convNode != nullptr); | |||
| int32_t kernelNum = -1; | |||
| if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { | |||
| kernelNum = convNode->primitive->value.AsConv2D()->channelOut; | |||
| } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| kernelNum = convNode->primitive->value.AsDepthwiseConv2D()->channelMultiplier * | |||
| convNode->primitive->value.AsDepthwiseConv2D()->channelIn; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (kernelNum <= 0) { | |||
| MS_LOG(ERROR) << "KernelNum should be positive, " << kernelNum; | |||
| return RET_ERROR; | |||
| } | |||
| this->transScale = new (std::nothrow) float[kernelNum]; | |||
| this->transBias = new (std::nothrow) float[kernelNum]; | |||
| if (transScale == nullptr) { | |||
| MS_LOG(ERROR) << "new transScale failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (transBias == nullptr) { | |||
| MS_LOG(ERROR) << "new transBias failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (0 != memset_s(transScale, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) { | |||
| MS_LOG(ERROR) << "memset transScale failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (0 != memset_s(transBias, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) { | |||
| MS_LOG(ERROR) << "memset transBias failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto status = GetTransParam(graph, dstPath, kernelNum); | |||
| if (RET_OK != status) { | |||
| MS_LOG(ERROR) << "GetTransParam failed, " << status; | |||
| return status; | |||
| } | |||
| status = CalConvWeightTensors(graph, convPath, kernelNum); | |||
| if (RET_OK != status) { | |||
| MS_LOG(ERROR) << "GenConvWeightTensors failed, " << status; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvScaleBiasFusionPass::CalNewWeightTensor(TensorT *oldWeightTensor, const int32_t kernelNum, | |||
| const size_t kernelSize) { | |||
| MS_ASSERT(oldWeightTensor != nullptr); | |||
| auto weightData = reinterpret_cast<float *>(oldWeightTensor->data.data()); | |||
| size_t kernelDataCount = kernelNum * kernelSize; | |||
| if (kernelDataCount == 0) { | |||
| MS_LOG(ERROR) << "KernelDataCount should be positive, " << kernelDataCount; | |||
| return RET_ERROR; | |||
| } | |||
| this->newWeightData = new (std::nothrow) float[kernelDataCount]; | |||
| if (newWeightData == nullptr) { | |||
| MS_LOG(ERROR) << "new newWeightData failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (0 != memset_s(newWeightData, kernelDataCount * sizeof(float), 0, kernelDataCount * sizeof(float))) { | |||
| MS_LOG(ERROR) << "memset newWeightData failed"; | |||
| return RET_ERROR; | |||
| } | |||
| for (size_t i = 0; i < kernelNum; i++) { | |||
| for (size_t j = 0; j < kernelSize; j++) { | |||
| newWeightData[i * kernelSize + j] = weightData[i * kernelSize + j] * transScale[i]; | |||
| } | |||
| } | |||
| auto newCharWeightData = reinterpret_cast<uint8_t *>(newWeightData); | |||
| std::vector<uint8_t> tmpWeightVec(newCharWeightData, | |||
| newCharWeightData + kernelDataCount * sizeof(float) / sizeof(uint8_t)); | |||
| this->newWeightTensor = std::unique_ptr<TensorT>(new (std::nothrow) TensorT); | |||
| if (this->newWeightTensor == nullptr) { | |||
| MS_LOG(ERROR) << "new newWeightTensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->newWeightTensor->dims.insert(this->newWeightTensor->dims.begin(), oldWeightTensor->dims.begin(), | |||
| oldWeightTensor->dims.end()); | |||
| this->newWeightTensor->dataType = oldWeightTensor->dataType; | |||
| this->newWeightTensor->format = oldWeightTensor->format; | |||
| this->newWeightTensor->refCount = oldWeightTensor->refCount; | |||
| this->newWeightTensor->data.swap(tmpWeightVec); | |||
| delete (this->newWeightData); | |||
| newWeightData = nullptr; | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvScaleBiasFusionPass::CalNewBiasTensor(TensorT *oldWeightTensor, TensorT *oldBiasTensor, | |||
| const int32_t kernelNum) { | |||
| MS_ASSERT(oldWeightTensor != nullptr); | |||
| this->newBiasData = new (std::nothrow) float[kernelNum]; | |||
| if (newBiasData == nullptr) { | |||
| MS_LOG(ERROR) << "new newBiasData failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (0 != memset_s(newBiasData, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) { | |||
| MS_LOG(ERROR) << "memset newBiasData failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (oldBiasTensor != nullptr) { | |||
| auto *biasData = reinterpret_cast<float *>(oldBiasTensor->data.data()); | |||
| for (size_t i = 0; i < kernelNum; i++) { | |||
| this->newBiasData[i] = biasData[i] * transScale[i] + transBias[i]; | |||
| } | |||
| } else { | |||
| if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), transBias, kernelNum * sizeof(float))) { | |||
| MS_LOG(ERROR) << "memcpy_s newBiasData failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| auto *newCharBiasData = reinterpret_cast<uint8_t *>(newBiasData); | |||
| std::vector<uint8_t> tmpBiasVec(newCharBiasData, newCharBiasData + kernelNum * sizeof(float) / sizeof(uint8_t)); | |||
| this->newBiasTensor = std::unique_ptr<TensorT>(new (std::nothrow) TensorT); | |||
| if (this->newBiasTensor == nullptr) { | |||
| MS_LOG(ERROR) << "new newBiasTensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| // todo biasShape | |||
| this->newBiasTensor->dims = {kernelNum}; | |||
| this->newBiasTensor->dataType = oldWeightTensor->dataType; | |||
| this->newBiasTensor->format = oldWeightTensor->format; | |||
| this->newBiasTensor->refCount = oldWeightTensor->refCount; | |||
| this->newBiasTensor->data.swap(tmpBiasVec); | |||
| delete (this->newBiasData); | |||
| newCharBiasData = nullptr; | |||
| newBiasData = nullptr; | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvScaleBiasFusionPass::CalConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr<Path> &convPath, | |||
| int32_t kernelNum) { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(convPath != nullptr); | |||
| auto convNode = graph->nodes.at(convPath->nodeIdx).get(); | |||
| MS_ASSERT(convNode != nullptr); | |||
| auto convWeightTensorIdxes = convNode->inputIndex; | |||
| convWeightTensorIdxes.erase(convWeightTensorIdxes.begin()); | |||
| TensorT *weightTensor = nullptr; | |||
| TensorT *biasTensor = nullptr; | |||
| if (convWeightTensorIdxes.size() == CONV_OP_NO_BIAS_WEIGHT_NUM) { | |||
| weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get(); | |||
| } else if (convWeightTensorIdxes.size() == CONV_OP_HAS_BIAS_WEIGHT_NUM) { | |||
| weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get(); | |||
| biasTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_BIAS_INDEX_IN_WEIGHT]).get(); | |||
| } else { | |||
| MS_LOG(ERROR) << "Conv2D should has " << CONV_OP_NO_BIAS_WEIGHT_NUM << " or " << CONV_OP_HAS_BIAS_WEIGHT_NUM | |||
| << " weight tensors, current number of weight tensors " << convWeightTensorIdxes.size(); | |||
| return RET_ERROR; | |||
| } | |||
| if (weightTensor == nullptr) { | |||
| MS_LOG(ERROR) << "Conv2D's weight tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto weightShape = weightTensor->dims; | |||
| if (weightShape.size() != CONV_FILTER_SHAPE_SIZE) { | |||
| MS_LOG(ERROR) << "Size of dims of weight tensor should be " << CONV_FILTER_SHAPE_SIZE << " rather than " | |||
| << weightShape.size(); | |||
| return RET_ERROR; | |||
| } | |||
| size_t kernelSize = GetShapeSize(*weightTensor) / kernelNum; | |||
| // cal new weightData | |||
| auto status = CalNewWeightTensor(weightTensor, kernelNum, kernelSize); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CalNewWeightTensor error " << status; | |||
| return status; | |||
| } | |||
| // cal new biasData | |||
| status = CalNewBiasTensor(weightTensor, biasTensor, kernelNum); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CalNewBiasTensor error " << status; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvScaleBiasFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); } | |||
| ConvScaleBiasFusionPass::~ConvScaleBiasFusionPass() { | |||
| if (this->transScale != nullptr) { | |||
| delete (this->transScale); | |||
| } | |||
| if (this->transBias != nullptr) { | |||
| delete (this->transBias); | |||
| } | |||
| if (this->newWeightData != nullptr) { | |||
| delete (this->newWeightData); | |||
| } | |||
| if (this->newBiasData != nullptr) { | |||
| delete (this->newBiasData); | |||
| } | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,67 +0,0 @@ | |||
| /* | |||
| * Copyright (c) Huawei Technologies Co., Ltd. 2018-2019. All rights reserved. | |||
| * Description: mslite | |||
| * Author: mslite | |||
| * Create: 2019-12-13 | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H | |||
| #define MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| struct BNWeightTensors { | |||
| schema::TensorT *meanTensor = nullptr; | |||
| schema::TensorT *varianceTensor = nullptr; | |||
| schema::TensorT *scaleTensor = nullptr; | |||
| schema::TensorT *biasTensor = nullptr; | |||
| }; | |||
| class ConvScaleBiasFusionPass : public FusionPass { | |||
| public: | |||
| ConvScaleBiasFusionPass() = default; | |||
| ~ConvScaleBiasFusionPass() override; | |||
| STATUS DefinePattern() override = 0; | |||
| // 1. generate biasTensor according to BN weightTensor | |||
| // 2. change attr of conv | |||
| // 3. delete BN node | |||
| STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| protected: | |||
| // call GetTransParam() and CalConvWeightTensors() | |||
| STATUS GenConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr<Path> &convPath, | |||
| std::shared_ptr<Path> dstPath); | |||
| // fill this->transScale and this->transBias | |||
| virtual STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr<Path> dstPath, int32_t kernelNum) = 0; | |||
| // fill this->newWeightTensor and this->newBiasTensor according to this->transScale and this->transBias | |||
| STATUS CalConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr<Path> &convPath, int32_t kernelNum); | |||
| STATUS CalNewWeightTensor(schema::TensorT *oldWeightTensor, int32_t kernelNum, size_t kernelSize); | |||
| STATUS CalNewBiasTensor(schema::TensorT *oldWeightTensor, schema::TensorT *oldBiasTensor, int32_t kernelNum); | |||
| protected: | |||
| float *transScale = nullptr; | |||
| float *transBias = nullptr; | |||
| float *newWeightData = nullptr; | |||
| float *newBiasData = nullptr; | |||
| std::unique_ptr<schema::TensorT> newWeightTensor = nullptr; | |||
| std::unique_ptr<schema::TensorT> newBiasTensor = nullptr; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H | |||
| @@ -1,126 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h" | |||
| #include "securec/include/securec.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #define SCALE_OP_NO_BIAS_WEIGHT_NUM 1 | |||
| #define SCALE_OP_HAS_BIAS_WEIGHT_NUM 2 | |||
| #define SCALE_OP_SCALE_INDEX_IN_WEIGHT 0 | |||
| #define SCALE_OP_BIAS_INDEX_IN_WEIGHT 1 | |||
| STATUS ConvScaleFusionPass::DefinePattern() { | |||
| auto convOp = std::make_shared<PatternOp>(); | |||
| convOp->id = kConvName; | |||
| convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; | |||
| auto scaleOp = std::make_shared<PatternOp>(); | |||
| scaleOp->id = DST_NAME; | |||
| scaleOp->types = {schema::PrimitiveType_Scale}; | |||
| scaleOp->left = convOp; | |||
| std::unique_ptr<FusionPattern> fusionPattern(new (std::nothrow) FusionPattern("ConvScaleFusion")); | |||
| if (fusionPattern == nullptr) { | |||
| MS_LOG(ERROR) << "new fusionPattern failed"; | |||
| return RET_ERROR; | |||
| } | |||
| fusionPattern->AddPatternOp(convOp); | |||
| fusionPattern->AddPatternOp(scaleOp); | |||
| fusionPattern->Finish(); | |||
| this->patterns.emplace_back(fusionPattern.release()); | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvScaleFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||
| return ConvScaleBiasFusionPass::DoFusion(graph, patternName, matchedPath); | |||
| } | |||
| STATUS ConvScaleFusionPass::Run(schema::MetaGraphT *graph) { return ConvScaleBiasFusionPass::Run(graph); } | |||
| STATUS ConvScaleFusionPass::GetTransParam(schema::MetaGraphT *graph, std::shared_ptr<Path> scalePath, | |||
| int32_t kernelNum) { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(scalePath != nullptr); | |||
| auto scaleNode = graph->nodes.at(scalePath->nodeIdx).get(); | |||
| MS_ASSERT(scaleNode != nullptr); | |||
| auto scaleWeightTensorIdxes = scaleNode->inputIndex; | |||
| scaleWeightTensorIdxes.erase(scaleWeightTensorIdxes.begin()); | |||
| schema::TensorT *scaleTensor = nullptr; | |||
| schema::TensorT *biasTensor = nullptr; | |||
| if (scaleWeightTensorIdxes.size() == SCALE_OP_NO_BIAS_WEIGHT_NUM) { | |||
| scaleTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_SCALE_INDEX_IN_WEIGHT]).get(); | |||
| } else if (scaleWeightTensorIdxes.size() == SCALE_OP_HAS_BIAS_WEIGHT_NUM) { | |||
| scaleTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_SCALE_INDEX_IN_WEIGHT]).get(); | |||
| biasTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_BIAS_INDEX_IN_WEIGHT]).get(); | |||
| } else { | |||
| MS_LOG(ERROR) << "Scale should has %d or %d weight tensors, current number of weight tensors %zu"; | |||
| // SCALE_OP_NO_BIAS_WEIGHT_NUM, SCALE_OP_HAS_BIAS_WEIGHT_NUM, scaleWeightTensorIdxes.size()); | |||
| return RET_ERROR; | |||
| } | |||
| if (scaleTensor == nullptr) { | |||
| MS_LOG(ERROR) << "Scale's scale tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (kernelNum != scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { | |||
| MS_LOG(ERROR) << "conv kernel num %u is expected to be equal to scale size(%lu)"; | |||
| //, kernelNum, scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)); | |||
| return RET_ERROR; | |||
| } | |||
| const float *scaleData = reinterpret_cast<float *>(scaleTensor->data.data()); | |||
| if (0 != memcpy_s(transScale, kernelNum * sizeof(float), scaleData, kernelNum * sizeof(float))) { | |||
| MS_LOG(ERROR) << "memcpy_s transScale failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (biasTensor != nullptr) { | |||
| if (kernelNum != biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { | |||
| MS_LOG(ERROR) << "conv kernel num %u is expected to be equal to bias size(%lu)"; | |||
| //, kernelNum, biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)); | |||
| return RET_ERROR; | |||
| } | |||
| const float *biasData = reinterpret_cast<float *>(biasTensor->data.data()); | |||
| if (0 != memcpy_s(transBias, kernelNum * sizeof(float), biasData, kernelNum * sizeof(float))) { | |||
| MS_LOG(ERROR) << "memcpy_s transBias failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,46 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H | |||
| #define MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h" | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class ConvScaleFusionPass : public ConvScaleBiasFusionPass { | |||
| public: | |||
| ConvScaleFusionPass() = default; | |||
| ~ConvScaleFusionPass() override = default; | |||
| STATUS DefinePattern() override; | |||
| STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| private: | |||
| STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr<Path> scalePath, int32_t kernelNum) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H | |||
| @@ -4,6 +4,8 @@ add_library(graph_pass_mid OBJECT | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/weight_format_hardcode_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/weight_format_transform_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/unused_node_remove_pass.cc | |||
| ) | |||
| @@ -25,147 +25,134 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| if (node->primitive->value.type != PrimitiveType_Eltwise) { | |||
| continue; | |||
| } | |||
| auto node_name = node->name; | |||
| auto input_node_indexes = GetInputNodeIdx(*graph, *node); | |||
| auto pre_type = schema::PrimitiveType_NONE; | |||
| size_t has_trans_count = 0; | |||
| auto can_fusion = true; | |||
| for (auto input_node_index : input_node_indexes) { | |||
| MS_ASSERT(graph->nodes.size() > input_node_index); | |||
| auto &pre_node = graph->nodes.at(input_node_index); | |||
| MS_ASSERT(pre_node != nullptr); | |||
| if (pre_type == schema::PrimitiveType_NONE) { | |||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||
| pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| pre_type = pre_node->primitive->value.type; | |||
| bool EltwiseFormatTransPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { | |||
| auto input_node_indexes = GetInputNodeIdx(*graph, *node); | |||
| pre_type_ = schema::PrimitiveType_NONE; | |||
| size_t has_trans_count = 0; | |||
| auto can_fusion = true; | |||
| for (auto input_node_index : input_node_indexes) { | |||
| MS_ASSERT(graph->nodes.size() > input_node_index); | |||
| auto &pre_node = graph->nodes.at(input_node_index); | |||
| MS_ASSERT(pre_node != nullptr); | |||
| if (pre_type_ == schema::PrimitiveType_NONE) { | |||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||
| pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| pre_type_ = pre_node->primitive->value.type; | |||
| has_trans_count++; | |||
| } | |||
| } else { | |||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||
| pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| if (pre_type_ != pre_node->primitive->value.type) { | |||
| can_fusion = false; | |||
| break; | |||
| } else { | |||
| has_trans_count++; | |||
| } | |||
| } else { | |||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||
| pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| if (pre_type != pre_node->primitive->value.type) { | |||
| can_fusion = false; | |||
| break; | |||
| } else { | |||
| has_trans_count++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (!can_fusion) { | |||
| continue; | |||
| } | |||
| auto output_node_indexes = GetOutputNodeIdx(*graph, *node); | |||
| auto post_type = schema::PrimitiveType_NONE; | |||
| for (auto output_node_index : output_node_indexes) { | |||
| MS_ASSERT(graph->nodes.size() > output_node_index); | |||
| auto &post_node = graph->nodes.at(output_node_index); | |||
| MS_ASSERT(post_node != nullptr); | |||
| if (post_type == schema::PrimitiveType_NONE) { | |||
| if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||
| post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| post_type = post_node->primitive->value.type; | |||
| } | |||
| if (!can_fusion) { | |||
| return false; | |||
| } | |||
| auto output_node_indexes = GetOutputNodeIdx(*graph, *node); | |||
| post_type_ = schema::PrimitiveType_NONE; | |||
| for (auto output_node_index : output_node_indexes) { | |||
| MS_ASSERT(graph->nodes.size() > output_node_index); | |||
| auto &post_node = graph->nodes.at(output_node_index); | |||
| MS_ASSERT(post_node != nullptr); | |||
| if (post_type_ == schema::PrimitiveType_NONE) { | |||
| if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||
| post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| post_type_ = post_node->primitive->value.type; | |||
| has_trans_count++; | |||
| } | |||
| } else { | |||
| if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||
| post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| if (post_type_ != post_node->primitive->value.type) { | |||
| can_fusion = false; | |||
| break; | |||
| } else { | |||
| has_trans_count++; | |||
| } | |||
| } else { | |||
| if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||
| post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| if (post_type != post_node->primitive->value.type) { | |||
| can_fusion = false; | |||
| break; | |||
| } else { | |||
| has_trans_count++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (!can_fusion) { | |||
| continue; | |||
| } | |||
| auto total_node_count = input_node_indexes.size() + output_node_indexes.size(); | |||
| size_t half_count = total_node_count / 2; | |||
| if (total_node_count % 2 == 0) { | |||
| can_fusion = has_trans_count > half_count; | |||
| } else { | |||
| can_fusion = has_trans_count >= half_count; | |||
| } | |||
| if (!can_fusion) { | |||
| return false; | |||
| } | |||
| if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { | |||
| return false; | |||
| } | |||
| auto total_node_count = input_node_indexes.size() + output_node_indexes.size(); | |||
| size_t half_count = total_node_count / 2; | |||
| if (total_node_count % 2 == 0) { | |||
| can_fusion = has_trans_count > half_count; | |||
| } else { | |||
| can_fusion = has_trans_count >= half_count; | |||
| } | |||
| return can_fusion; | |||
| } | |||
| STATUS EltwiseFormatTransPass::FindOutTransType() { | |||
| pre_insert_trans_type_ = kNHWC2NCHW; | |||
| post_insert_trans_type_ = kNHWC2NCHW; | |||
| if (pre_type_ == PrimitiveType_NONE && post_type_ != PrimitiveType_NONE) { | |||
| pre_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; | |||
| post_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||
| } else if (pre_type_ != PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { | |||
| pre_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||
| post_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; | |||
| } else if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { | |||
| MS_ASSERT(false); | |||
| } else { | |||
| if (pre_type_ == post_type_) { | |||
| MS_LOG(ERROR) << "Unknow error"; | |||
| return RET_ERROR; | |||
| } | |||
| if (!can_fusion) { | |||
| pre_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||
| post_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| if (node->primitive->value.type != PrimitiveType_Eltwise) { | |||
| continue; | |||
| } | |||
| FormatTransNodeType pre_insert_trans_type = kNHWC2NCHW; | |||
| FormatTransNodeType post_insert_trans_type = kNHWC2NCHW; | |||
| if (pre_type == PrimitiveType_NONE && post_type != PrimitiveType_NONE) { | |||
| pre_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; | |||
| post_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||
| } else if (pre_type != PrimitiveType_NONE && post_type == PrimitiveType_NONE) { | |||
| pre_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||
| post_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; | |||
| } else if (pre_type == PrimitiveType_NONE && post_type == PrimitiveType_NONE) { | |||
| auto node_name = node->name; | |||
| if (!CanFusion(graph, node)) { | |||
| continue; | |||
| } else { | |||
| if (pre_type == post_type) { | |||
| MS_LOG(ERROR) << "Unknow error"; | |||
| return RET_ERROR; | |||
| } | |||
| pre_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||
| post_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||
| } | |||
| auto ret = FindOutTransType(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "FindOutTransType error"; | |||
| return ret; | |||
| } | |||
| STATUS status = RET_OK; | |||
| auto input_tensor_size = (*iter)->inputIndex.size(); | |||
| for (auto i = 0; i < input_tensor_size; i++) { | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type, &status); | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Insert" << pre_insert_trans_type << "before " << (*iter)->name << " failed"; | |||
| MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed"; | |||
| return status; | |||
| } | |||
| } | |||
| auto output_tensor_size = (*iter)->outputIndex.size(); | |||
| for (auto i = 0; i < output_tensor_size; i++) { | |||
| iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type, &status); | |||
| iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type_, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Insert" << post_insert_trans_type << "Node before " << (*iter)->name << " failed"; | |||
| MS_LOG(ERROR) << "Insert" << post_insert_trans_type_ << "Node before " << (*iter)->name << " failed"; | |||
| return status; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| NodeIter EltwiseFormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, | |||
| InsertPlace place, size_t inoutIdx, FormatTransNodeType nodeType, | |||
| STATUS *errorCode) { | |||
| MS_ASSERT((*existNodeIter) != nullptr); | |||
| auto existNodeName = (*existNodeIter)->name; | |||
| std::string tileName; | |||
| if (place == kBefore) { | |||
| tileName = existNodeName + "_pre"; | |||
| } else { | |||
| tileName = existNodeName + "_post"; | |||
| } | |||
| auto transNode = std::make_unique<schema::CNodeT>(); | |||
| transNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (nodeType == kNCHW2NHWC) { | |||
| transNode->name = "nchw2nhwc_" + tileName + std::to_string(id++); | |||
| transNode->primitive->value.type = schema::PrimitiveType_Nchw2Nhwc; | |||
| } else { | |||
| transNode->name = "nhwc2nchw_" + tileName + std::to_string(id++); | |||
| transNode->primitive->value.type = schema::PrimitiveType_Nhwc2Nchw; | |||
| } | |||
| return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode); | |||
| } | |||
| void EltwiseFormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } | |||
| void EltwiseFormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -17,7 +17,7 @@ | |||
| #ifndef MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H | |||
| #define MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H | |||
| #include "tools/converter/optimizer.h" | |||
| #include <memory> | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" | |||
| @@ -25,26 +25,24 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class EltwiseFormatTransPass : public GraphPass { | |||
| class EltwiseFormatTransPass : public FormatTransPass { | |||
| public: | |||
| EltwiseFormatTransPass() : id(0) {} | |||
| EltwiseFormatTransPass() : FormatTransPass() {} | |||
| ~EltwiseFormatTransPass() override = default; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| void SetQuantType(QuantType quantType); | |||
| void SetFmk(converter::FmkType fmkType); | |||
| private: | |||
| NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, | |||
| FormatTransNodeType nodeType, STATUS *errorCode); | |||
| bool CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node); | |||
| STATUS FindOutTransType(); | |||
| private: | |||
| size_t id; | |||
| QuantType quantType = QuantType_QUANT_NONE; | |||
| converter::FmkType fmkType = converter::FmkType_TF; | |||
| FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW; | |||
| FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW; | |||
| schema::PrimitiveType pre_type_ = schema::PrimitiveType_NONE; | |||
| schema::PrimitiveType post_type_ = schema::PrimitiveType_NONE; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -37,16 +37,19 @@ class FormatTransPass : public GraphPass { | |||
| void SetFmk(converter::FmkType fmkType); | |||
| protected: | |||
| NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, | |||
| FormatTransNodeType nodeType, STATUS *errorCode); | |||
| private: | |||
| STATUS DoModelInputFormatTrans(schema::MetaGraphT *graph); | |||
| STATUS DoNodeInoutFormatTrans(schema::MetaGraphT *graph); | |||
| NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, | |||
| FormatTransNodeType nodeType, STATUS *errorCode); | |||
| protected: | |||
| size_t id = 0; | |||
| private: | |||
| size_t id; | |||
| QuantType quantType = QuantType_QUANT_NONE; | |||
| converter::FmkType fmkType = converter::FmkType_TF; | |||
| }; | |||
| @@ -54,4 +57,3 @@ class FormatTransPass : public GraphPass { | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H | |||
| @@ -0,0 +1,213 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h" | |||
| #include "tools/common/converter_op_utils.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "src/common/utils.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| void WeightFormatHardCodePass::SetQuantType(QuantType quantType) { this->quantType = quantType; } | |||
| void WeightFormatHardCodePass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; } | |||
| // pre set tensor format | |||
| // non quant, filterFormat: | |||
| // conv deconv depth dedepth | |||
| // caffe K(C/g)HW C(K/g)HW / / // todo with deconvOp | |||
| // tf HWCK HWKC HWCK HWKC | |||
| // onnx K(C/g)HW C(K/g)HW / / | |||
| // awareing quant, filterFormat: | |||
| // conv deconv depth dedepth | |||
| // onnx KHWC ? CHWK ? | |||
| // tf HWCK ? HWCK ? | |||
| STATUS WeightFormatHardCodePass::Run(MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| for (auto &node : graph->nodes) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(node->primitive != nullptr); | |||
| auto opType = node->primitive->value.type; | |||
| if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D && opType != PrimitiveType_DeConv2D && | |||
| opType != PrimitiveType_DeDepthwiseConv2D) { | |||
| continue; | |||
| } | |||
| MS_ASSERT(node->inputIndex.size() >= 2); | |||
| auto weightIndex = node->inputIndex.at(1); | |||
| MS_ASSERT(subGraph->allTensors.size() > weightIndex); | |||
| auto &weightTensor = graph->allTensors[weightIndex]; | |||
| MS_ASSERT(weightTensor->dims.size() == 4 || weightTensor->dims.empty()); // for conv with fakqQuant before weight | |||
| STATUS status; | |||
| switch (fmkType) { | |||
| case converter::FmkType_CAFFE: | |||
| status = HardCodeCAFFE(node, weightTensor); | |||
| break; | |||
| case converter::FmkType_TFLITE: | |||
| status = HardCodeTFLITE(node, weightTensor); | |||
| break; | |||
| case converter::FmkType_ONNX: | |||
| status = HardCodeONNX(node, weightTensor); | |||
| break; | |||
| case converter::FmkType_MS: | |||
| status = HardCodeMS(node, weightTensor); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported fmkType: " << fmkType << ", node: " << node->name; | |||
| return RET_ERROR; | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Format hardCode faild: " << status << ", node: " << node->name; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightFormatHardCodePass::HardCodeCAFFE(const std::unique_ptr<CNodeT> &node, | |||
| const std::unique_ptr<TensorT> &weightTensor) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(weightTensor != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(node->primitive != nullptr); | |||
| auto opType = node->primitive->value.type; | |||
| switch (this->quantType) { | |||
| case QuantType_QUANT_NONE: { | |||
| if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D || | |||
| opType == schema::PrimitiveType_DeConv2D || opType == schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| weightTensor->format = Format_KCHW; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; | |||
| } | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(node->quantType) << ", node: " << node->name; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightFormatHardCodePass::HardCodeONNX(const std::unique_ptr<CNodeT> &node, | |||
| const std::unique_ptr<TensorT> &weightTensor) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(weightTensor != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(node->primitive != nullptr); | |||
| auto opType = node->primitive->value.type; | |||
| switch (this->quantType) { | |||
| case QuantType_AwareTraining: { | |||
| // sum up from current onnx quant models | |||
| if (opType == PrimitiveType_Conv2D) { | |||
| weightTensor->format = Format_KHWC; | |||
| } else if (opType == PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = Format_CHWK; | |||
| } else if (opType == PrimitiveType_DeConv2D) { | |||
| weightTensor->format = Format_CKHW; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; | |||
| return RET_ERROR; | |||
| } | |||
| } break; | |||
| case QuantType_QUANT_NONE: { | |||
| // conv (K x C/group x kH x kW) group = 1 | |||
| // depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W) | |||
| // deconv (C x K/group x kH x kW) group = 1 | |||
| // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) | |||
| if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = Format_KCHW; | |||
| } else if (opType == PrimitiveType_DeConv2D) { | |||
| weightTensor->format = Format_CKHW; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; | |||
| return RET_ERROR; | |||
| } | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(node->quantType) << ", node: " << node->name; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightFormatHardCodePass::HardCodeMS(const std::unique_ptr<CNodeT> &node, | |||
| const std::unique_ptr<TensorT> &weightTensor) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(weightTensor != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(node->primitive != nullptr); | |||
| auto opType = node->primitive->value.type; | |||
| switch (this->quantType) { | |||
| case QuantType_AwareTraining: { | |||
| if (opType == schema::PrimitiveType_Conv2D) { | |||
| weightTensor->format = schema::Format_HWCK; | |||
| } else if (opType == PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = Format_CKHW; | |||
| } else { | |||
| weightTensor->format = schema::Format_HWKC; | |||
| } | |||
| } break; | |||
| case QuantType_QUANT_NONE: { | |||
| // sum up from current ms quant models | |||
| if (opType == PrimitiveType_Conv2D) { | |||
| weightTensor->format = Format_KCHW; | |||
| } else if (opType == PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = Format_CKHW; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; | |||
| return RET_ERROR; | |||
| } | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(node->quantType) << ", node: " << node->name; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightFormatHardCodePass::HardCodeTFLITE(const std::unique_ptr<CNodeT> &node, | |||
| const std::unique_ptr<TensorT> &weightTensor) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(weightTensor != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(node->primitive != nullptr); | |||
| auto opType = node->primitive->value.type; | |||
| switch (this->quantType) { | |||
| case QuantType_AwareTraining: | |||
| case QuantType_PostTraining: | |||
| case QuantType_QUANT_NONE: { | |||
| if (opType == schema::PrimitiveType_Conv2D) { | |||
| weightTensor->format = schema::Format_KHWC; | |||
| } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = schema::Format_CHWK; | |||
| } else if (opType == schema::PrimitiveType_DeConv2D) { | |||
| weightTensor->format = schema::Format_CHWK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; | |||
| return RET_ERROR; | |||
| } | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H | |||
| #include <memory> | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/converter/optimizer.h" | |||
| #include "tools/common/graph_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class WeightFormatHardCodePass : public GraphPass { | |||
| public: | |||
| WeightFormatHardCodePass() = default; | |||
| ~WeightFormatHardCodePass() override = default; | |||
| void SetQuantType(QuantType quantType); | |||
| void SetFmkType(converter::FmkType fmkType); | |||
| STATUS Run(MetaGraphT *graph) override; | |||
| private: | |||
| STATUS HardCodeCAFFE(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor); | |||
| STATUS HardCodeTFLITE(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor); | |||
| STATUS HardCodeONNX(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor); | |||
| STATUS HardCodeMS(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor); | |||
| private: | |||
| QuantType quantType = QuantType_QUANT_NONE; | |||
| converter::FmkType fmkType = converter::FmkType_TF; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H | |||
| @@ -0,0 +1,140 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h" | |||
| #include <queue> | |||
| #include "tools/common/node_util.h" | |||
| #include "tools/common/converter_op_utils.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "src/common/utils.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| void WeightFormatTransformPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } | |||
| void WeightFormatTransformPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; } | |||
| void WeightFormatTransformPass::SetDstFormat(Format format) { this->dstFormat = format; } | |||
| STATUS WeightFormatTransformPass::Run(MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (this->quantType == QuantType_AwareTraining) { | |||
| auto status = QuantDataFormatTrans(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; | |||
| return status; | |||
| } | |||
| } else { | |||
| auto status = NonQuantDataFormatTrans(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "NonQuantDataFormatTrans failed: " << status; | |||
| return status; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightFormatTransformPass::QuantDataFormatTrans(MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| for (auto &node : graph->nodes) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(node->primitive != nullptr); | |||
| auto opType = node->primitive->value.type; | |||
| if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D) { | |||
| continue; | |||
| } | |||
| MS_ASSERT(node->inputIndex.size() >= 2); | |||
| auto weightIndex = node->inputIndex.at(1); | |||
| MS_ASSERT(subGraph->allTensors.size() > weightIndex); | |||
| auto &weightTensor = graph->allTensors[weightIndex]; | |||
| MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT); | |||
| STATUS status; | |||
| if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D) { // weight should be HWCK | |||
| Format curDstFormat; | |||
| if (this->dstFormat == Format_NUM_OF_FORMAT) { | |||
| curDstFormat = Format_KHWC; | |||
| } else { | |||
| curDstFormat = this->dstFormat; | |||
| } | |||
| status = TransFilterFormat(weightTensor.get(), curDstFormat); | |||
| if (status == RET_OK) { | |||
| // node->primitive->value.AsConv2D()->format = schema::Format_NHWC; | |||
| weightTensor->format = curDstFormat; | |||
| } else { | |||
| MS_LOG(WARNING) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" | |||
| << EnumNameFormat(curDstFormat) << " failed, node : " << node->name; | |||
| // todo(00445839): consider varible weight condition | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightFormatTransformPass::NonQuantDataFormatTrans(MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| for (auto &node : graph->nodes) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(node->primitive != nullptr); | |||
| auto opType = node->primitive->value.type; | |||
| if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D && opType != PrimitiveType_DeConv2D && | |||
| opType != PrimitiveType_DeDepthwiseConv2D) { | |||
| continue; | |||
| } | |||
| MS_ASSERT(node->inputIndex.size() >= 2); | |||
| auto weightIndex = node->inputIndex.at(1); | |||
| MS_ASSERT(subGraph->allTensors.size() > weightIndex); | |||
| auto &weightTensor = graph->allTensors[weightIndex]; | |||
| MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT); | |||
| STATUS status; | |||
| if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D || | |||
| opType == schema::PrimitiveType_DeConv2D) { | |||
| Format curDstFormat; | |||
| if (this->dstFormat == Format_NUM_OF_FORMAT) { | |||
| curDstFormat = Format_KHWC; | |||
| } else { | |||
| curDstFormat = this->dstFormat; | |||
| } | |||
| status = TransFilterFormat(weightTensor.get(), curDstFormat); | |||
| if (status == RET_OK) { | |||
| // node->attr.AsConv2D()->format = Format_NCHW; | |||
| weightTensor->format = curDstFormat; | |||
| } else { | |||
| MS_LOG(WARNING) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" | |||
| << EnumNameFormat(curDstFormat) << " failed, node : " << node->name; | |||
| // todo(00445839): consider varible weight condition | |||
| } | |||
| } else { // weight should be CKHW | |||
| Format curDstFormat; | |||
| if (this->dstFormat == Format_NUM_OF_FORMAT) { | |||
| curDstFormat = Format_KHWC; | |||
| } else { | |||
| curDstFormat = this->dstFormat; | |||
| } | |||
| status = TransFilterFormat(weightTensor.get(), curDstFormat); | |||
| if (status == RET_OK) { | |||
| // node->attr.AsDepthwiseConv2D()->format = Format_NCHW; | |||
| weightTensor->format = curDstFormat; | |||
| } else { | |||
| MS_LOG(WARNING) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" | |||
| << EnumNameFormat(curDstFormat) << " failed, node : " << node->name; | |||
| // todo(00445839): consider varible weight condition | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -14,45 +14,40 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H | |||
| #define MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H | |||
| #include "tools/converter/optimizer.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class WeightFormatPass : public NodePass { | |||
| class WeightFormatTransformPass : public GraphPass { | |||
| public: | |||
| WeightFormatPass() = default; | |||
| WeightFormatTransformPass() = default; | |||
| ~WeightFormatPass() override = default; | |||
| ~WeightFormatTransformPass() override = default; | |||
| void SetQuantType(QuantType quantType); | |||
| void SetFmkType(converter::FmkType fmkType); | |||
| int Run(GraphNode *graphNode) override; | |||
| void SetDstFormat(Format format); | |||
| private: | |||
| // correct weightTensor->Format | |||
| int ShapeFormatTrans(GraphNode *graphNode); | |||
| STATUS Run(MetaGraphT *graph) override; | |||
| // transform weightTensor data and format | |||
| // if quant : conv transform dataFormat to NHWC, weight format to HWCK | |||
| // if quant : depth transform dataFormat to NCHW, weight format to CKHW | |||
| int QuantDataFormatTrans(GraphNode *graphNode); | |||
| private: | |||
| STATUS QuantDataFormatTrans(MetaGraphT *graph); | |||
| // if no quant : transform dataFormat to NCHW, weight format to KCHW/CKHW | |||
| int NonQuantDataFormatTrans(GraphNode *graphNode); | |||
| STATUS NonQuantDataFormatTrans(MetaGraphT *graph); | |||
| private: | |||
| QuantType quantType = QuantType_QUANT_NONE; | |||
| converter::FmkType fmkType = converter::FmkType_TF; | |||
| Format dstFormat = Format_NUM_OF_FORMAT; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H | |||
| @@ -1,3 +0,0 @@ | |||
| add_library(node_mid OBJECT | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/weight_format_pass.cc | |||
| ) | |||
| @@ -1,407 +0,0 @@ | |||
| /** | |||
| * Copyright 201+ Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/legacy_optimizer/node/weight_format_pass.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "tools/common/tensor_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| int WeightFormatPass::Run(GraphNode *graphNode) { | |||
| MS_ASSERT(graphNode != nullptr); | |||
| auto status = ShapeFormatTrans(graphNode); | |||
| if (status != 0) { | |||
| MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status; | |||
| return status; | |||
| } | |||
| if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_PostTraining) { | |||
| status = QuantDataFormatTrans(graphNode); | |||
| if (status != 0) { | |||
| MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; | |||
| return status; | |||
| } | |||
| } else { | |||
| status = NonQuantDataFormatTrans(graphNode); | |||
| if (status != 0) { | |||
| MS_LOG(ERROR) << "NonQuantDataFormatTrans failed: " << status; | |||
| return status; | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| void WeightFormatPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } | |||
| void WeightFormatPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; } | |||
| // pre set tensor format | |||
| // non quant, filterFormat: | |||
| // conv deconv depth dedepth | |||
| // caffe K(C/g)HW C(K/g)HW / / | |||
| // tf HWCK HWKC HWCK HWKC | |||
| // onnx K(C/g)HW C(K/g)HW / / | |||
| // awareing quant, filterFormat: | |||
| // conv deconv depth dedepth | |||
| // onnx KHWC ? CHWK ? | |||
| // tf HWCK ? HWCK ? | |||
| int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||
| MS_ASSERT(graphNode != nullptr); | |||
| auto &subGraph = graphNode->subGraph; | |||
| auto &node = graphNode->opDef; | |||
| MS_ASSERT(subGraph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| auto opType = node->primitive->value.type; | |||
| if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && | |||
| opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| return 0; | |||
| } | |||
| MS_ASSERT(node->inputIndex.size() >= 2); | |||
| auto weightIndex = node->inputIndex.at(1); | |||
| MS_ASSERT(subGraph->allTensors.size() > weightIndex); | |||
| auto &weightTensor = subGraph->allTensors[weightIndex]; | |||
| auto &shape = weightTensor->dims; | |||
| MS_ASSERT(shape.size() == 4); | |||
| if (fmkType == converter::FmkType_CAFFE) { | |||
| switch (node->quantType) { | |||
| case QuantType_QUANT_NONE: { | |||
| if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D || | |||
| opType == schema::PrimitiveType_DeConv2D || opType == schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| weightTensor->format = schema::Format_KCHW; | |||
| } else { | |||
| MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType) | |||
| << ", node: " << node->name.c_str(); | |||
| return -1; | |||
| } | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Invalid quantType: " << schema::EnumNameQuantType(node->quantType) | |||
| << ", node: " << node->name.c_str(); | |||
| return -1; | |||
| } | |||
| } | |||
| return 0; | |||
| } else if (fmkType == converter::FmkType_MS) { | |||
| switch (node->quantType) { | |||
| case QuantType_AwareTraining: { | |||
| if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = schema::Format_HWCK; | |||
| } else { | |||
| weightTensor->format = schema::Format_HWKC; | |||
| } | |||
| } break; | |||
| case QuantType_QUANT_NONE: { | |||
| // conv [filter_height, filter_width, in_channels, out_channels] | |||
| // depthwise [filter_height, filter_width, in_channels, channel_multiplier] | |||
| if (opType == schema::PrimitiveType_Conv2D) { | |||
| weightTensor->format = schema::Format_KCHW; | |||
| } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = schema::Format_KCHW; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; | |||
| return -1; | |||
| } | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); | |||
| return -1; | |||
| } | |||
| } | |||
| return 0; | |||
| } else if (fmkType == converter::FmkType_TF) { | |||
| switch (node->quantType) { | |||
| case QuantType_AwareTraining: { | |||
| if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = schema::Format_HWCK; | |||
| } else { | |||
| weightTensor->format = schema::Format_HWKC; | |||
| } | |||
| } break; | |||
| case QuantType_QUANT_NONE: { | |||
| // conv [filter_height, filter_width, in_channels, out_channels] | |||
| // depthwise [filter_height, filter_width, in_channels, channel_multiplier] | |||
| if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = schema::Format_HWCK; | |||
| } else { | |||
| weightTensor->format = schema::Format_HWKC; | |||
| } | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); | |||
| return -1; | |||
| } | |||
| } | |||
| return 0; | |||
| } else if (fmkType == converter::FmkType_TFLITE) { | |||
| switch (node->quantType) { | |||
| case QuantType_QUANT_NONE: | |||
| case QuantType_AwareTraining: | |||
| case QuantType_PostTraining: { | |||
| if (opType == schema::PrimitiveType_Conv2D) { | |||
| weightTensor->format = schema::Format_KHWC; | |||
| } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = schema::Format_CHWK; | |||
| } else if (opType == schema::PrimitiveType_DeConv2D) { | |||
| weightTensor->format = schema::Format_CHWK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported format"; | |||
| return -1; | |||
| } | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); | |||
| return -1; | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "weight_tensor_format: " << weightTensor->format; | |||
| return 0; | |||
| } else if (fmkType == converter::FmkType_ONNX) { | |||
| switch (node->quantType) { | |||
| case QuantType_AwareTraining: { | |||
| // sum up from current onnx quant models | |||
| if (opType == schema::PrimitiveType_Conv2D) { | |||
| weightTensor->format = schema::Format_KHWC; | |||
| } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = schema::Format_CHWK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); | |||
| return -1; | |||
| } | |||
| } break; | |||
| case QuantType_QUANT_NONE: { | |||
| // conv (K x C/group x kH x kW) group = 1 | |||
| // depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W) | |||
| // deconv (C x K/group x kH x kW) group = 1 | |||
| // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) | |||
| if (opType == schema::PrimitiveType_Conv2D) { | |||
| weightTensor->format = schema::Format_KCHW; | |||
| } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = schema::Format_KCHW; | |||
| } else if (opType == schema::PrimitiveType_DeConv2D) { | |||
| weightTensor->format = schema::Format_CKHW; | |||
| } else { | |||
| MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); | |||
| return -1; | |||
| } | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported quantType: %d, node: " << node->quantType, node->name.c_str(); | |||
| return -1; | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Invalid fmkType: %d, node: " << fmkType, node->name.c_str(); | |||
| return -1; | |||
| } | |||
| return 0; | |||
| } | |||
| // inference needed filterFormat: | |||
| // conv deconv depth dedepth | |||
| // uint8 KHWC KHWC KHWC KHWC | |||
| int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||
| MS_ASSERT(graphNode != nullptr); | |||
| auto &subGraph = graphNode->subGraph; | |||
| auto &node = graphNode->opDef; | |||
| MS_ASSERT(subGraph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| auto opType = node->primitive->value.type; | |||
| if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && | |||
| opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| return RET_OK; | |||
| } | |||
| MS_ASSERT(node->inputIndex.size() >= 2); | |||
| auto weightIndex = node->inputIndex.at(1); | |||
| MS_ASSERT(subGraph->allTensors.size() > weightIndex); | |||
| auto &weightTensor = subGraph->allTensors[weightIndex]; | |||
| MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT | |||
| STATUS status = RET_OK; | |||
| if (opType == schema::PrimitiveType_Conv2D) { // weight should be KHWC | |||
| if (weightTensor->format == schema::Format_KCHW) { // from caffe | |||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||
| MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format | |||
| << weightTensor->dataType; | |||
| status = TransFilterFormat<int8_t>(weightTensor.get(), kKCHW2HWCK); | |||
| } else { | |||
| MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format | |||
| << weightTensor->dataType; | |||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | |||
| } | |||
| } else if (weightTensor->format != schema::Format_KHWC) { | |||
| MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | |||
| return -1; | |||
| } | |||
| if (status == 0) { | |||
| node->primitive->value.AsConv2D()->format = schema::Format_NHWC; | |||
| weightTensor->format = schema::Format_KHWC; | |||
| } else { | |||
| MS_LOG(WARNING) << "TransFilter %sToKHWC failed, node : " | |||
| << (weightTensor->format == schema::Format_KHWC ? "KHWC" : "KCHW") << node->name.c_str(); | |||
| // todo(00445839): consider varible weight condition | |||
| } | |||
| } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be KHWC | |||
| if (weightTensor->format == schema::Format_CKHW) { // from caffe | |||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||
| MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format | |||
| << "datatype: " << weightTensor->dataType; | |||
| status = TransFilterFormat<int8_t>(weightTensor.get(), kCKHW2KHWC); | |||
| } else if (weightTensor->dataType == kNumberTypeUInt8) { | |||
| MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format | |||
| << "datatype: " << weightTensor->dataType; | |||
| status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2KHWC); | |||
| } else { | |||
| MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format | |||
| << "datatype: " << weightTensor->dataType; | |||
| status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC); | |||
| } | |||
| } else if (weightTensor->format == schema::Format_CHWK) { // from onnx | |||
| if (weightTensor->dataType == kNumberTypeInt8) { | |||
| MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format | |||
| << "datatype: " << weightTensor->dataType; | |||
| status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC); | |||
| } else if (weightTensor->dataType == kNumberTypeUInt8) { | |||
| MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format | |||
| << "datatype: " << weightTensor->dataType; | |||
| status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2KHWC); | |||
| } else { | |||
| MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format | |||
| << "datatype: " << weightTensor->dataType; | |||
| status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); | |||
| } | |||
| } else if (weightTensor->format != schema::Format_KHWC) { | |||
| MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | |||
| return -1; | |||
| } | |||
| if (status == 0) { | |||
| node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; | |||
| weightTensor->format = schema::Format_KHWC; | |||
| } else { | |||
| MS_LOG(WARNING) << "TransFilter" << (weightTensor->format == schema::Format_KHWC ? "KHWC" : "CKHW") | |||
| << "To KHWC failed, node : " << node->name.c_str(); | |||
| // todo(00445839): consider varible weight condition | |||
| } | |||
| } else { // weight should be HWCK | |||
| node->primitive->value.AsDeConv2D()->format = schema::Format_NHWC; | |||
| weightTensor->format = schema::Format_KHWC; | |||
| } | |||
| return 0; | |||
| } | |||
| // inference needed filterFormat: | |||
| // conv deconv depth dedepth | |||
| // fp32 KCHW CKHW CKHW CKHW | |||
| int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { | |||
| MS_ASSERT(graphNode != nullptr); | |||
| auto &subGraph = graphNode->subGraph; | |||
| auto &node = graphNode->opDef; | |||
| MS_ASSERT(subGraph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| auto opType = node->primitive->value.type; | |||
| if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && | |||
| opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| return 0; | |||
| } | |||
| MS_ASSERT(node->inputIndex.size() >= 2); | |||
| auto weightIndex = node->inputIndex.at(1); | |||
| MS_ASSERT(subGraph->allTensors.size() > weightIndex); | |||
| auto &weightTensor = subGraph->allTensors[weightIndex]; | |||
| if (weightTensor->dataType != TypeId::kNumberTypeFloat32) { | |||
| MS_LOG(ERROR) << "weight tensor data should be float"; | |||
| // return -1; | |||
| } | |||
| STATUS status = RET_OK; | |||
| if (opType == schema::PrimitiveType_Conv2D) { // weight should be KCHW | |||
| if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms | |||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | |||
| } else if (weightTensor->format == schema::Format_KHWC) { | |||
| status = RET_OK; | |||
| } else if (weightTensor->format == schema::Format_CHWK) { | |||
| status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | |||
| return -1; | |||
| } | |||
| if (status == 0) { | |||
| node->primitive->value.AsConv2D()->format = schema::Format_NHWC; | |||
| weightTensor->format = schema::Format_KHWC; | |||
| } else { | |||
| MS_LOG(WARNING) << "TransFilter " << ((weightTensor->format == schema::Format_HWCK) ? "HWCK" : "NHWC") | |||
| << "ToKCHW failed, node : " << node->name.c_str(); | |||
| // todo(00445839): consider varible weight condition | |||
| } | |||
| } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be CKHW | |||
| if (fmkType == converter::FmkType_MS) { | |||
| weightTensor->format = schema::Format_CKHW; | |||
| } | |||
| if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms | |||
| status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC); | |||
| } else if (weightTensor->format == schema::Format_KCHW) { | |||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | |||
| } else if (weightTensor->format == schema::Format_CHWK) { // from tflite | |||
| status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | |||
| return -1; | |||
| } | |||
| if (status == 0) { | |||
| node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; | |||
| weightTensor->format = schema::Format_KHWC; | |||
| } else { | |||
| MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str(); | |||
| // todo(00445839): consider varible weight condition | |||
| } | |||
| } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC | |||
| if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms | |||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | |||
| } else if (weightTensor->format == schema::Format_CHWK) { // from tflite | |||
| status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | |||
| return -1; | |||
| } | |||
| if (status == 0) { | |||
| node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW; | |||
| weightTensor->format = schema::Format_KHWC; | |||
| } else { | |||
| MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str(); | |||
| // todo(00445839): consider varible weight condition | |||
| } | |||
| } else if (opType == schema::PrimitiveType_DeDepthwiseConv2D) { // weight should be KHWC | |||
| if (weightTensor->format == schema::Format_KHWC) { | |||
| return 0; | |||
| } else if (weightTensor->format == schema::Format_KCHW) { // from caffe | |||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | |||
| } else if (weightTensor->format == schema::Format_HWKC) { // from tf or onnx | |||
| status = TransFilterFormat<float>(weightTensor.get(), kHWKC2CKHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | |||
| return -1; | |||
| } | |||
| if (status == 0) { | |||
| node->primitive->value.AsDeDepthwiseConv2D()->format = schema::Format_NHWC; | |||
| weightTensor->format = schema::Format_CKHW; | |||
| } else { | |||
| MS_LOG(WARNING) << "TransFilter HWKCToCKHW failed, node : " << node->name.c_str(); | |||
| // todo(00445839): consider varible weight condition | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||