Merge pull request !4118 from xutianchun/quant_0807tags/v0.7.0-beta
| @@ -177,31 +177,44 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { | |||
| if (node->quantType == schema::QuantType_PostTraining) { | |||
| MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; | |||
| // activation | |||
| auto activate_index = node->inputIndex[0]; | |||
| auto tensor_input = metaGraphT->allTensors[activate_index].get(); | |||
| auto input_quant_params = primitiveT_value->GetInputQuantParams(); | |||
| if (input_quant_params.empty()) { | |||
| MS_LOG(WARNING) << "node: " << node->name | |||
| << " input quant params is empty"; | |||
| } else { | |||
| auto node_type = primitiveT_value->GetPrimitiveT()->value.type; | |||
| for (int i = 0; i < input_quant_params.size(); i++) { | |||
| if (i >= node->inputIndex.size()) { | |||
| MS_LOG(ERROR) << "node: " << node->name << " input has " << input_quant_params.size() | |||
| << " quant_params; but only " << node->inputIndex.size() << " input"; | |||
| break; | |||
| } | |||
| auto activate_index = node->inputIndex[i]; | |||
| auto tensor_input = metaGraphT->allTensors[activate_index].get(); | |||
| std::unique_ptr<schema::QuantParamT> input_quant_param = | |||
| std::make_unique<schema::QuantParamT>(input_quant_params[0]); | |||
| std::make_unique<schema::QuantParamT>(input_quant_params[i]); | |||
| MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale | |||
| << " zp: " << input_quant_param->zeroPoint; | |||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param)); | |||
| if (!(node_type == schema::PrimitiveType_QuantDTypeCast && | |||
| primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) { | |||
| tensor_input->dataType = kNumberTypeInt8; | |||
| } | |||
| } | |||
| tensor_input->dataType = kNumberTypeInt8; | |||
| // output | |||
| auto output_index = node->outputIndex[0]; | |||
| auto tensor_output = metaGraphT->allTensors[output_index].get(); | |||
| auto output_quant_params = primitiveT_value->GetOutputQuantParams(); | |||
| if (output_quant_params.empty()) { | |||
| MS_LOG(WARNING) << "node: " << node->name | |||
| << " output quant params is empty"; | |||
| MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; | |||
| } else { | |||
| std::unique_ptr<schema::QuantParamT> output_quant_param = | |||
| std::make_unique<schema::QuantParamT>(output_quant_params[0]); | |||
| std::make_unique<schema::QuantParamT>(output_quant_params[0]); | |||
| MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale | |||
| << " zp: " << output_quant_param->zeroPoint; | |||
| tensor_output->quantParams.emplace_back(std::move(output_quant_param)); | |||
| } | |||
| tensor_output->dataType = kNumberTypeInt8; | |||
| if (!(node_type == schema::PrimitiveType_QuantDTypeCast && | |||
| primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { | |||
| tensor_output->dataType = kNumberTypeInt8; | |||
| } | |||
| // // TensorType | |||
| // valuePtr = primitive->GetAttr(kInputTensorDataType); | |||
| // if (valuePtr != nullptr) { | |||
| @@ -64,6 +64,16 @@ int LiteSession::ConvertTensors(const lite::Model *model) { | |||
| // no copy data, do copy when call LiteKernel::Init | |||
| dstTensor->SetData(const_cast<unsigned char *>(srcTensor->data()->data())); | |||
| } | |||
| auto quant_params = srcTensor->quantParams(); | |||
| if (quant_params != nullptr) { | |||
| for (int j = 0; j < quant_params->size(); j++) { | |||
| tensor::QuantArg quant_arg{}; | |||
| quant_arg.scale = quant_params->Get(j)->scale(); | |||
| quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint(); | |||
| dstTensor->AddQuantParam(quant_arg); | |||
| } | |||
| } | |||
| this->tensors.emplace_back(dstTensor); | |||
| } | |||
| return RET_OK; | |||
| @@ -30,6 +30,7 @@ int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto | |||
| auto param = primitive->value_as_QuantDTypeCast(); | |||
| MS_ASSERT(input->data_type() == param->srcT); | |||
| output->set_data_type(static_cast<TypeId>(param->dstT())); | |||
| output->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -62,7 +62,7 @@ int QuantDTypeCastCPUKernel::Init() { | |||
| } | |||
| inverse_ = true; | |||
| } else { | |||
| MS_LOG(ERROR) << "param data type not supported."; | |||
| MS_LOG(ERROR) << "param data type not supported:" << " src: " << param->srcT << " dst: " << param->dstT; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -148,7 +148,6 @@ kernel::LiteKernel *CpuQuantDTypeCastFp32KernelCreator(const std::vector<lite::t | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -23,7 +23,7 @@ int DequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_ | |||
| } | |||
| for (int i = 0; i < size; ++i) { | |||
| real_values[i] = (quant_values[i] + zp) * scale; | |||
| real_values[i] = (quant_values[i] - zp) * scale; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -34,7 +34,14 @@ int QuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_ | |||
| } | |||
| for (int i = 0; i < size; ++i) { | |||
| quant_values[i] = (int8_t)round(real_values[i] / scale + zp); | |||
| float temp = round(real_values[i] / scale + zp); | |||
| if (temp > 127) { | |||
| quant_values[i] = 127; | |||
| } else if (temp < -128) { | |||
| quant_values[i] = -128; | |||
| } else { | |||
| quant_values[i] = (int8_t)temp; | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -166,6 +166,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||
| return -1; | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "weight_tensor_format: " << weightTensor->format; | |||
| return 0; | |||
| } else if (fmkType == converter::FmkType_ONNX) { | |||
| switch (node->quantType) { | |||
| @@ -217,7 +218,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||
| auto opType = node->primitive->value.type; | |||
| if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && | |||
| opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| return 0; | |||
| return RET_OK; | |||
| } | |||
| MS_ASSERT(node->inputIndex.size() >= 2); | |||
| @@ -225,7 +226,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||
| MS_ASSERT(subGraph->allTensors.size() > weightIndex); | |||
| auto &weightTensor = subGraph->allTensors[weightIndex]; | |||
| MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT | |||
| STATUS status; | |||
| STATUS status = RET_OK; | |||
| if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK | |||
| if (weightTensor->format == schema::Format_KCHW) { // from caffe | |||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||
| @@ -238,11 +239,12 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK); | |||
| } | |||
| } else if (weightTensor->format == schema::Format_KHWC) { // from onnx | |||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||
| status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK); | |||
| } else { | |||
| status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK); | |||
| } | |||
| return RET_OK; | |||
| // if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||
| // status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK); | |||
| // } else { | |||
| // status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK); | |||
| // } | |||
| } else if (weightTensor->format == schema::Format_HWCK) { // from tf | |||
| return 0; | |||
| } else { | |||
| @@ -273,8 +275,8 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||
| } else if (weightTensor->format == schema::Format_HWCK) { // from tf | |||
| return 0; | |||
| } else if (weightTensor->format == schema::Format_CHWK) { // from onnx | |||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||
| status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2HWCK); | |||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||
| status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC); | |||
| } else { | |||
| status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK); | |||
| } | |||
| @@ -54,7 +54,7 @@ struct DivergInfo { | |||
| size_t bit_num; | |||
| int quant_max = 255; | |||
| int quant_min = 0; | |||
| DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max = 255, int quant_min = 0) { | |||
| DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min) { | |||
| this->cnode = cnode; | |||
| this->bin_num = bins; | |||
| this->bit_num = bits; | |||
| @@ -81,6 +81,9 @@ struct DivergInfo { | |||
| STATUS UpdateHistogram(const std::vector<float> &data, const std::vector<int> &shape) { | |||
| for (auto value : data) { | |||
| if (value == 0) { | |||
| continue; | |||
| } | |||
| int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval), bin_num - 1); | |||
| this->histogram[bin_index]++; | |||
| } | |||
| @@ -470,8 +473,10 @@ STATUS Calibrator::ReadConfig() { | |||
| Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin) | |||
| : config_path_(path), bit_num_(bitNum), quant_max_(quantMax), quant_min_(quantMin) {} | |||
| PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type) | |||
| PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type, | |||
| bool per_channel) | |||
| : Quantizer(graph) { | |||
| this->per_channel_ = per_channel; | |||
| this->bit_num = bit_num; | |||
| this->target_type_ = target_type; | |||
| if (target_type == kNumberTypeInt8) { | |||
| @@ -533,7 +538,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr node) { | |||
| } | |||
| auto parameter = std::dynamic_pointer_cast<Parameter>(node); | |||
| ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param()); | |||
| auto status = QuantFilter(paramValue, QuantType_PostTraining, quant_max, quant_min, bit_num); | |||
| auto status = QuantFilter(paramValue, QuantType_PostTraining, quant_max, quant_min, bit_num, per_channel_); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed: " << status; | |||
| return status; | |||
| @@ -670,18 +675,32 @@ STATUS PostTrainingQuantizer::QuantNode() { | |||
| MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; | |||
| continue; | |||
| } | |||
| if (input_scale.find(cnode) == input_scale.end()) { | |||
| primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE); | |||
| continue; | |||
| } | |||
| auto input_vec = cnode->inputs(); | |||
| auto op_name = cnode->fullname_with_scope(); | |||
| auto op_type = primitiveT_value->GetPrimitiveT()->value.type; | |||
| MS_LOG(INFO) << "OpName: " << op_name; | |||
| if (input_vec.size() <= 3 && op_name != "Conv2D" && op_name != "DepthwiseConv2D") { | |||
| MS_LOG(INFO) << "todo(x): "; | |||
| // int32_t qnodeOutputZeropoint = outputZeropoint[cnode]; | |||
| // p->AddAttr(kInputTensorDataType, MakeValue((int)targetType)); | |||
| if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D) { | |||
| for (auto i = 1; i < cnode->inputs().size(); i++) { | |||
| auto input_node = cnode->input(i); | |||
| if (!input_node->isa<mindspore::CNode>()) { | |||
| MS_LOG(WARNING) << "node: " << cnode_name << " input " << i << " not a cnode"; | |||
| continue; | |||
| } | |||
| auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node); | |||
| auto input_cnode_primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(input_cnode->input(0)); | |||
| if (input_cnode_primitiveT_value == nullptr) { | |||
| MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " | |||
| << " PrimitiveTValue is null"; | |||
| continue; | |||
| } | |||
| for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) { | |||
| primitiveT_value->AddInputQuantParam(quant_param); | |||
| } | |||
| } | |||
| } else { | |||
| // do input quant | |||
| double scale = input_scale[cnode]; | |||
| @@ -55,15 +55,18 @@ struct ConfigParam { | |||
| class PostTrainingQuantizer : public Quantizer { | |||
| public: | |||
| PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8); | |||
| PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8, | |||
| bool per_channel = false); | |||
| STATUS DoQuantize(FuncGraphPtr funcGraph) override; | |||
| size_t bit_num; | |||
| int quant_max{255}; | |||
| int quant_min{0}; | |||
| int quant_max{127}; | |||
| int quant_min{-128}; | |||
| private: | |||
| bool per_channel_; | |||
| TypeId target_type_{kNumberTypeInt8}; | |||
| std::unique_ptr<Calibrator> calibrator_; | |||
| @@ -25,10 +25,11 @@ namespace mindspore::lite::quant { | |||
| ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params) { | |||
| std::unique_ptr<schema::PrimitiveT> primitive = std::make_unique<schema::PrimitiveT>(); | |||
| schema::QuantDTypeCastT quant_dtype_cast; | |||
| quant_dtype_cast.srcT = src_type; // kNumberTypeUInt8; | |||
| quant_dtype_cast.srcT = src_type; // kNumberTypeInt8; | |||
| quant_dtype_cast.dstT = dst_type; // kNumberTypeFloat32; | |||
| primitive->value.Set(quant_dtype_cast); | |||
| auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release()); | |||
| primTValue->SetQuantType(schema::QuantType_PostTraining); | |||
| for (auto &quant_param : quant_params) { | |||
| primTValue->AddInputQuantParam(quant_param); | |||
| } | |||
| @@ -52,7 +53,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||
| if (first) { | |||
| if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { | |||
| auto value_node = | |||
| NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8, primitiveT_value->GetInputQuantParams()); | |||
| NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams()); | |||
| std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)}; | |||
| auto quant_cast_cnode = graph->NewCNode(op_inputs); | |||
| quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); | |||
| @@ -82,11 +83,11 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||
| ValueNodePtr value_node = nullptr; | |||
| if (curnode_quant_type == schema::QuantType_PostTraining && | |||
| input_cnode_quant_type == schema::QuantType_QUANT_NONE) { | |||
| value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8, | |||
| input_cnode_primitiveT_value->GetInputQuantParams()); | |||
| value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, | |||
| primitiveT_value->GetInputQuantParams()); | |||
| } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && | |||
| input_cnode_quant_type == schema::QuantType_PostTraining) { | |||
| value_node = NewQuantCastValueNode(kNumberTypeUInt8, kNumberTypeFloat32, | |||
| value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, | |||
| input_cnode_primitiveT_value->GetInputQuantParams()); | |||
| } | |||
| if (value_node == nullptr) { | |||
| @@ -98,7 +98,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { | |||
| static const std::vector<schema::PrimitiveType> uint8OpList = { | |||
| schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, | |||
| schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, | |||
| schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, schema::PrimitiveType_Reshape, | |||
| schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape, | |||
| schema::PrimitiveType_Activation}; | |||
| return IsContain(uint8OpList, type); | |||
| } | |||
| @@ -242,64 +242,122 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double | |||
| return RET_OK; | |||
| } | |||
| STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum) { | |||
| STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum, | |||
| bool per_channel) { | |||
| if (per_channel) { | |||
| // per channel | |||
| auto dims = weightPtr->tensor_shape(); | |||
| if (dims.size() < 1) { | |||
| MS_LOG(ERROR) << "weight dims size error"; | |||
| return RET_ERROR; | |||
| MS_LOG(ERROR) << "weight dims size error"; | |||
| return RET_ERROR; | |||
| } | |||
| uint32_t channels = dims[0]; | |||
| // todo(x) | |||
| uint32_t channels = dims[3]; | |||
| if (channels == 0) { | |||
| MS_LOG(ERROR) << "channels error 0"; | |||
| return RET_ERROR; | |||
| MS_LOG(ERROR) << "channels error 0"; | |||
| return RET_ERROR; | |||
| } | |||
| size_t shapeSize = weightPtr->tensor_shape_size(); | |||
| size_t oneFilterSize = shapeSize / channels; | |||
| auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr()); | |||
| if (rawDatas == nullptr) { | |||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||
| return RET_ERROR; | |||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| weightPtr->quant_param().clear(); | |||
| vector<uint8_t> qDatas(shapeSize); | |||
| vector<int8_t> qDatas(shapeSize); | |||
| for (uint32_t i = 0; i < channels; i++) { | |||
| float min = 0; | |||
| float max = 0; | |||
| // find min and max | |||
| for (uint32_t j = 0; j < oneFilterSize; j++) { | |||
| min = std::min(min, rawDatas[j + i * oneFilterSize]); | |||
| max = std::max(max, rawDatas[j + i * oneFilterSize]); | |||
| } | |||
| float min = 0; | |||
| float max = 0; | |||
| // find min and max | |||
| for (uint32_t j = 0; j < oneFilterSize; j++) { | |||
| min = std::min(min, rawDatas[j + i * oneFilterSize]); | |||
| max = std::max(max, rawDatas[j + i * oneFilterSize]); | |||
| } | |||
| std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam); | |||
| STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||
| return status; | |||
| } | |||
| // update data and datatype | |||
| for (uint32_t j = 0; j < oneFilterSize; j++) { | |||
| float rawData = rawDatas[j + i * oneFilterSize]; | |||
| auto qData = QuantizeData<int8_t>(rawData, quantParam.get(), quant_max, quant_min); | |||
| qDatas[j + i * oneFilterSize] = qData; | |||
| } | |||
| weightPtr->set_quant_param(quantParam); | |||
| } | |||
| auto ret = memcpy_s(const_cast<float*>(rawDatas), weightPtr->tensor_size(), | |||
| qDatas.data(), shapeSize * sizeof(int8_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| if (quantType == QuantType_WeightQuant) { | |||
| PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum); | |||
| } | |||
| std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam); | |||
| STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||
| return status; | |||
| } | |||
| // update data and datatype | |||
| for (uint32_t j = 0; j < oneFilterSize; j++) { | |||
| float rawData = rawDatas[j + i * oneFilterSize]; | |||
| auto qData = QuantizeData<uint8_t>(rawData, quantParam.get()); | |||
| qDatas[j + i * oneFilterSize] = qData; | |||
| } | |||
| weightPtr->set_tensor_type(kNumberTypeInt8); | |||
| weightPtr->set_tensor_size(shapeSize * sizeof(int8_t)); | |||
| } else { | |||
| // per layer | |||
| size_t shapeSize = weightPtr->tensor_shape_size(); | |||
| auto *rawDatas = static_cast<float *>(weightPtr->tensor_addr()); | |||
| if (rawDatas == nullptr) { | |||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| weightPtr->set_quant_param(quantParam); | |||
| weightPtr->quant_param().clear(); | |||
| vector<int8_t> qDatas(shapeSize); | |||
| float min = 0; | |||
| float max = 0; | |||
| for (uint32_t i = 0; i < shapeSize; i++) { | |||
| // find max min | |||
| min = std::min(min, rawDatas[i]); | |||
| max = std::max(max, rawDatas[i]); | |||
| } | |||
| auto ret = memcpy_s(const_cast<float*>(rawDatas), weightPtr->tensor_size(), | |||
| qDatas.data(), shapeSize * sizeof(uint8_t)); | |||
| std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam); | |||
| STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||
| return status; | |||
| } | |||
| // update data and datatype | |||
| for (uint32_t i = 0; i < shapeSize; i++) { | |||
| float rawData = rawDatas[i]; | |||
| auto quant_data = std::round(rawData / quantParam->scale + quantParam->zeroPoint); | |||
| if (quant_data > quant_max) { | |||
| qDatas[i] = quant_max; | |||
| } else if (quant_data < quant_min) { | |||
| qDatas[i] = quant_min; | |||
| } else { | |||
| qDatas[i] = static_cast<int8_t>(quant_data); | |||
| } | |||
| } | |||
| weightPtr->set_quant_param(quantParam); | |||
| auto ret = memcpy_s(rawDatas, weightPtr->tensor_size() * sizeof(int8_t), | |||
| qDatas.data(), shapeSize * sizeof(int8_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| if (quantType == QuantType_WeightQuant) { | |||
| PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum); | |||
| PostBitPack(rawDatas, shapeSize, bitNum); | |||
| } | |||
| weightPtr->set_tensor_type(kNumberTypeInt8); | |||
| weightPtr->set_tensor_size(shapeSize * sizeof(int8_t)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -63,41 +63,30 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double | |||
| bool narrowRange, int quant_max, int quant_min, int num_bits); | |||
| template <typename T> | |||
| T QuantizeData(const float originData, const AnfQuantParam *quantParam) { | |||
| T QuantizeData(float originData, const AnfQuantParam *quantParam, int quant_max, int quant_min) { | |||
| MS_ASSERT(quantParam != nullptr); | |||
| MS_ASSERT(quantParam->inited); | |||
| const auto scale = quantParam->scale; | |||
| const auto zeroPoint = quantParam->zeroPoint; | |||
| const auto numBit = quantParam->numBits; | |||
| const int zeroPoint = quantParam->zeroPoint; | |||
| const auto narrowRange = quantParam->narrowRange; | |||
| const double maxLimit = static_cast<float>((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale; | |||
| double minLimit; | |||
| if (narrowRange) { | |||
| minLimit = static_cast<float>(1 - zeroPoint) * scale; | |||
| } else { | |||
| minLimit = static_cast<float>(0 - zeroPoint) * scale; | |||
| } | |||
| const int maxLimit = quant_max; | |||
| const int minLimit = quant_min; | |||
| return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] { | |||
| double tmp = 0.0f; | |||
| if (originData > maxLimit) { | |||
| tmp = maxLimit; | |||
| } else if (originData < minLimit) { | |||
| tmp = minLimit; | |||
| } else { | |||
| tmp = originData; | |||
| } | |||
| auto quantData = static_cast<T>(std::round(tmp / scale + zeroPoint)); | |||
| if (quantData == 0 && narrowRange) { | |||
| quantData++; | |||
| int quant_data = std::round(originData / scale + zeroPoint); | |||
| if (quant_data > maxLimit) { | |||
| quant_data = maxLimit; | |||
| } else if (quant_data < minLimit) { | |||
| quant_data = minLimit; | |||
| } | |||
| return quantData; | |||
| return static_cast<T>(quant_data); | |||
| }(); | |||
| } | |||
| void CalFakeNode(const AnfNodePtr &inTensor); | |||
| STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, | |||
| size_t bitNum = UINT8_QUANTIZATION); | |||
| size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false); | |||
| STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); | |||