| @@ -32,8 +32,9 @@ table QuantParam { | |||
| narrowRange: bool = true; | |||
| numBits: int = 8; | |||
| inited: bool = false; | |||
| var_corr: double = 1; | |||
| mean_corr: double = 0; | |||
| varCorr: double = 1; | |||
| meanCorr: double = 0; | |||
| dstDtype: int = 32; | |||
| clusters: [float]; | |||
| } | |||
| @@ -27,6 +27,9 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/common/option.h" | |||
| #include "include/errorcode.h" | |||
| #ifdef ENABLE_ARM64 | |||
| #include "nnacl/optimized_kernel.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -186,6 +189,38 @@ inline Option<bool> GenericParseValue(const std::string &value) { | |||
| return Option<bool>(None()); | |||
| } | |||
| using Float16CastFunc = void (*)(const void *, void *, int); | |||
| class Float16CastUtil { | |||
| public: | |||
| static Float16CastUtil *GetInstance() { | |||
| static Float16CastUtil float16_cast_util; | |||
| return &float16_cast_util; | |||
| } | |||
| private: | |||
| Float16CastUtil() { | |||
| #ifdef ENABLE_ARM64 | |||
| void *fp16_op_handler = Float16Module::GetInstance()->float16_op_handler_; | |||
| if (fp16_op_handler != nullptr) { | |||
| dlerror(); | |||
| *(reinterpret_cast<void **>(&float16_to_float32_func_)) = dlsym(fp16_op_handler, "Float16ToFloat32_fp16_handler"); | |||
| *(reinterpret_cast<void **>(&float32_to_float16_func_)) = dlsym(fp16_op_handler, "Float32ToFloat16_fp16_handler"); | |||
| auto dlopen_error = dlerror(); | |||
| if (dlopen_error != nullptr) { | |||
| MS_LOG(ERROR) << "load float16 cast func failed! " << dlopen_error << "."; | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| ~Float16CastUtil() = default; | |||
| public: | |||
| Float16CastFunc float16_to_float32_func_ = nullptr; | |||
| Float16CastFunc float32_to_float16_func_ = nullptr; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -108,8 +108,9 @@ int LiteSession::ConvertTensors(const lite::Model *model) { | |||
| QuantArg quant_arg{}; | |||
| quant_arg.scale = quant_params->Get(j)->scale(); | |||
| quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint(); | |||
| quant_arg.var_corr = quant_params->Get(j)->var_corr(); | |||
| quant_arg.mean_corr = quant_params->Get(j)->mean_corr(); | |||
| quant_arg.var_corr = quant_params->Get(j)->varCorr(); | |||
| quant_arg.mean_corr = quant_params->Get(j)->meanCorr(); | |||
| quant_arg.inited = quant_params->Get(j)->inited(); | |||
| auto quant_clusters = quant_params->Get(j)->clusters(); | |||
| if (quant_clusters != nullptr) { | |||
| for (size_t k = 0; k < quant_clusters->size(); k++) { | |||
| @@ -49,12 +49,7 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | |||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | |||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | |||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||
| SetOutputQuantParam(vecOutputQuantParam); | |||
| } | |||
| PopulaterQuantParam(prim, inputs); | |||
| return RET_OK; | |||
| } | |||
| @@ -277,13 +277,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||
| PopulaterConv2DSingleGroup(prim, this->primitive_, group); | |||
| } | |||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | |||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | |||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | |||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||
| SetInputQuantParam(vecInputQuantParam); | |||
| SetOutputQuantParam(vecOutputQuantParam); | |||
| } | |||
| PopulaterQuantParam(prim, inputs); | |||
| return RET_OK; | |||
| } | |||
| @@ -254,14 +254,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i | |||
| } else if (group > 1) { | |||
| PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); | |||
| } | |||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | |||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | |||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | |||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||
| SetInputQuantParam(vecInputQuantParam); | |||
| SetOutputQuantParam(vecOutputQuantParam); | |||
| } | |||
| PopulaterQuantParam(prim, inputs); | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| @@ -146,14 +146,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode | |||
| this->primitive_->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||
| this->primitive_->value.value = attr.release(); | |||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | |||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | |||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | |||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||
| SetInputQuantParam(vecInputQuantParam); | |||
| SetOutputQuantParam(vecOutputQuantParam); | |||
| } | |||
| PopulaterQuantParam(prim, inputs); | |||
| return RET_OK; | |||
| } | |||
| @@ -61,13 +61,8 @@ int MatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | |||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | |||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | |||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||
| SetInputQuantParam(vecInputQuantParam); | |||
| SetOutputQuantParam(vecOutputQuantParam); | |||
| } | |||
| PopulaterQuantParam(prim, inputs); | |||
| return RET_OK; | |||
| } | |||
| @@ -164,32 +164,29 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| void PrimitiveC::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { | |||
| void PrimitiveC::CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax) { | |||
| const float qmin = 0; | |||
| const float qmax = 255; | |||
| *mMin = static_cast<float>((qmin - mean) / stdDev); | |||
| *mMax = static_cast<float>((qmax - mean) / stdDev); | |||
| } | |||
| void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | |||
| std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam, | |||
| std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| auto narrow_range = prim.GetAttr("narrow_range"); | |||
| bool narrowRangeQuantParam = GetValue<bool>(narrow_range); | |||
| bool narrowRangeQuantParam = narrow_range != nullptr ? GetValue<bool>(narrow_range) : false; | |||
| auto num_bits = prim.GetAttr("num_bits"); | |||
| int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits); | |||
| int32_t numbitsRangeQuantParam = num_bits != nullptr ? GetValue<int32_t>(num_bits) : 8; | |||
| std::vector<schema::QuantParamT> quants; | |||
| schema::QuantParamT quantParam; | |||
| auto mean = prim.GetAttr("mean"); | |||
| auto std_dev = prim.GetAttr("std_dev"); | |||
| if (mean != nullptr && std_dev != nullptr) { | |||
| auto meanQuantOaram = GetValue<double>(mean); | |||
| double stddevQuantOaram = GetValue<double>(std_dev); | |||
| auto meanValue = GetValue<double>(mean); | |||
| auto stddevValue = GetValue<double>(std_dev); | |||
| float mMin = 0.0; | |||
| float mMax = 0.0; | |||
| CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); | |||
| CalFloatScopeByMeanAndStddev(meanValue, stddevValue, &mMin, &mMax); | |||
| quantParam.min = mMin; | |||
| quantParam.max = mMax; | |||
| } else { | |||
| @@ -198,8 +195,8 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | |||
| if (inputMin != nullptr && inputMax != nullptr) { | |||
| auto inputMinPtr = inputMin->cast<TensorPtr>(); | |||
| auto inputMaxPtr = inputMax->cast<TensorPtr>(); | |||
| float *minBuf = static_cast<float *>(inputMinPtr->data_c()); | |||
| float *maxBuf = static_cast<float *>(inputMaxPtr->data_c()); | |||
| auto *minBuf = static_cast<float *>(inputMinPtr->data_c()); | |||
| auto *maxBuf = static_cast<float *>(inputMaxPtr->data_c()); | |||
| quantParam.min = *minBuf; | |||
| quantParam.max = *maxBuf; | |||
| } | |||
| @@ -207,7 +204,7 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | |||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | |||
| numbitsRangeQuantParam); | |||
| quants.emplace_back(quantParam); | |||
| vecInputQuantParam->emplace_back(quants); | |||
| input_quant_param_.emplace_back(quants); | |||
| quants.clear(); | |||
| auto filterMin = prim.GetAttr("filter_minq"); | |||
| @@ -227,17 +224,25 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | |||
| } | |||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, true, numbitsRangeQuantParam); | |||
| quants.emplace_back(quantParam); | |||
| vecInputQuantParam->emplace_back(quants); | |||
| input_quant_param_.emplace_back(quants); | |||
| } | |||
| if (vecInputQuantParam->size() == kDoubleNum) { | |||
| if (input_quant_param_.size() == kDoubleNum) { | |||
| quants.clear(); | |||
| quantParam.min = 0.0; | |||
| quantParam.max = 0.0; | |||
| quantParam.zeroPoint = 0; | |||
| quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(0).scale; | |||
| quantParam.scale = input_quant_param_.at(0).at(0).scale * input_quant_param_.at(1).at(0).scale; | |||
| quants.emplace_back(quantParam); | |||
| vecInputQuantParam->emplace_back(quants); | |||
| input_quant_param_.emplace_back(quants); | |||
| } | |||
| // fill input_quant_param_ by not inited quant_parm | |||
| if (input_quant_param_.size() < inputs.size()) { | |||
| quants.clear(); | |||
| schema::QuantParamT tmpQuantParam; | |||
| quants.emplace_back(tmpQuantParam); | |||
| input_quant_param_.insert(input_quant_param_.end(), inputs.size() - 1 - input_quant_param_.size(), quants); | |||
| } | |||
| quants.clear(); | |||
| @@ -253,7 +258,11 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | |||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | |||
| numbitsRangeQuantParam); | |||
| quants.emplace_back(quantParam); | |||
| vecOutputQuantParam->emplace_back(quants); | |||
| output_quant_param_.emplace_back(quants); | |||
| } else { | |||
| schema::QuantParamT tmpQuantParam; | |||
| quants.emplace_back(tmpQuantParam); | |||
| output_quant_param_.emplace_back(quants); | |||
| } | |||
| } | |||
| @@ -279,14 +288,48 @@ schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_; | |||
| void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; } | |||
| void PrimitiveC::SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) { | |||
| void PrimitiveC::SetInputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) { | |||
| this->input_quant_param_ = input_quant_param; | |||
| } | |||
| void PrimitiveC::SetOutputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) { | |||
| void PrimitiveC::SetInputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) { | |||
| MS_ASSERT(index < this->input_quant_param_.size()); | |||
| this->input_quant_param_[index] = input_quant_param; | |||
| } | |||
| void PrimitiveC::SetOutputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) { | |||
| this->output_quant_param_ = output_quant_param; | |||
| } | |||
| void PrimitiveC::SetOutputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param) { | |||
| MS_ASSERT(index < this->output_quant_param_.size()); | |||
| this->output_quant_param_[index] = output_quant_param; | |||
| } | |||
| bool PrimitiveC::IsInputQuantParamsInited() { | |||
| if (this->input_quant_param_.empty()) { | |||
| return false; | |||
| } | |||
| for (auto &quant_param : this->input_quant_param_) { | |||
| if (!quant_param.front().inited) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool PrimitiveC::IsOutputQuantParamsInited() { | |||
| if (this->output_quant_param_.empty()) { | |||
| return false; | |||
| } | |||
| for (auto &quant_param : this->output_quant_param_) { | |||
| if (!quant_param.front().inited) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void PrimitiveC::ClearInputOutputQuantParam() { | |||
| input_quant_param_.clear(); | |||
| output_quant_param_.clear(); | |||
| @@ -88,9 +88,17 @@ class PrimitiveC : public mindspore::Primitive { | |||
| } | |||
| } | |||
| void SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param); | |||
| void SetInputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param); | |||
| void SetOutputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param); | |||
| void SetInputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param); | |||
| void SetOutputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param); | |||
| void SetOutputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param); | |||
| bool IsInputQuantParamsInited(); | |||
| bool IsOutputQuantParamsInited(); | |||
| void ClearInputOutputQuantParam(); | |||
| @@ -120,10 +128,8 @@ class PrimitiveC : public mindspore::Primitive { | |||
| static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | |||
| const schema::QuantType &quantType); | |||
| void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam, | |||
| std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam, | |||
| const std::vector<AnfNodePtr> &inputs); | |||
| void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); | |||
| void PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs); | |||
| void CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax); | |||
| protected: | |||
| virtual int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_ERROR; } | |||
| @@ -42,7 +42,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| // data of second tensor of fc may be nullptr | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||
| restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| @@ -53,7 +55,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T | |||
| auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (!kernel) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (!weight_tensor->GetQuantParams().empty()) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -65,13 +67,13 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (!weight_tensor->GetQuantParams().empty()) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -98,8 +98,9 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { | |||
| MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found."; | |||
| return RET_ERROR; | |||
| } | |||
| auto quant_arg = !out_tensors_.front()->GetQuantParams().empty() ? out_tensors_.front()->GetQuantParams().front() | |||
| : in_tensors_.front()->GetQuantParams().front(); | |||
| auto quant_arg = out_tensors_.front()->GetQuantParams().front().inited | |||
| ? out_tensors_.front()->GetQuantParams().front() | |||
| : in_tensors_.front()->GetQuantParams().front(); | |||
| int ret = RET_OK; | |||
| if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeFloat32) { | |||
| ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, | |||
| @@ -140,8 +140,8 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| auto dequant_flag = | |||
| (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; | |||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||
| restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| @@ -182,8 +182,8 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| auto dequant_flag = | |||
| (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; | |||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||
| restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| @@ -204,8 +204,8 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| auto dequant_flag = | |||
| (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; | |||
| auto dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||
| restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| @@ -216,8 +216,8 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| auto dequant_flag = | |||
| (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; | |||
| auto dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||
| restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| @@ -237,7 +237,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| // data of second tensor of fc may be nullptr | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||
| restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| @@ -250,7 +252,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||
| auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (!weight_tensor->GetQuantParams().empty()) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -262,13 +264,13 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| delete kernel; | |||
| if (!weight_tensor->GetQuantParams().empty()) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (!weight_tensor->GetQuantParams().empty()) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -251,7 +251,9 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||
| restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| @@ -263,7 +265,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -275,13 +277,13 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| delete kernel; | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -284,7 +284,9 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||
| restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| @@ -303,7 +305,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -315,14 +317,14 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_)); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -124,6 +124,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| @@ -234,7 +234,9 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||
| restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| @@ -255,7 +257,7 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -267,14 +269,14 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -196,7 +196,9 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||
| restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| @@ -209,7 +211,7 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor | |||
| new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -221,13 +223,13 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -208,9 +208,8 @@ kernel::LiteKernel *CpuMatmulInt8KernelCreator(const std::vector<lite::Tensor *> | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto is_const_quant_weight = | |||
| (restore_data != nullptr) && | |||
| ((weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16)); | |||
| bool is_const_quant_weight = !weight_tensor->GetQuantParams().empty() && | |||
| weight_tensor->GetQuantParams().front().inited && restore_data != nullptr; | |||
| if (is_const_quant_weight) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| @@ -25,36 +25,6 @@ using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_INFER_ERR; | |||
| using mindspore::lite::RET_INFER_INVALID; | |||
| using mindspore::lite::RET_OK; | |||
| using Float16CastFunc = void (*)(const void *, void *, int); | |||
| class Float16CastUtil { | |||
| public: | |||
| static Float16CastUtil *GetInstance() { | |||
| static Float16CastUtil float16_cast_util; | |||
| return &float16_cast_util; | |||
| } | |||
| private: | |||
| Float16CastUtil() { | |||
| #ifdef ENABLE_ARM64 | |||
| void *fp16_op_handler = Float16Module::GetInstance()->float16_op_handler_; | |||
| if (fp16_op_handler != nullptr) { | |||
| dlerror(); | |||
| *(reinterpret_cast<void **>(&float16_to_float32_func_)) = dlsym(fp16_op_handler, "Float16ToFloat32_fp16_handler"); | |||
| *(reinterpret_cast<void **>(&float32_to_float16_func_)) = dlsym(fp16_op_handler, "Float32ToFloat16_fp16_handler"); | |||
| auto dlopen_error = dlerror(); | |||
| if (dlopen_error != nullptr) { | |||
| MS_LOG(ERROR) << "load float16 cast func failed! " << dlopen_error << "."; | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| ~Float16CastUtil() = default; | |||
| public: | |||
| Float16CastFunc float16_to_float32_func_ = nullptr; | |||
| Float16CastFunc float32_to_float16_func_ = nullptr; | |||
| }; | |||
| int SubGraphKernel::Prepare() { | |||
| for (auto node : this->nodes_) { | |||
| @@ -208,7 +178,7 @@ int CpuFp16SubGraph::PreProcess() { | |||
| } | |||
| int CpuFp16SubGraph::PostProcess() { | |||
| auto fp16_to_fp32_cast_func = Float16CastUtil::GetInstance()->float16_to_float32_func_; | |||
| auto fp16_to_fp32_cast_func = lite::Float16CastUtil::GetInstance()->float16_to_float32_func_; | |||
| if (fp16_to_fp32_cast_func == nullptr) { | |||
| MS_LOG(ERROR) << "Can not find cast fp16 to fp32 func"; | |||
| return RET_ERROR; | |||
| @@ -35,6 +35,7 @@ struct QuantArg { | |||
| int32_t zeroPoint; | |||
| double var_corr{1}; | |||
| double mean_corr{0}; | |||
| bool inited; | |||
| std::vector<float> clusters{}; | |||
| }; | |||
| @@ -33,6 +33,6 @@ lite-model_on_device_vision_classifier_landmarks_classifier_oceania_antarctica_V | |||
| lite-model_on_device_vision_classifier_landmarks_classifier_europe_V1_1.tflite | |||
| lite-model_on_device_vision_classifier_landmarks_classifier_south_america_V1_1.tflite | |||
| vision_classifier_fungi_mobile_V1_1_default_1.tflite | |||
| detect.tflite | |||
| ssd_mobilenet_v1_1_default_1.tflite | |||
| object_detection_mobile_object_localizer_v1_1_default_1.tflite | |||
| #detect.tflite | |||
| #ssd_mobilenet_v1_1_default_1.tflite | |||
| #object_detection_mobile_object_localizer_v1_1_default_1.tflite | |||
| @@ -121,8 +121,8 @@ function Run_Converter() { | |||
| continue | |||
| fi | |||
| echo ${model_name} >> "${run_converter_log_file}" | |||
| echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}' --quantType=AwareTraining' >> "${run_converter_log_file}" | |||
| ./converter_lite --fmk=TFLITE --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name} --quantType=AwareTraining | |||
| echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}' --inputDataType=FLOAT --outputDataType=FLOAT' >> "${run_converter_log_file}" | |||
| ./converter_lite --fmk=TFLITE --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name} --inputDataType=FLOAT --outputDataType=FLOAT | |||
| if [ $? = 0 ]; then | |||
| converter_result='converter aware_training '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} | |||
| else | |||
| @@ -544,6 +544,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||
| } else { | |||
| auto ms_tensor = new schema::TensorT(); | |||
| ms_tensor->nodeType = schema::NodeType_CNode; | |||
| ms_tensor->dataType = TypeId::kNumberTypeFloat32; | |||
| fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| node_id_map_[cnode_name] = meta_graphT->allTensors.size(); | |||
| meta_graphT->allTensors.emplace_back(ms_tensor); | |||
| @@ -73,30 +73,30 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s | |||
| auto primitiveCValue = PrimitiveC::Create(cNode->primitive.release()); | |||
| cNode->primitive = nullptr; | |||
| // add quant parameter | |||
| if (cNode->quantType != schema::QuantType_PostTraining && cNode->quantType != schema::QuantType_WeightQuant) { | |||
| primitiveCValue->SetQuantType(cNode->quantType); | |||
| for (int index : cNode->inputIndex) { | |||
| if (!meta_graph_->allTensors[index]->quantParams.empty()) { | |||
| std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size()); | |||
| std::transform( | |||
| meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), | |||
| quant_params.begin(), | |||
| [](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; }); | |||
| primitiveCValue->AddInputQuantParam(quant_params); | |||
| } else { | |||
| std::vector<schema::QuantParamT> empty_quant_params; | |||
| primitiveCValue->AddInputQuantParam(empty_quant_params); | |||
| } | |||
| for (auto index : cNode->inputIndex) { | |||
| if (!meta_graph_->allTensors[index]->quantParams.empty()) { | |||
| std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size()); | |||
| std::transform( | |||
| meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), | |||
| quant_params.begin(), | |||
| [](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; }); | |||
| primitiveCValue->AddInputQuantParam(quant_params); | |||
| } else { | |||
| std::vector<schema::QuantParamT> notinited_quant_params(1); | |||
| primitiveCValue->AddInputQuantParam(notinited_quant_params); | |||
| } | |||
| for (int index : cNode->outputIndex) { | |||
| if (!meta_graph_->allTensors[index]->quantParams.empty()) { | |||
| std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size()); | |||
| std::transform( | |||
| meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), | |||
| quant_params.begin(), | |||
| [](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; }); | |||
| primitiveCValue->AddOutputQuantParam(quant_params); | |||
| } | |||
| } | |||
| for (auto index : cNode->outputIndex) { | |||
| if (!meta_graph_->allTensors[index]->quantParams.empty()) { | |||
| std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size()); | |||
| std::transform( | |||
| meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), | |||
| quant_params.begin(), | |||
| [](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; }); | |||
| primitiveCValue->AddOutputQuantParam(quant_params); | |||
| } else { | |||
| std::vector<schema::QuantParamT> notinited_quant_params(1); | |||
| primitiveCValue->AddOutputQuantParam(notinited_quant_params); | |||
| } | |||
| } | |||
| auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitiveCValue)); | |||
| @@ -52,16 +52,13 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); | |||
| // for now - trainning is not supporting fuse operations | |||
| if (config != nullptr && config->trainModel == false) { | |||
| if (config != nullptr && !config->trainModel) { | |||
| // remove quantdtype when awaretraining | |||
| if (config->fmk == lite::converter::FmkType_ONNX) { | |||
| auto remove_identity_pass = std::make_shared<opt::RemoveIdentityOpPass>(); | |||
| remove_identity_pass->SetFmkType(config->fmk); | |||
| pm->AddPass(remove_identity_pass); | |||
| } | |||
| if (config->quantType == QuantType_AwareTraining) { | |||
| pm->AddPass(std::make_shared<opt::QuantDtypeCastFusion>()); | |||
| } | |||
| pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ConvScaleFusion>()); | |||
| @@ -101,27 +98,25 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| return nullptr; | |||
| } | |||
| // quant | |||
| if (config != nullptr) { | |||
| if (config->quantType == schema::QuantType_PostTraining) { | |||
| this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8); | |||
| if (mQuantizer == nullptr) { | |||
| MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||
| return nullptr; | |||
| } | |||
| } else if (config->quantType == schema::QuantType_WeightQuant) { | |||
| if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { | |||
| MS_LOG(ERROR) << "weight quant input param error"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantWeightSize, | |||
| config->quantWeightChannel, config->bitNum); | |||
| if (mQuantizer == nullptr) { | |||
| MS_LOG(ERROR) << "New WeightQuantizer failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||
| return nullptr; | |||
| } | |||
| if (config->quantType == schema::QuantType_PostTraining) { | |||
| this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8); | |||
| if (mQuantizer == nullptr) { | |||
| MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||
| return nullptr; | |||
| } | |||
| } else if (config->quantType == schema::QuantType_WeightQuant) { | |||
| if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { | |||
| MS_LOG(ERROR) << "weight quant input param error"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantWeightSize, | |||
| config->quantWeightChannel, config->bitNum); | |||
| if (mQuantizer == nullptr) { | |||
| MS_LOG(ERROR) << "New WeightQuantizer failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||
| return nullptr; | |||
| } | |||
| } | |||
| if (mQuantizer != nullptr) { | |||
| @@ -93,7 +93,6 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||
| } | |||
| // transform | |||
| transform->SetGraphDef(meta_graph); | |||
| transform->CreateQuantizer(flag); | |||
| auto status = transform->Transform(*flag); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Transform meta graph failed " << status; | |||
| @@ -29,9 +29,14 @@ Flags::Flags() { | |||
| AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); | |||
| AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel", | |||
| ""); | |||
| AddFlag(&Flags::inferenceTypeIn, "inferenceType", "Data type of input and output tensors. FLOAT | INT8 | UINT8", | |||
| "FLOAT"); | |||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", ""); | |||
| AddFlag(&Flags::inputDataTypeIn, "inputDataType", | |||
| "Data type of input tensors, default is same with the type defined in model. FLOAT | INT8 | UINT8 | DEFAULT", | |||
| "DEFAULT"); | |||
| AddFlag(&Flags::outputDataTypeIn, "outputDataType", | |||
| "Data type of output and output tensors, default is same with the type defined in model. FLOAT | INT8 | " | |||
| "UINT8 | DEFAULT", | |||
| "DEFAULT"); | |||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", ""); | |||
| AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8"); | |||
| AddFlag(&Flags::quantWeightSize, "quantWeightSize", "Weight quantization size threshold", "0"); | |||
| AddFlag(&Flags::quantWeightChannel, "quantWeightChannel", "Channel threshold for weight quantization", "16"); | |||
| @@ -78,15 +83,32 @@ int Flags::Init(int argc, const char **argv) { | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->inferenceTypeIn == "FLOAT") { | |||
| this->inferenceType = TypeId::kNumberTypeFloat; | |||
| } else if (this->inferenceTypeIn == "INT8") { | |||
| this->inferenceType = TypeId::kNumberTypeInt8; | |||
| } else if (this->inferenceTypeIn == "UINT8") { | |||
| this->inferenceType = TypeId::kNumberTypeUInt8; | |||
| if (this->inputDataTypeIn == "FLOAT") { | |||
| this->inputDataType = TypeId::kNumberTypeFloat; | |||
| } else if (this->inputDataTypeIn == "INT8") { | |||
| this->inputDataType = TypeId::kNumberTypeInt8; | |||
| } else if (this->inputDataTypeIn == "UINT8") { | |||
| this->inputDataType = TypeId::kNumberTypeUInt8; | |||
| } else if (this->inputDataTypeIn == "DEFAULT") { | |||
| this->inputDataType = TypeId::kTypeUnknown; | |||
| } else { | |||
| std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | UINT8", | |||
| this->inferenceTypeIn.c_str(); | |||
| std::cerr << "INPUT INVALID: inputDataType is invalid: %s, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT", | |||
| this->inputDataTypeIn.c_str(); | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->outputDataTypeIn == "FLOAT") { | |||
| this->outputDataType = TypeId::kNumberTypeFloat; | |||
| } else if (this->outputDataTypeIn == "INT8") { | |||
| this->outputDataType = TypeId::kNumberTypeInt8; | |||
| } else if (this->outputDataTypeIn == "UINT8") { | |||
| this->outputDataType = TypeId::kNumberTypeUInt8; | |||
| } else if (this->outputDataTypeIn == "DEFAULT") { | |||
| this->outputDataType = TypeId::kTypeUnknown; | |||
| } else { | |||
| std::cerr | |||
| << "INPUT INVALID: outputDataType is invalid: %s, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT", | |||
| this->outputDataTypeIn.c_str(); | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| @@ -107,9 +129,8 @@ int Flags::Init(int argc, const char **argv) { | |||
| std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag"; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->quantTypeIn == "AwareTraining") { | |||
| this->quantType = QuantType_AwareTraining; | |||
| } else if (this->quantTypeIn == "WeightQuant") { | |||
| if (this->quantTypeIn == "WeightQuant") { | |||
| this->quantType = QuantType_WeightQuant; | |||
| } else if (this->quantTypeIn == "PostTraining") { | |||
| this->quantType = QuantType_PostTraining; | |||
| @@ -53,8 +53,11 @@ class Flags : public virtual mindspore::lite::FlagParser { | |||
| std::string quantTypeIn; | |||
| QuantType quantType; | |||
| std::string inferenceTypeIn; | |||
| std::string inputDataTypeIn; | |||
| std::string outputDataTypeIn; | |||
| // used for parse aware trainning | |||
| TypeId inferenceType = TypeId::kNumberTypeFloat; | |||
| TypeId inputDataType; | |||
| TypeId outputDataType; | |||
| // used for post-trainning-weight | |||
| std::string quantWeightSize; | |||
| std::string bitNum; | |||
| @@ -34,6 +34,8 @@ | |||
| #include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" | |||
| #include "tools/converter/quantizer/aware_quantizer.h" | |||
| using std::string; | |||
| @@ -44,20 +46,6 @@ GraphDefTransform::~GraphDefTransform() = default; | |||
| void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; } | |||
| void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { | |||
| auto type = flags->quantType; | |||
| switch (type) { | |||
| case QuantType::QuantType_AwareTraining: { | |||
| MS_LOG(INFO) << "create AwareTrainingQuantizer!"; | |||
| fbQuantizer = std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inferenceType); | |||
| break; | |||
| } | |||
| default: | |||
| MS_LOG(INFO) << "will support quantizer type " << flags->quantTypeIn << " in the future"; | |||
| break; | |||
| } | |||
| } | |||
| int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| STATUS status; | |||
| { | |||
| @@ -84,26 +72,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| // generate and infer quant parameters | |||
| { | |||
| if (fbQuantizer != nullptr) { | |||
| Optimizer topologicalOptimizer; | |||
| topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| status = topologicalOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| if (ctx.quantType == QuantType_AwareTraining) { | |||
| status = fbQuantizer->GenerateQuantParam(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "GenerateQuantParam failed"; | |||
| return status; | |||
| } | |||
| status = fbQuantizer->DetermineNodeQuantType(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DetermineNodeQuant failed"; | |||
| return status; | |||
| } | |||
| } | |||
| Optimizer inferQuantParamPass; | |||
| inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); | |||
| status = inferQuantParamPass.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| @@ -146,12 +121,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| } | |||
| } | |||
| { | |||
| Optimizer fusionOptimizer; | |||
| fusionOptimizer.AddPass(new (std::nothrow) FormatTransPermuteFusionPass()); | |||
| fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| status = fusionOptimizer.Run(graphDefT); | |||
| Optimizer inferQuantParamOtimizer; | |||
| inferQuantParamOtimizer.AddPass(new (std::nothrow) InferQuantParamPass()); | |||
| status = inferQuantParamOtimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run tensorQuantOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| @@ -168,8 +142,10 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| } | |||
| // do quantization | |||
| if (fbQuantizer != nullptr) { | |||
| status = fbQuantizer->DoQuantize(); | |||
| { | |||
| Optimizer fusionOptimizer; | |||
| fusionOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); | |||
| status = fusionOptimizer.Run(graphDefT); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoQuantize failed!"; | |||
| return status; | |||
| @@ -177,11 +153,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| } | |||
| // insert quantNode and deQuantNode | |||
| if (ctx.quantType == QuantType_AwareTraining) { | |||
| { | |||
| Optimizer quantNodeOptimizer; | |||
| auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | |||
| dTypeTransPass->SetInputDataDType(ctx.inferenceType); | |||
| dTypeTransPass->SetOutputDataDType(ctx.inferenceType); | |||
| dTypeTransPass->SetInputDataDType(ctx.inputDataType); | |||
| dTypeTransPass->SetOutputDataDType(ctx.outputDataType); | |||
| quantNodeOptimizer.AddPass(dTypeTransPass); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| @@ -37,14 +37,10 @@ class GraphDefTransform { | |||
| virtual int Transform(const converter::Flags &ctx); | |||
| void SetGraphDef(schema::MetaGraphT *dstDef); | |||
| inline schema::MetaGraphT *GetOutput() { return graphDefT; } | |||
| void CreateQuantizer(const converter::Flags *flags); | |||
| protected: | |||
| schema::MetaGraphT *graphDefT = nullptr; | |||
| Optimizer *optimizer = nullptr; | |||
| std::unique_ptr<quant::Quantizer> mQuantizer; | |||
| std::unique_ptr<quant::FbQuantizer> fbQuantizer; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -10,6 +10,8 @@ file(GLOB GRAPH_PASS | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/trans_format_remove_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/tensor_quant_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/infer_quant_param_pass.cc | |||
| ) | |||
| set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | |||
| add_library(graph_pass_mid OBJECT ${GRAPH_PASS}) | |||
| @@ -27,9 +27,6 @@ namespace lite { | |||
| #define kMinInputNum 1 | |||
| #define kOutputNum 1 | |||
| static const std::set<schema::PrimitiveType> NoNeedDtypeTransList = { | |||
| PrimitiveType_QuantDTypeCast, PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; | |||
| STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| @@ -44,12 +41,6 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { | |||
| MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status; | |||
| return status; | |||
| } | |||
| status = DoNodeInoutDTypeTrans(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -57,7 +48,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto &graphInIdxes = graph->inputIndex; | |||
| if (this->inputDataDType == TypeId::kNumberTypeInt8) { | |||
| if (this->inputDataDType == TypeId::kNumberTypeInt8 || this->inputDataDType == TypeId::kTypeUnknown) { | |||
| return RET_OK; | |||
| } | |||
| if (this->inputDataDType != TypeId::kNumberTypeFloat && this->inputDataDType != TypeId::kNumberTypeUInt8) { | |||
| @@ -68,7 +59,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| for (auto graphInIdx : graphInIdxes) { | |||
| MS_ASSERT(graphInIdx < graph->allTensors.size()); | |||
| auto &tensor = graph->allTensors.at(graphInIdx); | |||
| if (tensor->dims.size() != kNHWCDimNumber || tensor->dataType != kNumberTypeInt8) { | |||
| if (tensor->dataType != kNumberTypeInt8 || tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||
| continue; | |||
| } | |||
| @@ -98,7 +89,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (outputDataDType == TypeId::kNumberTypeInt8) { | |||
| if (outputDataDType == TypeId::kNumberTypeInt8 || outputDataDType == TypeId::kTypeUnknown) { | |||
| return RET_OK; | |||
| } | |||
| if (this->outputDataDType != TypeId::kNumberTypeFloat && this->outputDataDType != TypeId::kNumberTypeUInt8) { | |||
| @@ -107,6 +98,11 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| } | |||
| auto &graphOutIdxes = graph->outputIndex; | |||
| for (auto graphOutIdx : graphOutIdxes) { | |||
| MS_ASSERT(graphOutIdx < graph->allTensors.size()); | |||
| auto &tensor = graph->allTensors.at(graphOutIdx); | |||
| if (tensor->dataType != kNumberTypeInt8 || tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||
| continue; | |||
| } | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto nodeName = (*iter)->name; | |||
| MS_ASSERT(node != nullptr); | |||
| @@ -131,67 +127,6 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| return RET_OK; | |||
| } | |||
| STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| // insert transNode before and after existNode | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { | |||
| continue; | |||
| } | |||
| auto iterType = GetCNodeTType(**iter); | |||
| if (NoNeedDtypeTransList.find(iterType) != NoNeedDtypeTransList.end()) { | |||
| continue; | |||
| } | |||
| bool needInsertPost = true; | |||
| if (GetCNodeTType(**iter) == PrimitiveType_Shape) { | |||
| needInsertPost = false; | |||
| } | |||
| auto nodeName = (*iter)->name; | |||
| if ((*iter)->inputIndex.size() < kMinInputNum) { | |||
| MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least"; | |||
| return RET_ERROR; | |||
| } | |||
| STATUS status; | |||
| // insert pre | |||
| for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) { | |||
| MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i)); | |||
| auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i)); | |||
| if (preTensor->dataType == TypeId::kNumberTypeInt || preTensor->dataType == TypeId::kNumberTypeInt32) { | |||
| continue; | |||
| } | |||
| auto &graphInIdxes = graph->inputIndex; | |||
| if (!preTensor->data.empty() && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { | |||
| continue; | |||
| } | |||
| if ((preTensor->dataType != TypeId::kNumberTypeInt8) && (IsContain(graphInIdxes, (*iter)->inputIndex.at(i)))) { | |||
| continue; | |||
| } | |||
| iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (needInsertPost) { | |||
| for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) { | |||
| auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i)); | |||
| if (postTensor->dataType == TypeId::kNumberTypeInt || postTensor->dataType == TypeId::kNumberTypeInt32) { | |||
| continue; | |||
| } | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| (*iter)->quantType = QuantType_QUANT_NONE; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, | |||
| size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) { | |||
| MS_ASSERT((*existNodeIter) != nullptr); | |||
| @@ -45,8 +45,6 @@ class DTypeTransPass : public GraphPass { | |||
| STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph); | |||
| STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph); | |||
| NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, | |||
| DTypeTransNodeType nodeType, STATUS *errorCode); | |||
| @@ -0,0 +1,87 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/common/utils.h" | |||
| #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" | |||
| #include "tools/converter/quantizer/calc_quant_param.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "tools/common/converter_op_utils.h" | |||
| namespace mindspore::lite { | |||
| STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) { | |||
| auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| MS_ASSERT(node != nullptr); | |||
| if (node->quantType == schema::QuantType_WeightQuant) { | |||
| continue; | |||
| } | |||
| if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax || | |||
| GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { | |||
| MS_ASSERT(false); | |||
| } | |||
| auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); | |||
| if (quantParamCalcer == nullptr) { | |||
| MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str() | |||
| << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; | |||
| node->quantType = static_cast<schema::QuantType>(schema::QuantType_QUANT_NONE); | |||
| } else { | |||
| auto status = quantParamCalcer->Calc(graph, *node); | |||
| if (status != RET_OK) { | |||
| MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); | |||
| node->quantType = schema::QuantType_QUANT_NONE; | |||
| } else { | |||
| DetermineNodeQuantType(*graph, node.get()); | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void InferQuantParamPass::DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| bool canQuant = true; | |||
| for (auto &inputTensorIdx : cnode->inputIndex) { | |||
| MS_ASSERT(graph.allTensors.size() > inputTensorIdx); | |||
| auto &inTensor = graph.allTensors.at(inputTensorIdx); | |||
| MS_ASSERT(inTensor != nullptr); | |||
| if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr || | |||
| !inTensor->quantParams.front()->inited) { | |||
| canQuant = false; | |||
| break; | |||
| } | |||
| } | |||
| for (auto &outTensorIdx : cnode->outputIndex) { | |||
| MS_ASSERT(graph.allTensors.size() > outTensorIdx); | |||
| auto &outTensor = graph.allTensors.at(outTensorIdx); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || | |||
| !outTensor->quantParams.front()->inited) { | |||
| canQuant = false; | |||
| break; | |||
| } | |||
| } | |||
| if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*cnode))) { | |||
| cnode->quantType = schema::QuantType_AwareTraining; | |||
| } else { | |||
| cnode->quantType = schema::QuantType_QUANT_NONE; | |||
| } | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_INFER_QUANT_PARAM_PASS_H | |||
| #define LITE_INFER_QUANT_PARAM_PASS_H | |||
| #include <memory> | |||
| #include "tools/converter/optimizer.h" | |||
| #include "tools/common/graph_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class InferQuantParamPass : public GraphPass { | |||
| public: | |||
| InferQuantParamPass() {} | |||
| ~InferQuantParamPass() override = default; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| private: | |||
| void DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_INFER_QUANT_PARAM_PASS_H | |||
| @@ -0,0 +1,87 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <cmath> | |||
| #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "tools/common/tensor_util.h" | |||
| namespace mindspore::lite { | |||
| STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | |||
| for (auto &tensor : graph->allTensors) { | |||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) { | |||
| continue; | |||
| } | |||
| if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && | |||
| tensor->dataType != TypeId::kNumberTypeUInt8) { | |||
| continue; | |||
| } | |||
| // perlayer | |||
| if (tensor->quantParams.size() == 1) { | |||
| auto &quantParam = tensor->quantParams.front(); | |||
| size_t wShapeSize = GetShapeSize(*(tensor.get())); | |||
| void *oriWeightData = tensor->data.data(); | |||
| if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { | |||
| std::vector<int8_t> qDatas(wShapeSize); | |||
| auto weightQauntParam = GetTensorQuantParam(tensor); | |||
| if (tensor->dataType == TypeId::kNumberTypeFloat || | |||
| tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant | |||
| auto *weightData = static_cast<float *>(oriWeightData); | |||
| for (size_t j = 0; j < wShapeSize; j++) { | |||
| qDatas[j] = quant::QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); | |||
| } | |||
| } else { // tflite awareing quant | |||
| auto *weightData = static_cast<uint8_t *>(oriWeightData); | |||
| for (size_t j = 0; j < wShapeSize; j++) { | |||
| qDatas[j] = (int32_t)weightData[j] - 128; | |||
| } | |||
| weightQauntParam->zeroPoint -= 128; | |||
| tensor->quantParams.clear(); | |||
| tensor->quantParams.emplace_back(weightQauntParam.release()); | |||
| } | |||
| tensor->dataType = TypeId::kNumberTypeInt8; | |||
| ::memcpy(tensor->data.data(), qDatas.data(), wShapeSize); | |||
| } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { | |||
| // quant bias data | |||
| auto bShapeSize = GetShapeSize(*(tensor.get())); | |||
| std::unique_ptr<int32_t[]> qDatas(new (std::nothrow) int32_t[bShapeSize]); | |||
| if (qDatas == nullptr) { | |||
| MS_LOG(ERROR) << "new qDatas failed"; | |||
| return RET_ERROR; | |||
| } | |||
| void *biasData = tensor->data.data(); | |||
| auto *rawDatas = static_cast<float *>(biasData); | |||
| for (size_t i = 0; i < bShapeSize; ++i) { | |||
| qDatas[i] = (int32_t)std::round(rawDatas[i] / quantParam->scale); | |||
| } | |||
| tensor->dataType = TypeId::kNumberTypeInt32; | |||
| tensor->data.clear(); | |||
| tensor->data.resize(bShapeSize * sizeof(int32_t)); | |||
| auto ret = | |||
| memcpy_s(tensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } else { // pertensor | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_TENSOR_QUANT_PASS_H | |||
| #define LITE_TENSOR_QUANT_PASS_H | |||
| #include <memory> | |||
| #include "tools/converter/optimizer.h" | |||
| #include "tools/common/graph_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TensorQuantPass : public GraphPass { | |||
| public: | |||
| TensorQuantPass() {} | |||
| ~TensorQuantPass() override = default; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_TENSOR_QUANT_PASS_H | |||
| @@ -52,7 +52,7 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||
| attr->srcT = kNumberTypeInt8; | |||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||
| op->primitive->value.value = attr.release(); | |||
| op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | |||
| @@ -52,7 +52,7 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||
| attr->dstT = kNumberTypeInt8; | |||
| op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | |||
| op->primitive->value.value = attr.release(); | |||
| } else { | |||
| @@ -35,29 +35,10 @@ using std::string; | |||
| using std::vector; | |||
| namespace mindspore::lite::quant { | |||
| const std::array<schema::PrimitiveType, 7> AwareQuantizer::propagatedOps = { | |||
| {schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, schema::PrimitiveType_Reshape, | |||
| schema::PrimitiveType_Squeeze, schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation, | |||
| schema::PrimitiveType_DetectionPostProcess}}; | |||
| AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType) : FbQuantizer(graph) {} | |||
| STATUS AwareQuantizer::RemoveFakeQuant() { return RET_OK; } | |||
| STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph) { | |||
| MS_ASSERT(subGraph != nullptr); | |||
| for (const auto &tensor : subGraph->allTensors) { | |||
| if (!tensor->quantParams.empty()) { | |||
| continue; | |||
| } | |||
| std::unique_ptr<schema::QuantParamT> defaultQuantParam(new QuantParamT()); | |||
| tensor->quantParams.emplace_back(std::move(defaultQuantParam)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { return RET_OK; } | |||
| STATUS AwareQuantizer::GenerateQuantParam() { | |||
| auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); | |||
| @@ -70,13 +51,13 @@ STATUS AwareQuantizer::GenerateQuantParam() { | |||
| } | |||
| auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); | |||
| if (quantParamCalcer == nullptr) { | |||
| MS_LOG(INFO) << "Can not find QuantParamCalcer for " << node->name.c_str() | |||
| << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; | |||
| MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str() | |||
| << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; | |||
| node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE); | |||
| } else { | |||
| auto status = quantParamCalcer->Calc(graph, *node); | |||
| if (status != RET_OK) { | |||
| MS_LOG(INFO) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); | |||
| MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); | |||
| node->quantType = schema::QuantType_QUANT_NONE; | |||
| } else { | |||
| node->quantType = schema::QuantType_AwareTraining; | |||
| @@ -87,250 +68,65 @@ STATUS AwareQuantizer::GenerateQuantParam() { | |||
| } | |||
| STATUS AwareQuantizer::DoQuantize() { | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| if (!IsContain(GetInt8OpList(), GetCNodeTType(*node))) { | |||
| for (auto &tensor : graph->allTensors) { | |||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) { | |||
| continue; | |||
| } | |||
| if (node->quantType != schema::QuantType_AwareTraining) { | |||
| if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && | |||
| tensor->dataType != TypeId::kNumberTypeUInt8) { | |||
| continue; | |||
| } | |||
| STATUS status; | |||
| if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D || | |||
| GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D || | |||
| GetCNodeTType(*node) == schema::PrimitiveType_DeConv2D || | |||
| GetCNodeTType(*node) == schema::PrimitiveType_FullConnection || | |||
| GetCNodeTType(*node) == schema::PrimitiveType_MatMul) { | |||
| auto inputIndexes = node->inputIndex; | |||
| if (inputIndexes.size() < 2) { | |||
| MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count"; | |||
| return RET_ERROR; | |||
| } | |||
| // quant weight | |||
| auto &weightTensor = graph->allTensors.at(node->inputIndex.at(1)); | |||
| if (!weightTensor->quantParams.empty() && weightTensor->quantParams.at(0)->inited) { | |||
| status = QuantConvWeight(graph, node.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantConvWeight failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| // quant bias | |||
| if (inputIndexes.size() == 3) { | |||
| auto &biasTensor = graph->allTensors.at(node->inputIndex.at(2)); | |||
| if (!biasTensor->quantParams.empty() && biasTensor->quantParams.at(0)->inited) { | |||
| status = QuantConvBias(graph, node.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantConvBias failed!"; | |||
| return RET_ERROR; | |||
| // perlayer | |||
| if (tensor->quantParams.size() == 1) { | |||
| auto &quantParam = tensor->quantParams.front(); | |||
| size_t wShapeSize = GetShapeSize(*(tensor.get())); | |||
| void *oriWeightData = tensor->data.data(); | |||
| if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { | |||
| vector<int8_t> qDatas(wShapeSize); | |||
| auto weightQauntParam = GetTensorQuantParam(tensor); | |||
| if (tensor->dataType == TypeId::kNumberTypeFloat || | |||
| tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant | |||
| auto *weightData = static_cast<float *>(oriWeightData); | |||
| for (size_t j = 0; j < wShapeSize; j++) { | |||
| qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); | |||
| } | |||
| } else { // tflite awareing quant | |||
| auto *weightData = static_cast<uint8_t *>(oriWeightData); | |||
| for (size_t j = 0; j < wShapeSize; j++) { | |||
| qDatas[j] = (int32_t)weightData[j] - 128; | |||
| } | |||
| weightQauntParam->zeroPoint -= 128; | |||
| tensor->quantParams.clear(); | |||
| tensor->quantParams.emplace_back(weightQauntParam.release()); | |||
| } | |||
| } | |||
| } else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { | |||
| status = QuantDetectionPostProcessConstTensor(graph, node.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (GetCNodeTType(*node) == schema::PrimitiveType_Add || | |||
| GetCNodeTType(*node) == schema::PrimitiveType_Scale || | |||
| GetCNodeTType(*node) == schema::PrimitiveType_Mul) { | |||
| status = QuantArithmeticConstTensor(graph, node.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantArithmeticConstTensor failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| const auto nodeType = GetCNodeTType(*node); | |||
| auto find = std::find(propagatedOps.begin(), propagatedOps.end(), nodeType); | |||
| if (find != propagatedOps.end()) { | |||
| auto inputTensor = graph->allTensors.at(node->inputIndex[0]).get(); | |||
| auto outputTensor = graph->allTensors.at(node->outputIndex[0]).get(); | |||
| MS_ASSERT(inputTensor != nullptr); | |||
| MS_ASSERT(outputTensor != nullptr); | |||
| outputTensor->dataType = inputTensor->dataType; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| for (size_t i = 0; i < node->inputIndex.size(); i++) { | |||
| auto inTensorIdx = node->inputIndex.at(i); | |||
| MS_ASSERT(graph->allTensors.size() > inTensorIdx); | |||
| auto &inTensor = graph->allTensors.at(inTensorIdx); | |||
| MS_ASSERT(inTensor != nullptr); | |||
| if (!inTensor->data.empty()) { | |||
| if (inTensor->dataType == TypeId::kNumberTypeInt8) { | |||
| continue; | |||
| } | |||
| if (inTensor->dataType != TypeId::kNumberTypeFloat32 && inTensor->dataType != TypeId::kNumberTypeFloat && | |||
| inTensor->dataType != TypeId::kNumberTypeUInt8) { | |||
| MS_LOG(ERROR) << node->name.c_str() << "'s weight data is not float or uint8"; | |||
| return RET_ERROR; | |||
| } | |||
| auto quantParam = GetTensorQuantParam(inTensor); | |||
| MS_ASSERT(quantParam != nullptr); | |||
| MS_ASSERT(quantParam->inited); | |||
| auto constTensorShapeSize = GetShapeSize(*(inTensor.get())); | |||
| vector<int8_t> qDatas(constTensorShapeSize); | |||
| void *inData = inTensor->data.data(); | |||
| if (inTensor->dataType == TypeId::kNumberTypeFloat || | |||
| inTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant | |||
| auto *weightData = static_cast<float *>(inData); | |||
| for (size_t j = 0; j < constTensorShapeSize; j++) { | |||
| qDatas[j] = QuantizeData<int8_t>(weightData[j], quantParam.get()); | |||
| ::memcpy(tensor->data.data(), qDatas.data(), wShapeSize); | |||
| } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { | |||
| // quant bias data | |||
| auto bShapeSize = GetShapeSize(*(tensor.get())); | |||
| std::unique_ptr<int32_t[]> qDatas(new (std::nothrow) int32_t[bShapeSize]); | |||
| if (qDatas == nullptr) { | |||
| MS_LOG(ERROR) << "new qDatas failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else { // tflite awareing quant | |||
| auto *weightData = static_cast<uint8_t *>(inData); | |||
| for (size_t j = 0; j < constTensorShapeSize; j++) { | |||
| qDatas[j] = (int32_t)weightData[j] - 128; | |||
| void *biasData = tensor->data.data(); | |||
| auto *rawDatas = static_cast<float *>(biasData); | |||
| for (size_t i = 0; i < bShapeSize; ++i) { | |||
| qDatas[i] = (int32_t)std::round(rawDatas[i] / quantParam->scale); | |||
| } | |||
| tensor->dataType = TypeId::kNumberTypeInt32; | |||
| tensor->data.clear(); | |||
| tensor->data.resize(bShapeSize * sizeof(int32_t)); | |||
| auto ret = | |||
| memcpy_s(tensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| quantParam->zeroPoint -= 128; | |||
| inTensor->quantParams.clear(); | |||
| inTensor->quantParams.emplace_back(quantParam.release()); | |||
| } | |||
| ::memcpy(inTensor->data.data(), qDatas.data(), constTensorShapeSize); | |||
| inTensor->dataType = TypeId::kNumberTypeInt8; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { | |||
| MS_ASSERT(subGraph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| auto &constTensor = subGraph->allTensors.at(node->inputIndex[2]); | |||
| MS_ASSERT(constTensor != nullptr); | |||
| const auto *constData = reinterpret_cast<const float *>(constTensor->data.data()); | |||
| if (!constTensor->data.empty() && | |||
| (constTensor->dataType == TypeId::kNumberTypeFloat || constTensor->dataType == TypeId::kNumberTypeFloat32)) { | |||
| size_t constTensorShapeSize = GetShapeSize(*constTensor); | |||
| std::unique_ptr<QuantParamT> quantParam = GetTensorQuantParam(constTensor); | |||
| if (quantParam == nullptr) { | |||
| MS_LOG(ERROR) << "new QuantParamT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| vector<int8_t> qDatas(constTensorShapeSize); | |||
| for (size_t j = 0; j < constTensorShapeSize; j++) { | |||
| float rawData = constData[j]; | |||
| qDatas[j] = QuantizeData<int8_t>(rawData, quantParam.get()); | |||
| } | |||
| ::memcpy(constTensor->data.data(), qDatas.data(), constTensorShapeSize); | |||
| constTensor->dataType = TypeId::kNumberTypeInt8; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, mindspore::schema::CNodeT *node) { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| auto inputIndexes = node->inputIndex; | |||
| MS_ASSERT(inputIndexes.size() >= 3); | |||
| MS_ASSERT(graph->allTensors.size() > inputIndexes.at(0)); | |||
| MS_ASSERT(graph->allTensors.size() > inputIndexes.at(1)); | |||
| MS_ASSERT(graph->allTensors.size() > inputIndexes.at(2)); | |||
| auto &biasTensor = graph->allTensors.at(inputIndexes.at(2)); | |||
| MS_ASSERT(biasTensor != nullptr); | |||
| if (biasTensor->dataType == TypeId::kNumberTypeInt32) { | |||
| return RET_OK; | |||
| } | |||
| if (biasTensor->dataType != TypeId::kNumberTypeFloat && biasTensor->dataType != TypeId::kNumberTypeFloat32) { | |||
| MS_LOG(ERROR) << "conv " << node->name << "'s bias data is not float"; | |||
| return RET_ERROR; | |||
| } | |||
| auto &inputTensor = graph->allTensors.at(inputIndexes.at(0)); | |||
| auto &weightTensor = graph->allTensors.at(inputIndexes.at(1)); | |||
| MS_ASSERT(inputTensor != nullptr); | |||
| MS_ASSERT(weightTensor != nullptr); | |||
| auto inputScale = inputTensor->quantParams.front()->scale; | |||
| auto weightScale = weightTensor->quantParams.front()->scale; | |||
| auto scale = inputScale * weightScale; | |||
| // set bias quant param | |||
| std::unique_ptr<QuantParamT> biasQuantParam = GetTensorQuantParam(biasTensor); | |||
| if (biasQuantParam == nullptr) { | |||
| MS_LOG(ERROR) << "new QuantParamT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| biasQuantParam->inited = true; | |||
| biasQuantParam->scale = scale; | |||
| biasQuantParam->zeroPoint = 0; | |||
| biasQuantParam->numBits = 8; | |||
| biasQuantParam->narrowRange = false; | |||
| biasQuantParam->min = 0.0; | |||
| biasQuantParam->max = 0.0; | |||
| // quant bias data | |||
| auto bShapeSize = GetShapeSize(*(biasTensor.get())); | |||
| std::unique_ptr<int32_t[]> qDatas(new (std::nothrow) int32_t[bShapeSize]); | |||
| if (qDatas == nullptr) { | |||
| MS_LOG(ERROR) << "new qDatas failed"; | |||
| return RET_ERROR; | |||
| } | |||
| void *biasData = biasTensor->data.data(); | |||
| auto *rawDatas = static_cast<float *>(biasData); | |||
| for (size_t i = 0; i < bShapeSize; ++i) { | |||
| qDatas[i] = (int32_t)std::round(rawDatas[i] / scale); | |||
| } | |||
| biasTensor->dataType = TypeId::kNumberTypeInt32; | |||
| biasTensor->data.clear(); | |||
| biasTensor->data.resize(bShapeSize * sizeof(int32_t)); | |||
| auto ret = | |||
| memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { | |||
| MS_ASSERT(subGraph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(node->quantParam.size() == node->inputIndex.size() + node->outputIndex.size()); | |||
| auto inputIndexes = node->inputIndex; | |||
| MS_ASSERT(inputIndexes.size() >= 2); | |||
| MS_ASSERT(subGraph->allTensors.size() > inputIndexes.at(1)); | |||
| auto &weightTensor = subGraph->allTensors.at(inputIndexes.at(1)); | |||
| if (weightTensor->dataType == TypeId::kNumberTypeInt8) { | |||
| return RET_OK; | |||
| } | |||
| if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && weightTensor->dataType != TypeId::kNumberTypeFloat && | |||
| weightTensor->dataType != TypeId::kNumberTypeUInt8) { | |||
| MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8"; | |||
| return RET_ERROR; | |||
| } | |||
| size_t wShapeSize = GetShapeSize(*(weightTensor.get())); | |||
| void *oriWeightData = weightTensor->data.data(); | |||
| MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr); | |||
| vector<int8_t> qDatas(wShapeSize); | |||
| // todo support perchannel | |||
| auto weightQauntParam = GetTensorQuantParam(weightTensor); | |||
| if (weightTensor->dataType == TypeId::kNumberTypeFloat || | |||
| weightTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant | |||
| auto *weightData = static_cast<float *>(oriWeightData); | |||
| for (size_t j = 0; j < wShapeSize; j++) { | |||
| qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); | |||
| } | |||
| } else { // tflite awareing quant | |||
| auto *weightData = static_cast<uint8_t *>(oriWeightData); | |||
| for (size_t j = 0; j < wShapeSize; j++) { | |||
| qDatas[j] = (int32_t)weightData[j] - 128; | |||
| } else { // pertensor | |||
| } | |||
| weightQauntParam->zeroPoint -= 128; | |||
| weightTensor->quantParams.clear(); | |||
| weightTensor->quantParams.emplace_back(weightQauntParam.release()); | |||
| } | |||
| weightTensor->data.resize(wShapeSize * sizeof(uint8_t)); | |||
| ::memcpy(weightTensor->data.data(), qDatas.data(), wShapeSize); | |||
| weightTensor->dataType = TypeId::kNumberTypeInt8; | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::DetermineNodeQuantType() { | |||
| @@ -39,24 +39,6 @@ class AwareQuantizer : public FbQuantizer { | |||
| STATUS DetermineNodeQuantType() override; | |||
| STATUS DoQuantize() override; // override; | |||
| private: | |||
| // RemoveFakeQuant | |||
| STATUS SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node); | |||
| STATUS GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph); | |||
| STATUS QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node); | |||
| STATUS QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node); | |||
| STATUS QuantConvBias(const schema::MetaGraphT *graph, schema::CNodeT *node); | |||
| STATUS QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node); | |||
| float inputScale = 0.0f; | |||
| static const std::array<schema::PrimitiveType, 7> propagatedOps; | |||
| }; | |||
| } // namespace mindspore::lite::quant | |||
| #endif | |||
| @@ -26,6 +26,9 @@ | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| namespace mindspore::lite { | |||
| static constexpr size_t BIAS_SIZE = 3; | |||
| static constexpr size_t BIAS_ADD_SIZE = 2; | |||
| STATUS QuantParamCalcer::ComputeConstQuantParam(const schema::TensorT &tensor, QuantParamT *quantParam) { | |||
| MS_ASSERT(quantParam != nullptr); | |||
| // int32 weight no need to quant | |||
| @@ -126,6 +129,36 @@ int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { | |||
| return RET_OK; | |||
| } | |||
| int ConvCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { | |||
| auto status = CommonCalcer::Calc(subGraph, node); | |||
| if (status != RET_OK) { | |||
| MS_LOG(WARNING) << "Call CommonCalcer::Calc failed: " << status; | |||
| return status; | |||
| } | |||
| if (node.inputIndex.size() == BIAS_SIZE) { | |||
| auto &biasTensor = subGraph->allTensors.at(node.inputIndex.at(BIAS_SIZE - 1)); | |||
| for (auto &quantParam : biasTensor->quantParams) { | |||
| quantParam->dstDtype = TypeId::kNumberTypeInt32; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int BiasAddCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { | |||
| auto status = CommonCalcer::Calc(subGraph, node); | |||
| if (status != RET_OK) { | |||
| MS_LOG(WARNING) << "Call CommonCalcer::Calc failed: " << status; | |||
| return status; | |||
| } | |||
| if (node.inputIndex.size() == BIAS_ADD_SIZE) { | |||
| auto &biasTensor = subGraph->allTensors.at(node.inputIndex.at(BIAS_ADD_SIZE - 1)); | |||
| for (auto &quantParam : biasTensor->quantParams) { | |||
| quantParam->dstDtype = TypeId::kNumberTypeInt32; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | |||
| auto status = QuantParamCalcer::Calc(graph, node); | |||
| if (status != RET_OK) { | |||
| @@ -474,10 +507,10 @@ QuantParamCalcRegister::QuantParamCalcRegister() { | |||
| _registerMap[schema::PrimitiveType_Activation] = std::make_shared<CalcActivation>(); | |||
| _registerMap[schema::PrimitiveType_Add] = std::make_shared<CalcAdd>(); | |||
| _registerMap[schema::PrimitiveType_Mul] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_Scale] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_Conv2D] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_DeConv2D] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_Scale] = std::make_shared<ConvCalcer>(); | |||
| _registerMap[schema::PrimitiveType_Conv2D] = std::make_shared<ConvCalcer>(); | |||
| _registerMap[schema::PrimitiveType_DeConv2D] = std::make_shared<ConvCalcer>(); | |||
| _registerMap[schema::PrimitiveType_DepthwiseConv2D] = std::make_shared<ConvCalcer>(); | |||
| _registerMap[schema::PrimitiveType_Pooling] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_Resize] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_Reshape] = linearCalcer; | |||
| @@ -487,11 +520,11 @@ QuantParamCalcRegister::QuantParamCalcRegister() { | |||
| _registerMap[schema::PrimitiveType_Squeeze] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_RealDiv] = std::make_shared<CalcRealDiv>(); | |||
| _registerMap[schema::PrimitiveType_Reduce] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_BiasAdd] = std::make_shared<BiasAddCalcer>(); | |||
| _registerMap[schema::PrimitiveType_Mean] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_Transpose] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_MatMul] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_FullConnection] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_MatMul] = std::make_shared<ConvCalcer>(); | |||
| _registerMap[schema::PrimitiveType_FullConnection] = std::make_shared<ConvCalcer>(); | |||
| _registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer; | |||
| // detection_postprocess op's quant param will not infer only fetch from preNode or postNode | |||
| @@ -46,6 +46,20 @@ class CommonCalcer : public QuantParamCalcer { | |||
| int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override; | |||
| }; | |||
| class ConvCalcer : public CommonCalcer { | |||
| public: | |||
| ConvCalcer() = default; | |||
| ~ConvCalcer() override = default; | |||
| int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override; | |||
| }; | |||
| class BiasAddCalcer : public CommonCalcer { | |||
| public: | |||
| BiasAddCalcer() = default; | |||
| ~BiasAddCalcer() override = default; | |||
| int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override; | |||
| }; | |||
| class LinearCalcer : public QuantParamCalcer { | |||
| public: | |||
| LinearCalcer() = default; | |||
| @@ -564,11 +564,8 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in | |||
| } | |||
| } | |||
| STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct MaxMin *max_min, | |||
| std::shared_ptr<PrimitiveC> lite_primitive) { | |||
| if (!lite_primitive->GetInputQuantParams().empty()) { | |||
| MS_LOG(DEBUG) << "input quant params not empty"; // multi-input op: like concat | |||
| } | |||
| STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, | |||
| std::shared_ptr<PrimitiveC> lite_primitive, const size_t &index) { | |||
| schema::QuantParamT quant_param; | |||
| quant_param.scale = scale; | |||
| quant_param.zeroPoint = zeropoint; | |||
| @@ -577,15 +574,12 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct M | |||
| quant_param.numBits = bit_num; | |||
| quant_param.narrowRange = false; | |||
| std::vector<schema::QuantParamT> quant_params = {quant_param}; | |||
| lite_primitive->AddInputQuantParam(quant_params); | |||
| lite_primitive->SetInputQuantParam(index, quant_params); | |||
| return RET_OK; | |||
| } | |||
| STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min, | |||
| std::shared_ptr<PrimitiveC> lite_primitive) { | |||
| if (!lite_primitive->GetOutputQuantParams().empty()) { | |||
| MS_LOG(DEBUG) << "output quant params not empty"; // multi-output op: like split | |||
| } | |||
| schema::QuantParamT quant_param; | |||
| quant_param.scale = scale; | |||
| quant_param.zeroPoint = zeropoint; | |||
| @@ -593,8 +587,9 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct | |||
| quant_param.min = max_min->min; | |||
| quant_param.numBits = bit_num; | |||
| quant_param.narrowRange = false; | |||
| quant_param.inited = true; | |||
| std::vector<schema::QuantParamT> quant_params = {quant_param}; | |||
| lite_primitive->AddOutputQuantParam(quant_params); | |||
| lite_primitive->SetOutputQuantParam(0, quant_params); | |||
| return RET_OK; | |||
| } | |||
| @@ -647,7 +642,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi | |||
| auto bias_param = std::dynamic_pointer_cast<ParamValueLite>(bias_default_param); | |||
| auto active_weight_quant_params = primitive_c->GetInputQuantParams(); | |||
| if (active_weight_quant_params.size() != 2) { | |||
| if (active_weight_quant_params.size() != 3) { | |||
| MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size(); | |||
| return RET_ERROR; | |||
| } | |||
| @@ -714,7 +709,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi | |||
| double filter_scale = std::abs(raw_datas[i]) / (activate_scale * quanted_bias_abs_limit); | |||
| active_weight_quant_params[1][i].scale = filter_scale; | |||
| active_weight_quant_params[1][i].zeroPoint = 0; | |||
| primitive_c->SetInputQuantParam(active_weight_quant_params); | |||
| primitive_c->SetInputQuantParams(active_weight_quant_params); | |||
| bias_scale_tmp = std::abs(raw_datas[i]) / quanted_bias_abs_limit; | |||
| quant_params[i].scale = bias_scale_tmp; | |||
| MS_LOG(DEBUG) << "new filter scale: " << filter_scale; | |||
| @@ -726,7 +721,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi | |||
| auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp); | |||
| quant_datas[i] = quant_data; | |||
| } | |||
| primitive_c->AddInputQuantParam(quant_params); | |||
| primitive_c->SetInputQuantParam(2, quant_params); | |||
| auto ret = memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed."; | |||
| @@ -834,22 +829,21 @@ STATUS PostTrainingQuantizer::QuantNode() { | |||
| << " PrimitiveC is null"; | |||
| continue; | |||
| } | |||
| if (!input_cnode_primitive_c->GetOutputQuantParams().empty()) { | |||
| for (auto &quant_param : input_cnode_primitive_c->GetOutputQuantParams()) { | |||
| primitive_c->AddInputQuantParam(quant_param); | |||
| } | |||
| if (input_cnode_primitive_c->IsOutputQuantParamsInited()) { | |||
| auto quant_param = input_cnode_primitive_c->GetOutputQuantParams().front(); | |||
| primitive_c->SetInputQuantParam(i - 1, quant_param); | |||
| } else { | |||
| // do input quant | |||
| double scale = input_scale[cnode]; | |||
| int32_t zp = input_zero_point[cnode]; | |||
| DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c); | |||
| DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c, i - 1); | |||
| } | |||
| } | |||
| } else { | |||
| // do input quant | |||
| double scale = input_scale[cnode]; | |||
| int32_t convInputzeropoint = input_zero_point[cnode]; | |||
| DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c); | |||
| DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c, 0); | |||
| // do weight quant | |||
| auto weight = cnode->input(2); | |||
| bool perchannel = per_channel_; | |||
| @@ -106,7 +106,8 @@ class PostTrainingQuantizer : public Quantizer { | |||
| STATUS QuantNode(); | |||
| STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>); | |||
| STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, | |||
| std::shared_ptr<PrimitiveC> lite_primitive, const size_t &index); | |||
| STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>); | |||
| STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel); | |||
| @@ -246,7 +246,7 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl | |||
| MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; | |||
| return RET_ERROR; | |||
| } | |||
| quantParam->inited = true; | |||
| quantParam->inited = false; | |||
| quantParam->min = mMin; | |||
| quantParam->max = mMax; | |||
| quantParam->scale = 0.0f; | |||
| @@ -39,6 +39,7 @@ namespace mindspore { | |||
| namespace lite { | |||
| namespace quant { | |||
| static constexpr size_t UINT8_QUANTIZATION = 8; | |||
| static constexpr size_t WEIGHT_INDEX = 1; | |||
| /** | |||
| * 1. when op's weight size > mWeightSize just skip | |||
| @@ -225,16 +226,16 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||
| } | |||
| variance_dequant = std::sqrt(variance_dequant / one_filter_size); | |||
| variance_raw = std::sqrt(variance_raw / one_filter_size); | |||
| quant_param.var_corr = 1; | |||
| quant_param.varCorr = 1; | |||
| if (variance_raw != 0 && variance_dequant != 0) { | |||
| auto temp_var_corr = variance_raw / variance_dequant; | |||
| if (temp_var_corr > 0 && temp_var_corr < 10) { | |||
| quant_param.var_corr = temp_var_corr; | |||
| quant_param.varCorr = temp_var_corr; | |||
| } else { | |||
| MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr; | |||
| } | |||
| } | |||
| quant_param.mean_corr = average_raw - average_dequant * quant_param.var_corr; | |||
| quant_param.meanCorr = average_raw - average_dequant * quant_param.varCorr; | |||
| } | |||
| quant_params.emplace_back(quant_param); | |||
| } | |||
| @@ -282,7 +283,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||
| MS_LOG(ERROR) << "quant_params empty"; | |||
| return RET_ERROR; | |||
| } | |||
| primitive_c->AddInputQuantParam(quant_params); | |||
| primitive_c->SetInputQuantParam(WEIGHT_INDEX, quant_params); | |||
| return RET_OK; | |||
| } | |||
| @@ -101,8 +101,6 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<schema::QuantParamT> quant_params; | |||
| primitive_c->AddInputQuantParam(quant_params); | |||
| auto status = RET_ERROR; | |||
| if (type_id == kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||
| @@ -143,9 +141,9 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||
| ParameterPtr param_node = nullptr; | |||
| for (size_t i = 1; i < node->size(); i++) { | |||
| auto inputNode = node->input(i); | |||
| if (inputNode->isa<Parameter>() == true) { | |||
| if (inputNode->isa<Parameter>()) { | |||
| param_node = inputNode->cast<ParameterPtr>(); | |||
| if ((param_node != nullptr) && (param_node->has_default() == true)) { | |||
| if ((param_node != nullptr) && param_node->has_default()) { | |||
| param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||
| if ((param_value == nullptr) || (param_value->tensor_size() == 0) || | |||
| (param_value->tensor_addr() == nullptr) || | |||
| @@ -169,8 +167,6 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<schema::QuantParamT> quant_params; | |||
| primitive_c->AddInputQuantParam(quant_params); | |||
| auto status = RET_ERROR; | |||
| if (type_id == kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||
| @@ -619,7 +619,7 @@ STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int3 | |||
| } | |||
| return RET_OK; | |||
| } | |||
| template<typename T> | |||
| template <typename T> | |||
| static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | |||
| int32_t filterH, int32_t filterW) { | |||
| MS_ASSERT(tensor != nullptr); | |||
| @@ -628,7 +628,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||
| MS_LOG(ERROR) << "Dim size invalid"; | |||
| return RET_ERROR; | |||
| } | |||
| std::unique_ptr<T[]> buf(new(std::nothrow) T[count]); | |||
| std::unique_ptr<T[]> buf(new (std::nothrow) T[count]); | |||
| if (buf == nullptr) { | |||
| MS_LOG(ERROR) << "new buf failed"; | |||
| return RET_ERROR; | |||
| @@ -653,18 +653,17 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||
| p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); | |||
| if (type == kCHWK2HWCK) { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kCHWK2KHWC) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| case kKHWC2HWCK: { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| @@ -677,8 +676,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| case kKCHW2HWCK: | |||
| case kKCHW2CKHW: | |||
| case kKCHW2KHWC: | |||
| @@ -690,24 +688,23 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||
| p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| if (type == kKCHW2HWCK) { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kKCHW2KHWC) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } else if (type == kKCHW2CKHW) { | |||
| p2Buff = | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| case kCKHW2HWCK: | |||
| case kCKHW2KHWC: | |||
| case kCKHW2HWKC: { | |||
| @@ -718,21 +715,20 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||
| p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| if (type == kCKHW2HWCK) { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kCKHW2KHWC) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| case kHWCK2KCHW: | |||
| case kHWCK2CKHW: { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| @@ -742,18 +738,17 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| if (type == kHWCK2KCHW) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| case kHWKC2KCHW: | |||
| case kHWKC2CKHW: { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| @@ -763,18 +758,17 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | |||
| if (type == kHWKC2KCHW) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| case kNHWC2HWCK: | |||
| case kNHWC2KCHW: | |||
| case kNHWC2CKHW: { | |||
| @@ -785,21 +779,20 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | |||
| if (type == kNHWC2HWCK) { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kNHWC2CKHW) { | |||
| p2Buff = | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| case kKHWC2CHWK: { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| @@ -812,8 +805,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported transFilterType: " << type; | |||
| return RET_ERROR; | |||
| @@ -828,7 +820,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||
| return RET_OK; | |||
| } | |||
| template<typename T> | |||
| template <typename T> | |||
| static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type) { | |||
| MS_ASSERT(tensor != nullptr); | |||
| auto oriDims = tensor->tensor_shape(); | |||
| @@ -882,6 +874,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kKCHW2KHWC); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| @@ -894,6 +888,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kCKHW2KHWC); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| @@ -906,18 +902,20 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kCHWK2KHWC); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_KHWC:return RET_OK; | |||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||
| << EnumNameFormat(dst_format); | |||
| case schema::Format::Format_KHWC: | |||
| return RET_OK; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| case schema::Format::Format_HWCK: { | |||
| switch (src_format) { | |||
| case schema::Format::Format_KCHW: | |||
| @@ -927,6 +925,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kKCHW2HWCK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| @@ -939,6 +939,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kKHWC2HWCK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| @@ -951,6 +953,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kCKHW2HWCK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| @@ -963,21 +967,24 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kCHWK2HWCK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return lite::RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_HWCK:return RET_OK; | |||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||
| << EnumNameFormat(dst_format); | |||
| case schema::Format::Format_HWCK: | |||
| return RET_OK; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| case schema::Format::Format_KCHW: { | |||
| switch (src_format) { | |||
| case schema::Format::Format_KCHW:return RET_OK; | |||
| case schema::Format::Format_KCHW: | |||
| return RET_OK; | |||
| case schema::Format::Format_HWCK: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kHWCK2KCHW); | |||
| @@ -985,6 +992,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kHWCK2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| @@ -997,6 +1006,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kHWCK2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| @@ -1009,6 +1020,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kKHWC2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| @@ -1021,6 +1034,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kCKHW2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| @@ -1033,17 +1048,18 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kCKHW2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||
| << EnumNameFormat(dst_format); | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| case schema::Format::Format_CKHW: { | |||
| switch (src_format) { | |||
| case schema::Format::Format_HWCK: | |||
| @@ -1053,6 +1069,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kHWCK2CKHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| @@ -1065,6 +1083,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kHWKC2CKHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| @@ -1077,20 +1097,22 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW); | |||
| } else if (data_type == kNumberTypeFloat16) { | |||
| status = TransFilterFormat<float16>(tensor, kKCHW2CKHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_CKHW:return RET_OK; | |||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||
| << EnumNameFormat(dst_format); | |||
| case schema::Format::Format_CKHW: | |||
| return RET_OK; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||
| << EnumNameFormat(dst_format); | |||
| } break; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); | |||
| return RET_ERROR; | |||
| } | |||
| if (status != RET_OK) { | |||
| @@ -155,8 +155,8 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons | |||
| rmatmul_quant_params.pop_back(); | |||
| // no bias quantParams | |||
| rmatmul_quant_params.emplace_back(jointed_quant_params); | |||
| matmul_cvalue->SetInputQuantParam(rmatmul_quant_params); | |||
| matmul_cvalue->SetOutputQuantParam(fc_prim->GetOutputQuantParams()); | |||
| matmul_cvalue->SetInputQuantParams(rmatmul_quant_params); | |||
| matmul_cvalue->SetOutputQuantParams(fc_prim->GetOutputQuantParams()); | |||
| auto matmul_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(matmul_cvalue)); | |||
| std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input}; | |||