|
|
|
@@ -118,7 +118,7 @@ STATUS WeightFormatHardCodePass::HardCodeONNX(const std::unique_ptr<CNodeT> &nod |
|
|
|
} else if (opType == PrimitiveType_DepthwiseConv2D) { |
|
|
|
weightTensor->format = schema::Format::Format_CHWK; |
|
|
|
} else if (opType == PrimitiveType_DeConv2D) { |
|
|
|
weightTensor->format = schema::Format::Format_CKHW; |
|
|
|
weightTensor->format = schema::Format::Format_KCHW; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; |
|
|
|
return RET_ERROR; |
|
|
|
@@ -133,7 +133,7 @@ STATUS WeightFormatHardCodePass::HardCodeONNX(const std::unique_ptr<CNodeT> &nod |
|
|
|
if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D) { |
|
|
|
weightTensor->format = schema::Format::Format_KCHW; |
|
|
|
} else if (opType == PrimitiveType_DeConv2D) { |
|
|
|
weightTensor->format = schema::Format::Format_CKHW; |
|
|
|
weightTensor->format = schema::Format::Format_KCHW; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; |
|
|
|
return RET_ERROR; |
|
|
|
|