diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc index 5e20788188..65e09881e5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc @@ -231,9 +231,23 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector const mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); + auto *weight_tensor = inputs.at(kWeightIndex); + auto *restore_data = weight_tensor->MutableData(); + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); + if (dequant_weight == nullptr) { + MS_LOG(ERROR) << "dequant data is nullptr."; + return nullptr; + } + weight_tensor->SetData(dequant_weight); + } auto kernel = new (std::nothrow) kernel::DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } auto ret = kernel->Init(); @@ -241,8 +255,18 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector delete kernel; MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } + + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } + return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc index 79814f2b2e..cb0b7bf6ae 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc @@ -199,10 +199,24 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vectorMutableData(); + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); + if (dequant_weight == nullptr) { + MS_LOG(ERROR) << "dequant data is nullptr."; + return nullptr; + } + weight_tensor->SetData(dequant_weight); + } auto kernel = new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } auto ret = kernel->Init(); @@ -210,8 +224,16 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vectorname_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return kernel; } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc index ee26865341..0cb13c5cd4 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc @@ -53,16 +53,19 @@ STATUS WeightFormatTransformPass::QuantDataFormatTrans(MetaGraphT *graph) { MS_ASSERT(node != nullptr); MS_ASSERT(node->primitive != nullptr); auto opType = node->primitive->value.type; - if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D) { + if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D && + opType != PrimitiveType_DeConv2D && opType != PrimitiveType_DeDepthwiseConv2D) { continue; } MS_ASSERT(node->inputIndex.size() >= 2); auto weightIndex = node->inputIndex.at(1); MS_ASSERT(subGraph->allTensors.size() > weightIndex); auto &weightTensor = graph->allTensors[weightIndex]; - MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT); + MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT || + weightTensor->dataType == DataType_DT_INT8); STATUS status; - if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D) { // weight should be HWCK + if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D || + opType == PrimitiveType_DeConv2D || opType == PrimitiveType_DeDepthwiseConv2D) { // weight should be HWCK Format curDstFormat; if (this->dstFormat == Format_NUM_OF_FORMAT) { curDstFormat = Format_KHWC; diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index bd41b78072..5c1eac363a 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -80,7 +80,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co return nullptr; } - status = ParseLayer(proto, weight, &tensorCache, metaGraph.get()); + status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quantType); if (status != RET_OK) { MS_LOG(ERROR) << "ParseLayer failed " << status; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); @@ -177,7 +177,8 @@ STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, T } STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight, - TensorCache *tensorCache, schema::MetaGraphT *subGraphDef) { + TensorCache *tensorCache, schema::MetaGraphT *subGraphDef, + const QuantType &quantType) { for (int i = 0; i < proto.layer_size(); i++) { auto layer = proto.layer(i); @@ -214,7 +215,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff std::unique_ptr op = std::make_unique(); op->name = layer.name(); - + op->quantType = quantType; if (layer.type() == "Split") { for (int j = 0; j < layer.top_size(); ++j) { splitLayer.emplace(layer.top(j), layer.bottom(0)); diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h index 4b69a363bb..f15d291799 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h @@ -50,7 +50,7 @@ class CaffeModelParser : public ModelParser { schema::MetaGraphT *subGraphDef); STATUS ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight, TensorCache *tensorCache, - schema::MetaGraphT *subGraphDef); + schema::MetaGraphT *subGraphDef, const QuantType &quantType); STATUS GetModelInput(const caffe::NetParameter &proto, TensorCache *tensorCache); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 4680d387f2..7f22dd07a4 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -247,9 +247,10 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, schema::TensorT *dst_tensor, - TensorCache *tensor_cache) { + TensorCache *tensor_cache, const QuantType &quantType) { // change op_type() to name(), that is unique dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); + dst_op->quantType = quantType; // dst_op->fmkType = FmkType_ONNX; MS_LOG(DEBUG) << "onnx op name " << onnx_node.op_type() << ", dst op name: " << dst_op->name << ", input size " << onnx_node.input_size(); @@ -520,7 +521,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con std::unique_ptr dst_op = std::make_unique(); std::unique_ptr dst_tensor = std::make_unique(); - status = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache); + status = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType); if (status != RET_OK) { MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index a53229e949..7b7b952a8c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -61,7 +61,8 @@ class OnnxModelParser : public ModelParser { TensorCache *tensor_cache, int *index); STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache); + schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache, + const QuantType &quantType); void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::MetaGraphT *graph, TensorCache *tensor_cache); diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 88c1594d58..6a0ea9b47b 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -32,22 +32,24 @@ using std::vector; namespace mindspore { namespace lite { namespace quant { -const std::array QuantStrategy::mConvTypes = { - {"Conv2D", "DeConv2D", "DepthwiseConv2D", "DeDepthwiseConv2D"}}; -const std::array QuantStrategy::mMulTypes = {{"Mul", "MatMul", "BatchMatMul", "FullConnection"}}; - +const std::vector QuantStrategy::conv_types = { + schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DeDepthwiseConv2D, + schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; +const std::vector QuantStrategy::mul_types = { + schema::PrimitiveType_Mul, schema::PrimitiveType_MatMul, schema::PrimitiveType_FullConnection}; QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThreshold) : mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {} bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const { - size_t i = 0; - for (i = 0; i < mConvTypes.size(); i++) { - if (node->fullname_with_scope().find(mConvTypes[i]) == 0) { - break; - } + auto primitive_c = GetValueNode>(node->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + return false; } - - if ((i == mConvTypes.size()) || (node->size() < 3)) { + if (!IsContain(conv_types, (schema::PrimitiveType)primitive_c->Type())) { + return false; + } + if (node->size() < 3) { return false; } @@ -107,13 +109,13 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { } bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const { - size_t i = 0; - for (i = 0; i < mMulTypes.size(); i++) { - if (node->fullname_with_scope().find(mMulTypes[i]) == 0) { - break; - } + auto primitive_c = GetValueNode>(node->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + return false; } - if (i == mMulTypes.size()) { + + if (!IsContain(mul_types, (schema::PrimitiveType)primitive_c->Type())) { return false; } diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index e64e40bb9d..1c90eb7b03 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -57,9 +57,8 @@ class QuantStrategy { private: size_t mWeightSize; size_t mConvWeightQuantChannelThreshold; - - static const std::array mConvTypes; - static const std::array mMulTypes; + static const std::vector conv_types; + static const std::vector mul_types; }; STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max, diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index de07c4ef2b..f1861be3a0 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -69,13 +69,9 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list &nodes) { std::vector quant_params; primitive_c->AddInputQuantParam(quant_params); - - auto op_type = (schema::PrimitiveType)primitive_c->Type(); - bool depthwise = op_type == schema::PrimitiveType_DepthwiseConv2D ? true : false; - auto status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, - quant_max, quant_min, bitNum, true, depthwise); + quant_max, quant_min, bitNum, true, false); if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed : " << status; return status;