Merge pull request !8125 from ghzl/mix-bit-pack-and-unpacktags/v1.1.0
| @@ -27,6 +27,7 @@ | |||
| #include "src/common/graph_util.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/model_common.h" | |||
| #include "mindspore/lite/src/runtime/kernel/arm/base/dequant.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -95,13 +96,26 @@ int LiteSession::ConvertTensors(const lite::Model *model) { | |||
| memcpy(dst_data, srcTensor->data()->data(), dstTensor->Size()); | |||
| copyed_tensor_idxes_.emplace_back(i); | |||
| } else { | |||
| dstTensor->set_data(const_cast<unsigned char *>(srcTensor->data()->data())); | |||
| int pack_size = srcTensor->data()->size(); | |||
| int org_size = dstTensor->Size(); | |||
| if (pack_size != org_size && (dataType == kNumberTypeInt8 || dataType == kNumberTypeInt16)) { | |||
| auto ret = dstTensor->MallocData(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Malloc data for " << i << "tensor failed "; | |||
| delete dstTensor; | |||
| return RET_ERROR; | |||
| } | |||
| kernel::DequantUtil::UnPackToInt(srcTensor, dstTensor->MutableData()); | |||
| } else { | |||
| dstTensor->set_data(const_cast<unsigned char *>(srcTensor->data()->data())); | |||
| } | |||
| } | |||
| } | |||
| auto quant_params = srcTensor->quantParams(); | |||
| if (quant_params != nullptr) { | |||
| for (size_t j = 0; j < quant_params->size(); j++) { | |||
| QuantArg quant_arg{}; | |||
| quant_arg.bitNum = quant_params->Get(j)->numBits(); | |||
| quant_arg.scale = quant_params->Get(j)->scale(); | |||
| quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint(); | |||
| quant_arg.var_corr = quant_params->Get(j)->varCorr(); | |||
| @@ -13,6 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cmath> | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| namespace mindspore::kernel { | |||
| @@ -32,4 +33,18 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { | |||
| return DequantData<int8_t>(input_tensor); | |||
| } | |||
| } | |||
| void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) { | |||
| auto quant_params = input_tensor->quantParams(); | |||
| if (quant_params == nullptr) { | |||
| MS_LOG(ERROR) << "low bits quantparams is empty."; | |||
| return; | |||
| } | |||
| int origin_bit = quant_params->Get(0)->numBits(); | |||
| if (origin_bit < 8 && origin_bit > 0) { | |||
| UnPackUtil<int8_t, uint8_t>(input_tensor, origin_bit, unpack_int_data); | |||
| } else if (origin_bit < 16 && origin_bit > 8) { | |||
| UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data); | |||
| } | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -18,6 +18,8 @@ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ | |||
| #include <vector> | |||
| #include <queue> | |||
| #include <cmath> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/tensor.h" | |||
| @@ -27,6 +29,8 @@ class DequantUtil { | |||
| public: | |||
| static float *DequantWeight(lite::Tensor *input_tensor); | |||
| static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); | |||
| template <typename T> | |||
| static float *DequantData(lite::Tensor *input_tensor) { | |||
| const auto *quant_datas = static_cast<const T *>(input_tensor->MutableData()); | |||
| @@ -99,6 +103,62 @@ class DequantUtil { | |||
| } | |||
| return dequant_datas; | |||
| } | |||
| private: | |||
| template <typename T1, typename T2> | |||
| static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int, | |||
| size_t *count, bool is_last) { | |||
| T2 uint_result = 0; | |||
| T1 result = 0; | |||
| UnPackFromUintToOrigin<T2>(packed_data, unpack_bit_data); | |||
| while (static_cast<int>(unpack_bit_data->size()) >= origin_bit) { | |||
| for (int k = 0; k < origin_bit; k++) { | |||
| bool bit_tmp = unpack_bit_data->front(); | |||
| uint_result = (static_cast<int>(bit_tmp) << k) + uint_result; | |||
| unpack_bit_data->pop(); | |||
| } | |||
| result = uint_result - static_cast<T2>(pow(2, origin_bit - 1)); | |||
| (static_cast<T1 *>(unpack_int))[*count] = result; | |||
| uint_result = 0; | |||
| (*count)++; | |||
| } | |||
| if (is_last) { | |||
| int remainder = unpack_bit_data->size(); | |||
| for (int i = 0; i < remainder; i++) { | |||
| bool bit = unpack_bit_data->front(); | |||
| uint_result = (static_cast<int>(bit) << i) + uint_result; | |||
| unpack_bit_data->pop(); | |||
| } | |||
| result = static_cast<T1>(uint_result - static_cast<T2>(pow(2, origin_bit - 1))); | |||
| (static_cast<T1 *>(unpack_int))[*count] = result; | |||
| } | |||
| } | |||
| template <typename T1, typename T2> | |||
| static void UnPackUtil(const schema::Tensor *input_tensor, int origin_bit, void *unpack_int_data) { | |||
| auto weight_data = input_tensor->data()->data(); | |||
| int pack_size = | |||
| input_tensor->dataType() == kNumberTypeInt8 ? input_tensor->data()->size() : input_tensor->data()->size() / 2; | |||
| std::queue<bool> unpack_bit_data; | |||
| size_t count = 0; | |||
| for (int i = 0; i < pack_size; ++i) { | |||
| T2 pack_data = (static_cast<const T2 *>(static_cast<const void *>(weight_data)))[i]; | |||
| bool is_last = i == pack_size - 1; | |||
| UnPackData<T1, T2>(origin_bit, pack_data, &unpack_bit_data, unpack_int_data, &count, is_last); | |||
| } | |||
| } | |||
| template <typename T2> | |||
| static void UnPackFromUintToOrigin(const T2 &packed_data, std::queue<bool> *unpack_bit_data) { | |||
| auto n = packed_data; | |||
| size_t bit_count = 0; | |||
| while (bit_count < sizeof(T2) * 8) { | |||
| bool a = n % 2; | |||
| n = n >> 1; | |||
| bit_count++; | |||
| unpack_bit_data->push(a); | |||
| } | |||
| } | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -37,6 +37,7 @@ struct QuantArg { | |||
| float mean_corr{0}; | |||
| bool inited; | |||
| std::vector<float> clusters{}; | |||
| int bitNum; | |||
| }; | |||
| class Tensor : public mindspore::tensor::MSTensor { | |||
| @@ -8,7 +8,6 @@ file(GLOB QUANTIZER | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc | |||
| @@ -0,0 +1,74 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__GENERAL_BITPACKING_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__GENERAL_BITPACKING_H | |||
| #include <stdint.h> | |||
| #include <stack> | |||
| #include <queue> | |||
| #include <vector> | |||
| #include <cassert> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class BitPack { | |||
| public: | |||
| ~BitPack() = default; | |||
| template <typename T1, typename T2> | |||
| static void BitPacking(int bit_num, const std::vector<T1> &origin_data_vec, std::vector<T2> *packed_data_vec) { | |||
| std::stack<bool> bit_data_vec; | |||
| for (size_t i = 0; i < origin_data_vec.size(); i++) { | |||
| T2 tmp = origin_data_vec[i] + static_cast<T2>(pow(2, bit_num - 1)); | |||
| DoBinary<T2>(bit_num, tmp, &bit_data_vec, packed_data_vec); | |||
| } | |||
| size_t remain_bit_data = bit_data_vec.size(); | |||
| if (sizeof(T1) * 8 > remain_bit_data && remain_bit_data > 0) { | |||
| for (size_t i = 0; i < sizeof(T1) * 8 - remain_bit_data; i++) { | |||
| bit_data_vec.push(0); | |||
| } | |||
| PackFromOriginToUint<T2>(&bit_data_vec, packed_data_vec); | |||
| } | |||
| } | |||
| private: | |||
| template <typename T2> | |||
| static void PackFromOriginToUint(std::stack<bool> *ans, std::vector<T2> *packed_data_vec) { | |||
| uint32_t result = 0; | |||
| for (size_t i = 0; i < sizeof(T2) * 8; i++) { | |||
| bool bit_tmp = ans->top(); | |||
| result = (result << 1) + static_cast<int>(bit_tmp); | |||
| ans->pop(); | |||
| } | |||
| packed_data_vec->push_back(result); | |||
| } | |||
| template <typename T2> | |||
| static void DoBinary(int bin_num, T2 n, std::stack<bool> *ans, std::vector<T2> *packed_data_vec) { | |||
| for (int bit_count = 0; bit_count < bin_num; bit_count++) { | |||
| bool a = n % 2; | |||
| n = n / 2; | |||
| ans->push(a); | |||
| if (ans->size() == sizeof(T2) * 8) { | |||
| PackFromOriginToUint(ans, packed_data_vec); | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif | |||
| @@ -1,84 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/quantizer/general_bitpacking.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| BitPack::BitPack(const uint8_t &bitnum) { this->bitnum = bitnum; } | |||
| void BitPack::UnPackFromUint8ToOrigin(uint8_t &n, std::queue<bool> &unpackBitData) { | |||
| int bitCount = 0; | |||
| while (bitCount < 8) { | |||
| bool a = n % 2; | |||
| n = n >> 1; | |||
| bitCount++; | |||
| unpackBitData.push(a); | |||
| } | |||
| } | |||
| void BitPack::UnPack(uint8_t bitnum, uint8_t &packedData, std::vector<uint8_t> &originData, | |||
| std::queue<bool> &unpackBitData) { | |||
| UnPackFromUint8ToOrigin(packedData, unpackBitData); | |||
| // std::queue<bool> unpackBitTmpData; | |||
| while (unpackBitData.size() > bitnum) { | |||
| uint32_t result = 0; | |||
| for (int k = 0; k < bitnum; k++) { | |||
| bool bitTmp = unpackBitData.front(); | |||
| result = (result << 1) + static_cast<int>(bitTmp); | |||
| unpackBitData.pop(); | |||
| } | |||
| originData.push_back(result); | |||
| } | |||
| } | |||
| void BitPack::PackFromOriginToUint8(std::stack<bool> &ans, std::vector<uint8_t> &packedDataVec) { | |||
| uint32_t result = 0; | |||
| for (size_t i = 0; i < 8; i++) { | |||
| bool bit_tmp = ans.top(); | |||
| result = (result << 1) + static_cast<int>(bit_tmp); | |||
| ans.pop(); | |||
| } | |||
| packedDataVec.push_back(result); | |||
| } | |||
| void BitPack::DoBinary(uint8_t &n, std::stack<bool> &ans, std::vector<uint8_t> &packedDataVec) { | |||
| int bitCount = 0; | |||
| while (bitCount < bitnum) { | |||
| bool a = n / (1 << (unsigned int)(bitnum - bitCount - 1)); | |||
| n = n - a * (1 << (unsigned int)(bitnum - bitCount - 1)); | |||
| bitCount++; | |||
| ans.push(a); | |||
| if (ans.size() == 8) { | |||
| PackFromOriginToUint8(ans, packedDataVec); | |||
| } | |||
| } | |||
| } | |||
| void BitPack::BitPacking(const std::vector<uint8_t> &originDataVec, std::vector<uint8_t> &packedDataVec) { | |||
| std::stack<bool> bitDataVec; | |||
| for (size_t i = 0; i < originDataVec.size(); i++) { | |||
| uint8_t tmp = originDataVec[i]; | |||
| DoBinary(tmp, bitDataVec, packedDataVec); | |||
| } | |||
| size_t remainBitData = bitDataVec.size(); | |||
| if (8 > remainBitData && remainBitData > 0) { | |||
| for (size_t i = 0; i < 8 - remainBitData; i++) { | |||
| bitDataVec.push(0); | |||
| } | |||
| PackFromOriginToUint8(bitDataVec, packedDataVec); | |||
| } | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,43 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__GENERAL_BITPACKING_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__GENERAL_BITPACKING_H | |||
| #include <stdint.h> | |||
| #include <stack> | |||
| #include <queue> | |||
| #include <vector> | |||
| #include <cassert> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class BitPack { | |||
| public: | |||
| explicit BitPack(const uint8_t &bitbum = 8); | |||
| ~BitPack() = default; | |||
| void BitPacking(const std::vector<uint8_t> &originDataVec, std::vector<uint8_t> &packedDataVec); | |||
| void UnPack(uint8_t bitnum, uint8_t &packedData, std::vector<uint8_t> &originData, std::queue<bool> &unpackBitData); | |||
| private: | |||
| void UnPackFromUint8ToOrigin(uint8_t &n, std::queue<bool> &unpackBitData); | |||
| void PackFromOriginToUint8(std::stack<bool> &ans, std::vector<uint8_t> &packedDataVec); | |||
| void DoBinary(uint8_t &n, std::stack<bool> &ans, std::vector<uint8_t> &packed_data_vec); | |||
| uint8_t bitnum; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif | |||
| @@ -22,7 +22,7 @@ | |||
| #include <vector> | |||
| #include <set> | |||
| #include "src/ops/primitive_c.h" | |||
| #include "mindspore/lite/tools/converter/quantizer/general_bitpacking.h" | |||
| #include "mindspore/lite/tools/converter/quantizer/bitpacking.h" | |||
| #include "src/common/utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "securec/include/securec.h" | |||
| @@ -292,30 +292,6 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl | |||
| return RET_OK; | |||
| } | |||
| STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) { | |||
| auto *rawDatas = reinterpret_cast<uint8_t *>(weight); | |||
| vector<uint8_t> qDatas(rawDatas, rawDatas + shapeSize); | |||
| vector<uint8_t> qDatas_packed; | |||
| if (bitNum < 8 && bitNum > 1) { | |||
| BitPack weight_bitpack(bitNum); | |||
| weight_bitpack.BitPacking(qDatas, qDatas_packed); | |||
| if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas_packed[0], shapeSize)) { | |||
| MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (bitNum == 8) { | |||
| if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas[0], shapeSize)) { | |||
| MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "bitNum must be between 0 and 8 : " << bitNum; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| static bool SearchLowerBound(const std::vector<float> &data, const size_t &index, const float &max_tmp, float *min_tmp, | |||
| size_t *min_idx) { | |||
| size_t length = data.size(); | |||
| @@ -34,6 +34,7 @@ | |||
| #include "base/base.h" | |||
| #include "ir/primitive.h" | |||
| #include "abstract/dshape.h" | |||
| #include "tools/converter/quantizer/bitpacking.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -279,6 +280,34 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||
| } | |||
| weight->set_tensor_size(elem_count * sizeof(T)); | |||
| } | |||
| // do bit pack | |||
| if (bitNum != 8 && bitNum != 16) { | |||
| std::vector<T> data{}; | |||
| for (size_t i = 0; i < quant_datas.size(); ++i) { | |||
| data.emplace_back((static_cast<T>(quant_datas[i]))); | |||
| } | |||
| if (bitNum > 0 && bitNum < 8) { | |||
| std::vector<uint8_t> pack_data{}; | |||
| BitPack::BitPacking<T, uint8_t>(bitNum, data, &pack_data); | |||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), pack_data.data(), pack_data.size() * sizeof(uint8_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed"; | |||
| return RET_ERROR; | |||
| } | |||
| weight->set_tensor_size(pack_data.size() * sizeof(uint8_t)); | |||
| } else if (bitNum > 8 && bitNum < 16) { | |||
| std::vector<uint16_t> pack_data{}; | |||
| BitPack::BitPacking<T, uint16_t>(bitNum, data, &pack_data); | |||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), pack_data.data(), pack_data.size() * sizeof(uint16_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed"; | |||
| return RET_ERROR; | |||
| } | |||
| weight->set_tensor_size(pack_data.size() * sizeof(uint16_t)); | |||
| } | |||
| } | |||
| if (quant_params.empty()) { | |||
| MS_LOG(ERROR) << "quant_params empty"; | |||
| return RET_ERROR; | |||
| @@ -291,8 +320,6 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||
| return RET_OK; | |||
| } | |||
| STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); | |||
| schema::PrimitiveType NodePrimitiveType(CNodePtr cnode); | |||
| } // namespace quant | |||
| } // namespace lite | |||
| @@ -48,8 +48,13 @@ STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { | |||
| MS_LOG(ERROR) << "quantSize must be valid pos num."; | |||
| return RET_ERROR; | |||
| } | |||
| if (!WeightQuantizer::IsPosNum(config->bitNum) || (config->bitNum != "8" && config->bitNum != "16")) { | |||
| MS_LOG(ERROR) << "bitNum must be valid pos num, current only support 8 or 16 bit weight quant."; | |||
| if (!WeightQuantizer::IsPosNum(config->bitNum)) { | |||
| MS_LOG(ERROR) << "bitNum must be valid pos num."; | |||
| return RET_ERROR; | |||
| } | |||
| int bitNum = std::stoi(config->bitNum); | |||
| if (bitNum <= 0 || bitNum > 16) { | |||
| MS_LOG(ERROR) << "bitNum should be more than 0 and less than 16 currently."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| @@ -63,10 +68,13 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, | |||
| mStrategy.reset(new QuantStrategy(quantSize, convQuantWeightChannelThreshold)); | |||
| quant_max = (1 << (unsigned int)(this->bitNum - 1)) - 1; | |||
| quant_min = -(1 << (unsigned int)(this->bitNum - 1)); | |||
| if (this->bitNum == 8) { | |||
| // parse type_id | |||
| if (this->bitNum > 0 && this->bitNum <= 8) { | |||
| type_id = kNumberTypeInt8; | |||
| } else if (this->bitNum == 16) { | |||
| } else if (this->bitNum <= 16) { | |||
| type_id = kNumberTypeInt16; | |||
| } else { | |||
| MS_LOG(ERROR) << "invalid input bits"; | |||
| } | |||
| } | |||
| @@ -100,7 +108,6 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||
| MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type(); | |||
| return RET_ERROR; | |||
| } | |||
| auto status = RET_ERROR; | |||
| if (type_id == kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||
| @@ -127,7 +134,6 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||
| abstractTensor->element()->set_type(TypeIdToType(type_id)); | |||
| primitive_c->SetQuantType(schema::QuantType_WeightQuant); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -136,7 +142,6 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||
| if (!mStrategy->CanMulOpQuantized(node)) { | |||
| continue; | |||
| } | |||
| auto already_quant = false; | |||
| ParamValueLitePtr param_value = nullptr; | |||
| ParameterPtr param_node = nullptr; | |||
| for (size_t i = 1; i < node->size(); i++) { | |||
| @@ -146,16 +151,8 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||
| if ((param_node != nullptr) && param_node->has_default()) { | |||
| param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||
| if ((param_value == nullptr) || (param_value->tensor_size() == 0) || | |||
| (param_value->tensor_addr() == nullptr)) { | |||
| param_value = nullptr; | |||
| continue; | |||
| } else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 || | |||
| param_value->tensor_type() == mindspore::kNumberTypeInt16) { | |||
| MS_LOG(INFO) << "the node: " << node->fullname_with_scope() << " input_i: " << i << "has been " | |||
| << " quantized"; | |||
| already_quant = true; | |||
| break; | |||
| } else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { | |||
| (param_value->tensor_addr() == nullptr) || | |||
| (param_value->tensor_type() != mindspore::kNumberTypeFloat32)) { | |||
| param_value = nullptr; | |||
| continue; | |||
| } else { | |||
| @@ -164,11 +161,6 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||
| } | |||
| } | |||
| } | |||
| if (already_quant) { | |||
| continue; | |||
| } | |||
| if (param_value == nullptr) { | |||
| MS_LOG(ERROR) << "No valid input param node !"; | |||
| return RET_ERROR; | |||