diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index a58524abb6..f5595794ff 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -353,7 +353,6 @@ if (BUILD_CONVERTER) tflite_parser_mid caffe_parser_mid onnx_parser_mid - node_mid graph_pass_mid fusion_mid quantizer_mid diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 6f9a4e52e2..1e92d9d069 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -24,74 +24,6 @@ namespace mindspore { namespace lite { -STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr &node) { - MS_ASSERT(graphT != nullptr); - MS_ASSERT(node != nullptr); - // set quantParam to preNode - for (size_t i = 0; i < node->inputIndex.size(); i++) { - auto preNodeIdexes = GetInputNodeIdx(*graphT, *(node.get()), i); - for (auto preNodeIdx : preNodeIdexes) { - MS_ASSERT(graphT->nodes.size() > preNodeIdx); - auto &preNode = graphT->nodes.at(preNodeIdx); - MS_ASSERT(preNode != nullptr); - // if preNode is not init, it maybe not a quantNode, so skip - // if (preNode->inputIndex.size() + preNode->outputIndex.size() != preNode->quantParam.size()) { - // continue; - // } - auto preNodeOutputIndexes = preNode->outputIndex; - int32_t currentNodeIndexInPre = -1; - for (auto index : preNodeOutputIndexes) { - currentNodeIndexInPre++; - if (index == node->inputIndex.at(i)) { - break; - } - } - MS_ASSERT(currentNodeIndexInPre != -1); - MS_ASSERT(node->quantParam.size() > i); - MS_ASSERT(node->quantParam.at(i) != nullptr); - // auto quantParamArrayCopy = CopyQuantParamArrayT(node->quantParam.at(i)); - // if (quantParamArrayCopy == nullptr) { - // //MS_LOG(ERROR)("CopyQuantParamArray return nullptr, node: %s", node->name.c_str()); - // return RET_ERROR; - // } - // preNode->quantParam.at(preNode->inputIndex.size() + currentNodeIndexInPre) = - // std::move(CopyQuantParamArrayT(quantParamArrayCopy)); - } - } - - // set quantParam to postNode - for (size_t i = 0; i < node->outputIndex.size(); i++) { - auto postNodeIdexes = GetOutputNodeIdx(*graphT, *(node.get()), i); - for (auto postNodeIdx : postNodeIdexes) { - MS_ASSERT(graphT->nodes.size() > postNodeIdx); - auto &postNode = graphT->nodes.at(postNodeIdx); - MS_ASSERT(postNode != nullptr); - // if postNode is not init, it maybe not a quantNode, so skip - // if (postNode->inputIndex.size() + postNode->outputIndex.size() != postNode->quantParam.size()) { - // continue; - // } - auto postNodeInputIndexes = postNode->inputIndex; - int32_t currentNodeIndexInPost = -1; - for (auto index : postNodeInputIndexes) { - currentNodeIndexInPost++; - if (index == node->outputIndex.at(i)) { - break; - } - } - MS_ASSERT(currentNodeIndexInPost != -1); - MS_ASSERT(node->quantParam.size() > node->inputIndex.size() + i); - MS_ASSERT(node->quantParam.at(node->inputIndex.size() + i) != nullptr); - // auto quantParamArrayCopy = CopyQuantParamArrayT(node->quantParam.at(node->inputIndex.size() + i)); - // if (quantParamArrayCopy == nullptr) { - // //MS_LOG(ERROR)("CopyQuantParamArray return nullptr, node: %s", node->name.c_str()); - // return RET_ERROR; - // } - // postNode->quantParam.at(currentNodeIndexInPost) = std::move(CopyQuantParamArrayT(quantParamArrayCopy)); - } - } - return RET_OK; -} - static const std::vector nhwcOpList = { schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, @@ -121,8 +53,8 @@ std::vector GetUint8OpList() { return uint8OpList; } STATUS NodeUtils::ConvertDims(mindspore::lite::Format src_format, const std::vector &src_dims, mindspore::lite::Format dst_format, std::vector *dst_dims) { if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) { - // MS_LOG(ERROR)("Convert format , src size %lu <3 or src format is equal to dst format,not need convert", - // src_dims.size()); + MS_LOG(ERROR) << "Convert format , src size " << src_dims.size() + << " <3 or src format is equal to dst format,not need convert"; *dst_dims = src_dims; return RET_PARAM_INVALID; } @@ -145,12 +77,12 @@ STATUS NodeUtils::ConvertDims(mindspore::lite::Format src_format, const std::vec } break; default: - // MS_LOG(ERROR)("Not support src format: %d", src_format); + MS_LOG(ERROR) << "Not support src format: " << schema::EnumNameFormat(src_format); return RET_ERROR; } if (nchw_dim.size() == 0) { - // MS_LOG(ERROR)("Param nchw_dim is empty!"); + MS_LOG(ERROR) << "Param nchw_dim is empty!"; return RET_ERROR; } @@ -172,7 +104,250 @@ STATUS NodeUtils::ConvertDims(mindspore::lite::Format src_format, const std::vec } return RET_OK; } + +STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { + if (tensor == nullptr) { + return RET_NULL_PTR; + } + std::vector oriDims = tensor->dims; + if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) { + MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size(); + return RET_ERROR; + } + auto srcFormat = tensor->format; + auto dataType = tensor->dataType; + STATUS status; + switch (dstFormat) { + case schema::Format_KHWC: { + switch (srcFormat) { + case schema::Format_KCHW: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kKCHW2KHWC); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kKCHW2KHWC); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kKCHW2KHWC); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + case schema::Format_CKHW: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kCKHW2KHWC); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kCKHW2KHWC); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kCKHW2KHWC); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + case schema::Format_CHWK: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kCHWK2KHWC); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kCHWK2KHWC); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kCHWK2KHWC); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + case schema::Format_KHWC: + return RET_OK; + default: + MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to " + << schema::EnumNameFormat(dstFormat); + return RET_ERROR; + } + } break; + case schema::Format_HWCK: { + switch (srcFormat) { + case schema::Format_KCHW: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kKCHW2HWCK); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kKCHW2HWCK); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kKCHW2HWCK); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + case schema::Format_KHWC: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kKHWC2HWCK); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kKHWC2HWCK); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kKHWC2HWCK); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + case schema::Format_CKHW: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kCKHW2HWCK); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kCKHW2HWCK); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kCKHW2HWCK); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + case schema::Format_CHWK: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kCHWK2HWCK); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kCHWK2HWCK); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kCHWK2HWCK); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + case schema::Format_HWCK: + return RET_OK; + default: + MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to " + << schema::EnumNameFormat(dstFormat); + return RET_ERROR; + } + } break; + case schema::Format_KCHW: { + switch (srcFormat) { + case schema::Format_KCHW: + return RET_OK; + case schema::Format_HWCK: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kHWCK2KCHW); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kHWCK2KCHW); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kHWCK2KCHW); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + case schema::Format_HWKC: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kHWKC2KCHW); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kHWKC2KCHW); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kHWKC2KCHW); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + case schema::Format_KHWC: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kKHWC2KCHW); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kKHWC2KCHW); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kKHWC2KCHW); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + case schema::Format_CKHW: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kCKHW2KCHW); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kCKHW2KCHW); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kCKHW2KCHW); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + case schema::Format_CHWK: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kCHWK2KCHW); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kCHWK2KCHW); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kCHWK2KCHW); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + default: + MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to " + << schema::EnumNameFormat(dstFormat); + return RET_ERROR; + } + } break; + case schema::Format_CKHW: { + switch (srcFormat) { + case schema::Format_HWCK: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kHWCK2CKHW); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kHWCK2CKHW); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kHWCK2CKHW); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + case schema::Format_HWKC: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kHWKC2CKHW); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kHWKC2CKHW); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kHWKC2CKHW); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + case schema::Format_KCHW: + if (dataType == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kKCHW2CKHW); + } else if (dataType == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kKCHW2CKHW); + } else if (dataType == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kKCHW2CKHW); + } else { + MS_LOG(ERROR) << "Unsupported dataType: " << dataType; + return RET_ERROR; + } + break; + case schema::Format_CKHW: + return RET_OK; + default: + MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to " + << schema::EnumNameFormat(dstFormat); + return RET_ERROR; + } + } break; + default: + MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to " + << schema::EnumNameFormat(dstFormat); + return RET_ERROR; + } + if (status != RET_OK) { + MS_LOG(ERROR) << "TransFilterData failed: " << status; + return status; + } + return RET_OK; +} } // namespace lite } // namespace mindspore - - diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index f3db9cfbb5..26fb9be223 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -53,7 +53,7 @@ class NodeUtils { // todo check this enum kTransFilterType { - kKCHW2HWCK, + kKCHW2HWCK, // 0 kKCHW2KHWC, kCKHW2KHWC, kCKHW2HWCK, @@ -63,19 +63,23 @@ enum kTransFilterType { kHWCK2CKHW, kHWKC2KCHW, kHWKC2CKHW, - kNHWC2KCHW, + kNHWC2KCHW, // 10 kNHWC2CKHW, kNHWC2HWCK, kKHWC2HWCK, kCHWK2HWCK, kKHWC2CHWK, - kCHWK2KHWC + kCHWK2KHWC, + kKHWC2KCHW, + kCKHW2KCHW, + kCHWK2KCHW, + kKCHW2CKHW // 20 }; static STATUS GetFilterDim(std::vector &oriDims, kTransFilterType type, int32_t &filterK, int32_t &filterC, int32_t &filterH, int32_t &filterW) { MS_ASSERT(oriDims.size() == 4); - if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC) { + if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) { filterK = oriDims.at(KCHW_K); filterC = oriDims.at(KCHW_C); filterH = oriDims.at(KCHW_H); @@ -126,7 +130,7 @@ static STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32 tensor->dims = {filterH, filterW, filterK, filterC}; } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) { tensor->dims = {filterK, filterC, filterH, filterW}; - } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW) { + } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) { tensor->dims = {filterC, filterK, filterH, filterW}; } else if (type == kKHWC2CHWK) { tensor->dims = {filterC, filterH, filterW, filterK}; @@ -198,6 +202,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in } } break; case kKCHW2HWCK: + case kKCHW2CKHW: case kKCHW2KHWC: case kKCHW2HWKC: { for (int k = 0; k < filterK; ++k) { @@ -211,6 +216,9 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in } else if (type == kKCHW2KHWC) { p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } else if (type == kKCHW2CKHW) { + p2Buff = + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); @@ -367,7 +375,8 @@ static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) return RET_OK; } + +STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat); } // namespace lite } // namespace mindspore #endif // MINDSPORE_PREDICT_NODE_UTIL_H - diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 706721e1a1..c07eb40d5c 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -103,7 +103,6 @@ target_link_libraries(converter_lite PRIVATE onnx_parser_mid anf_importer_mid anf_exporter_mid - node_mid graph_pass_mid fusion_mid quantizer_mid diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 965cf7473c..741a068627 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -23,34 +23,12 @@ #include "src/common/op_utils.h" #include "tools/converter/converter_flags.h" #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" -#include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h" -#include "tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h" -#include "tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h" -#include "tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h" -#include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h" // #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" -// #include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h" -// -// #include "tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.h" -// #include "tools/converter/legacy_optimizer/const_fold/cast_const_fold_pass.h" -// #include "tools/converter/legacy_optimizer/const_fold/concat_v2_const_fold_pass.h" -// #include "tools/converter/legacy_optimizer/const_fold/expand_dims_const_fold_pass.h" -// #include "tools/converter/legacy_optimizer/const_fold/mul_const_fold_pass.h" -// #include "tools/converter/legacy_optimizer/const_fold/range_const_fold_pass.h" -// #include "tools/converter/legacy_optimizer/const_fold/reshape_const_fold_pass.h" -// #include "tools/converter/legacy_optimizer/const_fold/rsqrt_const_fold_pass.h" -// #include "tools/converter/legacy_optimizer/const_fold/shape_const_fold_pass.h" -// #include "tools/converter/legacy_optimizer/const_fold/slice_const_fold_pass.h" -// #include "tools/converter/legacy_optimizer/const_fold/stack_const_fold_pass.h" -// #include "tools/converter/legacy_optimizer/const_fold/strided_slice_const_fold_pass.h" -// #include "tools/converter/legacy_optimizer/const_fold/sub_const_fold_pass.h" -// #include "tools/converter/legacy_optimizer/const_fold/tile_const_fold_pass.h" -// #include "tools/converter/legacy_optimizer/const_fold/transpose_const_fold_pass.h" -// -#include "tools/converter/legacy_optimizer/node/weight_format_pass.h" +#include "tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h" +#include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h" #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" #include "tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h" #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" @@ -76,19 +54,6 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { std::make_unique(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean); break; } - // case QuantType::QuantType_WeightQuant: { - // MS_LOGI("create WeightQuantizer!"); - // mQuantizer.reset(new WeightQuantizer(graphDefT, flags->quantSize)); - // break; - // } - // case QuantType_PostTraining: { - // MS_LOGI("create PostTrainningQuantizer!"); - // mQuantizer.reset(new PostTrainingQuantizer(graphDefT, flags->configFile)); - // break; - // } - // case QuantType::QuantType_QUANT_NONE: - // MS_LOGD("Not do quantization for model!"); - // break; default: // MS_LOGI("will support quantizer type %s in the future!", flags->quantTypeIn.c_str()); break; @@ -97,16 +62,16 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { int GraphDefTransform::Transform(const converter::Flags &ctx) { STATUS status; - // weight format trans - if (ctx.formatTrans) { + { Optimizer weightFormatOptimizer; - auto weightFormatPass = new (std::nothrow) WeightFormatPass(); - if (weightFormatPass == nullptr) { - MS_LOG(ERROR) << "new weightFormatPass failed"; - return RET_ERROR; - } + auto weightHardCodePass = new WeightFormatHardCodePass(); + auto weightFormatPass = new WeightFormatTransformPass(); + weightHardCodePass->SetQuantType(ctx.quantType); + weightHardCodePass->SetFmkType(ctx.fmk); weightFormatPass->SetQuantType(ctx.quantType); weightFormatPass->SetFmkType(ctx.fmk); +// weightFormatPass->SetDstFormat(Format_KHWC); + weightFormatOptimizer.AddPass(weightHardCodePass); weightFormatOptimizer.AddPass(weightFormatPass); status = weightFormatOptimizer.Run(graphDefT); if (status != RET_OK && status != RET_NO_CHANGE) { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/CMakeLists.txt index 898d060738..ed38cf819a 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/CMakeLists.txt @@ -1,6 +1,4 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) add_subdirectory(fusion) -#add_subdirectory(const_fold) -add_subdirectory(node) add_subdirectory(graph) \ No newline at end of file diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt index 3706eccda5..c31eed7663 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt @@ -1,13 +1,6 @@ add_library(fusion_mid OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/fusion_pattern.cc ${CMAKE_CURRENT_SOURCE_DIR}/fusion_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/conv_scale_bias_fusion_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/conv_bn_fusion_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/conv_scale_fusion_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/conv_activation_fusion_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/conv_relu_fusion_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/conv_relu6_fusion_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/conv_biasadd_fusion_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/matmul_biasadd_fusion_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast_fusion_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_fold_fusion_pass.cc diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.cc deleted file mode 100644 index d132aa8981..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.cc +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h" -#include -#include -#include -#include "utils/log_adapter.h" -#include "include/errorcode.h" -#include "schema/inner/model_generated.h" -#include "tools/common/graph_util.h" -#include "src/common/op_utils.h" - -namespace mindspore { -namespace lite { -#define CONV_ACTIVATION_MATCH_PATH_LEN 2 - -STATUS ConvActivationFusionPass::DefinePattern() { - auto convOp = std::make_shared(); - convOp->id = kConvName; - convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; - auto actOp = std::make_shared(); - actOp->id = ACTIVATION_NAME; - actOp->types = {schema::PrimitiveType_Activation}; - actOp->left = convOp; - - std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("ConvActivationFusion")); - if (fusionPattern == nullptr) { - MS_LOG(ERROR) << "new fusionPattern failed"; - return RET_ERROR; - } - fusionPattern->AddPatternOp(convOp); - fusionPattern->AddPatternOp(actOp); - fusionPattern->Finish(); - - this->patterns.emplace_back(fusionPattern.release()); - - return RET_OK; -} - -// 1. change attr of conv -// 2. delete Activation node -STATUS ConvActivationFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) { - MS_ASSERT(graph != nullptr); - if (matchedPath.size() != CONV_ACTIVATION_MATCH_PATH_LEN) { - MS_LOG(ERROR) << "Conv-Activation-Fusion should have two NodeIndex in matchedPair"; - return RET_PARAM_INVALID; - } - - auto convPath = matchedPath[kConvName]; - auto actPath = matchedPath[ACTIVATION_NAME]; - auto &convNode = graph->nodes.at(convPath->nodeIdx); - auto &actNode = graph->nodes.at(actPath->nodeIdx); - - // todo if combine conv_relu_fusion and conv_relu6_fusion to conv_activation_fusion - if (actNode->primitive->value.AsActivation()->type != this->activationType) { - return RET_NO_CHANGE; - } - - if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { - convNode->primitive->value.AsConv2D()->activationType = this->activationType; - } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { - convNode->primitive->value.AsDepthwiseConv2D()->activationType = this->activationType; - } else { - MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; - return RET_ERROR; - } - - // remove activation node - MergeNodeAttrFromPost(convNode, actNode); - auto status = IsolateOneWayNode(graph, actPath->nodeIdx); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode failed, subGraph: " << actPath->subGraphIdx << ", node: " << actPath->nodeIdx - << ", error: " << status; - return status; - } - - return RET_OK; -} - -STATUS ConvActivationFusionPass::Run(schema::MetaGraphT *graph) { - SetActivationType(); - return FusionPass::Run(graph); -} - -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h deleted file mode 100644 index b760b503d5..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H -#define MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H - -#include -#include -#include -#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" - -namespace mindspore { -namespace lite { -class ConvActivationFusionPass : public FusionPass { - public: - ConvActivationFusionPass() = default; - - ~ConvActivationFusionPass() override = default; - - STATUS DefinePattern() override; - - virtual STATUS SetActivationType() = 0; - - // 1. change attr of conv - // 2. delete Activation node - STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) override; - - STATUS Run(schema::MetaGraphT *graph) override; - - protected: - schema::ActivationType activationType = schema::ActivationType_RELU; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.cc deleted file mode 100644 index 093af0ed54..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.cc +++ /dev/null @@ -1,295 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h" -#include -#include -#include -#include -#include -#include -#include "utils/log_adapter.h" -#include "securec/include/securec.h" -// #include "utils/log_adapter.h" -#include "tools/common/graph_util.h" -#include "include/errorcode.h" -#include "schema/inner/model_generated.h" -#include "src/common/op_utils.h" - -namespace mindspore { -namespace lite { -#define CONV_BIASADD_MATCH_PATH_LEN 2 -#define BIASADD_OP_BIAS_INDEX_IN_WEIGHT 0 -#define BIASADD_OP_INPUT_NUM 2 -#define BIASADD_OP_CONST_TENSOR_INDEX 1 - -STATUS ConvBiasAddFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } - -STATUS ConvBiasAddFusionPass::DefinePattern() { - auto convOp = std::make_shared(); - convOp->id = kConvName; - convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeConv2D}; - auto baOp = std::make_shared(); - baOp->id = BIASADD_NAME; - baOp->types = {schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Add}; - baOp->left = convOp; - - std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("ConvBiasAddFusion")); - if (fusionPattern == nullptr) { - MS_LOG(ERROR) << "new fusionPattern failed"; - return RET_ERROR; - } - fusionPattern->AddPatternOp(convOp); - fusionPattern->AddPatternOp(baOp); - fusionPattern->Finish(); - - this->patterns.emplace_back(fusionPattern.release()); - - return RET_OK; -} - -STATUS ConvBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) { - MS_ASSERT(graph != nullptr); - if (matchedPath.size() != CONV_BIASADD_MATCH_PATH_LEN) { - MS_LOG(ERROR) << "Conv-BiasAdd-Fusion should have two NodeIndex in matchedPair"; - return RET_PARAM_INVALID; - } - - auto convPath = matchedPath[kConvName]; - auto baPath = matchedPath[BIASADD_NAME]; - auto &convNode = graph->nodes.at(convPath->nodeIdx); - auto &baNode = graph->nodes.at(baPath->nodeIdx); - // add/biasadd node the second tensor is not constant tensor, don't fusion - auto baNodeInputIndex = baNode->inputIndex; - if (baNodeInputIndex.size() != BIASADD_OP_INPUT_NUM) { - MS_LOG(ERROR) << baNode->name.c_str() << " node tensors number is invalid! "; - return RET_ERROR; - } - auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get(); - MS_ASSERT(baNodeBiasTensor != nullptr); - if (baNodeBiasTensor->nodeType != schema::NodeType_ValueNode) { - // dont fusion, return - return RET_OK; - } - - // 1. generate newBiasTensor for conv - auto status = GenConvBiasTensor(convPath, baPath, graph); - if (RET_OK != status) { - MS_LOG(ERROR) << "GenConvBiasTensor failed, " << status; - return status; - } - if (this->newBiasTensor != nullptr) { - status = AddTensor2Node(graph, convPath->nodeIdx, std::move(this->newBiasTensor)); - this->newBiasTensor = nullptr; - if (status != RET_OK) { - MS_LOG(ERROR) << "AddTensor2Node failed, node: " << convPath->nodeIdx << ", error: " << status; - return status; - } - // add bias quantParam - // todo add quantParam for tensors - - // if (convNode->quantParam.size() == convNode->inputIndex.size() + convNode->outputIndex.size() - 1) { - // std::unique_ptr quantParamArray(new QuantParamArrayT()); - // if (quantParamArray == nullptr) { - // MS_LOG(ERROR) << "new QuantParamArrayT failed"); - // return RET_ERROR; - // } - // std::unique_ptr quantParam(new QuantParamT()); - // if (quantParam == nullptr) { - // MS_LOG(ERROR) << "new QuantParamT failed"); - // return RET_ERROR; - // } - // quantParam->numBits = -1; - // quantParam->scale = FLT_MAX; - // quantParam->zeroPoint = 0; - // quantParam->narrowRange = true; - // quantParam->min = FLT_MAX; - // quantParam->max = FLT_MAX; - // quantParamArray->param.emplace_back(quantParam.release()); - // convNode->quantParam.emplace_back(quantParamArray.release()); - // } - } - - // 2. change attr of conv - if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { - convNode->primitive->value.AsConv2D()->hasBias = true; - } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { - convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true; - } else if (convNode->primitive->value.type == schema::PrimitiveType_DeConv2D) { - convNode->primitive->value.AsDeConv2D()->hasBias = true; - } else { - MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; - return RET_ERROR; - } - - // 5. delete BiasAdd node - MergeNodeAttrFromPost(convNode, baNode); - status = IsolateOneWayNode(graph, baPath->nodeIdx); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode failed, graph: %zu, node: %zu, error: %d"; - //, baPath->subGraphIdx, baPath->nodeIdx, status); - return status; - } - - return RET_OK; -} - -#define BIASADD_WEIGHT_SHAPE_SIZE 1 -#define BIASADD_BIAS_DIM_INDEX 0 - -STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr convPath, std::shared_ptr baPath, - MetaGraphT *graph) { - MS_ASSERT(convPath != nullptr); - MS_ASSERT(baPath != nullptr); - MS_ASSERT(graph != nullptr); - - auto convNode = graph->nodes.at(convPath->nodeIdx).get(); - MS_ASSERT(convNode != nullptr); - auto baNode = graph->nodes.at(baPath->nodeIdx).get(); - MS_ASSERT(baNode != nullptr); - int32_t kernelNum = 0; - if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { - kernelNum = convNode->primitive->value.AsConv2D()->channelOut; - } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { - kernelNum = convNode->primitive->value.AsDepthwiseConv2D()->channelIn * - convNode->primitive->value.AsDepthwiseConv2D()->channelMultiplier; - } else if (convNode->primitive->value.type == schema::PrimitiveType_DeConv2D) { - kernelNum = convNode->primitive->value.AsDeConv2D()->channelOut; - } - auto convWeightTensorIdxes = convNode->inputIndex; - if (convWeightTensorIdxes.size() < CONV_OP_NO_BIAS_INPUT_NUM) { - MS_LOG(ERROR) << convNode->name.c_str() << " node tensors number is invalid! "; - return RET_ERROR; - } - convWeightTensorIdxes.erase(convWeightTensorIdxes.begin()); - auto baWeightTensorIdxes = baNode->inputIndex; - if (baWeightTensorIdxes.size() != BIASADD_OP_INPUT_NUM) { - MS_LOG(ERROR) << baNode->name.c_str() << " node tensors number is invalid! "; - return RET_ERROR; - } - baWeightTensorIdxes.erase(baWeightTensorIdxes.begin()); - - if (convWeightTensorIdxes.empty()) { - MS_LOG(ERROR) << "Conv2D should has one weight tensors at least, current number of weight tensors " - << convWeightTensorIdxes.size(); - return RET_ERROR; - } - - if (baWeightTensorIdxes.empty()) { - MS_LOG(ERROR) << "BiasAdd should has one weight tensors at least, current number of weight tensors " - << baWeightTensorIdxes.size(); - return RET_ERROR; - } - - TensorT *oldBiasTensor = nullptr; - TensorT *biasTensor = nullptr; - - if (convWeightTensorIdxes.size() == CONV_OP_HAS_BIAS_WEIGHT_NUM) { - oldBiasTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_BIAS_INDEX_IN_WEIGHT]).get(); - MS_ASSERT(oldBiasTensor != nullptr); - } - biasTensor = graph->allTensors.at(baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX_IN_WEIGHT)).get(); - MS_ASSERT(biasTensor != nullptr); - auto biasDims = biasTensor->dims; - // if biasTensor is a scaler - if (biasDims.empty() && biasTensor->data.data() == nullptr) { - MS_LOG(ERROR) << "BiasAdd node %s bias tensor is invalid" << baNode->name.c_str(); - return RET_ERROR; - } - if (!biasDims.empty() && biasDims.size() != BIASADD_WEIGHT_SHAPE_SIZE) { - MS_LOG(ERROR) << "BiasAdd bias tensor should has one dimension, current number of dimension " << biasDims.size() - << ". or bias tensor is a scaler"; - return RET_ERROR; - } - - bool bias_const = !biasDims.empty() && biasDims.size() == 1 && biasDims[0] == 1; - if (!biasDims.empty() && !bias_const && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) { - MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)" - << biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum; - return RET_ERROR; - } - - // cal new biasData - this->newBiasData = new (std::nothrow) float[kernelNum]; - if (newBiasData == nullptr) { - MS_LOG(ERROR) << "new newBiasData failed"; - return RET_ERROR; - } - - if (biasDims.empty() && biasTensor->data.data() != nullptr) { - auto *biasData = reinterpret_cast(biasTensor->data.data()); - if (0 != memset_s(newBiasData, kernelNum * sizeof(float), *biasData, kernelNum * sizeof(float))) { - MS_LOG(ERROR) << "memset_s newBiasData failed"; - return RET_ERROR; - } - } else if (bias_const) { - auto *biasData = reinterpret_cast(biasTensor->data.data()); - for (size_t i = 0; i < kernelNum; i++) { - newBiasData[i] = *biasData; - } - } else { - if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) { - MS_LOG(ERROR) << "memcpy_s newBiasData failed"; - return RET_ERROR; - } - } - if (oldBiasTensor != nullptr) { - auto oldBiasDims = oldBiasTensor->dims; - if (oldBiasDims.size() != 1) { - MS_LOG(ERROR) - << "Conv bias tensor should has one dimension, current number of dimension %zu"; // oldBiasDims.size()); - return RET_ERROR; - } - if (oldBiasDims.at(0) != kernelNum) { - MS_LOG(ERROR) - << "Size(%zu) of Conv bias tensor should be equal to kernelNum(%d), current number of dimension %zu"; - // oldBiasDims.size(), kernelNum); - return RET_ERROR; - } - auto *oldBiasData = reinterpret_cast(oldBiasTensor->data.data()); - for (size_t i = 0; i < kernelNum; i++) { - oldBiasData[i] += newBiasData[i]; - } - } else { - auto *newCharBiasData = reinterpret_cast(newBiasData); - std::vector tmpBiasVec(newCharBiasData, newCharBiasData + kernelNum * sizeof(float) / sizeof(uint8_t)); - - auto weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get(); - this->newBiasTensor = std::unique_ptr(new (std::nothrow) TensorT); - // todo biasShape - this->newBiasTensor->dims = {kernelNum}; - this->newBiasTensor->dataType = weightTensor->dataType; - this->newBiasTensor->format = weightTensor->format; - this->newBiasTensor->refCount = weightTensor->refCount; - this->newBiasTensor->data.swap(tmpBiasVec); - newCharBiasData = nullptr; - } - - delete (this->newBiasData); - newBiasData = nullptr; - - return RET_OK; -} - -ConvBiasAddFusionPass::~ConvBiasAddFusionPass() { - if (this->newBiasData != nullptr) { - delete (this->newBiasData); - } -} - -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h deleted file mode 100644 index 1f104af2c3..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H -#define MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H - -#include -#include -#include -#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" - -namespace mindspore { -namespace lite { -class ConvBiasAddFusionPass : public FusionPass { - public: - ConvBiasAddFusionPass() = default; - - ~ConvBiasAddFusionPass() override; - - STATUS DefinePattern() override; - - STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) override; - - STATUS Run(schema::MetaGraphT *graph) override; - - protected: - // gen this->newBiasTensor if conv has no bias before - STATUS GenConvBiasTensor(std::shared_ptr convPath, std::shared_ptr dstPath, schema::MetaGraphT *graph); - - protected: - float *newBiasData = nullptr; - std::unique_ptr newBiasTensor = nullptr; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.cc deleted file mode 100644 index ae63da7a79..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.cc +++ /dev/null @@ -1,224 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h" -#include "securec/include/securec.h" -#include "include/errorcode.h" -#include "schema/inner/model_generated.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace lite { -#define CAFFE_BATCHNORM_OP_WEIGHT_NUM 2 -#define TF_BATCHNORM_OP_WEIGHT_NUM 4 -#define CAFFE_BATCHNORM_MEAN_INDEX 0 -#define CAFFE_BATCHNORM_VARIANCE_INDEX 1 -#define TF_BATCHNORM_SCALE_INDEX 0 -#define TF_BATCHNORM_BIAS_INDEX 1 -#define TF_BATCHNORM_MEAN_INDEX 2 -#define TF_BATCHNORM_VARIANCE_INDEX 3 - -constexpr const float EPS = 1e-8; -constexpr const float EPS_DEFAULT_FLOAT = 1e-5; -constexpr const float POW_NUM = 0.5; - -STATUS ConvBNFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) { - return ConvScaleBiasFusionPass::DoFusion(graph, patternName, matchedPath); -} - -STATUS ConvBNFusionPass::DefinePattern() { - auto convOp = std::make_shared(); - convOp->id = kConvName; - convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; - auto bnOp = std::make_shared(); - bnOp->id = DST_NAME; - bnOp->types = {schema::PrimitiveType_FusedBatchNorm, schema::PrimitiveType_BatchNorm}; - bnOp->left = convOp; - - std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("ConvBatchNormFusion")); - if (fusionPattern == nullptr) { - MS_LOG(ERROR) << "new fusionPattern failed"; - return RET_ERROR; - } - fusionPattern->AddPatternOp(convOp); - fusionPattern->AddPatternOp(bnOp); - fusionPattern->Finish(); - - this->patterns.emplace_back(fusionPattern.release()); - - return RET_OK; -} - -STATUS ConvBNFusionPass::Run(schema::MetaGraphT *graph) { return ConvScaleBiasFusionPass::Run(graph); } - -STATUS ConvBNFusionPass::GetTransParam(schema::MetaGraphT *graph, std::shared_ptr bnPath, int32_t kernelNum) { - MS_ASSERT(graph != nullptr); - MS_ASSERT(bnPath != nullptr); - - BNWeightTensors bnWeightTensors; - - auto status = GetBnWeightTensors(graph, bnPath, kernelNum, bnWeightTensors); - if (status != RET_OK) { - MS_LOG(ERROR) << "GetBnWeightTensors error " << status; - return status; - } - schema::TensorT *meanTensor = bnWeightTensors.meanTensor; - schema::TensorT *varianceTensor = bnWeightTensors.varianceTensor; - schema::TensorT *scaleTensor = bnWeightTensors.scaleTensor; - schema::TensorT *biasTensor = bnWeightTensors.biasTensor; - - auto *meanData = reinterpret_cast(meanTensor->data.data()); - auto *varianceData = reinterpret_cast(varianceTensor->data.data()); - - float eps = EPS_DEFAULT_FLOAT; - status = GetBnEpsilon(graph, bnPath, eps); - if (status != RET_OK) { - MS_LOG(ERROR) << "GetBnEpsilon failed " << status; - return status; - } - - // cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps) - if (memcpy_s(transScale, kernelNum * sizeof(float), varianceData, kernelNum * sizeof(float)) != 0) { - MS_LOG(ERROR) << "memcpy_s transScale error"; - return RET_ERROR; - } - // 1/sqrt(variance + eps) - for (int32_t i = 0; i < kernelNum; i++) { - float tmp = transScale[i] + eps; - tmp = pow(tmp, POW_NUM); - transScale[i] = 1 / tmp; - } - - if (scaleTensor != nullptr) { - auto *scaleData = reinterpret_cast(scaleTensor->data.data()); - // scale/sqrt(variance + eps) - for (int32_t i = 0; i < kernelNum; i++) { - transScale[i] *= scaleData[i]; - } - } - - // cal transBias, tf : -scale*mean/sqrt(variance + eps) + bias; caffe : -mean/sqrt(variance + eps) - // -mean/sqrt(variance + eps) - for (int32_t i = 0; i < kernelNum; i++) { - transBias[i] = -meanData[i] * transScale[i]; - } - - if (biasTensor != nullptr) { - auto *biasData = reinterpret_cast(biasTensor->data.data()); - // -scale*mean/sqrt(variance + eps) + bias - for (int32_t i = 0; i < kernelNum; i++) { - transBias[i] += biasData[i]; - } - } - - return RET_OK; -} - -// BatchNorm weight Tensor definition: -// caffe -// estimated_mean --0 -// estimated_variance --1 -// tensorflow -// scale -- 0 -// bias --1 -// estimated_mean --2 -// estimated_variance --3 -STATUS ConvBNFusionPass::GetBnWeightTensors(schema::MetaGraphT *graph, std::shared_ptr bnPath, int32_t kernelNum, - BNWeightTensors &bnWeightTensors) { - MS_ASSERT(graph != nullptr); - MS_ASSERT(bnPath != nullptr); - auto bnNode = graph->nodes.at(bnPath->nodeIdx).get(); - auto bnWeightTensorIdxes = bnNode->inputIndex; - bnWeightTensorIdxes.erase(bnWeightTensorIdxes.begin()); - if (bnWeightTensorIdxes.size() == CAFFE_BATCHNORM_OP_WEIGHT_NUM) { - bnWeightTensors.meanTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_MEAN_INDEX]).get(); - bnWeightTensors.varianceTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_VARIANCE_INDEX]).get(); - } else if (bnWeightTensorIdxes.size() == TF_BATCHNORM_OP_WEIGHT_NUM) { - bnWeightTensors.scaleTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_SCALE_INDEX]).get(); - bnWeightTensors.biasTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_BIAS_INDEX]).get(); - bnWeightTensors.meanTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_MEAN_INDEX]).get(); - bnWeightTensors.varianceTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_VARIANCE_INDEX]).get(); - } else { - MS_LOG(ERROR) << "BatchNorm should has " << CAFFE_BATCHNORM_OP_WEIGHT_NUM << " or " << TF_BATCHNORM_OP_WEIGHT_NUM - << " weight tensors, current number of weight tensors " << bnWeightTensorIdxes.size(); - return RET_ERROR; - } - - if (bnWeightTensors.meanTensor == nullptr) { - MS_LOG(ERROR) << "BatchNorm's mean tensor is nullptr"; - return RET_ERROR; - } - - if (bnWeightTensors.varianceTensor == nullptr) { - MS_LOG(ERROR) << "BatchNorm's variance tensor is nullptr"; - return RET_ERROR; - } - - if (kernelNum != bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { - MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" - << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; - return RET_ERROR; - } - - if (kernelNum != bnWeightTensors.varianceTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { - MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" - << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; - return RET_ERROR; - } - - if (bnWeightTensors.scaleTensor != nullptr) { - if (kernelNum != bnWeightTensors.scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { - MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" - << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; - return RET_ERROR; - } - } - - if (bnWeightTensors.biasTensor != nullptr) { - if (kernelNum != bnWeightTensors.biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { - MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" - << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; - return RET_ERROR; - } - } - return RET_OK; -} - -STATUS ConvBNFusionPass::GetBnEpsilon(schema::MetaGraphT *graph, std::shared_ptr bnPath, float &eps) { - MS_ASSERT(graph != nullptr); - auto bnNode = graph->nodes.at(bnPath->nodeIdx).get(); - MS_ASSERT(bnNode != nullptr); - if (bnNode->primitive->value.type == schema::PrimitiveType_FusedBatchNorm) { - eps = bnNode->primitive->value.AsFusedBatchNorm()->epsilon; - } else if (bnNode->primitive->value.type == schema::PrimitiveType_BatchNorm) { - eps = bnNode->primitive->value.AsBatchNorm()->epsilon; - } else { - MS_LOG(ERROR) << "match pattern has error, " << bnNode->name.c_str() << " not BatchNorm node"; - return RET_ERROR; - } - - if (eps < EPS) { - eps = EPS_DEFAULT_FLOAT; - } - return RET_OK; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h deleted file mode 100644 index b7eb7c1d26..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#ifndef MINDSPORE_CONV_BN_FUSION_PASS_H -#define MINDSPORE_CONV_BN_FUSION_PASS_H - -#include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h" -#include "tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h" - -namespace mindspore { -namespace lite { -class ConvBNFusionPass : public ConvScaleBiasFusionPass { - public: - ConvBNFusionPass() = default; - - ~ConvBNFusionPass() override = default; - - STATUS DefinePattern() override; - - STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) override; - - STATUS Run(schema::MetaGraphT *graph) override; - - protected: - STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr bnPath, int32_t kernelNum) override; - - // Get and check BNNode weight tensor - STATUS GetBnWeightTensors(schema::MetaGraphT *graph, std::shared_ptr bnPath, int32_t kernelNum, - BNWeightTensors &bnWeightTensors); - - STATUS GetBnEpsilon(schema::MetaGraphT *graph, std::shared_ptr bnPath, float &eps); -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_CONV_BN_FUSION_PASS_H - diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.cc deleted file mode 100644 index 8c22b52772..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.cc +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h" -#include "include/errorcode.h" -#include "schema/inner/model_generated.h" - -namespace mindspore { -namespace lite { -STATUS ConvRelu6FusionPass::DefinePattern() { return ConvActivationFusionPass::DefinePattern(); } - -STATUS ConvRelu6FusionPass::SetActivationType() { - this->activationType = ActivationType_RELU6; - return RET_OK; -} - -STATUS ConvRelu6FusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) { - return ConvActivationFusionPass::DoFusion(graph, patternName, matchedPath); -} - -STATUS ConvRelu6FusionPass::Run(MetaGraphT *graph) { return ConvActivationFusionPass::Run(graph); } - -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h deleted file mode 100644 index a500633351..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H -#define MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H - -#include "tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h" -#include -#include -#include - -namespace mindspore { -namespace lite { -class ConvRelu6FusionPass : public ConvActivationFusionPass { - public: - ConvRelu6FusionPass() = default; - - ~ConvRelu6FusionPass() override = default; - - STATUS DefinePattern() override; - - STATUS SetActivationType() override; - - STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) override; - - STATUS Run(schema::MetaGraphT *graph) override; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H - diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.cc deleted file mode 100644 index 05a7880bd4..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h" -#include "include/errorcode.h" -#include "schema/inner/model_generated.h" - -namespace mindspore { -namespace lite { -STATUS ConvReluFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) { - return ConvActivationFusionPass::DoFusion(graph, patternName, matchedPath); -} - -STATUS ConvReluFusionPass::Run(schema::MetaGraphT *graph) { return ConvActivationFusionPass::Run(graph); } - -STATUS ConvReluFusionPass::SetActivationType() { - this->activationType = schema::ActivationType_RELU; - return RET_OK; -} - -STATUS ConvReluFusionPass::DefinePattern() { return ConvActivationFusionPass::DefinePattern(); } -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h deleted file mode 100644 index e7c87cd197..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H -#define MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H - -#include "tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h" -#include -#include -#include - -namespace mindspore { -namespace lite { -class ConvReluFusionPass : public ConvActivationFusionPass { - public: - ConvReluFusionPass() = default; - - ~ConvReluFusionPass() override = default; - - STATUS DefinePattern() override; - - STATUS SetActivationType() override; - - STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) override; - - STATUS Run(schema::MetaGraphT *graph) override; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.cc deleted file mode 100644 index 618fd9c5ac..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.cc +++ /dev/null @@ -1,361 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2018-2019. All rights reserved. - * Description: mslite - * Author: mslite - * Create: 2019-12-13 - */ - -#include -#include -#include -#include -#include -#include -#include "tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h" -#include "securec/include/securec.h" -#include "utils/log_adapter.h" -#include "include/errorcode.h" -#include "schema/inner/model_generated.h" -#include "src/common/op_utils.h" -#include "tools/common/graph_util.h" -#include "tools/common/tensor_util.h" - -namespace mindspore { -namespace lite { - -#define CONV_SCALE_BIAS_MATCH_PATH_LEN 2 - -// 1. generate biasTensor according to BN weightTensor -// 2. change attr of conv -// 3. delete BN node -STATUS ConvScaleBiasFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) { - MS_ASSERT(graph != nullptr); - if (matchedPath.size() != CONV_SCALE_BIAS_MATCH_PATH_LEN) { - MS_LOG(ERROR) << "Conv-Scale-Bias-Fusion should have two NodeIndex in matchedPair"; - return RET_PARAM_INVALID; - } - - auto convPath = matchedPath[kConvName]; - MS_ASSERT(convPath != nullptr); - auto dstPath = matchedPath[DST_NAME]; - MS_ASSERT(dstPath != nullptr); - MS_ASSERT(subGraph != nullptr); - auto &convNode = graph->nodes.at(convPath->nodeIdx); - MS_ASSERT(convNode != nullptr); - auto &dstNode = graph->nodes.at(dstPath->nodeIdx); - MS_ASSERT(dstNode != nullptr); - - // 1. generate new weightTensor and biasTensor for conv - auto status = GenConvWeightTensors(graph, convPath, dstPath); - if (RET_OK != status) { - MS_LOG(ERROR) << "GenConvWeightTensors failed, " << status; - return status; - } - if (convNode->inputIndex.size() == CONV_OP_HAS_BIAS_INPUT_NUM) { - status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT), - std::move(this->newWeightTensor)); - this->newWeightTensor = nullptr; - if (status != RET_OK) { - MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx - << ", node: " << convPath->nodeIdx << ", tensor " - << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; - return status; - } - status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_BIAS_INDEX_IN_INPUT), - std::move(this->newBiasTensor)); - this->newBiasTensor = nullptr; - if (status != RET_OK) { - MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx - << ", node: " << convPath->nodeIdx << ", tensor " - << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; - return status; - } - } else if (convNode->inputIndex.size() == CONV_OP_NO_BIAS_INPUT_NUM) { - status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT), - std::move(this->newWeightTensor)); - this->newWeightTensor = nullptr; - if (status != RET_OK) { - MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx - << ", node: " << convPath->nodeIdx << ", tensor " - << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; - return status; - } - status = AddTensor2Node(graph, convPath->nodeIdx, std::move(this->newBiasTensor)); - this->newBiasTensor = nullptr; - if (status != RET_OK) { - MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx - << ", node: " << convPath->nodeIdx << ", tensor " - << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; - return status; - } - // if (convNode->name == "Conv_461") { - // } - // add bias quantParam - // todo use tensor quant param - // if (convNode->quantParam.size() == convNode->inputIndex.size() + convNode->outputIndex.size() - 1) { - // std::unique_ptr quantParamArray(new QuantParamArrayT()); - // if (quantParamArray == nullptr) { - // MS_LOG(ERROR) << "new QuantParamArrayT failed"; - // return RET_ERROR; - // } - // std::unique_ptr quantParam(new QuantParamT()); - // if (quantParam == nullptr) { - // MS_LOG(ERROR) << "new QuantParamT failed"; - // return RET_ERROR; - // } - // quantParam->numBits = -1; - // quantParam->scale = FLT_MAX; - // quantParam->zeroPoint = 0; - // quantParam->narrowRange = true; - // quantParam->min = FLT_MAX; - // quantParam->max = FLT_MAX; - // quantParamArray->param.emplace_back(quantParam.release()); - // convNode->quantParam.emplace_back(quantParamArray.release()); - // } - } else { - MS_LOG(ERROR) << "Conv node should has 2 or 3 weight tensors rather than " << convNode->inputIndex.size(); - return RET_ERROR; - } - - // 2. change attr of conv - if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { - convNode->primitive->value.AsConv2D()->hasBias = true; - } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { - convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true; - } else { - MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; - return RET_ERROR; - } - - // 3. delete DST node - MergeNodeAttrFromPost(convNode, dstNode); - status = IsolateOneWayNode(graph, dstPath->nodeIdx); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstPath->nodeIdx << ", error: " << status; - return status; - } - - return RET_OK; -} - -STATUS ConvScaleBiasFusionPass::GenConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr &convPath, - std::shared_ptr dstPath) { - MS_ASSERT(graph != nullptr); - MS_ASSERT(convPath != nullptr); - MS_ASSERT(dstPath != nullptr); - MS_ASSERT(subGraph != nullptr); - auto &convNode = graph->nodes.at(convPath->nodeIdx); - MS_ASSERT(convNode != nullptr); - int32_t kernelNum = -1; - if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { - kernelNum = convNode->primitive->value.AsConv2D()->channelOut; - } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { - kernelNum = convNode->primitive->value.AsDepthwiseConv2D()->channelMultiplier * - convNode->primitive->value.AsDepthwiseConv2D()->channelIn; - } else { - MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; - return RET_ERROR; - } - if (kernelNum <= 0) { - MS_LOG(ERROR) << "KernelNum should be positive, " << kernelNum; - return RET_ERROR; - } - - this->transScale = new (std::nothrow) float[kernelNum]; - this->transBias = new (std::nothrow) float[kernelNum]; - - if (transScale == nullptr) { - MS_LOG(ERROR) << "new transScale failed"; - return RET_ERROR; - } - - if (transBias == nullptr) { - MS_LOG(ERROR) << "new transBias failed"; - return RET_ERROR; - } - - if (0 != memset_s(transScale, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) { - MS_LOG(ERROR) << "memset transScale failed"; - return RET_ERROR; - } - - if (0 != memset_s(transBias, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) { - MS_LOG(ERROR) << "memset transBias failed"; - return RET_ERROR; - } - - auto status = GetTransParam(graph, dstPath, kernelNum); - if (RET_OK != status) { - MS_LOG(ERROR) << "GetTransParam failed, " << status; - return status; - } - - status = CalConvWeightTensors(graph, convPath, kernelNum); - if (RET_OK != status) { - MS_LOG(ERROR) << "GenConvWeightTensors failed, " << status; - return status; - } - return RET_OK; -} - -STATUS ConvScaleBiasFusionPass::CalNewWeightTensor(TensorT *oldWeightTensor, const int32_t kernelNum, - const size_t kernelSize) { - MS_ASSERT(oldWeightTensor != nullptr); - auto weightData = reinterpret_cast(oldWeightTensor->data.data()); - size_t kernelDataCount = kernelNum * kernelSize; - if (kernelDataCount == 0) { - MS_LOG(ERROR) << "KernelDataCount should be positive, " << kernelDataCount; - return RET_ERROR; - } - this->newWeightData = new (std::nothrow) float[kernelDataCount]; - if (newWeightData == nullptr) { - MS_LOG(ERROR) << "new newWeightData failed"; - return RET_ERROR; - } - - if (0 != memset_s(newWeightData, kernelDataCount * sizeof(float), 0, kernelDataCount * sizeof(float))) { - MS_LOG(ERROR) << "memset newWeightData failed"; - return RET_ERROR; - } - - for (size_t i = 0; i < kernelNum; i++) { - for (size_t j = 0; j < kernelSize; j++) { - newWeightData[i * kernelSize + j] = weightData[i * kernelSize + j] * transScale[i]; - } - } - auto newCharWeightData = reinterpret_cast(newWeightData); - std::vector tmpWeightVec(newCharWeightData, - newCharWeightData + kernelDataCount * sizeof(float) / sizeof(uint8_t)); - - this->newWeightTensor = std::unique_ptr(new (std::nothrow) TensorT); - if (this->newWeightTensor == nullptr) { - MS_LOG(ERROR) << "new newWeightTensor failed"; - return RET_ERROR; - } - this->newWeightTensor->dims.insert(this->newWeightTensor->dims.begin(), oldWeightTensor->dims.begin(), - oldWeightTensor->dims.end()); - this->newWeightTensor->dataType = oldWeightTensor->dataType; - this->newWeightTensor->format = oldWeightTensor->format; - this->newWeightTensor->refCount = oldWeightTensor->refCount; - this->newWeightTensor->data.swap(tmpWeightVec); - delete (this->newWeightData); - newWeightData = nullptr; - - return RET_OK; -} - -STATUS ConvScaleBiasFusionPass::CalNewBiasTensor(TensorT *oldWeightTensor, TensorT *oldBiasTensor, - const int32_t kernelNum) { - MS_ASSERT(oldWeightTensor != nullptr); - this->newBiasData = new (std::nothrow) float[kernelNum]; - if (newBiasData == nullptr) { - MS_LOG(ERROR) << "new newBiasData failed"; - return RET_ERROR; - } - if (0 != memset_s(newBiasData, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) { - MS_LOG(ERROR) << "memset newBiasData failed"; - return RET_ERROR; - } - - if (oldBiasTensor != nullptr) { - auto *biasData = reinterpret_cast(oldBiasTensor->data.data()); - - for (size_t i = 0; i < kernelNum; i++) { - this->newBiasData[i] = biasData[i] * transScale[i] + transBias[i]; - } - } else { - if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), transBias, kernelNum * sizeof(float))) { - MS_LOG(ERROR) << "memcpy_s newBiasData failed"; - return RET_ERROR; - } - } - auto *newCharBiasData = reinterpret_cast(newBiasData); - std::vector tmpBiasVec(newCharBiasData, newCharBiasData + kernelNum * sizeof(float) / sizeof(uint8_t)); - - this->newBiasTensor = std::unique_ptr(new (std::nothrow) TensorT); - if (this->newBiasTensor == nullptr) { - MS_LOG(ERROR) << "new newBiasTensor failed"; - return RET_ERROR; - } - // todo biasShape - this->newBiasTensor->dims = {kernelNum}; - this->newBiasTensor->dataType = oldWeightTensor->dataType; - this->newBiasTensor->format = oldWeightTensor->format; - this->newBiasTensor->refCount = oldWeightTensor->refCount; - this->newBiasTensor->data.swap(tmpBiasVec); - delete (this->newBiasData); - newCharBiasData = nullptr; - newBiasData = nullptr; - return RET_OK; -} - -STATUS ConvScaleBiasFusionPass::CalConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr &convPath, - int32_t kernelNum) { - MS_ASSERT(graph != nullptr); - MS_ASSERT(convPath != nullptr); - - auto convNode = graph->nodes.at(convPath->nodeIdx).get(); - MS_ASSERT(convNode != nullptr); - auto convWeightTensorIdxes = convNode->inputIndex; - convWeightTensorIdxes.erase(convWeightTensorIdxes.begin()); - - TensorT *weightTensor = nullptr; - TensorT *biasTensor = nullptr; - if (convWeightTensorIdxes.size() == CONV_OP_NO_BIAS_WEIGHT_NUM) { - weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get(); - } else if (convWeightTensorIdxes.size() == CONV_OP_HAS_BIAS_WEIGHT_NUM) { - weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get(); - biasTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_BIAS_INDEX_IN_WEIGHT]).get(); - } else { - MS_LOG(ERROR) << "Conv2D should has " << CONV_OP_NO_BIAS_WEIGHT_NUM << " or " << CONV_OP_HAS_BIAS_WEIGHT_NUM - << " weight tensors, current number of weight tensors " << convWeightTensorIdxes.size(); - return RET_ERROR; - } - if (weightTensor == nullptr) { - MS_LOG(ERROR) << "Conv2D's weight tensor is nullptr"; - return RET_ERROR; - } - - auto weightShape = weightTensor->dims; - if (weightShape.size() != CONV_FILTER_SHAPE_SIZE) { - MS_LOG(ERROR) << "Size of dims of weight tensor should be " << CONV_FILTER_SHAPE_SIZE << " rather than " - << weightShape.size(); - return RET_ERROR; - } - size_t kernelSize = GetShapeSize(*weightTensor) / kernelNum; - - // cal new weightData - auto status = CalNewWeightTensor(weightTensor, kernelNum, kernelSize); - if (status != RET_OK) { - MS_LOG(ERROR) << "CalNewWeightTensor error " << status; - return status; - } - // cal new biasData - status = CalNewBiasTensor(weightTensor, biasTensor, kernelNum); - if (status != RET_OK) { - MS_LOG(ERROR) << "CalNewBiasTensor error " << status; - return status; - } - return RET_OK; -} - -STATUS ConvScaleBiasFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); } - -ConvScaleBiasFusionPass::~ConvScaleBiasFusionPass() { - if (this->transScale != nullptr) { - delete (this->transScale); - } - if (this->transBias != nullptr) { - delete (this->transBias); - } - if (this->newWeightData != nullptr) { - delete (this->newWeightData); - } - if (this->newBiasData != nullptr) { - delete (this->newBiasData); - } -} - -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h deleted file mode 100644 index 1a9ce07b06..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2018-2019. All rights reserved. - * Description: mslite - * Author: mslite - * Create: 2019-12-13 - */ - -#ifndef MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H -#define MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H - -#include -#include -#include -#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" - -namespace mindspore { -namespace lite { -struct BNWeightTensors { - schema::TensorT *meanTensor = nullptr; - schema::TensorT *varianceTensor = nullptr; - schema::TensorT *scaleTensor = nullptr; - schema::TensorT *biasTensor = nullptr; -}; - -class ConvScaleBiasFusionPass : public FusionPass { - public: - ConvScaleBiasFusionPass() = default; - - ~ConvScaleBiasFusionPass() override; - - STATUS DefinePattern() override = 0; - - // 1. generate biasTensor according to BN weightTensor - // 2. change attr of conv - // 3. delete BN node - STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) override; - - STATUS Run(schema::MetaGraphT *graph) override; - - protected: - // call GetTransParam() and CalConvWeightTensors() - STATUS GenConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr &convPath, - std::shared_ptr dstPath); - - // fill this->transScale and this->transBias - virtual STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr dstPath, int32_t kernelNum) = 0; - - // fill this->newWeightTensor and this->newBiasTensor according to this->transScale and this->transBias - STATUS CalConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr &convPath, int32_t kernelNum); - - STATUS CalNewWeightTensor(schema::TensorT *oldWeightTensor, int32_t kernelNum, size_t kernelSize); - - STATUS CalNewBiasTensor(schema::TensorT *oldWeightTensor, schema::TensorT *oldBiasTensor, int32_t kernelNum); - - protected: - float *transScale = nullptr; - float *transBias = nullptr; - float *newWeightData = nullptr; - float *newBiasData = nullptr; - std::unique_ptr newWeightTensor = nullptr; - std::unique_ptr newBiasTensor = nullptr; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.cc deleted file mode 100644 index 56d5b3d262..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.cc +++ /dev/null @@ -1,126 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h" -#include "securec/include/securec.h" -#include "utils/log_adapter.h" -#include "include/errorcode.h" -#include "schema/inner/model_generated.h" - -namespace mindspore { -namespace lite { -#define SCALE_OP_NO_BIAS_WEIGHT_NUM 1 -#define SCALE_OP_HAS_BIAS_WEIGHT_NUM 2 - -#define SCALE_OP_SCALE_INDEX_IN_WEIGHT 0 -#define SCALE_OP_BIAS_INDEX_IN_WEIGHT 1 - -STATUS ConvScaleFusionPass::DefinePattern() { - auto convOp = std::make_shared(); - convOp->id = kConvName; - convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; - auto scaleOp = std::make_shared(); - scaleOp->id = DST_NAME; - scaleOp->types = {schema::PrimitiveType_Scale}; - scaleOp->left = convOp; - - std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("ConvScaleFusion")); - if (fusionPattern == nullptr) { - MS_LOG(ERROR) << "new fusionPattern failed"; - return RET_ERROR; - } - fusionPattern->AddPatternOp(convOp); - fusionPattern->AddPatternOp(scaleOp); - fusionPattern->Finish(); - - this->patterns.emplace_back(fusionPattern.release()); - - return RET_OK; -} - -STATUS ConvScaleFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) { - return ConvScaleBiasFusionPass::DoFusion(graph, patternName, matchedPath); -} - -STATUS ConvScaleFusionPass::Run(schema::MetaGraphT *graph) { return ConvScaleBiasFusionPass::Run(graph); } - -STATUS ConvScaleFusionPass::GetTransParam(schema::MetaGraphT *graph, std::shared_ptr scalePath, - int32_t kernelNum) { - MS_ASSERT(graph != nullptr); - MS_ASSERT(scalePath != nullptr); - - auto scaleNode = graph->nodes.at(scalePath->nodeIdx).get(); - MS_ASSERT(scaleNode != nullptr); - auto scaleWeightTensorIdxes = scaleNode->inputIndex; - scaleWeightTensorIdxes.erase(scaleWeightTensorIdxes.begin()); - - schema::TensorT *scaleTensor = nullptr; - schema::TensorT *biasTensor = nullptr; - - if (scaleWeightTensorIdxes.size() == SCALE_OP_NO_BIAS_WEIGHT_NUM) { - scaleTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_SCALE_INDEX_IN_WEIGHT]).get(); - } else if (scaleWeightTensorIdxes.size() == SCALE_OP_HAS_BIAS_WEIGHT_NUM) { - scaleTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_SCALE_INDEX_IN_WEIGHT]).get(); - biasTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_BIAS_INDEX_IN_WEIGHT]).get(); - } else { - MS_LOG(ERROR) << "Scale should has %d or %d weight tensors, current number of weight tensors %zu"; - // SCALE_OP_NO_BIAS_WEIGHT_NUM, SCALE_OP_HAS_BIAS_WEIGHT_NUM, scaleWeightTensorIdxes.size()); - return RET_ERROR; - } - - if (scaleTensor == nullptr) { - MS_LOG(ERROR) << "Scale's scale tensor is nullptr"; - return RET_ERROR; - } - - if (kernelNum != scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { - MS_LOG(ERROR) << "conv kernel num %u is expected to be equal to scale size(%lu)"; - //, kernelNum, scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)); - return RET_ERROR; - } - - const float *scaleData = reinterpret_cast(scaleTensor->data.data()); - - if (0 != memcpy_s(transScale, kernelNum * sizeof(float), scaleData, kernelNum * sizeof(float))) { - MS_LOG(ERROR) << "memcpy_s transScale failed"; - return RET_ERROR; - } - - if (biasTensor != nullptr) { - if (kernelNum != biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { - MS_LOG(ERROR) << "conv kernel num %u is expected to be equal to bias size(%lu)"; - //, kernelNum, biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)); - return RET_ERROR; - } - - const float *biasData = reinterpret_cast(biasTensor->data.data()); - - if (0 != memcpy_s(transBias, kernelNum * sizeof(float), biasData, kernelNum * sizeof(float))) { - MS_LOG(ERROR) << "memcpy_s transBias failed"; - return RET_ERROR; - } - } - - return RET_OK; -} -} // namespace lite -} // namespace mindspore - - diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h deleted file mode 100644 index 8c2ed2808c..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H -#define MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H - -#include "tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h" -#include -#include -#include - -namespace mindspore { -namespace lite { -class ConvScaleFusionPass : public ConvScaleBiasFusionPass { - public: - ConvScaleFusionPass() = default; - - ~ConvScaleFusionPass() override = default; - - STATUS DefinePattern() override; - - STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) override; - - STATUS Run(schema::MetaGraphT *graph) override; - - private: - STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr scalePath, int32_t kernelNum) override; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt index c519b768ca..cbc98adc13 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt @@ -4,6 +4,8 @@ add_library(graph_pass_mid OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/weight_format_hardcode_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/weight_format_transform_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/unused_node_remove_pass.cc ) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.cc index f84d4859aa..30fcea3566 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.cc @@ -25,147 +25,134 @@ namespace mindspore { namespace lite { - -STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - auto &node = *iter; - if (node->primitive->value.type != PrimitiveType_Eltwise) { - continue; - } - auto node_name = node->name; - auto input_node_indexes = GetInputNodeIdx(*graph, *node); - auto pre_type = schema::PrimitiveType_NONE; - size_t has_trans_count = 0; - auto can_fusion = true; - for (auto input_node_index : input_node_indexes) { - MS_ASSERT(graph->nodes.size() > input_node_index); - auto &pre_node = graph->nodes.at(input_node_index); - MS_ASSERT(pre_node != nullptr); - if (pre_type == schema::PrimitiveType_NONE) { - if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || - pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { - pre_type = pre_node->primitive->value.type; +bool EltwiseFormatTransPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr &node) { + auto input_node_indexes = GetInputNodeIdx(*graph, *node); + pre_type_ = schema::PrimitiveType_NONE; + size_t has_trans_count = 0; + auto can_fusion = true; + for (auto input_node_index : input_node_indexes) { + MS_ASSERT(graph->nodes.size() > input_node_index); + auto &pre_node = graph->nodes.at(input_node_index); + MS_ASSERT(pre_node != nullptr); + if (pre_type_ == schema::PrimitiveType_NONE) { + if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || + pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { + pre_type_ = pre_node->primitive->value.type; + has_trans_count++; + } + } else { + if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || + pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { + if (pre_type_ != pre_node->primitive->value.type) { + can_fusion = false; + break; + } else { has_trans_count++; } - } else { - if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || - pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { - if (pre_type != pre_node->primitive->value.type) { - can_fusion = false; - break; - } else { - has_trans_count++; - } - } } } - if (!can_fusion) { - continue; - } - auto output_node_indexes = GetOutputNodeIdx(*graph, *node); - auto post_type = schema::PrimitiveType_NONE; - for (auto output_node_index : output_node_indexes) { - MS_ASSERT(graph->nodes.size() > output_node_index); - auto &post_node = graph->nodes.at(output_node_index); - MS_ASSERT(post_node != nullptr); - if (post_type == schema::PrimitiveType_NONE) { - if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || - post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { - post_type = post_node->primitive->value.type; + } + if (!can_fusion) { + return false; + } + auto output_node_indexes = GetOutputNodeIdx(*graph, *node); + post_type_ = schema::PrimitiveType_NONE; + for (auto output_node_index : output_node_indexes) { + MS_ASSERT(graph->nodes.size() > output_node_index); + auto &post_node = graph->nodes.at(output_node_index); + MS_ASSERT(post_node != nullptr); + if (post_type_ == schema::PrimitiveType_NONE) { + if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || + post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { + post_type_ = post_node->primitive->value.type; + has_trans_count++; + } + } else { + if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || + post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { + if (post_type_ != post_node->primitive->value.type) { + can_fusion = false; + break; + } else { has_trans_count++; } - } else { - if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || - post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { - if (post_type != post_node->primitive->value.type) { - can_fusion = false; - break; - } else { - has_trans_count++; - } - } } } - if (!can_fusion) { - continue; - } - auto total_node_count = input_node_indexes.size() + output_node_indexes.size(); - size_t half_count = total_node_count / 2; - if (total_node_count % 2 == 0) { - can_fusion = has_trans_count > half_count; - } else { - can_fusion = has_trans_count >= half_count; + } + if (!can_fusion) { + return false; + } + if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { + return false; + } + auto total_node_count = input_node_indexes.size() + output_node_indexes.size(); + size_t half_count = total_node_count / 2; + if (total_node_count % 2 == 0) { + can_fusion = has_trans_count > half_count; + } else { + can_fusion = has_trans_count >= half_count; + } + return can_fusion; +} + +STATUS EltwiseFormatTransPass::FindOutTransType() { + pre_insert_trans_type_ = kNHWC2NCHW; + post_insert_trans_type_ = kNHWC2NCHW; + if (pre_type_ == PrimitiveType_NONE && post_type_ != PrimitiveType_NONE) { + pre_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; + post_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; + } else if (pre_type_ != PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { + pre_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; + post_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; + } else if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { + MS_ASSERT(false); + } else { + if (pre_type_ == post_type_) { + MS_LOG(ERROR) << "Unknow error"; + return RET_ERROR; } - if (!can_fusion) { + pre_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; + post_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; + } + return RET_OK; +} + +STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + auto &node = *iter; + if (node->primitive->value.type != PrimitiveType_Eltwise) { continue; } - FormatTransNodeType pre_insert_trans_type = kNHWC2NCHW; - FormatTransNodeType post_insert_trans_type = kNHWC2NCHW; - if (pre_type == PrimitiveType_NONE && post_type != PrimitiveType_NONE) { - pre_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; - post_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; - } else if (pre_type != PrimitiveType_NONE && post_type == PrimitiveType_NONE) { - pre_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; - post_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; - } else if (pre_type == PrimitiveType_NONE && post_type == PrimitiveType_NONE) { + auto node_name = node->name; + if (!CanFusion(graph, node)) { continue; - } else { - if (pre_type == post_type) { - MS_LOG(ERROR) << "Unknow error"; - return RET_ERROR; - } - pre_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; - post_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; + } + auto ret = FindOutTransType(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "FindOutTransType error"; + return ret; } STATUS status = RET_OK; auto input_tensor_size = (*iter)->inputIndex.size(); for (auto i = 0; i < input_tensor_size; i++) { - iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type, &status); + iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status); if (status != RET_OK) { - MS_LOG(ERROR) << "Insert" << pre_insert_trans_type << "before " << (*iter)->name << " failed"; + MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed"; return status; } } auto output_tensor_size = (*iter)->outputIndex.size(); for (auto i = 0; i < output_tensor_size; i++) { - iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type, &status); + iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type_, &status); if (status != RET_OK) { - MS_LOG(ERROR) << "Insert" << post_insert_trans_type << "Node before " << (*iter)->name << " failed"; + MS_LOG(ERROR) << "Insert" << post_insert_trans_type_ << "Node before " << (*iter)->name << " failed"; return status; } } } return RET_OK; } - -NodeIter EltwiseFormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, - InsertPlace place, size_t inoutIdx, FormatTransNodeType nodeType, - STATUS *errorCode) { - MS_ASSERT((*existNodeIter) != nullptr); - auto existNodeName = (*existNodeIter)->name; - std::string tileName; - if (place == kBefore) { - tileName = existNodeName + "_pre"; - } else { - tileName = existNodeName + "_post"; - } - auto transNode = std::make_unique(); - transNode->primitive = std::make_unique(); - - if (nodeType == kNCHW2NHWC) { - transNode->name = "nchw2nhwc_" + tileName + std::to_string(id++); - transNode->primitive->value.type = schema::PrimitiveType_Nchw2Nhwc; - } else { - transNode->name = "nhwc2nchw_" + tileName + std::to_string(id++); - transNode->primitive->value.type = schema::PrimitiveType_Nhwc2Nchw; - } - return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode); -} - -void EltwiseFormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } - -void EltwiseFormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h index a8de9c9797..3580aa9260 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H #define MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H -#include "tools/converter/optimizer.h" +#include #include "tools/common/graph_util.h" #include "tools/converter/converter_flags.h" #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" @@ -25,26 +25,24 @@ namespace mindspore { namespace lite { -class EltwiseFormatTransPass : public GraphPass { +class EltwiseFormatTransPass : public FormatTransPass { public: - EltwiseFormatTransPass() : id(0) {} + EltwiseFormatTransPass() : FormatTransPass() {} ~EltwiseFormatTransPass() override = default; STATUS Run(schema::MetaGraphT *graph) override; - void SetQuantType(QuantType quantType); - - void SetFmk(converter::FmkType fmkType); - private: - NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, - FormatTransNodeType nodeType, STATUS *errorCode); + bool CanFusion(schema::MetaGraphT *graph, const std::unique_ptr &node); + + STATUS FindOutTransType(); private: - size_t id; - QuantType quantType = QuantType_QUANT_NONE; - converter::FmkType fmkType = converter::FmkType_TF; + FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW; + FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW; + schema::PrimitiveType pre_type_ = schema::PrimitiveType_NONE; + schema::PrimitiveType post_type_ = schema::PrimitiveType_NONE; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h index 2fc754a36d..cc8ce850b7 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h @@ -37,16 +37,19 @@ class FormatTransPass : public GraphPass { void SetFmk(converter::FmkType fmkType); + protected: + NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, + FormatTransNodeType nodeType, STATUS *errorCode); + private: STATUS DoModelInputFormatTrans(schema::MetaGraphT *graph); STATUS DoNodeInoutFormatTrans(schema::MetaGraphT *graph); - NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, - FormatTransNodeType nodeType, STATUS *errorCode); + protected: + size_t id = 0; private: - size_t id; QuantType quantType = QuantType_QUANT_NONE; converter::FmkType fmkType = converter::FmkType_TF; }; @@ -54,4 +57,3 @@ class FormatTransPass : public GraphPass { } // namespace mindspore #endif // MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H - diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc new file mode 100644 index 0000000000..b8aa778b90 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc @@ -0,0 +1,213 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h" +#include "tools/common/converter_op_utils.h" +#include "utils/log_adapter.h" +#include "src/common/utils.h" + +namespace mindspore { +namespace lite { +void WeightFormatHardCodePass::SetQuantType(QuantType quantType) { this->quantType = quantType; } + +void WeightFormatHardCodePass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; } + +// pre set tensor format +// non quant, filterFormat: +// conv deconv depth dedepth +// caffe K(C/g)HW C(K/g)HW / / // todo with deconvOp +// tf HWCK HWKC HWCK HWKC +// onnx K(C/g)HW C(K/g)HW / / + +// awareing quant, filterFormat: +// conv deconv depth dedepth +// onnx KHWC ? CHWK ? +// tf HWCK ? HWCK ? +STATUS WeightFormatHardCodePass::Run(MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + for (auto &node : graph->nodes) { + MS_ASSERT(node != nullptr); + MS_ASSERT(node->primitive != nullptr); + auto opType = node->primitive->value.type; + if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D && opType != PrimitiveType_DeConv2D && + opType != PrimitiveType_DeDepthwiseConv2D) { + continue; + } + MS_ASSERT(node->inputIndex.size() >= 2); + auto weightIndex = node->inputIndex.at(1); + MS_ASSERT(subGraph->allTensors.size() > weightIndex); + auto &weightTensor = graph->allTensors[weightIndex]; + MS_ASSERT(weightTensor->dims.size() == 4 || weightTensor->dims.empty()); // for conv with fakqQuant before weight + STATUS status; + switch (fmkType) { + case converter::FmkType_CAFFE: + status = HardCodeCAFFE(node, weightTensor); + break; + case converter::FmkType_TFLITE: + status = HardCodeTFLITE(node, weightTensor); + break; + case converter::FmkType_ONNX: + status = HardCodeONNX(node, weightTensor); + break; + case converter::FmkType_MS: + status = HardCodeMS(node, weightTensor); + break; + default: + MS_LOG(ERROR) << "Unsupported fmkType: " << fmkType << ", node: " << node->name; + return RET_ERROR; + } + if (status != RET_OK) { + MS_LOG(ERROR) << "Format hardCode faild: " << status << ", node: " << node->name; + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS WeightFormatHardCodePass::HardCodeCAFFE(const std::unique_ptr &node, + const std::unique_ptr &weightTensor) { + MS_ASSERT(node != nullptr); + MS_ASSERT(weightTensor != nullptr); + MS_ASSERT(node != nullptr); + MS_ASSERT(node->primitive != nullptr); + auto opType = node->primitive->value.type; + switch (this->quantType) { + case QuantType_QUANT_NONE: { + if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D || + opType == schema::PrimitiveType_DeConv2D || opType == schema::PrimitiveType_DeDepthwiseConv2D) { + weightTensor->format = Format_KCHW; + } else { + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; + } + } break; + default: { + MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(node->quantType) << ", node: " << node->name; + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS WeightFormatHardCodePass::HardCodeONNX(const std::unique_ptr &node, + const std::unique_ptr &weightTensor) { + MS_ASSERT(node != nullptr); + MS_ASSERT(weightTensor != nullptr); + MS_ASSERT(node != nullptr); + MS_ASSERT(node->primitive != nullptr); + auto opType = node->primitive->value.type; + switch (this->quantType) { + case QuantType_AwareTraining: { + // sum up from current onnx quant models + if (opType == PrimitiveType_Conv2D) { + weightTensor->format = Format_KHWC; + } else if (opType == PrimitiveType_DepthwiseConv2D) { + weightTensor->format = Format_CHWK; + } else if (opType == PrimitiveType_DeConv2D) { + weightTensor->format = Format_CKHW; + } else { + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; + return RET_ERROR; + } + } break; + case QuantType_QUANT_NONE: { + // conv (K x C/group x kH x kW) group = 1 + // depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W) + // deconv (C x K/group x kH x kW) group = 1 + // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) + if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D) { + weightTensor->format = Format_KCHW; + } else if (opType == PrimitiveType_DeConv2D) { + weightTensor->format = Format_CKHW; + } else { + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; + return RET_ERROR; + } + } break; + default: { + MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(node->quantType) << ", node: " << node->name; + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS WeightFormatHardCodePass::HardCodeMS(const std::unique_ptr &node, + const std::unique_ptr &weightTensor) { + MS_ASSERT(node != nullptr); + MS_ASSERT(weightTensor != nullptr); + MS_ASSERT(node != nullptr); + MS_ASSERT(node->primitive != nullptr); + auto opType = node->primitive->value.type; + switch (this->quantType) { + case QuantType_AwareTraining: { + if (opType == schema::PrimitiveType_Conv2D) { + weightTensor->format = schema::Format_HWCK; + } else if (opType == PrimitiveType_DepthwiseConv2D) { + weightTensor->format = Format_CKHW; + } else { + weightTensor->format = schema::Format_HWKC; + } + } break; + case QuantType_QUANT_NONE: { + // sum up from current ms quant models + if (opType == PrimitiveType_Conv2D) { + weightTensor->format = Format_KCHW; + } else if (opType == PrimitiveType_DepthwiseConv2D) { + weightTensor->format = Format_CKHW; + } else { + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; + return RET_ERROR; + } + } break; + default: { + MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(node->quantType) << ", node: " << node->name; + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS WeightFormatHardCodePass::HardCodeTFLITE(const std::unique_ptr &node, + const std::unique_ptr &weightTensor) { + MS_ASSERT(node != nullptr); + MS_ASSERT(weightTensor != nullptr); + MS_ASSERT(node != nullptr); + MS_ASSERT(node->primitive != nullptr); + auto opType = node->primitive->value.type; + switch (this->quantType) { + case QuantType_AwareTraining: + case QuantType_PostTraining: + case QuantType_QUANT_NONE: { + if (opType == schema::PrimitiveType_Conv2D) { + weightTensor->format = schema::Format_KHWC; + } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_CHWK; + } else if (opType == schema::PrimitiveType_DeConv2D) { + weightTensor->format = schema::Format_CHWK; + } else { + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; + return RET_ERROR; + } + } break; + default: { + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; + return RET_ERROR; + } + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h new file mode 100644 index 0000000000..9abdc79efb --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H + +#include +#include "tools/converter/converter_flags.h" +#include "tools/converter/optimizer.h" +#include "tools/common/graph_util.h" + +namespace mindspore { +namespace lite { +class WeightFormatHardCodePass : public GraphPass { + public: + WeightFormatHardCodePass() = default; + + ~WeightFormatHardCodePass() override = default; + + void SetQuantType(QuantType quantType); + + void SetFmkType(converter::FmkType fmkType); + + STATUS Run(MetaGraphT *graph) override; + + private: + STATUS HardCodeCAFFE(const std::unique_ptr &node, const std::unique_ptr &weightTensor); + STATUS HardCodeTFLITE(const std::unique_ptr &node, const std::unique_ptr &weightTensor); + STATUS HardCodeONNX(const std::unique_ptr &node, const std::unique_ptr &weightTensor); + STATUS HardCodeMS(const std::unique_ptr &node, const std::unique_ptr &weightTensor); + + private: + QuantType quantType = QuantType_QUANT_NONE; + converter::FmkType fmkType = converter::FmkType_TF; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc new file mode 100644 index 0000000000..62c25aa673 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h" +#include +#include "tools/common/node_util.h" +#include "tools/common/converter_op_utils.h" +#include "utils/log_adapter.h" +#include "src/common/utils.h" + +namespace mindspore { +namespace lite { +void WeightFormatTransformPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } + +void WeightFormatTransformPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; } + +void WeightFormatTransformPass::SetDstFormat(Format format) { this->dstFormat = format; } + +STATUS WeightFormatTransformPass::Run(MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + if (this->quantType == QuantType_AwareTraining) { + auto status = QuantDataFormatTrans(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; + return status; + } + } else { + auto status = NonQuantDataFormatTrans(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "NonQuantDataFormatTrans failed: " << status; + return status; + } + } + return RET_OK; +} + +STATUS WeightFormatTransformPass::QuantDataFormatTrans(MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + for (auto &node : graph->nodes) { + MS_ASSERT(node != nullptr); + MS_ASSERT(node->primitive != nullptr); + auto opType = node->primitive->value.type; + if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D) { + continue; + } + MS_ASSERT(node->inputIndex.size() >= 2); + auto weightIndex = node->inputIndex.at(1); + MS_ASSERT(subGraph->allTensors.size() > weightIndex); + auto &weightTensor = graph->allTensors[weightIndex]; + MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT); + STATUS status; + if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D) { // weight should be HWCK + Format curDstFormat; + if (this->dstFormat == Format_NUM_OF_FORMAT) { + curDstFormat = Format_KHWC; + } else { + curDstFormat = this->dstFormat; + } + status = TransFilterFormat(weightTensor.get(), curDstFormat); + if (status == RET_OK) { + // node->primitive->value.AsConv2D()->format = schema::Format_NHWC; + weightTensor->format = curDstFormat; + } else { + MS_LOG(WARNING) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" + << EnumNameFormat(curDstFormat) << " failed, node : " << node->name; + // todo(00445839): consider varible weight condition + } + } + } + return RET_OK; +} + +STATUS WeightFormatTransformPass::NonQuantDataFormatTrans(MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + for (auto &node : graph->nodes) { + MS_ASSERT(node != nullptr); + MS_ASSERT(node->primitive != nullptr); + auto opType = node->primitive->value.type; + if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D && opType != PrimitiveType_DeConv2D && + opType != PrimitiveType_DeDepthwiseConv2D) { + continue; + } + MS_ASSERT(node->inputIndex.size() >= 2); + auto weightIndex = node->inputIndex.at(1); + MS_ASSERT(subGraph->allTensors.size() > weightIndex); + auto &weightTensor = graph->allTensors[weightIndex]; + MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT); + STATUS status; + if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D || + opType == schema::PrimitiveType_DeConv2D) { + Format curDstFormat; + if (this->dstFormat == Format_NUM_OF_FORMAT) { + curDstFormat = Format_KHWC; + } else { + curDstFormat = this->dstFormat; + } + status = TransFilterFormat(weightTensor.get(), curDstFormat); + if (status == RET_OK) { + // node->attr.AsConv2D()->format = Format_NCHW; + weightTensor->format = curDstFormat; + } else { + MS_LOG(WARNING) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" + << EnumNameFormat(curDstFormat) << " failed, node : " << node->name; + // todo(00445839): consider varible weight condition + } + } else { // weight should be CKHW + Format curDstFormat; + if (this->dstFormat == Format_NUM_OF_FORMAT) { + curDstFormat = Format_KHWC; + } else { + curDstFormat = this->dstFormat; + } + status = TransFilterFormat(weightTensor.get(), curDstFormat); + if (status == RET_OK) { + // node->attr.AsDepthwiseConv2D()->format = Format_NCHW; + weightTensor->format = curDstFormat; + } else { + MS_LOG(WARNING) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" + << EnumNameFormat(curDstFormat) << " failed, node : " << node->name; + // todo(00445839): consider varible weight condition + } + } + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h similarity index 56% rename from mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.h rename to mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h index cf7d3d462d..110b1df58a 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h @@ -14,45 +14,40 @@ * limitations under the License. */ -#ifndef MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H -#define MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H #include "tools/converter/optimizer.h" +#include "tools/common/graph_util.h" #include "tools/converter/converter_flags.h" -#include "utils/log_adapter.h" namespace mindspore { namespace lite { -class WeightFormatPass : public NodePass { +class WeightFormatTransformPass : public GraphPass { public: - WeightFormatPass() = default; + WeightFormatTransformPass() = default; - ~WeightFormatPass() override = default; + ~WeightFormatTransformPass() override = default; void SetQuantType(QuantType quantType); void SetFmkType(converter::FmkType fmkType); - int Run(GraphNode *graphNode) override; + void SetDstFormat(Format format); - private: - // correct weightTensor->Format - int ShapeFormatTrans(GraphNode *graphNode); + STATUS Run(MetaGraphT *graph) override; - // transform weightTensor data and format - // if quant : conv transform dataFormat to NHWC, weight format to HWCK - // if quant : depth transform dataFormat to NCHW, weight format to CKHW - int QuantDataFormatTrans(GraphNode *graphNode); + private: + STATUS QuantDataFormatTrans(MetaGraphT *graph); - // if no quant : transform dataFormat to NCHW, weight format to KCHW/CKHW - int NonQuantDataFormatTrans(GraphNode *graphNode); + STATUS NonQuantDataFormatTrans(MetaGraphT *graph); private: QuantType quantType = QuantType_QUANT_NONE; converter::FmkType fmkType = converter::FmkType_TF; + Format dstFormat = Format_NUM_OF_FORMAT; }; } // namespace lite } // namespace mindspore -#endif // MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H - +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/node/CMakeLists.txt deleted file mode 100755 index 6288071c81..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -add_library(node_mid OBJECT - ${CMAKE_CURRENT_SOURCE_DIR}/weight_format_pass.cc - ) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc deleted file mode 100644 index c863c502a8..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc +++ /dev/null @@ -1,407 +0,0 @@ -/** - * Copyright 201+ Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/legacy_optimizer/node/weight_format_pass.h" -#include "tools/common/node_util.h" -#include "tools/common/tensor_util.h" - -namespace mindspore { -namespace lite { -int WeightFormatPass::Run(GraphNode *graphNode) { - MS_ASSERT(graphNode != nullptr); - auto status = ShapeFormatTrans(graphNode); - if (status != 0) { - MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status; - return status; - } - if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_PostTraining) { - status = QuantDataFormatTrans(graphNode); - if (status != 0) { - MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; - return status; - } - } else { - status = NonQuantDataFormatTrans(graphNode); - if (status != 0) { - MS_LOG(ERROR) << "NonQuantDataFormatTrans failed: " << status; - return status; - } - } - return 0; -} - -void WeightFormatPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } - -void WeightFormatPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; } - -// pre set tensor format -// non quant, filterFormat: -// conv deconv depth dedepth -// caffe K(C/g)HW C(K/g)HW / / -// tf HWCK HWKC HWCK HWKC -// onnx K(C/g)HW C(K/g)HW / / - -// awareing quant, filterFormat: -// conv deconv depth dedepth -// onnx KHWC ? CHWK ? -// tf HWCK ? HWCK ? -int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { - MS_ASSERT(graphNode != nullptr); - auto &subGraph = graphNode->subGraph; - auto &node = graphNode->opDef; - MS_ASSERT(subGraph != nullptr); - MS_ASSERT(node != nullptr); - auto opType = node->primitive->value.type; - if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && - opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { - return 0; - } - MS_ASSERT(node->inputIndex.size() >= 2); - auto weightIndex = node->inputIndex.at(1); - MS_ASSERT(subGraph->allTensors.size() > weightIndex); - auto &weightTensor = subGraph->allTensors[weightIndex]; - auto &shape = weightTensor->dims; - MS_ASSERT(shape.size() == 4); - if (fmkType == converter::FmkType_CAFFE) { - switch (node->quantType) { - case QuantType_QUANT_NONE: { - if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D || - opType == schema::PrimitiveType_DeConv2D || opType == schema::PrimitiveType_DeDepthwiseConv2D) { - weightTensor->format = schema::Format_KCHW; - } else { - MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType) - << ", node: " << node->name.c_str(); - return -1; - } - } break; - default: { - MS_LOG(ERROR) << "Invalid quantType: " << schema::EnumNameQuantType(node->quantType) - << ", node: " << node->name.c_str(); - return -1; - } - } - return 0; - } else if (fmkType == converter::FmkType_MS) { - switch (node->quantType) { - case QuantType_AwareTraining: { - if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { - weightTensor->format = schema::Format_HWCK; - } else { - weightTensor->format = schema::Format_HWKC; - } - } break; - case QuantType_QUANT_NONE: { - // conv [filter_height, filter_width, in_channels, out_channels] - // depthwise [filter_height, filter_width, in_channels, channel_multiplier] - if (opType == schema::PrimitiveType_Conv2D) { - weightTensor->format = schema::Format_KCHW; - } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { - weightTensor->format = schema::Format_KCHW; - } else { - MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; - return -1; - } - } break; - default: { - MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); - return -1; - } - } - return 0; - } else if (fmkType == converter::FmkType_TF) { - switch (node->quantType) { - case QuantType_AwareTraining: { - if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { - weightTensor->format = schema::Format_HWCK; - } else { - weightTensor->format = schema::Format_HWKC; - } - } break; - case QuantType_QUANT_NONE: { - // conv [filter_height, filter_width, in_channels, out_channels] - // depthwise [filter_height, filter_width, in_channels, channel_multiplier] - if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { - weightTensor->format = schema::Format_HWCK; - } else { - weightTensor->format = schema::Format_HWKC; - } - } break; - default: { - MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); - return -1; - } - } - return 0; - } else if (fmkType == converter::FmkType_TFLITE) { - switch (node->quantType) { - case QuantType_QUANT_NONE: - case QuantType_AwareTraining: - case QuantType_PostTraining: { - if (opType == schema::PrimitiveType_Conv2D) { - weightTensor->format = schema::Format_KHWC; - } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { - weightTensor->format = schema::Format_CHWK; - } else if (opType == schema::PrimitiveType_DeConv2D) { - weightTensor->format = schema::Format_CHWK; - } else { - MS_LOG(ERROR) << "Unsupported format"; - return -1; - } - } break; - default: { - MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); - return -1; - } - } - MS_LOG(DEBUG) << "weight_tensor_format: " << weightTensor->format; - return 0; - } else if (fmkType == converter::FmkType_ONNX) { - switch (node->quantType) { - case QuantType_AwareTraining: { - // sum up from current onnx quant models - if (opType == schema::PrimitiveType_Conv2D) { - weightTensor->format = schema::Format_KHWC; - } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { - weightTensor->format = schema::Format_CHWK; - } else { - MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); - return -1; - } - } break; - case QuantType_QUANT_NONE: { - // conv (K x C/group x kH x kW) group = 1 - // depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W) - // deconv (C x K/group x kH x kW) group = 1 - // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) - if (opType == schema::PrimitiveType_Conv2D) { - weightTensor->format = schema::Format_KCHW; - } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { - weightTensor->format = schema::Format_KCHW; - } else if (opType == schema::PrimitiveType_DeConv2D) { - weightTensor->format = schema::Format_CKHW; - } else { - MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); - return -1; - } - } break; - default: { - MS_LOG(ERROR) << "Unsupported quantType: %d, node: " << node->quantType, node->name.c_str(); - return -1; - } - } - } else { - MS_LOG(ERROR) << "Invalid fmkType: %d, node: " << fmkType, node->name.c_str(); - return -1; - } - return 0; -} - -// inference needed filterFormat: -// conv deconv depth dedepth -// uint8 KHWC KHWC KHWC KHWC -int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { - MS_ASSERT(graphNode != nullptr); - auto &subGraph = graphNode->subGraph; - auto &node = graphNode->opDef; - MS_ASSERT(subGraph != nullptr); - MS_ASSERT(node != nullptr); - auto opType = node->primitive->value.type; - if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && - opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { - return RET_OK; - } - - MS_ASSERT(node->inputIndex.size() >= 2); - auto weightIndex = node->inputIndex.at(1); - MS_ASSERT(subGraph->allTensors.size() > weightIndex); - auto &weightTensor = subGraph->allTensors[weightIndex]; - MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT - STATUS status = RET_OK; - if (opType == schema::PrimitiveType_Conv2D) { // weight should be KHWC - if (weightTensor->format == schema::Format_KCHW) { // from caffe - if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { - MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format - << weightTensor->dataType; - status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); - } else { - MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format - << weightTensor->dataType; - status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); - } - } else if (weightTensor->format != schema::Format_KHWC) { - MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; - return -1; - } - if (status == 0) { - node->primitive->value.AsConv2D()->format = schema::Format_NHWC; - weightTensor->format = schema::Format_KHWC; - } else { - MS_LOG(WARNING) << "TransFilter %sToKHWC failed, node : " - << (weightTensor->format == schema::Format_KHWC ? "KHWC" : "KCHW") << node->name.c_str(); - // todo(00445839): consider varible weight condition - } - } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be KHWC - if (weightTensor->format == schema::Format_CKHW) { // from caffe - if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { - MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format - << "datatype: " << weightTensor->dataType; - status = TransFilterFormat(weightTensor.get(), kCKHW2KHWC); - } else if (weightTensor->dataType == kNumberTypeUInt8) { - MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format - << "datatype: " << weightTensor->dataType; - status = TransFilterFormat(weightTensor.get(), kCKHW2KHWC); - } else { - MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format - << "datatype: " << weightTensor->dataType; - status = TransFilterFormat(weightTensor.get(), kCKHW2KHWC); - } - - } else if (weightTensor->format == schema::Format_CHWK) { // from onnx - if (weightTensor->dataType == kNumberTypeInt8) { - MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format - << "datatype: " << weightTensor->dataType; - status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); - } else if (weightTensor->dataType == kNumberTypeUInt8) { - MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format - << "datatype: " << weightTensor->dataType; - status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); - } else { - MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format - << "datatype: " << weightTensor->dataType; - status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); - } - } else if (weightTensor->format != schema::Format_KHWC) { - MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; - return -1; - } - if (status == 0) { - node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; - weightTensor->format = schema::Format_KHWC; - } else { - MS_LOG(WARNING) << "TransFilter" << (weightTensor->format == schema::Format_KHWC ? "KHWC" : "CKHW") - << "To KHWC failed, node : " << node->name.c_str(); - // todo(00445839): consider varible weight condition - } - } else { // weight should be HWCK - node->primitive->value.AsDeConv2D()->format = schema::Format_NHWC; - weightTensor->format = schema::Format_KHWC; - } - return 0; -} - -// inference needed filterFormat: -// conv deconv depth dedepth -// fp32 KCHW CKHW CKHW CKHW -int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { - MS_ASSERT(graphNode != nullptr); - auto &subGraph = graphNode->subGraph; - auto &node = graphNode->opDef; - MS_ASSERT(subGraph != nullptr); - MS_ASSERT(node != nullptr); - auto opType = node->primitive->value.type; - if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && - opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { - return 0; - } - - MS_ASSERT(node->inputIndex.size() >= 2); - auto weightIndex = node->inputIndex.at(1); - MS_ASSERT(subGraph->allTensors.size() > weightIndex); - auto &weightTensor = subGraph->allTensors[weightIndex]; - if (weightTensor->dataType != TypeId::kNumberTypeFloat32) { - MS_LOG(ERROR) << "weight tensor data should be float"; - // return -1; - } - STATUS status = RET_OK; - if (opType == schema::PrimitiveType_Conv2D) { // weight should be KCHW - if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms - status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); - } else if (weightTensor->format == schema::Format_KHWC) { - status = RET_OK; - } else if (weightTensor->format == schema::Format_CHWK) { - status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); - } else { - MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; - return -1; - } - if (status == 0) { - node->primitive->value.AsConv2D()->format = schema::Format_NHWC; - weightTensor->format = schema::Format_KHWC; - } else { - MS_LOG(WARNING) << "TransFilter " << ((weightTensor->format == schema::Format_HWCK) ? "HWCK" : "NHWC") - << "ToKCHW failed, node : " << node->name.c_str(); - // todo(00445839): consider varible weight condition - } - } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be CKHW - if (fmkType == converter::FmkType_MS) { - weightTensor->format = schema::Format_CKHW; - } - if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms - status = TransFilterFormat(weightTensor.get(), kCKHW2KHWC); - } else if (weightTensor->format == schema::Format_KCHW) { - status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); - } else if (weightTensor->format == schema::Format_CHWK) { // from tflite - status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); - } else { - MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; - return -1; - } - if (status == 0) { - node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; - weightTensor->format = schema::Format_KHWC; - } else { - MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str(); - // todo(00445839): consider varible weight condition - } - } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC - if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms - status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); - } else if (weightTensor->format == schema::Format_CHWK) { // from tflite - status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); - } else { - MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; - return -1; - } - if (status == 0) { - node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW; - weightTensor->format = schema::Format_KHWC; - } else { - MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str(); - // todo(00445839): consider varible weight condition - } - } else if (opType == schema::PrimitiveType_DeDepthwiseConv2D) { // weight should be KHWC - if (weightTensor->format == schema::Format_KHWC) { - return 0; - } else if (weightTensor->format == schema::Format_KCHW) { // from caffe - status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); - } else if (weightTensor->format == schema::Format_HWKC) { // from tf or onnx - status = TransFilterFormat(weightTensor.get(), kHWKC2CKHW); - } else { - MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; - return -1; - } - if (status == 0) { - node->primitive->value.AsDeDepthwiseConv2D()->format = schema::Format_NHWC; - weightTensor->format = schema::Format_CKHW; - } else { - MS_LOG(WARNING) << "TransFilter HWKCToCKHW failed, node : " << node->name.c_str(); - // todo(00445839): consider varible weight condition - } - } - return 0; -} -} // namespace lite -} // namespace mindspore