| @@ -249,47 +249,52 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph) | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto data_shape = data_tensor->dims; | |||
| conv_attr->channelIn = data_shape[3]; | |||
| conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; | |||
| // update attr | |||
| conv_attr->group = 0; | |||
| conv_attr->format = attr->format; | |||
| conv_attr->kernelH = attr->kernelH; | |||
| conv_attr->kernelW = attr->kernelW; | |||
| conv_attr->strideH = attr->strideH; | |||
| conv_attr->strideW = attr->strideW; | |||
| conv_attr->padMode = attr->padMode; | |||
| conv_attr->padUp = attr->padUp; | |||
| conv_attr->padDown = attr->padDown; | |||
| conv_attr->padLeft = attr->padLeft; | |||
| conv_attr->padRight = attr->padRight; | |||
| conv_attr->dilateH = attr->dilateH; | |||
| conv_attr->dilateW = attr->dilateW; | |||
| conv_attr->hasBias = attr->hasBias; | |||
| conv_attr->activationType = attr->activationType; | |||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | |||
| op->primitive->value.value = conv_attr.release(); | |||
| // update weight | |||
| auto weight_id = op->inputIndex[1]; | |||
| auto &weight_tensor = sub_graph->allTensors.at(weight_id); | |||
| if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) { | |||
| auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) { | |||
| auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Trans filter format failed."; | |||
| if (data_shape[3] == 1) { | |||
| conv_attr->channelIn = data_shape[3]; | |||
| conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; | |||
| // update attr | |||
| conv_attr->group = 1; | |||
| conv_attr->format = attr->format; | |||
| conv_attr->kernelH = attr->kernelH; | |||
| conv_attr->kernelW = attr->kernelW; | |||
| conv_attr->strideH = attr->strideH; | |||
| conv_attr->strideW = attr->strideW; | |||
| conv_attr->padMode = attr->padMode; | |||
| conv_attr->padUp = attr->padUp; | |||
| conv_attr->padDown = attr->padDown; | |||
| conv_attr->padLeft = attr->padLeft; | |||
| conv_attr->padRight = attr->padRight; | |||
| conv_attr->dilateH = attr->dilateH; | |||
| conv_attr->dilateW = attr->dilateW; | |||
| conv_attr->hasBias = attr->hasBias; | |||
| conv_attr->activationType = attr->activationType; | |||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | |||
| op->primitive->value.value = conv_attr.release(); | |||
| // update weight | |||
| auto weight_id = op->inputIndex[1]; | |||
| auto &weight_tensor = sub_graph->allTensors.at(weight_id); | |||
| if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) { | |||
| auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) { | |||
| auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Trans filter format failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "The dataType of weight tensor is unsupported."; | |||
| return RET_ERROR; | |||
| } | |||
| weight_tensor->format = schema::Format_CHWK; | |||
| } | |||
| weight_tensor->format = schema::Format_CHWK; | |||
| } | |||
| } | |||
| } | |||