| @@ -50,7 +50,7 @@ void WeightFormatPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = | |||||
| // pre set tensor format | // pre set tensor format | ||||
| // non quant, filterFormat: | // non quant, filterFormat: | ||||
| // conv deconv depth dedepth | // conv deconv depth dedepth | ||||
| // caffe K(C/g)HW C(K/g)HW / / // todo with deconvOp | |||||
| // caffe K(C/g)HW C(K/g)HW / / | |||||
| // tf HWCK HWKC HWCK HWKC | // tf HWCK HWKC HWCK HWKC | ||||
| // onnx K(C/g)HW C(K/g)HW / / | // onnx K(C/g)HW C(K/g)HW / / | ||||
| @@ -78,7 +78,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||||
| if (fmkType == converter::FmkType_CAFFE) { | if (fmkType == converter::FmkType_CAFFE) { | ||||
| switch (node->quantType) { | switch (node->quantType) { | ||||
| case QuantType_QUANT_NONE: { | case QuantType_QUANT_NONE: { | ||||
| if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { | |||||
| if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D || | |||||
| opType == schema::PrimitiveType_DeConv2D) { | |||||
| weightTensor->format = schema::Format_KCHW; | weightTensor->format = schema::Format_KCHW; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType) | MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType) | ||||
| @@ -227,7 +228,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||||
| STATUS status; | STATUS status; | ||||
| if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK | if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK | ||||
| if (weightTensor->format == schema::Format_KCHW) { // from caffe | if (weightTensor->format == schema::Format_KCHW) { // from caffe | ||||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||||
| MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format | MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format | ||||
| << weightTensor->dataType; | << weightTensor->dataType; | ||||
| status = TransFilterFormat<int8_t>(weightTensor.get(), kKCHW2HWCK); | status = TransFilterFormat<int8_t>(weightTensor.get(), kKCHW2HWCK); | ||||
| @@ -237,7 +238,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK); | status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK); | ||||
| } | } | ||||
| } else if (weightTensor->format == schema::Format_KHWC) { // from onnx | } else if (weightTensor->format == schema::Format_KHWC) { // from onnx | ||||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||||
| status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK); | status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK); | ||||
| } else { | } else { | ||||
| status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK); | status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK); | ||||
| @@ -259,7 +260,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||||
| } | } | ||||
| } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be HWCK | } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be HWCK | ||||
| if (weightTensor->format == schema::Format_CKHW) { // from caffe | if (weightTensor->format == schema::Format_CKHW) { // from caffe | ||||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||||
| MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format, | MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format, | ||||
| weightTensor->dataType; | weightTensor->dataType; | ||||
| status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2HWCK); | status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2HWCK); | ||||
| @@ -272,13 +273,13 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||||
| } else if (weightTensor->format == schema::Format_HWCK) { // from tf | } else if (weightTensor->format == schema::Format_HWCK) { // from tf | ||||
| return 0; | return 0; | ||||
| } else if (weightTensor->format == schema::Format_CHWK) { // from onnx | } else if (weightTensor->format == schema::Format_CHWK) { // from onnx | ||||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||||
| status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2HWCK); | status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2HWCK); | ||||
| } else { | } else { | ||||
| status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK); | status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK); | ||||
| } | } | ||||
| } else if (weightTensor->format == schema::Format_KCHW) { | } else if (weightTensor->format == schema::Format_KCHW) { | ||||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||||
| status = TransFilterFormat<uint8_t>(weightTensor.get(), kKCHW2HWCK); | status = TransFilterFormat<uint8_t>(weightTensor.get(), kKCHW2HWCK); | ||||
| } else { | } else { | ||||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK); | status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK); | ||||
| @@ -365,7 +366,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { | |||||
| // todo(00445839): consider varible weight condition | // todo(00445839): consider varible weight condition | ||||
| } | } | ||||
| } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC | } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC | ||||
| if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms | |||||
| if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms | |||||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | ||||
| } else if (weightTensor->format == schema::Format_CHWK) { // from tf | } else if (weightTensor->format == schema::Format_CHWK) { // from tf | ||||
| status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); | status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); | ||||