/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "mindspore/lite/tools/converter/quantizer/quantize_util.h" #include #include #include #include #include #include #include "src/ops/primitive_c.h" #include "mindspore/lite/tools/converter/quantizer/general_bitpacking.h" #include "src/common/utils.h" #include "abstract/abstract_value.h" #include "securec/include/securec.h" using std::string; using std::vector; namespace mindspore { namespace lite { namespace quant { const std::vector QuantStrategy::conv_types = { schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; const std::vector QuantStrategy::mul_types = {schema::PrimitiveType_MatMul, schema::PrimitiveType_FullConnection}; QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThreshold) : mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {} bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const { auto primitive_c = GetValueNode>(node->input(0)); if (primitive_c == nullptr) { MS_LOG(ERROR) << "primitive_c is nullptr"; return false; } if (!IsContain(conv_types, (schema::PrimitiveType)primitive_c->Type())) { return false; } if (node->size() < 3) { return false; } auto inputNode = node->input(2); if (!inputNode->isa()) { return false; } auto paramNode = inputNode->cast(); auto abstract_base = paramNode->abstract(); if (abstract_base == nullptr) { return false; } if (!utils::isa(abstract_base->GetShapeTrack())) { MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); return false; } auto weight_shape = utils::cast(abstract_base->GetShapeTrack())->shape(); size_t shapeSize = 1; for (auto dim : weight_shape) { shapeSize = shapeSize * dim; } if (shapeSize < mWeightSize) { MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; return false; } if (weight_shape[0] <= static_cast(mConvWeightQuantChannelThreshold)) { MS_LOG(INFO) << "channel less mConvWeightQuantChannelThreshold!" << weight_shape[0]; return false; } return true; } bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { if (!node->isa()) { return false; } auto cnode = std::dynamic_pointer_cast(node); auto type = NodePrimitiveType(cnode); static const std::vector int8OpList = { schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, schema::PrimitiveType_Concat, schema::PrimitiveType_Split, schema::PrimitiveType_TupleGetItem, schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMul, schema::PrimitiveType_Crop, schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_Activation, schema::PrimitiveType_TupleGetItem, }; bool contain = IsContain(int8OpList, type); if (!contain) { MS_LOG(INFO) << "not quant, " << cnode->fullname_with_scope() << " of type: " << schema::EnumNamePrimitiveType(type); } return contain; } bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const { auto primitive_c = GetValueNode>(node->input(0)); if (primitive_c == nullptr) { MS_LOG(ERROR) << "primitive_c is nullptr"; return false; } if (!IsContain(mul_types, (schema::PrimitiveType)primitive_c->Type())) { return false; } if (node->size() < 3) { MS_LOG(INFO) << "input size less!"; return false; } auto inputNode1 = node->input(1); auto inputNode2 = node->input(2); if (inputNode1 == nullptr || inputNode2 == nullptr) { MS_LOG(INFO) << "mul input is nullptr!"; return false; } ParameterPtr paramNode = nullptr; if (inputNode1->isa()) { paramNode = inputNode1->cast(); } else if (inputNode2->isa()) { paramNode = inputNode2->cast(); } if (paramNode == nullptr) { MS_LOG(INFO) << "invalid paramNode!"; return false; } auto abstract_base = paramNode->abstract(); if (abstract_base == nullptr) { MS_LOG(INFO) << "abstract is nullptr"; return false; } if (!utils::isa(abstract_base->GetShapeTrack())) { MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); return false; } auto weight_shape = utils::cast(abstract_base->GetShapeTrack())->shape(); size_t shapeSize = 1; for (auto dim : weight_shape) { shapeSize = shapeSize * dim; } if (shapeSize < mWeightSize) { MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; return false; } return true; } STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max, int quant_min, int num_bits) { MS_ASSERT(quantParam != nullptr); if (mMin > 0.0f) { MS_LOG(DEBUG) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; mMin = 0.0f; } if (mMax < 0.0f) { MS_LOG(DEBUG) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision"; mMax = 0.0f; } if (mMin > mMax) { MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax; return RET_PARAM_INVALID; } if (mMin == mMax) { if (mMin != 0.0f) { MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; return RET_ERROR; } quantParam->inited = true; quantParam->min = mMin; quantParam->max = mMax; quantParam->scale = 0.0f; quantParam->zeroPoint = 0; quantParam->narrowRange = narrowRange; quantParam->numBits = num_bits; return RET_OK; } auto quantMinFloat = static_cast(quant_min); auto quantMaxFloat = static_cast(quant_max); double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat); const double zeroPointFromMin = quantMinFloat - mMin / scale; int zeroPoint = static_cast(std::round(zeroPointFromMin)); // The zero point should always be in the range of quantized value, // [qmin, qmax]. MS_ASSERT(zeroPoint >= quantMin); MS_ASSERT(zeroPoint <= quantMax); quantParam->inited = true; quantParam->min = mMin; quantParam->max = mMax; quantParam->scale = scale; quantParam->zeroPoint = zeroPoint; quantParam->narrowRange = narrowRange; quantParam->numBits = num_bits; return RET_OK; } STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int numBits) { MS_ASSERT(quantParam != nullptr); if (mMin > 0.0f) { MS_LOG(DEBUG) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; mMin = 0.0f; } if (mMax < 0.0f) { MS_LOG(DEBUG) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision"; mMax = 0.0f; } if (mMin > mMax) { MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax; return RET_PARAM_INVALID; } if (mMin == mMax) { if (mMin != 0.0f) { MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; return RET_ERROR; } quantParam->inited = false; quantParam->min = mMin; quantParam->max = mMax; quantParam->scale = 0.0f; quantParam->zeroPoint = 0; quantParam->narrowRange = narrowRange; quantParam->numBits = numBits; return RET_OK; } const int8_t quantMin = std::numeric_limits::min() + (narrowRange ? 1 : 0); const int8_t quantMax = std::numeric_limits::max(); auto quantMinFloat = static_cast(quantMin); auto quantMaxFloat = static_cast(quantMax); double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat); const double zeroPointFromMin = quantMinFloat - mMin / scale; const double zeroPointFromMax = quantMaxFloat - mMax / scale; const double zpFromMinError = std::abs(quantMinFloat) + std::abs(mMin / scale); const double zpFromMaxError = std::abs(quantMaxFloat) + std::abs(mMax / scale); const double zpDouble = zpFromMinError < zpFromMaxError ? zeroPointFromMin : zeroPointFromMax; int zeroPoint; if (zpDouble < quantMinFloat) { zeroPoint = quantMin; } else if (zpDouble > quantMaxFloat) { zeroPoint = quantMax; } else { zeroPoint = static_cast(std::round(zpDouble)); } if (std::abs(mMin) == std::abs(mMax)) { zeroPoint = 0; } // The zero point should always be in the range of quantized value, // [qmin, qmax]. MS_ASSERT(zeroPoint >= quantMin); MS_ASSERT(zeroPoint <= quantMax); quantParam->inited = true; quantParam->min = mMin; quantParam->max = mMax; quantParam->scale = scale; quantParam->zeroPoint = zeroPoint; quantParam->narrowRange = narrowRange; quantParam->numBits = numBits; return RET_OK; } STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) { auto *rawDatas = reinterpret_cast(weight); vector qDatas(rawDatas, rawDatas + shapeSize); vector qDatas_packed; if (bitNum < 8 && bitNum > 1) { BitPack weight_bitpack(bitNum); weight_bitpack.BitPacking(qDatas, qDatas_packed); if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas_packed[0], shapeSize)) { MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed"; return RET_ERROR; } } else if (bitNum == 8) { if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas[0], shapeSize)) { MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas failed"; return RET_ERROR; } } else { MS_LOG(ERROR) << "bitNum must be between 0 and 8 : " << bitNum; return RET_ERROR; } return RET_OK; } static bool SearchLowerBound(const std::vector &data, const size_t &index, const float &max_tmp, float *min_tmp, size_t *min_idx) { size_t length = data.size(); if (max_tmp - data.at(index) < delta) { return false; } float range_ratio = (data.at(index) - *min_tmp) / (max_tmp - *min_tmp); float index_ratio = static_cast(index - *min_idx) / (length - *min_idx); if (index_ratio > 0 && range_ratio / index_ratio > ratio) { *min_idx = index; *min_tmp = data.at(index); } return true; } static bool SearchUpperBound(const std::vector &data, const size_t &index, float *max_tmp, const float &min_tmp, size_t *max_idx) { size_t length = data.size(); if (data.at(index) - min_tmp < delta) { return false; } float range_ratio = (*max_tmp - data.at(index)) / (*max_tmp - min_tmp); float index_ratio = static_cast(index - *max_idx) / (length - *max_idx); if (index_ratio > 0 && range_ratio / index_ratio > ratio) { *max_idx = index; *max_tmp = data.at(index); } return true; } static float CalPercentile(const std::vector &datas, const int &outlier_percent) { const int size = datas.size(); float val = outlier_percent / 100.0 * size; int index = std::ceil(val); float result = 0.0; if (index - val > 0) { result = datas.at(index - 1); } else { result = (datas.at(index - 1) + datas.at(index)) / 2; } return result; } std::pair OutlierMethod(std::vector min_datas, std::vector max_datas) { std::sort(max_datas.begin(), max_datas.end()); std::sort(min_datas.begin(), min_datas.end()); float min_val = CalPercentile(min_datas, percent); float max_val = CalPercentile(max_datas, 100 - percent); std::reverse(max_datas.begin(), max_datas.end()); MS_ASSERT(min_val < max_val); MS_ASSERT(min_datas.size() == max_datas.size()); float min_tmp = min_val; float max_tmp = max_val; size_t min_idx = 0; size_t max_idx = 0; size_t length = min_datas.size(); for (size_t i = 0; i < length; i++) { if (!SearchLowerBound(min_datas, i, max_tmp, &min_tmp, &min_idx)) { break; } if (!SearchUpperBound(min_datas, i, &max_tmp, min_tmp, &max_idx)) { break; } } std::pair result{min_tmp, max_tmp}; return result; } static std::vector InitClusters(float *data, size_t elem_count, size_t k) { std::set set_unique{}; for (size_t i = 0; i < elem_count; i++) { set_unique.emplace(data[i]); } std::vector data_unique; data_unique.assign(set_unique.begin(), set_unique.end()); std::vector clusters{}; if (set_unique.size() < k) { return clusters; } // init cluster float ratio = static_cast(data_unique.size()) / (k - 1); std::sort(data_unique.begin(), data_unique.end()); for (size_t i = 0; i < k; i++) { size_t index = std::floor(i * ratio); if (i * ratio - index > 0) { clusters.emplace_back((data_unique[index] + data_unique[index + 1]) / 2); } else { clusters.emplace_back(data_unique[index]); } } return clusters; } std::vector KMeans(float *data, size_t elem_count, size_t k, size_t epochs, schema::QuantParamT *quantParam) { std::vector clusters = InitClusters(data, elem_count, k); std::vector clusters_index{}; double error{0}; if (clusters.size() < k) { MS_LOG(WARNING) << "K is less than the size of data so KMeans function is not executed."; return clusters_index; } for (size_t epoch = 0; epoch < epochs; epoch++) { double error_cur{0}; clusters_index.clear(); std::vector> clusters_data(clusters.size()); for (size_t i = 0; i < elem_count; i++) { size_t index = 0; float min_distance = pow(data[i] - clusters[0], 2); for (size_t j = 1; j < clusters.size(); j++) { if (pow(data[i] - clusters[j], 2) < min_distance) { min_distance = pow(data[i] - clusters[j], 2); index = j; } } clusters_index.emplace_back(index + INT8_MIN); clusters_data[index].emplace_back(data[i]); } for (size_t j = 0; j < clusters.size(); j++) { if (clusters_data[j].size() > 0) { clusters[j] = std::accumulate(clusters_data[j].begin(), clusters_data[j].end(), 0.0) / clusters_data[j].size(); } } // compare error for (size_t j = 0; j < elem_count; j++) { error_cur += pow(data[j] - clusters[clusters_index[j]], 2); } error_cur = pow(error_cur / elem_count, 0.5); if (std::abs((error_cur - error) / error_cur) < 1e-6) { break; } error = error_cur; } // update data quantParam->clusters = clusters; return clusters_index; } schema::PrimitiveType NodePrimitiveType(CNodePtr cnode) { if (cnode == nullptr) { MS_LOG(ERROR) << "cnode is null"; return schema::PrimitiveType_NONE; } auto primitive_c = GetValueNode>(cnode->input(0)); if (primitive_c == nullptr) { MS_LOG(ERROR) << "primitive_c is null"; return schema::PrimitiveType_NONE; } return (schema::PrimitiveType)primitive_c->Type(); } } // namespace quant } // namespace lite } // namespace mindspore