From 668db1dd7d9ab5f94a955ccc5a40cb8321a36780 Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Thu, 30 Jul 2020 20:21:22 +0800 Subject: [PATCH] deconv adapter --- mindspore/lite/src/model_impl.cc | 2 ++ .../optimizer/fusion/conv_biasadd_fusion_pass.cc | 11 +++++++++-- .../optimizer/node/weight_format_pass.cc | 16 +++++++++------- .../converter/parser/tflite/tflite_add_parser.cc | 10 +++++++++- 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/mindspore/lite/src/model_impl.cc b/mindspore/lite/src/model_impl.cc index c9265e6c2d..c05c16ae2f 100644 --- a/mindspore/lite/src/model_impl.cc +++ b/mindspore/lite/src/model_impl.cc @@ -76,6 +76,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { return new lite::Activation(const_cast(srcPrim)); case schema::PrimitiveType_Conv2D: return new lite::Conv2D(const_cast(srcPrim)); + case schema::PrimitiveType_DeConv2D: + return new lite::DeConv2D(const_cast(srcPrim)); case schema::PrimitiveType_Reduce: return new lite::Reduce(const_cast(srcPrim)); case schema::PrimitiveType_Pooling: diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc index c353f2ca94..f51789f4e3 100644 --- a/mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc @@ -81,7 +81,7 @@ STATUS ConvBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &pat } auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get(); MS_ASSERT(baNodeBiasTensor != nullptr); - if (baNodeBiasTensor->refCount != schema::NodeType_ValueNode) { + if (baNodeBiasTensor->nodeType != schema::NodeType_ValueNode) { // dont fusion, return return RET_OK; } @@ -215,7 +215,9 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr convPath, << ". or bias tensor is a scaler"; return RET_ERROR; } - if (!biasDims.empty() && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) { + + bool bias_const = !biasDims.empty() && biasDims.size() == 1 && biasDims[0] == 1; + if (!biasDims.empty() && !bias_const && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) { MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)" << biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum; return RET_ERROR; @@ -234,6 +236,11 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr convPath, MS_LOG(ERROR) << "memset_s newBiasData failed"; return RET_ERROR; } + } else if (bias_const) { + auto *biasData = reinterpret_cast(biasTensor->data.data()); + for (size_t i = 0; i < kernelNum; i++) { + newBiasData[i] = *biasData; + } } else { if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) { MS_LOG(ERROR) << "memcpy_s newBiasData failed"; diff --git a/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc index 3400c88451..0bd06696f0 100644 --- a/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc @@ -152,6 +152,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { weightTensor->format = schema::Format_KHWC; } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { weightTensor->format = schema::Format_CHWK; + } else if (opType == schema::PrimitiveType_DeConv2D) { + weightTensor->format = schema::Format_KHWC; } else { MS_LOG(ERROR) << "unsupport format"; return -1; @@ -355,18 +357,18 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str(); // todo(00445839): consider varible weight condition } - } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KCHW - if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx - return 0; - } else if (weightTensor->format == schema::Format_HWKC) { // from tf - status = TransFilterFormat(weightTensor.get(), kHWKC2KCHW); + } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC + if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms + status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); + } else if (weightTensor->format == schema::Format_KHWC) { // from tf + status = RET_OK; } else { MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; return -1; } if (status == 0) { - node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW; - weightTensor->format = schema::Format_KCHW; + node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW; + weightTensor->format = schema::Format_KHWC; } else { MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str(); // todo(00445839): consider varible weight condition diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc index 377ecab167..1a0e3aa01b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc @@ -27,8 +27,16 @@ STATUS TfliteAddParser::Parse(const std::unique_ptr &tfliteOp schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { - // MS_LOGD("parse TfliteAddParser"); + MS_LOG(DEBUG) << "parse TfliteAddParser"; std::unique_ptr attr(new schema::AddT()); + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + return RET_ERROR; + } + if (op != nullptr) { op->primitive = std::make_unique(); op->primitive->value.type = schema::PrimitiveType_Add;