Merge pull request !4118 from xutianchun/quant_0807tags/v0.7.0-beta
| @@ -177,31 +177,44 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { | |||||
| if (node->quantType == schema::QuantType_PostTraining) { | if (node->quantType == schema::QuantType_PostTraining) { | ||||
| MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; | MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; | ||||
| // activation | // activation | ||||
| auto activate_index = node->inputIndex[0]; | |||||
| auto tensor_input = metaGraphT->allTensors[activate_index].get(); | |||||
| auto input_quant_params = primitiveT_value->GetInputQuantParams(); | auto input_quant_params = primitiveT_value->GetInputQuantParams(); | ||||
| if (input_quant_params.empty()) { | |||||
| MS_LOG(WARNING) << "node: " << node->name | |||||
| << " input quant params is empty"; | |||||
| } else { | |||||
| auto node_type = primitiveT_value->GetPrimitiveT()->value.type; | |||||
| for (int i = 0; i < input_quant_params.size(); i++) { | |||||
| if (i >= node->inputIndex.size()) { | |||||
| MS_LOG(ERROR) << "node: " << node->name << " input has " << input_quant_params.size() | |||||
| << " quant_params; but only " << node->inputIndex.size() << " input"; | |||||
| break; | |||||
| } | |||||
| auto activate_index = node->inputIndex[i]; | |||||
| auto tensor_input = metaGraphT->allTensors[activate_index].get(); | |||||
| std::unique_ptr<schema::QuantParamT> input_quant_param = | std::unique_ptr<schema::QuantParamT> input_quant_param = | ||||
| std::make_unique<schema::QuantParamT>(input_quant_params[0]); | |||||
| std::make_unique<schema::QuantParamT>(input_quant_params[i]); | |||||
| MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale | |||||
| << " zp: " << input_quant_param->zeroPoint; | |||||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param)); | tensor_input->quantParams.emplace_back(std::move(input_quant_param)); | ||||
| if (!(node_type == schema::PrimitiveType_QuantDTypeCast && | |||||
| primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) { | |||||
| tensor_input->dataType = kNumberTypeInt8; | |||||
| } | |||||
| } | } | ||||
| tensor_input->dataType = kNumberTypeInt8; | |||||
| // output | // output | ||||
| auto output_index = node->outputIndex[0]; | auto output_index = node->outputIndex[0]; | ||||
| auto tensor_output = metaGraphT->allTensors[output_index].get(); | auto tensor_output = metaGraphT->allTensors[output_index].get(); | ||||
| auto output_quant_params = primitiveT_value->GetOutputQuantParams(); | auto output_quant_params = primitiveT_value->GetOutputQuantParams(); | ||||
| if (output_quant_params.empty()) { | if (output_quant_params.empty()) { | ||||
| MS_LOG(WARNING) << "node: " << node->name | |||||
| << " output quant params is empty"; | |||||
| MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; | |||||
| } else { | } else { | ||||
| std::unique_ptr<schema::QuantParamT> output_quant_param = | std::unique_ptr<schema::QuantParamT> output_quant_param = | ||||
| std::make_unique<schema::QuantParamT>(output_quant_params[0]); | |||||
| std::make_unique<schema::QuantParamT>(output_quant_params[0]); | |||||
| MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale | |||||
| << " zp: " << output_quant_param->zeroPoint; | |||||
| tensor_output->quantParams.emplace_back(std::move(output_quant_param)); | tensor_output->quantParams.emplace_back(std::move(output_quant_param)); | ||||
| } | } | ||||
| tensor_output->dataType = kNumberTypeInt8; | |||||
| if (!(node_type == schema::PrimitiveType_QuantDTypeCast && | |||||
| primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { | |||||
| tensor_output->dataType = kNumberTypeInt8; | |||||
| } | |||||
| // // TensorType | // // TensorType | ||||
| // valuePtr = primitive->GetAttr(kInputTensorDataType); | // valuePtr = primitive->GetAttr(kInputTensorDataType); | ||||
| // if (valuePtr != nullptr) { | // if (valuePtr != nullptr) { | ||||
| @@ -64,6 +64,16 @@ int LiteSession::ConvertTensors(const lite::Model *model) { | |||||
| // no copy data, do copy when call LiteKernel::Init | // no copy data, do copy when call LiteKernel::Init | ||||
| dstTensor->SetData(const_cast<unsigned char *>(srcTensor->data()->data())); | dstTensor->SetData(const_cast<unsigned char *>(srcTensor->data()->data())); | ||||
| } | } | ||||
| auto quant_params = srcTensor->quantParams(); | |||||
| if (quant_params != nullptr) { | |||||
| for (int j = 0; j < quant_params->size(); j++) { | |||||
| tensor::QuantArg quant_arg{}; | |||||
| quant_arg.scale = quant_params->Get(j)->scale(); | |||||
| quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint(); | |||||
| dstTensor->AddQuantParam(quant_arg); | |||||
| } | |||||
| } | |||||
| this->tensors.emplace_back(dstTensor); | this->tensors.emplace_back(dstTensor); | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -30,6 +30,7 @@ int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto | |||||
| auto param = primitive->value_as_QuantDTypeCast(); | auto param = primitive->value_as_QuantDTypeCast(); | ||||
| MS_ASSERT(input->data_type() == param->srcT); | MS_ASSERT(input->data_type() == param->srcT); | ||||
| output->set_data_type(static_cast<TypeId>(param->dstT())); | output->set_data_type(static_cast<TypeId>(param->dstT())); | ||||
| output->SetFormat(input->GetFormat()); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -62,7 +62,7 @@ int QuantDTypeCastCPUKernel::Init() { | |||||
| } | } | ||||
| inverse_ = true; | inverse_ = true; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "param data type not supported."; | |||||
| MS_LOG(ERROR) << "param data type not supported:" << " src: " << param->srcT << " dst: " << param->dstT; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -148,7 +148,6 @@ kernel::LiteKernel *CpuQuantDTypeCastFp32KernelCreator(const std::vector<lite::t | |||||
| } | } | ||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -23,7 +23,7 @@ int DequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_ | |||||
| } | } | ||||
| for (int i = 0; i < size; ++i) { | for (int i = 0; i < size; ++i) { | ||||
| real_values[i] = (quant_values[i] + zp) * scale; | |||||
| real_values[i] = (quant_values[i] - zp) * scale; | |||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| @@ -34,7 +34,14 @@ int QuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_ | |||||
| } | } | ||||
| for (int i = 0; i < size; ++i) { | for (int i = 0; i < size; ++i) { | ||||
| quant_values[i] = (int8_t)round(real_values[i] / scale + zp); | |||||
| float temp = round(real_values[i] / scale + zp); | |||||
| if (temp > 127) { | |||||
| quant_values[i] = 127; | |||||
| } else if (temp < -128) { | |||||
| quant_values[i] = -128; | |||||
| } else { | |||||
| quant_values[i] = (int8_t)temp; | |||||
| } | |||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| @@ -166,6 +166,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "weight_tensor_format: " << weightTensor->format; | |||||
| return 0; | return 0; | ||||
| } else if (fmkType == converter::FmkType_ONNX) { | } else if (fmkType == converter::FmkType_ONNX) { | ||||
| switch (node->quantType) { | switch (node->quantType) { | ||||
| @@ -217,7 +218,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||||
| auto opType = node->primitive->value.type; | auto opType = node->primitive->value.type; | ||||
| if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && | if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && | ||||
| opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { | opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { | ||||
| return 0; | |||||
| return RET_OK; | |||||
| } | } | ||||
| MS_ASSERT(node->inputIndex.size() >= 2); | MS_ASSERT(node->inputIndex.size() >= 2); | ||||
| @@ -225,7 +226,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||||
| MS_ASSERT(subGraph->allTensors.size() > weightIndex); | MS_ASSERT(subGraph->allTensors.size() > weightIndex); | ||||
| auto &weightTensor = subGraph->allTensors[weightIndex]; | auto &weightTensor = subGraph->allTensors[weightIndex]; | ||||
| MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT | MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT | ||||
| STATUS status; | |||||
| STATUS status = RET_OK; | |||||
| if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK | if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK | ||||
| if (weightTensor->format == schema::Format_KCHW) { // from caffe | if (weightTensor->format == schema::Format_KCHW) { // from caffe | ||||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | ||||
| @@ -238,11 +239,12 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK); | status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK); | ||||
| } | } | ||||
| } else if (weightTensor->format == schema::Format_KHWC) { // from onnx | } else if (weightTensor->format == schema::Format_KHWC) { // from onnx | ||||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||||
| status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK); | |||||
| } else { | |||||
| status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK); | |||||
| } | |||||
| return RET_OK; | |||||
| // if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||||
| // status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK); | |||||
| // } else { | |||||
| // status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK); | |||||
| // } | |||||
| } else if (weightTensor->format == schema::Format_HWCK) { // from tf | } else if (weightTensor->format == schema::Format_HWCK) { // from tf | ||||
| return 0; | return 0; | ||||
| } else { | } else { | ||||
| @@ -273,8 +275,8 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||||
| } else if (weightTensor->format == schema::Format_HWCK) { // from tf | } else if (weightTensor->format == schema::Format_HWCK) { // from tf | ||||
| return 0; | return 0; | ||||
| } else if (weightTensor->format == schema::Format_CHWK) { // from onnx | } else if (weightTensor->format == schema::Format_CHWK) { // from onnx | ||||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||||
| status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2HWCK); | |||||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||||
| status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC); | |||||
| } else { | } else { | ||||
| status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK); | status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK); | ||||
| } | } | ||||
| @@ -54,7 +54,7 @@ struct DivergInfo { | |||||
| size_t bit_num; | size_t bit_num; | ||||
| int quant_max = 255; | int quant_max = 255; | ||||
| int quant_min = 0; | int quant_min = 0; | ||||
| DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max = 255, int quant_min = 0) { | |||||
| DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min) { | |||||
| this->cnode = cnode; | this->cnode = cnode; | ||||
| this->bin_num = bins; | this->bin_num = bins; | ||||
| this->bit_num = bits; | this->bit_num = bits; | ||||
| @@ -81,6 +81,9 @@ struct DivergInfo { | |||||
| STATUS UpdateHistogram(const std::vector<float> &data, const std::vector<int> &shape) { | STATUS UpdateHistogram(const std::vector<float> &data, const std::vector<int> &shape) { | ||||
| for (auto value : data) { | for (auto value : data) { | ||||
| if (value == 0) { | |||||
| continue; | |||||
| } | |||||
| int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval), bin_num - 1); | int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval), bin_num - 1); | ||||
| this->histogram[bin_index]++; | this->histogram[bin_index]++; | ||||
| } | } | ||||
| @@ -470,8 +473,10 @@ STATUS Calibrator::ReadConfig() { | |||||
| Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin) | Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin) | ||||
| : config_path_(path), bit_num_(bitNum), quant_max_(quantMax), quant_min_(quantMin) {} | : config_path_(path), bit_num_(bitNum), quant_max_(quantMax), quant_min_(quantMin) {} | ||||
| PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type) | |||||
| PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type, | |||||
| bool per_channel) | |||||
| : Quantizer(graph) { | : Quantizer(graph) { | ||||
| this->per_channel_ = per_channel; | |||||
| this->bit_num = bit_num; | this->bit_num = bit_num; | ||||
| this->target_type_ = target_type; | this->target_type_ = target_type; | ||||
| if (target_type == kNumberTypeInt8) { | if (target_type == kNumberTypeInt8) { | ||||
| @@ -533,7 +538,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr node) { | |||||
| } | } | ||||
| auto parameter = std::dynamic_pointer_cast<Parameter>(node); | auto parameter = std::dynamic_pointer_cast<Parameter>(node); | ||||
| ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param()); | ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param()); | ||||
| auto status = QuantFilter(paramValue, QuantType_PostTraining, quant_max, quant_min, bit_num); | |||||
| auto status = QuantFilter(paramValue, QuantType_PostTraining, quant_max, quant_min, bit_num, per_channel_); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "QuantFilter failed: " << status; | MS_LOG(ERROR) << "QuantFilter failed: " << status; | ||||
| return status; | return status; | ||||
| @@ -670,18 +675,32 @@ STATUS PostTrainingQuantizer::QuantNode() { | |||||
| MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; | MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (input_scale.find(cnode) == input_scale.end()) { | if (input_scale.find(cnode) == input_scale.end()) { | ||||
| primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE); | primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE); | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto input_vec = cnode->inputs(); | auto input_vec = cnode->inputs(); | ||||
| auto op_name = cnode->fullname_with_scope(); | auto op_name = cnode->fullname_with_scope(); | ||||
| auto op_type = primitiveT_value->GetPrimitiveT()->value.type; | |||||
| MS_LOG(INFO) << "OpName: " << op_name; | MS_LOG(INFO) << "OpName: " << op_name; | ||||
| if (input_vec.size() <= 3 && op_name != "Conv2D" && op_name != "DepthwiseConv2D") { | |||||
| MS_LOG(INFO) << "todo(x): "; | |||||
| // int32_t qnodeOutputZeropoint = outputZeropoint[cnode]; | |||||
| // p->AddAttr(kInputTensorDataType, MakeValue((int)targetType)); | |||||
| if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D) { | |||||
| for (auto i = 1; i < cnode->inputs().size(); i++) { | |||||
| auto input_node = cnode->input(i); | |||||
| if (!input_node->isa<mindspore::CNode>()) { | |||||
| MS_LOG(WARNING) << "node: " << cnode_name << " input " << i << " not a cnode"; | |||||
| continue; | |||||
| } | |||||
| auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node); | |||||
| auto input_cnode_primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(input_cnode->input(0)); | |||||
| if (input_cnode_primitiveT_value == nullptr) { | |||||
| MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " | |||||
| << " PrimitiveTValue is null"; | |||||
| continue; | |||||
| } | |||||
| for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) { | |||||
| primitiveT_value->AddInputQuantParam(quant_param); | |||||
| } | |||||
| } | |||||
| } else { | } else { | ||||
| // do input quant | // do input quant | ||||
| double scale = input_scale[cnode]; | double scale = input_scale[cnode]; | ||||
| @@ -55,15 +55,18 @@ struct ConfigParam { | |||||
| class PostTrainingQuantizer : public Quantizer { | class PostTrainingQuantizer : public Quantizer { | ||||
| public: | public: | ||||
| PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8); | |||||
| PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8, | |||||
| bool per_channel = false); | |||||
| STATUS DoQuantize(FuncGraphPtr funcGraph) override; | STATUS DoQuantize(FuncGraphPtr funcGraph) override; | ||||
| size_t bit_num; | size_t bit_num; | ||||
| int quant_max{255}; | |||||
| int quant_min{0}; | |||||
| int quant_max{127}; | |||||
| int quant_min{-128}; | |||||
| private: | private: | ||||
| bool per_channel_; | |||||
| TypeId target_type_{kNumberTypeInt8}; | TypeId target_type_{kNumberTypeInt8}; | ||||
| std::unique_ptr<Calibrator> calibrator_; | std::unique_ptr<Calibrator> calibrator_; | ||||
| @@ -25,10 +25,11 @@ namespace mindspore::lite::quant { | |||||
| ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params) { | ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params) { | ||||
| std::unique_ptr<schema::PrimitiveT> primitive = std::make_unique<schema::PrimitiveT>(); | std::unique_ptr<schema::PrimitiveT> primitive = std::make_unique<schema::PrimitiveT>(); | ||||
| schema::QuantDTypeCastT quant_dtype_cast; | schema::QuantDTypeCastT quant_dtype_cast; | ||||
| quant_dtype_cast.srcT = src_type; // kNumberTypeUInt8; | |||||
| quant_dtype_cast.srcT = src_type; // kNumberTypeInt8; | |||||
| quant_dtype_cast.dstT = dst_type; // kNumberTypeFloat32; | quant_dtype_cast.dstT = dst_type; // kNumberTypeFloat32; | ||||
| primitive->value.Set(quant_dtype_cast); | primitive->value.Set(quant_dtype_cast); | ||||
| auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release()); | auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release()); | ||||
| primTValue->SetQuantType(schema::QuantType_PostTraining); | |||||
| for (auto &quant_param : quant_params) { | for (auto &quant_param : quant_params) { | ||||
| primTValue->AddInputQuantParam(quant_param); | primTValue->AddInputQuantParam(quant_param); | ||||
| } | } | ||||
| @@ -52,7 +53,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||||
| if (first) { | if (first) { | ||||
| if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { | if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { | ||||
| auto value_node = | auto value_node = | ||||
| NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8, primitiveT_value->GetInputQuantParams()); | |||||
| NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams()); | |||||
| std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)}; | std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)}; | ||||
| auto quant_cast_cnode = graph->NewCNode(op_inputs); | auto quant_cast_cnode = graph->NewCNode(op_inputs); | ||||
| quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); | quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); | ||||
| @@ -82,11 +83,11 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||||
| ValueNodePtr value_node = nullptr; | ValueNodePtr value_node = nullptr; | ||||
| if (curnode_quant_type == schema::QuantType_PostTraining && | if (curnode_quant_type == schema::QuantType_PostTraining && | ||||
| input_cnode_quant_type == schema::QuantType_QUANT_NONE) { | input_cnode_quant_type == schema::QuantType_QUANT_NONE) { | ||||
| value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8, | |||||
| input_cnode_primitiveT_value->GetInputQuantParams()); | |||||
| value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, | |||||
| primitiveT_value->GetInputQuantParams()); | |||||
| } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && | } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && | ||||
| input_cnode_quant_type == schema::QuantType_PostTraining) { | input_cnode_quant_type == schema::QuantType_PostTraining) { | ||||
| value_node = NewQuantCastValueNode(kNumberTypeUInt8, kNumberTypeFloat32, | |||||
| value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, | |||||
| input_cnode_primitiveT_value->GetInputQuantParams()); | input_cnode_primitiveT_value->GetInputQuantParams()); | ||||
| } | } | ||||
| if (value_node == nullptr) { | if (value_node == nullptr) { | ||||
| @@ -98,7 +98,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { | |||||
| static const std::vector<schema::PrimitiveType> uint8OpList = { | static const std::vector<schema::PrimitiveType> uint8OpList = { | ||||
| schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, | schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, | ||||
| schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, | schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, | ||||
| schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, schema::PrimitiveType_Reshape, | |||||
| schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape, | |||||
| schema::PrimitiveType_Activation}; | schema::PrimitiveType_Activation}; | ||||
| return IsContain(uint8OpList, type); | return IsContain(uint8OpList, type); | ||||
| } | } | ||||
| @@ -242,64 +242,122 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum) { | |||||
| STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum, | |||||
| bool per_channel) { | |||||
| if (per_channel) { | |||||
| // per channel | |||||
| auto dims = weightPtr->tensor_shape(); | auto dims = weightPtr->tensor_shape(); | ||||
| if (dims.size() < 1) { | if (dims.size() < 1) { | ||||
| MS_LOG(ERROR) << "weight dims size error"; | |||||
| return RET_ERROR; | |||||
| MS_LOG(ERROR) << "weight dims size error"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| uint32_t channels = dims[0]; | |||||
| // todo(x) | |||||
| uint32_t channels = dims[3]; | |||||
| if (channels == 0) { | if (channels == 0) { | ||||
| MS_LOG(ERROR) << "channels error 0"; | |||||
| return RET_ERROR; | |||||
| MS_LOG(ERROR) << "channels error 0"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| size_t shapeSize = weightPtr->tensor_shape_size(); | size_t shapeSize = weightPtr->tensor_shape_size(); | ||||
| size_t oneFilterSize = shapeSize / channels; | size_t oneFilterSize = shapeSize / channels; | ||||
| auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr()); | auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr()); | ||||
| if (rawDatas == nullptr) { | if (rawDatas == nullptr) { | ||||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||||
| return RET_ERROR; | |||||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| weightPtr->quant_param().clear(); | weightPtr->quant_param().clear(); | ||||
| vector<uint8_t> qDatas(shapeSize); | |||||
| vector<int8_t> qDatas(shapeSize); | |||||
| for (uint32_t i = 0; i < channels; i++) { | for (uint32_t i = 0; i < channels; i++) { | ||||
| float min = 0; | |||||
| float max = 0; | |||||
| // find min and max | |||||
| for (uint32_t j = 0; j < oneFilterSize; j++) { | |||||
| min = std::min(min, rawDatas[j + i * oneFilterSize]); | |||||
| max = std::max(max, rawDatas[j + i * oneFilterSize]); | |||||
| } | |||||
| float min = 0; | |||||
| float max = 0; | |||||
| // find min and max | |||||
| for (uint32_t j = 0; j < oneFilterSize; j++) { | |||||
| min = std::min(min, rawDatas[j + i * oneFilterSize]); | |||||
| max = std::max(max, rawDatas[j + i * oneFilterSize]); | |||||
| } | |||||
| std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam); | |||||
| STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||||
| return status; | |||||
| } | |||||
| // update data and datatype | |||||
| for (uint32_t j = 0; j < oneFilterSize; j++) { | |||||
| float rawData = rawDatas[j + i * oneFilterSize]; | |||||
| auto qData = QuantizeData<int8_t>(rawData, quantParam.get(), quant_max, quant_min); | |||||
| qDatas[j + i * oneFilterSize] = qData; | |||||
| } | |||||
| weightPtr->set_quant_param(quantParam); | |||||
| } | |||||
| auto ret = memcpy_s(const_cast<float*>(rawDatas), weightPtr->tensor_size(), | |||||
| qDatas.data(), shapeSize * sizeof(int8_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (quantType == QuantType_WeightQuant) { | |||||
| PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum); | |||||
| } | |||||
| std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam); | |||||
| STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||||
| return status; | |||||
| } | |||||
| // update data and datatype | |||||
| for (uint32_t j = 0; j < oneFilterSize; j++) { | |||||
| float rawData = rawDatas[j + i * oneFilterSize]; | |||||
| auto qData = QuantizeData<uint8_t>(rawData, quantParam.get()); | |||||
| qDatas[j + i * oneFilterSize] = qData; | |||||
| } | |||||
| weightPtr->set_tensor_type(kNumberTypeInt8); | |||||
| weightPtr->set_tensor_size(shapeSize * sizeof(int8_t)); | |||||
| } else { | |||||
| // per layer | |||||
| size_t shapeSize = weightPtr->tensor_shape_size(); | |||||
| auto *rawDatas = static_cast<float *>(weightPtr->tensor_addr()); | |||||
| if (rawDatas == nullptr) { | |||||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| weightPtr->set_quant_param(quantParam); | |||||
| weightPtr->quant_param().clear(); | |||||
| vector<int8_t> qDatas(shapeSize); | |||||
| float min = 0; | |||||
| float max = 0; | |||||
| for (uint32_t i = 0; i < shapeSize; i++) { | |||||
| // find max min | |||||
| min = std::min(min, rawDatas[i]); | |||||
| max = std::max(max, rawDatas[i]); | |||||
| } | } | ||||
| auto ret = memcpy_s(const_cast<float*>(rawDatas), weightPtr->tensor_size(), | |||||
| qDatas.data(), shapeSize * sizeof(uint8_t)); | |||||
| std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam); | |||||
| STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||||
| return status; | |||||
| } | |||||
| // update data and datatype | |||||
| for (uint32_t i = 0; i < shapeSize; i++) { | |||||
| float rawData = rawDatas[i]; | |||||
| auto quant_data = std::round(rawData / quantParam->scale + quantParam->zeroPoint); | |||||
| if (quant_data > quant_max) { | |||||
| qDatas[i] = quant_max; | |||||
| } else if (quant_data < quant_min) { | |||||
| qDatas[i] = quant_min; | |||||
| } else { | |||||
| qDatas[i] = static_cast<int8_t>(quant_data); | |||||
| } | |||||
| } | |||||
| weightPtr->set_quant_param(quantParam); | |||||
| auto ret = memcpy_s(rawDatas, weightPtr->tensor_size() * sizeof(int8_t), | |||||
| qDatas.data(), shapeSize * sizeof(int8_t)); | |||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||||
| return RET_ERROR; | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| if (quantType == QuantType_WeightQuant) { | if (quantType == QuantType_WeightQuant) { | ||||
| PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum); | |||||
| PostBitPack(rawDatas, shapeSize, bitNum); | |||||
| } | } | ||||
| weightPtr->set_tensor_type(kNumberTypeInt8); | weightPtr->set_tensor_type(kNumberTypeInt8); | ||||
| weightPtr->set_tensor_size(shapeSize * sizeof(int8_t)); | weightPtr->set_tensor_size(shapeSize * sizeof(int8_t)); | ||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -63,41 +63,30 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double | |||||
| bool narrowRange, int quant_max, int quant_min, int num_bits); | bool narrowRange, int quant_max, int quant_min, int num_bits); | ||||
| template <typename T> | template <typename T> | ||||
| T QuantizeData(const float originData, const AnfQuantParam *quantParam) { | |||||
| T QuantizeData(float originData, const AnfQuantParam *quantParam, int quant_max, int quant_min) { | |||||
| MS_ASSERT(quantParam != nullptr); | MS_ASSERT(quantParam != nullptr); | ||||
| MS_ASSERT(quantParam->inited); | MS_ASSERT(quantParam->inited); | ||||
| const auto scale = quantParam->scale; | const auto scale = quantParam->scale; | ||||
| const auto zeroPoint = quantParam->zeroPoint; | |||||
| const auto numBit = quantParam->numBits; | |||||
| const int zeroPoint = quantParam->zeroPoint; | |||||
| const auto narrowRange = quantParam->narrowRange; | const auto narrowRange = quantParam->narrowRange; | ||||
| const double maxLimit = static_cast<float>((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale; | |||||
| double minLimit; | |||||
| if (narrowRange) { | |||||
| minLimit = static_cast<float>(1 - zeroPoint) * scale; | |||||
| } else { | |||||
| minLimit = static_cast<float>(0 - zeroPoint) * scale; | |||||
| } | |||||
| const int maxLimit = quant_max; | |||||
| const int minLimit = quant_min; | |||||
| return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] { | return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] { | ||||
| double tmp = 0.0f; | |||||
| if (originData > maxLimit) { | |||||
| tmp = maxLimit; | |||||
| } else if (originData < minLimit) { | |||||
| tmp = minLimit; | |||||
| } else { | |||||
| tmp = originData; | |||||
| } | |||||
| auto quantData = static_cast<T>(std::round(tmp / scale + zeroPoint)); | |||||
| if (quantData == 0 && narrowRange) { | |||||
| quantData++; | |||||
| int quant_data = std::round(originData / scale + zeroPoint); | |||||
| if (quant_data > maxLimit) { | |||||
| quant_data = maxLimit; | |||||
| } else if (quant_data < minLimit) { | |||||
| quant_data = minLimit; | |||||
| } | } | ||||
| return quantData; | |||||
| return static_cast<T>(quant_data); | |||||
| }(); | }(); | ||||
| } | } | ||||
| void CalFakeNode(const AnfNodePtr &inTensor); | void CalFakeNode(const AnfNodePtr &inTensor); | ||||
| STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, | STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, | ||||
| size_t bitNum = UINT8_QUANTIZATION); | |||||
| size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false); | |||||
| STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); | STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); | ||||