|
|
|
@@ -153,6 +153,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { |
|
|
|
weightTensor->format = schema::Format_KHWC; |
|
|
|
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { |
|
|
|
weightTensor->format = schema::Format_CHWK; |
|
|
|
} else if (opType == schema::PrimitiveType_DeConv2D) { |
|
|
|
weightTensor->format = schema::Format_KHWC; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "unsupport format"; |
|
|
|
return -1; |
|
|
|
@@ -356,18 +358,18 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { |
|
|
|
MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str(); |
|
|
|
// todo(00445839): consider varible weight condition |
|
|
|
} |
|
|
|
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KCHW |
|
|
|
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx |
|
|
|
return 0; |
|
|
|
} else if (weightTensor->format == schema::Format_HWKC) { // from tf |
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kHWKC2KCHW); |
|
|
|
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC |
|
|
|
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms |
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); |
|
|
|
} else if (weightTensor->format == schema::Format_KHWC) { // from tf |
|
|
|
status = RET_OK; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; |
|
|
|
return -1; |
|
|
|
} |
|
|
|
if (status == 0) { |
|
|
|
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW; |
|
|
|
weightTensor->format = schema::Format_KCHW; |
|
|
|
node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW; |
|
|
|
weightTensor->format = schema::Format_KHWC; |
|
|
|
} else { |
|
|
|
MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str(); |
|
|
|
// todo(00445839): consider varible weight condition |
|
|
|
|