|
|
|
@@ -50,7 +50,7 @@ void WeightFormatPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = |
|
|
|
// pre set tensor format |
|
|
|
// non quant, filterFormat: |
|
|
|
// conv deconv depth dedepth |
|
|
|
// caffe K(C/g)HW C(K/g)HW / / // todo with deconvOp |
|
|
|
// caffe K(C/g)HW C(K/g)HW / / |
|
|
|
// tf HWCK HWKC HWCK HWKC |
|
|
|
// onnx K(C/g)HW C(K/g)HW / / |
|
|
|
|
|
|
|
@@ -78,7 +78,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { |
|
|
|
if (fmkType == converter::FmkType_CAFFE) { |
|
|
|
switch (node->quantType) { |
|
|
|
case QuantType_QUANT_NONE: { |
|
|
|
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { |
|
|
|
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D || |
|
|
|
opType == schema::PrimitiveType_DeConv2D) { |
|
|
|
weightTensor->format = schema::Format_KCHW; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType) |
|
|
|
@@ -227,7 +228,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { |
|
|
|
STATUS status; |
|
|
|
if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK |
|
|
|
if (weightTensor->format == schema::Format_KCHW) { // from caffe |
|
|
|
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { |
|
|
|
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { |
|
|
|
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format |
|
|
|
<< weightTensor->dataType; |
|
|
|
status = TransFilterFormat<int8_t>(weightTensor.get(), kKCHW2HWCK); |
|
|
|
@@ -237,7 +238,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { |
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK); |
|
|
|
} |
|
|
|
} else if (weightTensor->format == schema::Format_KHWC) { // from onnx |
|
|
|
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { |
|
|
|
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { |
|
|
|
status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK); |
|
|
|
} else { |
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK); |
|
|
|
@@ -259,7 +260,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { |
|
|
|
} |
|
|
|
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be HWCK |
|
|
|
if (weightTensor->format == schema::Format_CKHW) { // from caffe |
|
|
|
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { |
|
|
|
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { |
|
|
|
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format, |
|
|
|
weightTensor->dataType; |
|
|
|
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2HWCK); |
|
|
|
@@ -272,13 +273,13 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { |
|
|
|
} else if (weightTensor->format == schema::Format_HWCK) { // from tf |
|
|
|
return 0; |
|
|
|
} else if (weightTensor->format == schema::Format_CHWK) { // from onnx |
|
|
|
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { |
|
|
|
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { |
|
|
|
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2HWCK); |
|
|
|
} else { |
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK); |
|
|
|
} |
|
|
|
} else if (weightTensor->format == schema::Format_KCHW) { |
|
|
|
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { |
|
|
|
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { |
|
|
|
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKCHW2HWCK); |
|
|
|
} else { |
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK); |
|
|
|
@@ -365,7 +366,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { |
|
|
|
// todo(00445839): consider varible weight condition |
|
|
|
} |
|
|
|
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC |
|
|
|
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms |
|
|
|
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms |
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); |
|
|
|
} else if (weightTensor->format == schema::Format_CHWK) { // from tf |
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); |
|
|
|
|