Merge pull request !3713 from ghzl/deconv-adaptertags/v0.7.0-beta
| @@ -80,6 +80,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { | |||||
| return new lite::Activation(const_cast<schema::Primitive *>(srcPrim)); | return new lite::Activation(const_cast<schema::Primitive *>(srcPrim)); | ||||
| case schema::PrimitiveType_Conv2D: | case schema::PrimitiveType_Conv2D: | ||||
| return new lite::Conv2D(const_cast<schema::Primitive *>(srcPrim)); | return new lite::Conv2D(const_cast<schema::Primitive *>(srcPrim)); | ||||
| case schema::PrimitiveType_DeConv2D: | |||||
| return new lite::DeConv2D(const_cast<schema::Primitive *>(srcPrim)); | |||||
| case schema::PrimitiveType_Reduce: | case schema::PrimitiveType_Reduce: | ||||
| return new lite::Reduce(const_cast<schema::Primitive *>(srcPrim)); | return new lite::Reduce(const_cast<schema::Primitive *>(srcPrim)); | ||||
| case schema::PrimitiveType_Pooling: | case schema::PrimitiveType_Pooling: | ||||
| @@ -81,7 +81,7 @@ STATUS ConvBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &pat | |||||
| } | } | ||||
| auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get(); | auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get(); | ||||
| MS_ASSERT(baNodeBiasTensor != nullptr); | MS_ASSERT(baNodeBiasTensor != nullptr); | ||||
| if (baNodeBiasTensor->refCount != schema::NodeType_ValueNode) { | |||||
| if (baNodeBiasTensor->nodeType != schema::NodeType_ValueNode) { | |||||
| // dont fusion, return | // dont fusion, return | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -215,7 +215,9 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr<Path> convPath, | |||||
| << ". or bias tensor is a scaler"; | << ". or bias tensor is a scaler"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (!biasDims.empty() && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) { | |||||
| bool bias_const = !biasDims.empty() && biasDims.size() == 1 && biasDims[0] == 1; | |||||
| if (!biasDims.empty() && !bias_const && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) { | |||||
| MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)" | MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)" | ||||
| << biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum; | << biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -234,6 +236,11 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr<Path> convPath, | |||||
| MS_LOG(ERROR) << "memset_s newBiasData failed"; | MS_LOG(ERROR) << "memset_s newBiasData failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } else if (bias_const) { | |||||
| auto *biasData = reinterpret_cast<float *>(biasTensor->data.data()); | |||||
| for (size_t i = 0; i < kernelNum; i++) { | |||||
| newBiasData[i] = *biasData; | |||||
| } | |||||
| } else { | } else { | ||||
| if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) { | if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) { | ||||
| MS_LOG(ERROR) << "memcpy_s newBiasData failed"; | MS_LOG(ERROR) << "memcpy_s newBiasData failed"; | ||||
| @@ -153,6 +153,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||||
| weightTensor->format = schema::Format_KHWC; | weightTensor->format = schema::Format_KHWC; | ||||
| } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { | } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { | ||||
| weightTensor->format = schema::Format_CHWK; | weightTensor->format = schema::Format_CHWK; | ||||
| } else if (opType == schema::PrimitiveType_DeConv2D) { | |||||
| weightTensor->format = schema::Format_KHWC; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "unsupport format"; | MS_LOG(ERROR) << "unsupport format"; | ||||
| return -1; | return -1; | ||||
| @@ -356,18 +358,18 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { | |||||
| MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str(); | MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str(); | ||||
| // todo(00445839): consider varible weight condition | // todo(00445839): consider varible weight condition | ||||
| } | } | ||||
| } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KCHW | |||||
| if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx | |||||
| return 0; | |||||
| } else if (weightTensor->format == schema::Format_HWKC) { // from tf | |||||
| status = TransFilterFormat<float>(weightTensor.get(), kHWKC2KCHW); | |||||
| } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC | |||||
| if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms | |||||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | |||||
| } else if (weightTensor->format == schema::Format_KHWC) { // from tf | |||||
| status = RET_OK; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | ||||
| return -1; | return -1; | ||||
| } | } | ||||
| if (status == 0) { | if (status == 0) { | ||||
| node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW; | |||||
| weightTensor->format = schema::Format_KCHW; | |||||
| node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW; | |||||
| weightTensor->format = schema::Format_KHWC; | |||||
| } else { | } else { | ||||
| MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str(); | MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str(); | ||||
| // todo(00445839): consider varible weight condition | // todo(00445839): consider varible weight condition | ||||
| @@ -27,8 +27,16 @@ STATUS TfliteAddParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp | |||||
| schema::CNodeT *op, | schema::CNodeT *op, | ||||
| TensorCache *tensor_cache, | TensorCache *tensor_cache, | ||||
| bool quantizedModel) { | bool quantizedModel) { | ||||
| // MS_LOGD("parse TfliteAddParser"); | |||||
| MS_LOG(DEBUG) << "parse TfliteAddParser"; | |||||
| std::unique_ptr<schema::AddT> attr(new schema::AddT()); | std::unique_ptr<schema::AddT> attr(new schema::AddT()); | ||||
| auto weight_index = tfliteOp->inputs[1]; | |||||
| const auto &weight_tensor = tfliteTensors[weight_index]; | |||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | |||||
| if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (op != nullptr) { | if (op != nullptr) { | ||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | op->primitive = std::make_unique<schema::PrimitiveT>(); | ||||
| op->primitive->value.type = schema::PrimitiveType_Add; | op->primitive->value.type = schema::PrimitiveType_Add; | ||||