Merge pull request !4668 from xutianchun/quant_chtags/v0.7.0-beta
| @@ -47,7 +47,6 @@ class PrimitiveTValue : public Value { | |||||
| } | } | ||||
| } | } | ||||
| void SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) { | void SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) { | ||||
| this->input_quant_param_ = input_quant_param; | this->input_quant_param_ = input_quant_param; | ||||
| } | } | ||||
| @@ -56,6 +55,10 @@ class PrimitiveTValue : public Value { | |||||
| this->output_quant_param_ = output_quant_param; | this->output_quant_param_ = output_quant_param; | ||||
| } | } | ||||
| void ClearInputOutputQuantParam() { | |||||
| input_quant_param_.clear(); | |||||
| output_quant_param_.clear(); | |||||
| } | |||||
| void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) { | void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) { | ||||
| this->input_quant_param_.emplace_back(quant_param); | this->input_quant_param_.emplace_back(quant_param); | ||||
| @@ -25,16 +25,6 @@ | |||||
| #include "ir/dtype/type_id.h" | #include "ir/dtype/type_id.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| struct AnfQuantParam { | |||||
| double scale; | |||||
| int32_t zeroPoint; | |||||
| double min; | |||||
| double max; | |||||
| bool narrowRange; | |||||
| bool inited; | |||||
| int32_t numBits; | |||||
| AnfQuantParam() : scale(1.0), zeroPoint(0), min(0.0), max(0.0), narrowRange(false), numBits(8), inited(false) {} | |||||
| }; | |||||
| class ParamValueLite : public Value { | class ParamValueLite : public Value { | ||||
| public: | public: | ||||
| ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {} | ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {} | ||||
| @@ -59,10 +49,6 @@ class ParamValueLite : public Value { | |||||
| } | } | ||||
| return size; | return size; | ||||
| } | } | ||||
| std::vector<std::unique_ptr<AnfQuantParam>> &quant_param() { return quant_params_; } | |||||
| void set_quant_param(std::unique_ptr<AnfQuantParam> &quant_param) { | |||||
| quant_params_.emplace_back(std::move(quant_param)); | |||||
| } | |||||
| bool operator==(const Value &other) const override { | bool operator==(const Value &other) const override { | ||||
| return this == &other; | return this == &other; | ||||
| @@ -73,7 +59,6 @@ class ParamValueLite : public Value { | |||||
| size_t tensor_size_; | size_t tensor_size_; | ||||
| std::vector<int> tensor_shape_; | std::vector<int> tensor_shape_; | ||||
| TypeId type_id_; | TypeId type_id_; | ||||
| std::vector<std::unique_ptr<AnfQuantParam>> quant_params_; | |||||
| }; | }; | ||||
| using ParamValueLitePtr = std::shared_ptr<ParamValueLite>; | using ParamValueLitePtr = std::shared_ptr<ParamValueLite>; | ||||
| @@ -159,16 +159,6 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||||
| primitive->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { | primitive->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { | ||||
| tensor_output->dataType = kNumberTypeInt8; | tensor_output->dataType = kNumberTypeInt8; | ||||
| } | } | ||||
| // // TensorType | |||||
| // valuePtr = primitive->GetAttr(kInputTensorDataType); | |||||
| // if (valuePtr != nullptr) { | |||||
| // MS_LOG(INFO) << "node: " << node->name << " input tensor data | |||||
| // type: " << GetValue<int>(valuePtr); for (auto input : | |||||
| // node->inputIndex) { | |||||
| // auto tensor = subGraph->allTensors[input].get(); | |||||
| // tensor->dataType = kNumberTypeUInt8; | |||||
| // } | |||||
| // } | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -295,18 +285,6 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anod | |||||
| paramTensor->nodeType = schema::NodeType_ValueNode; | paramTensor->nodeType = schema::NodeType_ValueNode; | ||||
| paramTensor->data.resize(paramValue->tensor_size()); | paramTensor->data.resize(paramValue->tensor_size()); | ||||
| memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); | memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); | ||||
| for (auto &ite : paramValue->quant_param()) { | |||||
| auto quantPar = std::make_unique<schema::QuantParamT>(); | |||||
| quantPar->scale = ite->scale; | |||||
| quantPar->zeroPoint = ite->zeroPoint; | |||||
| quantPar->min = ite->zeroPoint; | |||||
| quantPar->max = ite->max; | |||||
| quantPar->narrowRange = ite->narrowRange; | |||||
| quantPar->inited = ite->inited; | |||||
| quantPar->numBits = ite->numBits; | |||||
| paramTensor->quantParams.emplace_back(std::move(quantPar)); | |||||
| paramTensor->dataType = paramValue->tensor_type(); | |||||
| } | |||||
| } | } | ||||
| node_id_map_[paramNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | node_id_map_[paramNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | ||||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | ||||
| @@ -61,17 +61,17 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { | |||||
| param_value->set_tensor_addr(tensor_data); | param_value->set_tensor_addr(tensor_data); | ||||
| param_value->set_tensor_size(size); | param_value->set_tensor_size(size); | ||||
| } | } | ||||
| if (!tensor->quantParams.empty()) { | |||||
| std::unique_ptr<AnfQuantParam> quantParam = std::make_unique<AnfQuantParam>(); | |||||
| quantParam->scale = tensor->quantParams[0]->scale; | |||||
| quantParam->zeroPoint = tensor->quantParams[0]->zeroPoint; | |||||
| quantParam->min = tensor->quantParams[0]->min; | |||||
| quantParam->max = tensor->quantParams[0]->max; | |||||
| quantParam->narrowRange = tensor->quantParams[0]->narrowRange; | |||||
| quantParam->numBits = tensor->quantParams[0]->numBits; | |||||
| quantParam->inited = tensor->quantParams[0]->inited; | |||||
| param_value->set_quant_param(quantParam); | |||||
| } | |||||
| // if (!tensor->quantParams.empty()) { | |||||
| // std::unique_ptr<AnfQuantParam> quantParam = std::make_unique<AnfQuantParam>(); | |||||
| // quantParam->scale = tensor->quantParams[0]->scale; | |||||
| // quantParam->zeroPoint = tensor->quantParams[0]->zeroPoint; | |||||
| // quantParam->min = tensor->quantParams[0]->min; | |||||
| // quantParam->max = tensor->quantParams[0]->max; | |||||
| // quantParam->narrowRange = tensor->quantParams[0]->narrowRange; | |||||
| // quantParam->numBits = tensor->quantParams[0]->numBits; | |||||
| // quantParam->inited = tensor->quantParams[0]->inited; | |||||
| // param_value->set_quant_param(quantParam); | |||||
| // } | |||||
| parameter->set_default_param(param_value); | parameter->set_default_param(param_value); | ||||
| AddNode(i, parameter); | AddNode(i, parameter); | ||||
| } | } | ||||
| @@ -31,7 +31,6 @@ | |||||
| #include "tools/anf_exporter/anf_exporter.h" | #include "tools/anf_exporter/anf_exporter.h" | ||||
| #include "tools/anf_importer/import_from_protobuf.h" | #include "tools/anf_importer/import_from_protobuf.h" | ||||
| #include "tools/converter/parser/onnx/onnx.pb.h" | #include "tools/converter/parser/onnx/onnx.pb.h" | ||||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||||
| #include "tools/converter/quantizer/post_training_quantizer.h" | #include "tools/converter/quantizer/post_training_quantizer.h" | ||||
| #include "tools/converter/quantizer/quant_cast.h" | #include "tools/converter/quantizer/quant_cast.h" | ||||
| @@ -7,7 +7,6 @@ add_library(quantizer_mid OBJECT | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc | ${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc | ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc | ${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc | ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc | ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc | ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc | ||||
| @@ -510,7 +510,6 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct M | |||||
| quant_param.narrowRange = false; | quant_param.narrowRange = false; | ||||
| std::vector<schema::QuantParamT> quant_params = {quant_param}; | std::vector<schema::QuantParamT> quant_params = {quant_param}; | ||||
| lite_primitive->AddInputQuantParam(quant_params); | lite_primitive->AddInputQuantParam(quant_params); | ||||
| // p->AddAttr("quant_input_dataType", MakeValue((int)DataType_DT_FLOAT)); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -528,51 +527,67 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct | |||||
| quant_param.narrowRange = false; | quant_param.narrowRange = false; | ||||
| std::vector<schema::QuantParamT> quant_params = {quant_param}; | std::vector<schema::QuantParamT> quant_params = {quant_param}; | ||||
| lite_primitive->AddOutputQuantParam(quant_params); | lite_primitive->AddOutputQuantParam(quant_params); | ||||
| // p->AddAttr("quant_output_dataType", MakeValue((int)DataType_DT_FLOAT)); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr node) { | |||||
| STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value, | |||||
| bool depthwise) { | |||||
| // const vector<int> dims = filter->dims; | // const vector<int> dims = filter->dims; | ||||
| // perlayer | // perlayer | ||||
| if (!node->isa<Parameter>()) { | |||||
| if (!weight->isa<Parameter>()) { | |||||
| MS_LOG(ERROR) << "not a parameter"; | MS_LOG(ERROR) << "not a parameter"; | ||||
| return RET_PARAM_INVALID; | return RET_PARAM_INVALID; | ||||
| } | } | ||||
| auto parameter = std::dynamic_pointer_cast<Parameter>(node); | |||||
| auto parameter = std::dynamic_pointer_cast<Parameter>(weight); | |||||
| ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param()); | ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param()); | ||||
| auto status = QuantFilter(paramValue, QuantType_PostTraining, quant_max, quant_min, bit_num, per_channel_); | |||||
| auto status = QuantFilter(paramValue, primitiveT_value, QuantType_PostTraining, quant_max, quant_min, bit_num, | |||||
| per_channel_, depthwise); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "QuantFilter failed: " << status; | MS_LOG(ERROR) << "QuantFilter failed: " << status; | ||||
| return status; | return status; | ||||
| } | } | ||||
| // set dtype | |||||
| auto abstractBase = parameter->abstract(); | |||||
| if (abstractBase == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << parameter->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << parameter->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||||
| abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt8)); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr<PrimitiveTValue> input, AnfNodePtr weight, AnfNodePtr bias) { | |||||
| if (input == nullptr || weight == nullptr || bias == nullptr) { | |||||
| STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveTValue> primitiveT_value) { | |||||
| if (primitiveT_value == nullptr || bias == nullptr) { | |||||
| MS_LOG(ERROR) << "null pointer!"; | MS_LOG(ERROR) << "null pointer!"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| ParameterPtr weightParameterPtr = std::dynamic_pointer_cast<Parameter>(weight); | |||||
| auto default_param = weightParameterPtr->default_param(); | |||||
| auto weight_param = std::dynamic_pointer_cast<ParamValueLite>(default_param); | |||||
| // std::vector<std::unique_ptr<mindspore::QuantParamT>> weight_quant_params = weight_param->get_quant_params(); | |||||
| ParameterPtr biasParameterPtr = std::dynamic_pointer_cast<Parameter>(bias); | |||||
| auto bias_default_param = biasParameterPtr->default_param(); | |||||
| auto bias_parameter_ptr = std::dynamic_pointer_cast<Parameter>(bias); | |||||
| auto bias_default_param = bias_parameter_ptr->default_param(); | |||||
| auto bias_param = std::dynamic_pointer_cast<ParamValueLite>(bias_default_param); | auto bias_param = std::dynamic_pointer_cast<ParamValueLite>(bias_default_param); | ||||
| auto active_weight_quant_params = primitiveT_value->GetInputQuantParams(); | |||||
| if (active_weight_quant_params.size() != 2) { | |||||
| MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto active_params = active_weight_quant_params[0]; | |||||
| auto weight_params = active_weight_quant_params[1]; | |||||
| vector<double> input_scales; | vector<double> input_scales; | ||||
| vector<double> filter_scales; | vector<double> filter_scales; | ||||
| vector<double> bias_scales; | vector<double> bias_scales; | ||||
| auto quant_params = input->GetInputQuantParams(); | |||||
| size_t sizeX = quant_params.size(); | |||||
| size_t sizeX = active_params.size(); | |||||
| for (size_t i = 0; i < sizeX; i++) { | for (size_t i = 0; i < sizeX; i++) { | ||||
| input_scales.emplace_back(quant_params[i].front().scale); | |||||
| input_scales.emplace_back(active_params[i].scale); | |||||
| } | } | ||||
| size_t sizeY = weight_param->quant_param().size(); | |||||
| size_t sizeY = weight_params.size(); | |||||
| if (sizeX != sizeY) { | if (sizeX != sizeY) { | ||||
| if (sizeX > 1 && sizeY > 1) { | if (sizeX > 1 && sizeY > 1) { | ||||
| MS_LOG(ERROR) << "input and filter's scale count cannot match!"; | MS_LOG(ERROR) << "input and filter's scale count cannot match!"; | ||||
| @@ -580,8 +595,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr<PrimitiveTValue> input | |||||
| } | } | ||||
| } | } | ||||
| for (size_t i = 0; i < sizeY; i++) { | for (size_t i = 0; i < sizeY; i++) { | ||||
| auto scale = weight_param->quant_param()[i]->scale; | |||||
| filter_scales.push_back(scale); | |||||
| filter_scales.emplace_back(weight_params[i].scale); | |||||
| } | } | ||||
| size_t size = std::max(sizeX, sizeY); | size_t size = std::max(sizeX, sizeY); | ||||
| for (size_t i = 0; i < size; i++) { | for (size_t i = 0; i < size; i++) { | ||||
| @@ -593,20 +607,22 @@ STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr<PrimitiveTValue> input | |||||
| size_t shape_size = bias_param->tensor_shape_size(); | size_t shape_size = bias_param->tensor_shape_size(); | ||||
| // set bias quant param | // set bias quant param | ||||
| bias_param->quant_param().clear(); | |||||
| vector<schema::QuantParamT> quant_params; | |||||
| for (size_t i = 0; i < bias_scales.size(); i++) { | for (size_t i = 0; i < bias_scales.size(); i++) { | ||||
| std::unique_ptr<AnfQuantParam> param(new (std::nothrow) AnfQuantParam()); | |||||
| param->scale = bias_scales[i]; | |||||
| param->zeroPoint = 0; | |||||
| bias_param->quant_param().emplace_back(std::move(param)); | |||||
| schema::QuantParamT quant_param; | |||||
| quant_param.scale = bias_scales[i]; | |||||
| quant_param.zeroPoint = 0; | |||||
| quant_param.inited = true; | |||||
| quant_params.emplace_back(quant_param); | |||||
| } | } | ||||
| primitiveT_value->AddInputQuantParam(quant_params); | |||||
| // quant bias data | // quant bias data | ||||
| int32_t *quant_datas = new (std::nothrow) int32_t[shape_size]; | int32_t *quant_datas = new (std::nothrow) int32_t[shape_size]; | ||||
| if (quant_datas == nullptr) { | if (quant_datas == nullptr) { | ||||
| MS_LOG(ERROR) << "null pointer dereferencing."; | MS_LOG(ERROR) << "null pointer dereferencing."; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| float *raw_datas = reinterpret_cast<float *>(bias_param->tensor_addr()); | |||||
| float *raw_datas = static_cast<float *>(bias_param->tensor_addr()); | |||||
| double bias_scale_tmp; | double bias_scale_tmp; | ||||
| for (size_t i = 0; i < shape_size; i++) { | for (size_t i = 0; i < shape_size; i++) { | ||||
| if (bias_scales.size() == 1) { | if (bias_scales.size() == 1) { | ||||
| @@ -625,38 +641,21 @@ STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr<PrimitiveTValue> input | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| delete[] quant_datas; | delete[] quant_datas; | ||||
| bias_param->set_tensor_type(kNumberTypeInt32); | |||||
| // set dtype | |||||
| auto abstractBase = bias_parameter_ptr->abstract(); | |||||
| if (abstractBase == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << bias_parameter_ptr->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << bias_parameter_ptr->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||||
| abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt32)); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| // STATUS PostTrainingQuantizer::reformatConvWeight(GraphDefT *graph) { | |||||
| // for (auto &subGraph : graphDefT->subgraphs) { | |||||
| // for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end(); iter++) { | |||||
| // OpDefT *node = (*iter).get(); | |||||
| // bool isConv = false; | |||||
| // kTransFilterType tansType; | |||||
| // if ((*node).attr.type == OpT_Conv2D) { | |||||
| // tansType = kKCHW2HWCK; | |||||
| // isConv = true; | |||||
| // } | |||||
| // else if ((*node).attr.type == OpT_DepthwiseConv2D) { | |||||
| // tansType = kCKHW2HWCK; | |||||
| // isConv = true; | |||||
| // } | |||||
| // if (isConv) { | |||||
| // auto status = TransFilterFormat<uint8_t>(&(*subGraph.get()->allTensors.at(node->inputIndex[1])), | |||||
| // tansType); | |||||
| // if (status != RET_OK) { | |||||
| // return status; | |||||
| // } | |||||
| // TensorDefT *weight = subGraph->allTensors.at(node->inputIndex[1]).get(); | |||||
| // weight->format = Format_HWCK; | |||||
| // PostBitPack(weight, bitNum); | |||||
| // } | |||||
| // } | |||||
| // } | |||||
| //} | |||||
| STATUS PostTrainingQuantizer::QuantNode() { | STATUS PostTrainingQuantizer::QuantNode() { | ||||
| auto input_min_max = this->calibrator_->GetMinMax(this->calibrator_->GetInputDivergInfo()); | auto input_min_max = this->calibrator_->GetMinMax(this->calibrator_->GetInputDivergInfo()); | ||||
| auto input_scale = this->calibrator_->GetResult(this->calibrator_->GetInputDivergInfo()); | auto input_scale = this->calibrator_->GetResult(this->calibrator_->GetInputDivergInfo()); | ||||
| @@ -682,7 +681,7 @@ STATUS PostTrainingQuantizer::QuantNode() { | |||||
| primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE); | primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE); | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto input_vec = cnode->inputs(); | |||||
| primitiveT_value->ClearInputOutputQuantParam(); | |||||
| auto op_name = cnode->fullname_with_scope(); | auto op_name = cnode->fullname_with_scope(); | ||||
| auto op_type = primitiveT_value->GetPrimitiveT()->value.type; | auto op_type = primitiveT_value->GetPrimitiveT()->value.type; | ||||
| MS_LOG(INFO) << "OpName: " << op_name; | MS_LOG(INFO) << "OpName: " << op_name; | ||||
| @@ -711,11 +710,12 @@ STATUS PostTrainingQuantizer::QuantNode() { | |||||
| DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value); | DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value); | ||||
| // do weight quant | // do weight quant | ||||
| auto weight = cnode->input(2); | auto weight = cnode->input(2); | ||||
| DoWeightQuant(weight); | |||||
| bool depthwise = op_type == PrimitiveType_DeDepthwiseConv2D; | |||||
| DoWeightQuant(weight, primitiveT_value, depthwise); | |||||
| // do bias quant | // do bias quant | ||||
| if (cnode->inputs().size() == 4) { | if (cnode->inputs().size() == 4) { | ||||
| auto bias = cnode->input(3); | auto bias = cnode->input(3); | ||||
| DoBiasQuant(primitiveT_value, weight, bias); | |||||
| DoBiasQuant(bias, primitiveT_value); | |||||
| } | } | ||||
| } | } | ||||
| // do output quant | // do output quant | ||||
| @@ -65,8 +65,8 @@ class PostTrainingQuantizer : public Quantizer { | |||||
| STATUS DoQuantize(FuncGraphPtr funcGraph) override; | STATUS DoQuantize(FuncGraphPtr funcGraph) override; | ||||
| size_t bit_num; | size_t bit_num; | ||||
| int quant_max{127}; | |||||
| int quant_min{-128}; | |||||
| int quant_max{INT8_MAX}; | |||||
| int quant_min{INT8_MIN}; | |||||
| private: | private: | ||||
| bool per_channel_; | bool per_channel_; | ||||
| @@ -96,9 +96,9 @@ class PostTrainingQuantizer : public Quantizer { | |||||
| STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveTValue>); | STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveTValue>); | ||||
| STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveTValue>); | STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveTValue>); | ||||
| STATUS DoWeightQuant(AnfNodePtr node); | |||||
| STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value, bool depthwise); | |||||
| STATUS DoBiasQuant(std::shared_ptr<PrimitiveTValue> input, AnfNodePtr weight, AnfNodePtr bias); | |||||
| STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveTValue> primitiveT_value); | |||||
| }; | }; | ||||
| struct DivergInfo; | struct DivergInfo; | ||||
| @@ -99,7 +99,9 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { | |||||
| schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, | schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, | ||||
| schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, | schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, | ||||
| schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, | schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, | ||||
| schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape, | |||||
| schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ | |||||
| schema::PrimitiveType_Reshape, /*schema::PrimitiveType_FullConnection,*/ | |||||
| schema::PrimitiveType_MatMul, | |||||
| schema::PrimitiveType_Activation}; | schema::PrimitiveType_Activation}; | ||||
| return IsContain(uint8OpList, type); | return IsContain(uint8OpList, type); | ||||
| } | } | ||||
| @@ -191,7 +193,7 @@ void CalFakeNode(const AnfNodePtr &inTensor) { | |||||
| // } | // } | ||||
| } | } | ||||
| STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin, double mMax, bool narrowRange, | |||||
| STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, | |||||
| int quant_max, int quant_min, int num_bits) { | int quant_max, int quant_min, int num_bits) { | ||||
| MS_ASSERT(quantParam != nullptr); | MS_ASSERT(quantParam != nullptr); | ||||
| if (mMin > 0.0f) { | if (mMin > 0.0f) { | ||||
| @@ -306,133 +308,178 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum, | |||||
| bool per_channel) { | |||||
| auto dims = weightPtr->tensor_shape(); | |||||
| if (dims.size() != 4) { | |||||
| MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer."; | |||||
| per_channel = false; | |||||
| } else { | |||||
| uint32_t channels = dims[0]; | |||||
| if (channels == 0) { | |||||
| MS_LOG(ERROR) << "channels is 0"; | |||||
| return RET_ERROR; | |||||
| STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value, QuantType quantType, | |||||
| int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) { | |||||
| auto dims = weight->tensor_shape(); | |||||
| if (per_channel) { | |||||
| if (dims.size() != 4) { | |||||
| MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer."; | |||||
| per_channel = false; | |||||
| } else { | |||||
| uint32_t channels = dims[0]; | |||||
| if (channels == 0) { | |||||
| MS_LOG(ERROR) << "channels is 0"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| vector<schema::QuantParamT> quant_params; | |||||
| size_t elem_count = weight->tensor_shape_size(); | |||||
| auto *raw_datas = static_cast<float *>(weight->tensor_addr()); | |||||
| if (raw_datas == nullptr) { | |||||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| vector<int8_t> quant_datas(elem_count); | |||||
| if (per_channel) { | if (per_channel) { | ||||
| // notice: | // notice: | ||||
| // at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D | // at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D | ||||
| // if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK | // if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK | ||||
| size_t shapeSize = weightPtr->tensor_shape_size(); | |||||
| auto channels = dims[0]; | |||||
| size_t oneFilterSize = shapeSize / channels; | |||||
| auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr()); | |||||
| if (rawDatas == nullptr) { | |||||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| float min = FLT_MAX; | |||||
| float max = -FLT_MAX; | |||||
| weightPtr->quant_param().clear(); | |||||
| vector<int8_t> qDatas(shapeSize); | |||||
| for (uint32_t i = 0; i < channels; i++) { | |||||
| // find min and max | |||||
| for (uint32_t j = 0; j < oneFilterSize; j++) { | |||||
| auto index = j + i * channels; | |||||
| if (index >= shapeSize) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| if (depth_wise) { | |||||
| // channel at last | |||||
| auto channels = dims[3]; | |||||
| if (channels == 0) { | |||||
| MS_LOG(ERROR) << "channels is zero"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| size_t one_filter_size = elem_count / channels; | |||||
| for (uint32_t i = 0; i < channels; i++) { | |||||
| float min = FLT_MAX; | |||||
| float max = -FLT_MAX; | |||||
| // find min and max | |||||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||||
| auto index = i + j * channels; | |||||
| if (index >= elem_count) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| min = std::min(min, raw_datas[index]); | |||||
| max = std::max(max, raw_datas[index]); | |||||
| } | |||||
| schema::QuantParamT quant_param; | |||||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||||
| return status; | |||||
| } | } | ||||
| min = std::min(min, rawDatas[index]); | |||||
| max = std::max(max, rawDatas[index]); | |||||
| quant_params.emplace_back(quant_param); | |||||
| // do quantization | |||||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||||
| auto index = i + j * channels; | |||||
| if (index >= elem_count) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| float raw_data = raw_datas[index]; | |||||
| auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min); | |||||
| quant_datas[index] = quant_data; | |||||
| } | |||||
| } | |||||
| auto ret = memcpy_s(const_cast<float *>(raw_datas), weight->tensor_size(), quant_datas.data(), | |||||
| elem_count * sizeof(int8_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam); | |||||
| STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||||
| return status; | |||||
| if (quantType == QuantType_WeightQuant) { | |||||
| PostBitPack(const_cast<float *>(raw_datas), elem_count, bitNum); | |||||
| } | } | ||||
| // do quantization | |||||
| for (uint32_t j = 0; j < oneFilterSize; j++) { | |||||
| auto index = j + i * channels; | |||||
| if (index >= shapeSize) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| weight->set_tensor_size(elem_count * sizeof(int8_t)); | |||||
| } else { | |||||
| // channel at first | |||||
| auto channels = dims[0]; | |||||
| if (channels == 0) { | |||||
| MS_LOG(ERROR) << "channels is zero"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| size_t one_filter_size = elem_count / channels; | |||||
| for (uint32_t i = 0; i < channels; i++) { | |||||
| float min = FLT_MAX; | |||||
| float max = -FLT_MAX; | |||||
| // find min and max | |||||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||||
| auto index = j + i * one_filter_size; | |||||
| if (index >= elem_count) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| min = std::min(min, raw_datas[index]); | |||||
| max = std::max(max, raw_datas[index]); | |||||
| } | |||||
| schema::QuantParamT quant_param; | |||||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||||
| return status; | |||||
| } | |||||
| quant_params.emplace_back(quant_param); | |||||
| // do quantization | |||||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||||
| auto index = j + i * one_filter_size; | |||||
| if (index >= elem_count) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| float raw_data = raw_datas[index]; | |||||
| auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min); | |||||
| quant_datas[index] = quant_data; | |||||
| } | } | ||||
| float rawData = rawDatas[index]; | |||||
| auto qData = QuantizeData<int8_t>(rawData, quantParam.get(), quant_max, quant_min); | |||||
| qDatas[index] = qData; | |||||
| } | } | ||||
| weightPtr->set_quant_param(quantParam); | |||||
| } | |||||
| auto ret = | |||||
| memcpy_s(const_cast<float *>(rawDatas), weightPtr->tensor_size(), qDatas.data(), shapeSize * sizeof(int8_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (quantType == QuantType_WeightQuant) { | |||||
| PostBitPack(const_cast<float *>(rawDatas), shapeSize, bitNum); | |||||
| auto ret = | |||||
| memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (quantType == QuantType_WeightQuant) { | |||||
| PostBitPack(const_cast<float *>(raw_datas), elem_count, bitNum); | |||||
| } | |||||
| weight->set_tensor_size(elem_count * sizeof(int8_t)); | |||||
| } | } | ||||
| weightPtr->set_tensor_type(kNumberTypeInt8); | |||||
| weightPtr->set_tensor_size(shapeSize * sizeof(int8_t)); | |||||
| } else { | } else { | ||||
| // per layer | // per layer | ||||
| size_t shapeSize = weightPtr->tensor_shape_size(); | |||||
| auto *rawDatas = static_cast<float *>(weightPtr->tensor_addr()); | |||||
| if (rawDatas == nullptr) { | |||||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| weightPtr->quant_param().clear(); | |||||
| vector<int8_t> qDatas(shapeSize); | |||||
| float min = 0; | |||||
| float max = 0; | |||||
| for (uint32_t i = 0; i < shapeSize; i++) { | |||||
| float min = FLT_MAX; | |||||
| float max = -FLT_MIN; | |||||
| for (uint32_t i = 0; i < elem_count; i++) { | |||||
| // find max min | // find max min | ||||
| min = std::min(min, rawDatas[i]); | |||||
| max = std::max(max, rawDatas[i]); | |||||
| min = std::min(min, raw_datas[i]); | |||||
| max = std::max(max, raw_datas[i]); | |||||
| } | } | ||||
| std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam); | |||||
| STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); | |||||
| schema::QuantParamT quant_param; | |||||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | ||||
| return status; | return status; | ||||
| } | } | ||||
| quant_params.emplace_back(quant_param); | |||||
| // update data and datatype | // update data and datatype | ||||
| for (uint32_t i = 0; i < shapeSize; i++) { | |||||
| float rawData = rawDatas[i]; | |||||
| auto quant_data = std::round(rawData / quantParam->scale + quantParam->zeroPoint); | |||||
| if (quant_data > quant_max) { | |||||
| qDatas[i] = quant_max; | |||||
| } else if (quant_data < quant_min) { | |||||
| qDatas[i] = quant_min; | |||||
| } else { | |||||
| qDatas[i] = static_cast<int8_t>(quant_data); | |||||
| } | |||||
| for (uint32_t i = 0; i < elem_count; i++) { | |||||
| float raw_data = raw_datas[i]; | |||||
| auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min); | |||||
| quant_datas[i] = quant_data; | |||||
| } | } | ||||
| weightPtr->set_quant_param(quantParam); | |||||
| auto ret = memcpy_s(rawDatas, weightPtr->tensor_size(), qDatas.data(), shapeSize * sizeof(int8_t)); | |||||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); | |||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | MS_LOG(ERROR) << "memcpy error: " << ret; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (quantType == QuantType_WeightQuant) { | if (quantType == QuantType_WeightQuant) { | ||||
| PostBitPack(rawDatas, shapeSize, bitNum); | |||||
| PostBitPack(raw_datas, elem_count, bitNum); | |||||
| } | } | ||||
| weightPtr->set_tensor_type(kNumberTypeInt8); | |||||
| weightPtr->set_tensor_size(shapeSize * sizeof(int8_t)); | |||||
| weight->set_tensor_size(elem_count * sizeof(int8_t)); | |||||
| } | } | ||||
| if (quant_params.empty()) { | |||||
| MS_LOG(ERROR) << "quant_params empty"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| primitiveT_value->AddInputQuantParam(quant_params); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "abstract/dshape.h" | #include "abstract/dshape.h" | ||||
| #include "mindspore/lite/tools/converter/quantizer/quantizer.h" | #include "mindspore/lite/tools/converter/quantizer/quantizer.h" | ||||
| #include "mindspore/lite/src/ir/primitive_t_value.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -58,7 +59,7 @@ class QuantStrategy { | |||||
| static const std::array<std::string, 4> mMulTypes; | static const std::array<std::string, 4> mMulTypes; | ||||
| }; | }; | ||||
| STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin, double mMax, | |||||
| STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, | |||||
| bool narrowRange, int quant_max, int quant_min, int num_bits); | bool narrowRange, int quant_max, int quant_min, int num_bits); | ||||
| STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, | STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, | ||||
| @@ -97,12 +98,12 @@ T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| T QuantizeData(float originData, const AnfQuantParam *quantParam, int quant_max, int quant_min) { | |||||
| T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quant_max, int quant_min) { | |||||
| MS_ASSERT(quantParam != nullptr); | MS_ASSERT(quantParam != nullptr); | ||||
| MS_ASSERT(quantParam->inited); | MS_ASSERT(quantParam->inited); | ||||
| const auto scale = quantParam->scale; | |||||
| const int zeroPoint = quantParam->zeroPoint; | |||||
| const auto narrowRange = quantParam->narrowRange; | |||||
| const auto scale = quantParam.scale; | |||||
| const int zeroPoint = quantParam.zeroPoint; | |||||
| const auto narrowRange = quantParam.narrowRange; | |||||
| const int maxLimit = quant_max; | const int maxLimit = quant_max; | ||||
| const int minLimit = quant_min; | const int minLimit = quant_min; | ||||
| @@ -119,8 +120,9 @@ T QuantizeData(float originData, const AnfQuantParam *quantParam, int quant_max, | |||||
| void CalFakeNode(const AnfNodePtr &inTensor); | void CalFakeNode(const AnfNodePtr &inTensor); | ||||
| STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, | |||||
| size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false); | |||||
| STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value, QuantType quantType, | |||||
| int quant_max, int quant_min, size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false, | |||||
| bool depth_wise = false); | |||||
| STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); | STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); | ||||
| } // namespace quant | } // namespace quant | ||||
| @@ -1,150 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||||
| #include <list> | |||||
| #include <string> | |||||
| #include "src/common/common.h" | |||||
| #include "ir/dtype/type_id.h" | |||||
| using std::string; | |||||
| using std::vector; | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace quant { | |||||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, | |||||
| const std::string &convWeightChannelThreshold, const std::string &bitNum) | |||||
| : Quantizer(graph) { | |||||
| auto quantSize = static_cast<size_t>(std::stoull(weightSize)); | |||||
| this->bitNum = static_cast<size_t>(std::stoull(bitNum)); | |||||
| auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold)); | |||||
| // TODO(...): update stractory | |||||
| mStrategy.reset(new QuantStrategy(quantSize, convQuantWeightChannelThreshold)); | |||||
| } | |||||
| // uint32_t GetConvChannel(TensorDefT *weight) { | |||||
| // uint32_t channel = 0; | |||||
| // const vector<int> dims = weight->dims; | |||||
| // switch (weight->format) { | |||||
| // case Format_NCHW: | |||||
| // case Format_KCHW: | |||||
| // case Format_NC4HW4: | |||||
| // channel = static_cast<uint32_t>(dims[NCHW_N]); | |||||
| // break; | |||||
| // case Format_NHWC: | |||||
| // case Format_HWKC: | |||||
| // channel = static_cast<uint32_t>(dims[NHWC_N]); | |||||
| // break; | |||||
| // case Format_HWCK: | |||||
| // channel = static_cast<uint32_t>(dims[HWCK_K]); | |||||
| // break; | |||||
| // case Format_CKHW: | |||||
| // channel = static_cast<uint32_t>(dims[CKHW_K]); | |||||
| // break; | |||||
| // default: | |||||
| // MS_LOGE("Unsupported format: %d", weight->format); | |||||
| // return 0; | |||||
| // } | |||||
| // return channel; | |||||
| // } | |||||
| STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||||
| for (auto &cnode : nodes) { | |||||
| if (!mStrategy->CanConvOpQuantized(cnode)) { | |||||
| continue; | |||||
| } | |||||
| auto inputNode = cnode->input(2); | |||||
| if (!inputNode->isa<Parameter>()) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto paramNode = inputNode->cast<ParameterPtr>(); | |||||
| if (!paramNode->has_default()) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| ParamValueLitePtr paramValue = std::static_pointer_cast<ParamValueLite>(paramNode->default_param()); | |||||
| auto status = QuantFilter(paramValue, QuantType_WeightQuant, 127, -128, bitNum); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||||
| return status; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||||
| for (auto &node : nodes) { | |||||
| if (!mStrategy->CanMulOpQuantized(node)) { | |||||
| continue; | |||||
| } | |||||
| ParamValueLitePtr paramValue = nullptr; | |||||
| for (size_t i = 1; i < node->size(); i++) { | |||||
| auto inputNode = node->input(i); | |||||
| if (inputNode->isa<Parameter>() == true) { | |||||
| auto paramNode = inputNode->cast<ParameterPtr>(); | |||||
| if ((paramNode != nullptr) && (paramNode->has_default() == true)) { | |||||
| paramValue = std::static_pointer_cast<ParamValueLite>(paramNode->default_param()); | |||||
| if ((paramValue == nullptr) || (paramValue->tensor_size() == 0) | |||||
| || (paramValue->tensor_shape().size() != 4) | |||||
| || (paramValue->tensor_addr() == nullptr) | |||||
| || (paramValue->tensor_type() != mindspore::kNumberTypeFloat32)) { | |||||
| paramValue = nullptr; | |||||
| continue; | |||||
| } else { | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| if (paramValue == nullptr) { | |||||
| MS_LOG(ERROR) << "No valid input param node !"; | |||||
| continue; | |||||
| } | |||||
| auto status = QuantFilter(paramValue, QuantType_WeightQuant, 127, -128, bitNum); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QunatFilter failed" << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) { | |||||
| auto ret = RET_OK; | |||||
| auto cnodes = funcGraph->GetOrderedCnodes(); | |||||
| ret = DoConvQuantize(cnodes); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoConvQuantize failed :" << ret; | |||||
| return ret; | |||||
| } | |||||
| ret = DoMulQuantize(cnodes); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoMulQuantize failed :" << ret; | |||||
| return ret; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } // namespace quant | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,53 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef WEIGHT_QUANTIZER_H | |||||
| #define WEIGHT_QUANTIZER_H | |||||
| #include <memory> | |||||
| #include <list> | |||||
| #include <string> | |||||
| #include "tools/converter/quantizer/quantizer.h" | |||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "ir/anf.h" | |||||
| #include "include/model.h" | |||||
| #include "base/base.h" | |||||
| #include "abstract/dshape.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace quant { | |||||
| class WeightQuantizer : public Quantizer { | |||||
| public: | |||||
| WeightQuantizer(FuncGraphPtr graph, const std::string& weightSize, | |||||
| const std::string& covWeightChannelThreshold, const std::string& bitNum); | |||||
| ~WeightQuantizer() = default; | |||||
| STATUS DoQuantize(FuncGraphPtr funcGraph) override; | |||||
| STATUS DoConvQuantize(const std::list<CNodePtr> &nodes); | |||||
| STATUS DoMulQuantize(const std::list<CNodePtr> &nodes); | |||||
| private: | |||||
| std::unique_ptr<QuantStrategy> mStrategy; | |||||
| size_t bitNum; | |||||
| }; | |||||
| } // namespace quant | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif | |||||