|
|
|
@@ -79,7 +79,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { |
|
|
|
switch (node->quantType) { |
|
|
|
case QuantType_QUANT_NONE: { |
|
|
|
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D || |
|
|
|
opType == schema::PrimitiveType_DeConv2D) { |
|
|
|
opType == schema::PrimitiveType_DeConv2D || opType == schema::PrimitiveType_DeDepthwiseConv2D) { |
|
|
|
weightTensor->format = schema::Format_KCHW; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType) |
|
|
|
@@ -240,11 +240,11 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { |
|
|
|
} |
|
|
|
} else if (weightTensor->format == schema::Format_KHWC) { // from onnx |
|
|
|
return RET_OK; |
|
|
|
// if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { |
|
|
|
// status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK); |
|
|
|
// } else { |
|
|
|
// status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK); |
|
|
|
// } |
|
|
|
// if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { |
|
|
|
// status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK); |
|
|
|
// } else { |
|
|
|
// status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK); |
|
|
|
// } |
|
|
|
} else if (weightTensor->format == schema::Format_HWCK) { // from tf |
|
|
|
return 0; |
|
|
|
} else { |
|
|
|
@@ -275,7 +275,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { |
|
|
|
} else if (weightTensor->format == schema::Format_HWCK) { // from tf |
|
|
|
return 0; |
|
|
|
} else if (weightTensor->format == schema::Format_CHWK) { // from onnx |
|
|
|
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { |
|
|
|
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { |
|
|
|
status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC); |
|
|
|
} else { |
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK); |
|
|
|
@@ -383,9 +383,11 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { |
|
|
|
MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str(); |
|
|
|
// todo(00445839): consider varible weight condition |
|
|
|
} |
|
|
|
} else if (opType == schema::PrimitiveType_DeDepthwiseConv2D) { // weight should be CKHW |
|
|
|
if (weightTensor->format == schema::Format_CKHW) { // from caffe |
|
|
|
} else if (opType == schema::PrimitiveType_DeDepthwiseConv2D) { // weight should be KHWC |
|
|
|
if (weightTensor->format == schema::Format_KHWC) { |
|
|
|
return 0; |
|
|
|
} else if (weightTensor->format == schema::Format_KCHW) { // from caffe |
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); |
|
|
|
} else if (weightTensor->format == schema::Format_HWKC) { // from tf or onnx |
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kHWKC2CKHW); |
|
|
|
} else { |
|
|
|
@@ -393,7 +395,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { |
|
|
|
return -1; |
|
|
|
} |
|
|
|
if (status == 0) { |
|
|
|
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; |
|
|
|
node->primitive->value.AsDeDepthwiseConv2D()->format = schema::Format_NHWC; |
|
|
|
weightTensor->format = schema::Format_CKHW; |
|
|
|
} else { |
|
|
|
MS_LOG(WARNING) << "TransFilter HWKCToCKHW failed, node : " << node->name.c_str(); |
|
|
|
|