diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc index 5432f58ae3..a79072281e 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc @@ -50,7 +50,7 @@ void WeightFormatPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = // pre set tensor format // non quant, filterFormat: // conv deconv depth dedepth -// caffe K(C/g)HW C(K/g)HW / / // todo with deconvOp +// caffe K(C/g)HW C(K/g)HW / / // tf HWCK HWKC HWCK HWKC // onnx K(C/g)HW C(K/g)HW / / @@ -78,7 +78,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { if (fmkType == converter::FmkType_CAFFE) { switch (node->quantType) { case QuantType_QUANT_NONE: { - if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { + if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D || + opType == schema::PrimitiveType_DeConv2D) { weightTensor->format = schema::Format_KCHW; } else { MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType) @@ -227,7 +228,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { STATUS status; if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK if (weightTensor->format == schema::Format_KCHW) { // from caffe - if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { + if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format << weightTensor->dataType; status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); @@ -237,7 +238,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); } } else if (weightTensor->format == schema::Format_KHWC) { // from onnx - if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { + if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); } else { status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); @@ -259,7 +260,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { } } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be HWCK if (weightTensor->format == schema::Format_CKHW) { // from caffe - if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { + if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format, weightTensor->dataType; status = TransFilterFormat(weightTensor.get(), kCKHW2HWCK); @@ -272,13 +273,13 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { } else if (weightTensor->format == schema::Format_HWCK) { // from tf return 0; } else if (weightTensor->format == schema::Format_CHWK) { // from onnx - if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { + if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { status = TransFilterFormat(weightTensor.get(), kCHWK2HWCK); } else { status = TransFilterFormat(weightTensor.get(), kCHWK2HWCK); } } else if (weightTensor->format == schema::Format_KCHW) { - if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { + if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); } else { status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); @@ -365,7 +366,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { // todo(00445839): consider varible weight condition } } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC - if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms + if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); } else if (weightTensor->format == schema::Format_CHWK) { // from tf status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC);