/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tools/converter/quantizer/aware_quantizer.h" #include #include #include #include #include #include "schema/inner/model_generated.h" #include "securec/include/securec.h" #include "src/common/utils.h" #include "tools/common/node_util.h" #include "tools/common/tensor_util.h" #include "tools/converter/quantizer/calc_quant_param.h" #include "src/common/log_adapter.h" using std::string; using std::vector; namespace mindspore::lite::quant { AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType) : FbQuantizer(graph) {} STATUS AwareQuantizer::RemoveFakeQuant() { return RET_OK; } STATUS AwareQuantizer::GenerateQuantParam() { auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { auto &node = *iter; MS_ASSERT(node != nullptr); if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax || GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { MS_ASSERT(false); } auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); if (quantParamCalcer == nullptr) { MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str() << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; node->quantType = static_cast(QuantType_QUANT_NONE); } else { auto status = quantParamCalcer->Calc(graph, *node); if (status != RET_OK) { MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); node->quantType = schema::QuantType_QUANT_NONE; } else { node->quantType = schema::QuantType_AwareTraining; } } } return RET_OK; } STATUS AwareQuantizer::DoQuantize() { for (auto &tensor : graph->allTensors) { if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) { continue; } if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && tensor->dataType != TypeId::kNumberTypeUInt8) { continue; } // perlayer if (tensor->quantParams.size() == 1) { auto &quantParam = tensor->quantParams.front(); size_t wShapeSize = GetShapeSize(*(tensor.get())); void *oriWeightData = tensor->data.data(); if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { vector qDatas(wShapeSize); auto weightQauntParam = GetTensorQuantParam(tensor); if (tensor->dataType == TypeId::kNumberTypeFloat || tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant auto *weightData = static_cast(oriWeightData); for (size_t j = 0; j < wShapeSize; j++) { qDatas[j] = QuantizeData(weightData[j], weightQauntParam.get()); } } else { // tflite awareing quant auto *weightData = static_cast(oriWeightData); for (size_t j = 0; j < wShapeSize; j++) { qDatas[j] = (int32_t)weightData[j] - 128; } weightQauntParam->zeroPoint -= 128; tensor->quantParams.clear(); tensor->quantParams.emplace_back(weightQauntParam.release()); } ::memcpy(tensor->data.data(), qDatas.data(), wShapeSize); } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { // quant bias data auto bShapeSize = GetShapeSize(*(tensor.get())); std::unique_ptr qDatas(new (std::nothrow) int32_t[bShapeSize]); if (qDatas == nullptr) { MS_LOG(ERROR) << "new qDatas failed"; return RET_ERROR; } void *biasData = tensor->data.data(); auto *rawDatas = static_cast(biasData); for (size_t i = 0; i < bShapeSize; ++i) { qDatas[i] = (int32_t)std::round(rawDatas[i] / quantParam->scale); } tensor->dataType = TypeId::kNumberTypeInt32; tensor->data.clear(); tensor->data.resize(bShapeSize * sizeof(int32_t)); auto ret = memcpy_s(tensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s failed: " << ret; return RET_ERROR; } } } else { // pertensor } } return RET_OK; } STATUS AwareQuantizer::DetermineNodeQuantType() { MS_ASSERT(graph != nullptr); for (auto &node : graph->nodes) { MS_ASSERT(node != nullptr); bool canQuant = true; for (auto &outTensorIdx : node->outputIndex) { MS_ASSERT(graph->allTensors.size() > outTensorIdx); auto &outTensor = graph->allTensors.at(outTensorIdx); MS_ASSERT(outTensor != nullptr); if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || !outTensor->quantParams.front()->inited) { canQuant = false; break; } } if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*node))) { node->quantType = schema::QuantType_AwareTraining; } else { node->quantType = schema::QuantType_QUANT_NONE; } } return RET_OK; } } // namespace mindspore::lite::quant