| @@ -249,47 +249,52 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph) | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto data_shape = data_tensor->dims; | auto data_shape = data_tensor->dims; | ||||
| conv_attr->channelIn = data_shape[3]; | |||||
| conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; | |||||
| // update attr | |||||
| conv_attr->group = 0; | |||||
| conv_attr->format = attr->format; | |||||
| conv_attr->kernelH = attr->kernelH; | |||||
| conv_attr->kernelW = attr->kernelW; | |||||
| conv_attr->strideH = attr->strideH; | |||||
| conv_attr->strideW = attr->strideW; | |||||
| conv_attr->padMode = attr->padMode; | |||||
| conv_attr->padUp = attr->padUp; | |||||
| conv_attr->padDown = attr->padDown; | |||||
| conv_attr->padLeft = attr->padLeft; | |||||
| conv_attr->padRight = attr->padRight; | |||||
| conv_attr->dilateH = attr->dilateH; | |||||
| conv_attr->dilateW = attr->dilateW; | |||||
| conv_attr->hasBias = attr->hasBias; | |||||
| conv_attr->activationType = attr->activationType; | |||||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| op->primitive->value.value = conv_attr.release(); | |||||
| // update weight | |||||
| auto weight_id = op->inputIndex[1]; | |||||
| auto &weight_tensor = sub_graph->allTensors.at(weight_id); | |||||
| if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) { | |||||
| auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) { | |||||
| auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Trans filter format failed."; | |||||
| if (data_shape[3] == 1) { | |||||
| conv_attr->channelIn = data_shape[3]; | |||||
| conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; | |||||
| // update attr | |||||
| conv_attr->group = 1; | |||||
| conv_attr->format = attr->format; | |||||
| conv_attr->kernelH = attr->kernelH; | |||||
| conv_attr->kernelW = attr->kernelW; | |||||
| conv_attr->strideH = attr->strideH; | |||||
| conv_attr->strideW = attr->strideW; | |||||
| conv_attr->padMode = attr->padMode; | |||||
| conv_attr->padUp = attr->padUp; | |||||
| conv_attr->padDown = attr->padDown; | |||||
| conv_attr->padLeft = attr->padLeft; | |||||
| conv_attr->padRight = attr->padRight; | |||||
| conv_attr->dilateH = attr->dilateH; | |||||
| conv_attr->dilateW = attr->dilateW; | |||||
| conv_attr->hasBias = attr->hasBias; | |||||
| conv_attr->activationType = attr->activationType; | |||||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| op->primitive->value.value = conv_attr.release(); | |||||
| // update weight | |||||
| auto weight_id = op->inputIndex[1]; | |||||
| auto &weight_tensor = sub_graph->allTensors.at(weight_id); | |||||
| if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) { | |||||
| auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) { | |||||
| auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Trans filter format failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "The dataType of weight tensor is unsupported."; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| weight_tensor->format = schema::Format_CHWK; | |||||
| } | } | ||||
| weight_tensor->format = schema::Format_CHWK; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||