From bed056aabbdb123d6fa3a62df8d000700db482fe Mon Sep 17 00:00:00 2001 From: yeyunpeng Date: Mon, 10 Aug 2020 19:34:46 +0800 Subject: [PATCH] Fix DeDepthwiseConv2D problem --- .../node/weight_format_pass.cc | 22 ++++++++++--------- .../caffe/caffe_deconvolution_parser.cc | 3 +-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc index 83db1e2b02..2e894a24f4 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc @@ -79,7 +79,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { switch (node->quantType) { case QuantType_QUANT_NONE: { if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D || - opType == schema::PrimitiveType_DeConv2D) { + opType == schema::PrimitiveType_DeConv2D || opType == schema::PrimitiveType_DeDepthwiseConv2D) { weightTensor->format = schema::Format_KCHW; } else { MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType) @@ -240,11 +240,11 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { } } else if (weightTensor->format == schema::Format_KHWC) { // from onnx return RET_OK; -// if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { -// status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); -// } else { -// status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); -// } + // if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { + // status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); + // } else { + // status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); + // } } else if (weightTensor->format == schema::Format_HWCK) { // from tf return 0; } else { @@ -275,7 +275,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { } else if (weightTensor->format == schema::Format_HWCK) { // from tf return 0; } else if (weightTensor->format == schema::Format_CHWK) { // from onnx - if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { + if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); } else { status = TransFilterFormat(weightTensor.get(), kCHWK2HWCK); @@ -383,9 +383,11 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str(); // todo(00445839): consider varible weight condition } - } else if (opType == schema::PrimitiveType_DeDepthwiseConv2D) { // weight should be CKHW - if (weightTensor->format == schema::Format_CKHW) { // from caffe + } else if (opType == schema::PrimitiveType_DeDepthwiseConv2D) { // weight should be KHWC + if (weightTensor->format == schema::Format_KHWC) { return 0; + } else if (weightTensor->format == schema::Format_KCHW) { // from caffe + status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); } else if (weightTensor->format == schema::Format_HWKC) { // from tf or onnx status = TransFilterFormat(weightTensor.get(), kHWKC2CKHW); } else { @@ -393,7 +395,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { return -1; } if (status == 0) { - node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; + node->primitive->value.AsDeDepthwiseConv2D()->format = schema::Format_NHWC; weightTensor->format = schema::Format_CKHW; } else { MS_LOG(WARNING) << "TransFilter HWKCToCKHW failed, node : " << node->name.c_str(); diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc index be9682f2fd..c219a14eed 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc @@ -46,14 +46,13 @@ void CaffeDeconvolutionParser::ParseGroupDeconvolution(schema::CNodeT *op, schem deDepthwiseConv2DParam->hasBias = attr->hasBias; deDepthwiseConv2DParam->activationType = attr->activationType; delete attr; - op->primitive = std::make_unique(); op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; op->primitive->value.value = deDepthwiseConv2DParam.release(); } STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, std::vector *weightVec) { op->name = proto.name(); - schema::DeConv2DT *attr = new schema::DeConv2DT(); + auto *attr = new schema::DeConv2DT(); attr->format = schema::Format_NCHW; const caffe::ConvolutionParameter convParam = proto.convolution_param();