|
|
|
@@ -154,7 +154,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { |
|
|
|
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { |
|
|
|
weightTensor->format = schema::Format_CHWK; |
|
|
|
} else if (opType == schema::PrimitiveType_DeConv2D) { |
|
|
|
weightTensor->format = schema::Format_KHWC; |
|
|
|
weightTensor->format = schema::Format_CHWK; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "unsupport format"; |
|
|
|
return -1; |
|
|
|
@@ -367,8 +367,8 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { |
|
|
|
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC |
|
|
|
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms |
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); |
|
|
|
} else if (weightTensor->format == schema::Format_KHWC) { // from tf |
|
|
|
status = RET_OK; |
|
|
|
} else if (weightTensor->format == schema::Format_CHWK) { // from tf |
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; |
|
|
|
return -1; |
|
|
|
@@ -390,7 +390,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { |
|
|
|
return -1; |
|
|
|
} |
|
|
|
if (status == 0) { |
|
|
|
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW; |
|
|
|
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; |
|
|
|
weightTensor->format = schema::Format_CKHW; |
|
|
|
} else { |
|
|
|
MS_LOG(WARNING) << "TransFilter HWKCToCKHW failed, node : " << node->name.c_str(); |
|
|
|
|