| @@ -18,9 +18,10 @@ | |||
| #include <memory> | |||
| #include "src/dequant.h" | |||
| #include "src/huffman_decode.h" | |||
| #include "src/ops/matmul.h" | |||
| namespace mindspore::lite { | |||
| float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { | |||
| float *DequantUtil::DequantWeight(lite::Tensor *input_tensor, bool channel_first) { | |||
| MS_ASSERT(input_tensor != nullptr); | |||
| if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) { | |||
| MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type(); | |||
| @@ -31,9 +32,9 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { | |||
| return nullptr; | |||
| } | |||
| if (input_tensor->data_type() == kNumberTypeInt16) { | |||
| return DequantData<int16_t>(input_tensor); | |||
| return DequantData<int16_t>(input_tensor, channel_first); | |||
| } else { | |||
| return DequantData<int8_t>(input_tensor); | |||
| return DequantData<int8_t>(input_tensor, channel_first); | |||
| } | |||
| } | |||
| @@ -65,19 +66,35 @@ int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_in | |||
| return RET_OK; | |||
| } | |||
| std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors, | |||
| std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const mindspore::lite::PrimitiveC *primitive, | |||
| const std::vector<Tensor *> &in_tensors, | |||
| TypeId data_type, bool need_restore) { | |||
| std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data; | |||
| if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) { | |||
| auto input_i = 0; | |||
| for (auto weight_tensor : in_tensors) { | |||
| MS_ASSERT(weight_tensor != nullptr); | |||
| input_i++; | |||
| auto channel_first = true; | |||
| if ((schema::PrimitiveType)primitive->Type() == schema::PrimitiveType_MatMul && | |||
| weight_tensor->shape().size() == 2) { | |||
| auto param = reinterpret_cast<mindspore::lite::MatMul *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| if (input_i == 1) { | |||
| channel_first = !param->GetTransposeA(); | |||
| } else if (input_i == 2) { | |||
| channel_first = param->GetTransposeB(); | |||
| } else { | |||
| MS_LOG(WARNING) << "unexpected input_i"; | |||
| } | |||
| } | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto restore_type = weight_tensor->data_type(); | |||
| bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && | |||
| restore_data != nullptr && | |||
| (restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16); | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor); | |||
| auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor, channel_first); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| return tensor_origin_data; | |||
| @@ -29,17 +29,18 @@ | |||
| namespace mindspore::lite { | |||
| class DequantUtil { | |||
| public: | |||
| static float *DequantWeight(lite::Tensor *input_tensor); | |||
| static float *DequantWeight(lite::Tensor *input_tensor, bool); | |||
| static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); | |||
| static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors, | |||
| static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const mindspore::lite::PrimitiveC *primitive, | |||
| const std::vector<Tensor *> &in_tensors, | |||
| TypeId data_type, bool need_restore = true); | |||
| static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map); | |||
| template <typename ST, typename DT = float> | |||
| static DT *DequantData(lite::Tensor *input_tensor) { | |||
| static DT *DequantData(lite::Tensor *input_tensor, bool channel_first = true) { | |||
| const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData()); | |||
| if (quant_datas == nullptr) { | |||
| MS_LOG(ERROR) << "Get quant tensor failed."; | |||
| @@ -65,6 +66,13 @@ class DequantUtil { | |||
| } | |||
| } else if (input_tensor->quant_params().size() != kPerTensor) { | |||
| auto channels = static_cast<size_t>(input_tensor->Batch()); | |||
| if (!channel_first) { | |||
| if (input_tensor->shape().size() != 2) { | |||
| MS_LOG(ERROR) << "unexpected shape size: " << input_tensor->shape().size(); | |||
| return nullptr; | |||
| } | |||
| channels = input_tensor->shape()[1]; | |||
| } | |||
| if (input_tensor->quant_params().size() != channels) { | |||
| MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->quant_params().size() << channels; | |||
| free(dequant_datas); | |||
| @@ -83,8 +91,12 @@ class DequantUtil { | |||
| var_corr = 1; | |||
| } | |||
| for (size_t j = 0; j < per_channel_size; j++) { | |||
| auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale; | |||
| dequant_datas[per_channel_size * i + j] = static_cast<DT>(dequant_data * var_corr + mean_corr); | |||
| auto index = per_channel_size * i + j; | |||
| if (!channel_first) { | |||
| index = channels * j + i; | |||
| } | |||
| auto dequant_data = (quant_datas[index] - zero_point) * scale; | |||
| dequant_datas[index] = static_cast<DT>(dequant_data * var_corr + mean_corr); | |||
| } | |||
| } | |||
| } else { | |||
| @@ -223,7 +223,8 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||
| if (mindspore::lite::IsSupportFloat16() && | |||
| ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { | |||
| kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; | |||
| auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type, need_restore); | |||
| auto tensor_origin_data_map = | |||
| DequantUtil::DequantTensor(primitive, in_tensors, fp16_cpu_desc.data_type, need_restore); | |||
| auto *kernel = | |||
| KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); | |||
| DequantUtil::RestoreTensorData(tensor_origin_data_map); | |||
| @@ -237,7 +238,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||
| MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; | |||
| desc.data_type = kNumberTypeFloat32; | |||
| } | |||
| auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type, need_restore); | |||
| auto tensor_origin_data_map = DequantUtil::DequantTensor(primitive, in_tensors, desc.data_type, need_restore); | |||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | |||
| DequantUtil::RestoreTensorData(tensor_origin_data_map); | |||
| if (kernel != nullptr) { | |||
| @@ -358,15 +358,15 @@ static bool SearchUpperBound(const std::vector<float> &data, const size_t &index | |||
| return true; | |||
| } | |||
| static float CalPercentile(const std::vector<float> &datas, const int &outlier_percent) { | |||
| const int size = datas.size(); | |||
| static float CalPercentile(const std::vector<float> &data, const int &outlier_percent) { | |||
| const int size = data.size(); | |||
| float val = outlier_percent / 100.0 * size; | |||
| int index = std::ceil(val); | |||
| float result; | |||
| if (index - val > 0) { | |||
| result = datas.at(index - 1); | |||
| result = data.at(index - 1); | |||
| } else { | |||
| result = (datas.at(index - 1) + datas.at(index)) / 2; | |||
| result = (data.at(index - 1) + data.at(index)) / 2; | |||
| } | |||
| return result; | |||
| } | |||
| @@ -522,11 +522,78 @@ std::vector<std::vector<int>> DataToVectors(const string &str) { | |||
| return result; | |||
| } | |||
| STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) { | |||
| if (post_quant_config == nullptr) { | |||
| MS_LOG(ERROR) << "post_quant_config is null."; | |||
| return RET_PARAM_INVALID; | |||
| void ParseInputShape(PostQuantConfig *post_quant_config, std::string raw_shape) { | |||
| MS_ASSERT(post_quant_config != nullptr); | |||
| auto ind = raw_shape.find('/'); | |||
| while (ind != std::string::npos) { | |||
| auto shape = raw_shape.substr(0, ind); | |||
| Trim(&shape); | |||
| post_quant_config->input_shapes.push_back(DataToVectors(shape)); | |||
| raw_shape = raw_shape.substr(ind + 1); | |||
| Trim(&raw_shape); | |||
| ind = raw_shape.find('/'); | |||
| } | |||
| if (!raw_shape.empty()) { | |||
| post_quant_config->input_shapes.push_back(DataToVectors(raw_shape)); | |||
| } | |||
| } | |||
| void ParseImagePath(PostQuantConfig *post_quant_config, std::string raw_image_paths) { | |||
| MS_ASSERT(post_quant_config != nullptr); | |||
| auto ind = raw_image_paths.find(','); | |||
| while (ind != std::string::npos) { | |||
| auto image_path = raw_image_paths.substr(0, ind); | |||
| Trim(&image_path); | |||
| post_quant_config->image_paths.push_back(image_path); | |||
| raw_image_paths = raw_image_paths.substr(ind + 1); | |||
| Trim(&raw_image_paths); | |||
| ind = raw_image_paths.find(','); | |||
| } | |||
| post_quant_config->image_paths.push_back(raw_image_paths); | |||
| } | |||
| void ParseBatchCount(PostQuantConfig *post_quant_config, std::string value) { | |||
| MS_ASSERT(post_quant_config != nullptr); | |||
| post_quant_config->batch_count = std::stoul(value); | |||
| } | |||
| void ParseThreadNum(PostQuantConfig *post_quant_config, std::string value) { | |||
| MS_ASSERT(post_quant_config != nullptr); | |||
| post_quant_config->thread_num = std::stoul(value); | |||
| } | |||
| void ParseMethodX(PostQuantConfig *post_quant_config, const std::string &value) { | |||
| MS_ASSERT(post_quant_config != nullptr); | |||
| if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) { | |||
| MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value."; | |||
| } else { | |||
| post_quant_config->method_x = value; | |||
| } | |||
| } | |||
| void ParseMixed(PostQuantConfig *post_quant_config, std::string value) { | |||
| MS_ASSERT(post_quant_config != nullptr); | |||
| std::for_each(value.begin(), value.end(), ::tolower); | |||
| if (value == "true") { | |||
| post_quant_config->mixed = true; | |||
| } | |||
| } | |||
| void ParseMeanErrorThreshold(PostQuantConfig *post_quant_config, std::string value) { | |||
| MS_ASSERT(post_quant_config != nullptr); | |||
| post_quant_config->mean_error_threshold = std::stof(value); | |||
| } | |||
| void ParseBiasCorrection(PostQuantConfig *post_quant_config, std::string value) { | |||
| MS_ASSERT(post_quant_config != nullptr); | |||
| std::for_each(value.begin(), value.end(), ::tolower); | |||
| if (value == "true") { | |||
| post_quant_config->bias_correction = true; | |||
| } | |||
| } | |||
| STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) { | |||
| MS_ASSERT(post_quant_config != nullptr); | |||
| if (config_file.empty() || config_file.length() > PATH_MAX) { | |||
| MS_LOG(ERROR) << "invalid config path!"; | |||
| @@ -552,6 +619,26 @@ STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_conf | |||
| MS_LOG(ERROR) << "config file open failed: " << config_file; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| std::string INPUT_SHAPES = "input_shapes"; | |||
| std::string IMAGE_PATH = "image_path"; | |||
| std::string BATCH_COUNT = "batch_count"; | |||
| std::string THREAD_NUM = "thread_num"; | |||
| std::string METHOD_X = "method_x"; | |||
| std::string MIXED = "mixed"; | |||
| std::string MEAN_ERROR_THRESHOLD = "mean_error_threshold"; | |||
| std::string BIAS_CORRECTION = "bias_correction"; | |||
| std::map<std::string, std::function<void(PostQuantConfig *, std::string)>> value_parser; | |||
| value_parser[INPUT_SHAPES] = ParseInputShape; | |||
| value_parser[IMAGE_PATH] = ParseImagePath; | |||
| value_parser[BATCH_COUNT] = ParseBatchCount; | |||
| value_parser[THREAD_NUM] = ParseThreadNum; | |||
| value_parser[METHOD_X] = ParseMethodX; | |||
| value_parser[MIXED] = ParseMixed; | |||
| value_parser[MEAN_ERROR_THRESHOLD] = ParseMeanErrorThreshold; | |||
| value_parser[BIAS_CORRECTION] = ParseBiasCorrection; | |||
| std::string line; | |||
| while (std::getline(fs, line)) { | |||
| Trim(&line); | |||
| @@ -567,54 +654,9 @@ STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_conf | |||
| auto value = line.substr(index + 1); | |||
| Trim(&key); | |||
| Trim(&value); | |||
| if (key == "image_path") { | |||
| auto &raw_image_paths = value; | |||
| auto ind = raw_image_paths.find(','); | |||
| while (ind != std::string::npos) { | |||
| auto image_path = raw_image_paths.substr(0, ind); | |||
| Trim(&image_path); | |||
| post_quant_config->image_paths.push_back(image_path); | |||
| raw_image_paths = raw_image_paths.substr(ind + 1); | |||
| Trim(&raw_image_paths); | |||
| ind = raw_image_paths.find(','); | |||
| } | |||
| post_quant_config->image_paths.push_back(raw_image_paths); | |||
| } else if (key == "batch_count") { | |||
| post_quant_config->batch_count = std::stoul(value); | |||
| } else if (key == "thread_num") { | |||
| post_quant_config->thread_num = std::stoul(value); | |||
| } else if (key == "method_x") { | |||
| if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) { | |||
| MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value."; | |||
| } else { | |||
| post_quant_config->method_x = value; | |||
| } | |||
| } else if (key == "bias_correction") { | |||
| std::for_each(value.begin(), value.end(), ::tolower); | |||
| if (value == "true") { | |||
| post_quant_config->bias_correction = true; | |||
| } | |||
| } else if (key == "mixed") { | |||
| std::for_each(value.begin(), value.end(), ::tolower); | |||
| if (value == "true") { | |||
| post_quant_config->mixed = true; | |||
| } | |||
| } else if (key == "mean_error_threshold") { | |||
| post_quant_config->mean_error_threshold = std::stof(value); | |||
| } else if (key == "input_shapes") { | |||
| auto &raw_shape = value; | |||
| auto ind = raw_shape.find('/'); | |||
| while (ind != std::string::npos) { | |||
| auto shape = raw_shape.substr(0, ind); | |||
| Trim(&shape); | |||
| post_quant_config->input_shapes.push_back(DataToVectors(shape)); | |||
| raw_shape = raw_shape.substr(ind + 1); | |||
| Trim(&raw_shape); | |||
| ind = raw_shape.find('/'); | |||
| } | |||
| if (!raw_shape.empty()) { | |||
| post_quant_config->input_shapes.push_back(DataToVectors(raw_shape)); | |||
| } | |||
| auto it = value_parser.find(key); | |||
| if (it != value_parser.end()) { | |||
| it->second(post_quant_config, value); | |||
| } else { | |||
| MS_LOG(WARNING) << "unsupported parameter: " << key; | |||
| } | |||
| @@ -881,4 +923,24 @@ STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int | |||
| return RET_OK; | |||
| } | |||
| void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas, | |||
| bool channel_at_first, float *desired_max, float *desired_min) { | |||
| float min = FLT_MAX; | |||
| float max = -FLT_MAX; | |||
| // find min and max | |||
| for (int j = 0; j < one_filter_size; j++) { | |||
| auto index = j + i * one_filter_size; | |||
| if (!channel_at_first) { | |||
| index = j * channels + i; | |||
| } | |||
| if (index >= elem_count) { | |||
| MS_LOG(ERROR) << "over flow!"; | |||
| } | |||
| min = std::min(min, raw_datas[index]); | |||
| max = std::max(max, raw_datas[index]); | |||
| } | |||
| *desired_max = max; | |||
| *desired_min = min; | |||
| } | |||
| } // namespace mindspore::lite::quant | |||
| @@ -107,6 +107,9 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc | |||
| STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int new_size); | |||
| void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas, | |||
| bool channel_at_first, float *desired_max, float *desired_min); | |||
| template <typename T> | |||
| T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { | |||
| MS_ASSERT(quantParam != nullptr); | |||
| @@ -163,11 +166,19 @@ template <typename T> | |||
| STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant_type, | |||
| std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min, | |||
| const size_t &bit_num, const bool &k_means, std::vector<T> *quant_datas, | |||
| std::vector<float> *dequant_datas) { | |||
| std::vector<float> *dequant_datas, bool channel_at_first = true) { | |||
| auto dims = weight->tensor_shape(); | |||
| size_t elem_count = weight->tensor_shape_size(); | |||
| auto *raw_datas = static_cast<float *>(weight->tensor_addr()); | |||
| auto channels = dims[0]; | |||
| if (!channel_at_first) { | |||
| if (dims.size() != 2) { | |||
| MS_LOG(ERROR) << "unexpected dims size: " << dims.size(); | |||
| channel_at_first = true; | |||
| } else { | |||
| channels = dims[1]; | |||
| } | |||
| } | |||
| if (channels == 0) { | |||
| MS_LOG(ERROR) << "channels is zero"; | |||
| return RET_ERROR; | |||
| @@ -181,16 +192,7 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant | |||
| for (int i = 0; i < channels; i++) { | |||
| float min = FLT_MAX; | |||
| float max = -FLT_MAX; | |||
| // find min and max | |||
| for (size_t j = 0; j < one_filter_size; j++) { | |||
| auto index = j + i * one_filter_size; | |||
| if (index >= elem_count) { | |||
| MS_LOG(ERROR) << "over flow!"; | |||
| return RET_ERROR; | |||
| } | |||
| min = std::min(min, raw_datas[index]); | |||
| max = std::max(max, raw_datas[index]); | |||
| } | |||
| GetMaxMinPerchannel(channels, one_filter_size, i, elem_count, raw_datas, channel_at_first, &max, &min); | |||
| schema::QuantParamT quant_param; | |||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num); | |||
| if (status != RET_OK) { | |||
| @@ -202,10 +204,10 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant | |||
| double average_raw = 0; | |||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||
| auto index = j + i * one_filter_size; | |||
| if (index >= elem_count) { | |||
| MS_LOG(ERROR) << "over flow!"; | |||
| return RET_ERROR; | |||
| if (!channel_at_first) { | |||
| index = j * channels + i; | |||
| } | |||
| MS_ASSERT(index < elem_count); | |||
| float raw_data = raw_datas[index]; | |||
| auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); | |||
| (*quant_datas)[index] = quant_data; | |||
| @@ -226,10 +228,10 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant | |||
| double variance_raw = 0; | |||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||
| auto index = j + i * one_filter_size; | |||
| if (index >= elem_count) { | |||
| MS_LOG(ERROR) << "over flow!"; | |||
| return RET_ERROR; | |||
| if (!channel_at_first) { | |||
| index = j * channels + i; | |||
| } | |||
| MS_ASSERT(index < elem_count); | |||
| variance_dequant += std::pow((*dequant_datas)[index] - average_dequant, 2); | |||
| variance_raw += std::pow(raw_datas[index] - average_raw, 2); | |||
| } | |||
| @@ -339,20 +341,26 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit | |||
| std::vector<schema::QuantParamT> quant_params; | |||
| size_t elem_count = weight->tensor_shape_size(); | |||
| auto *raw_datas = static_cast<float *>(weight->tensor_addr()); | |||
| if (raw_datas == nullptr) { | |||
| auto *raw_data = static_cast<float *>(weight->tensor_addr()); | |||
| if (raw_data == nullptr) { | |||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<T> quant_datas(elem_count); | |||
| std::vector<T> quant_data(elem_count); | |||
| std::vector<float> dequant_datas(elem_count); | |||
| int ret = RET_OK; | |||
| if (per_channel) { | |||
| // notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC | |||
| bool channel_at_first = true; | |||
| auto op_type = (schema::PrimitiveType)primitive_c->Type(); | |||
| if (op_type == schema::PrimitiveType_MatMul && weight->tensor_shape().size() == 2) { | |||
| auto matmul_op = primitive_c->primitiveT()->value.AsMatMul(); | |||
| MS_ASSERT(matmul_op != nullptr); | |||
| channel_at_first = !(index == 1 && !matmul_op->transposeB); | |||
| } | |||
| // channel at first | |||
| ret = DoPerChannelQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_datas, | |||
| &dequant_datas); | |||
| ret = DoPerChannelQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data, | |||
| &dequant_datas, channel_at_first); | |||
| if (ret == RET_CONTINUE) { | |||
| return ret; | |||
| } else if (ret != RET_OK) { | |||
| @@ -360,7 +368,7 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit | |||
| return ret; | |||
| } | |||
| } else { | |||
| ret = DoPerLayerQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_datas); | |||
| ret = DoPerLayerQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Do per layer quant failed."; | |||
| return ret; | |||
| @@ -376,7 +384,7 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit | |||
| } | |||
| #else | |||
| // do bit pack | |||
| ret = DoBitPack(weight, bit_num, quant_datas); | |||
| ret = DoBitPack(weight, bit_num, quant_data); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Do bit pack failed."; | |||
| return ret; | |||
| @@ -127,6 +127,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { | |||
| auto already_quant = false; | |||
| ParamValueLitePtr param_value = nullptr; | |||
| ParameterPtr param_node = nullptr; | |||
| int index = 0; | |||
| for (size_t i = 1; i < cnode->size(); i++) { | |||
| auto inputNode = cnode->input(i); | |||
| if (inputNode->isa<Parameter>()) { | |||
| @@ -146,6 +147,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { | |||
| param_value = nullptr; | |||
| continue; | |||
| } else { | |||
| index = i; | |||
| break; | |||
| } | |||
| } | |||
| @@ -169,11 +171,11 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { | |||
| auto status = RET_ERROR; | |||
| if (type_id_ == kNumberTypeInt8) { | |||
| status = | |||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||
| true, index - 1); | |||
| } else if (type_id_ == kNumberTypeInt16) { | |||
| status = | |||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); | |||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||
| true, index - 1); | |||
| } | |||
| if (status == RET_CONTINUE) { | |||
| return RET_OK; | |||