| @@ -32,8 +32,9 @@ table QuantParam { | |||||
| narrowRange: bool = true; | narrowRange: bool = true; | ||||
| numBits: int = 8; | numBits: int = 8; | ||||
| inited: bool = false; | inited: bool = false; | ||||
| var_corr: double = 1; | |||||
| mean_corr: double = 0; | |||||
| varCorr: double = 1; | |||||
| meanCorr: double = 0; | |||||
| dstDtype: int = 32; | |||||
| clusters: [float]; | clusters: [float]; | ||||
| } | } | ||||
| @@ -27,6 +27,9 @@ | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "tools/common/option.h" | #include "tools/common/option.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #ifdef ENABLE_ARM64 | |||||
| #include "nnacl/optimized_kernel.h" | |||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -186,6 +189,38 @@ inline Option<bool> GenericParseValue(const std::string &value) { | |||||
| return Option<bool>(None()); | return Option<bool>(None()); | ||||
| } | } | ||||
| using Float16CastFunc = void (*)(const void *, void *, int); | |||||
| class Float16CastUtil { | |||||
| public: | |||||
| static Float16CastUtil *GetInstance() { | |||||
| static Float16CastUtil float16_cast_util; | |||||
| return &float16_cast_util; | |||||
| } | |||||
| private: | |||||
| Float16CastUtil() { | |||||
| #ifdef ENABLE_ARM64 | |||||
| void *fp16_op_handler = Float16Module::GetInstance()->float16_op_handler_; | |||||
| if (fp16_op_handler != nullptr) { | |||||
| dlerror(); | |||||
| *(reinterpret_cast<void **>(&float16_to_float32_func_)) = dlsym(fp16_op_handler, "Float16ToFloat32_fp16_handler"); | |||||
| *(reinterpret_cast<void **>(&float32_to_float16_func_)) = dlsym(fp16_op_handler, "Float32ToFloat16_fp16_handler"); | |||||
| auto dlopen_error = dlerror(); | |||||
| if (dlopen_error != nullptr) { | |||||
| MS_LOG(ERROR) << "load float16 cast func failed! " << dlopen_error << "."; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } | |||||
| ~Float16CastUtil() = default; | |||||
| public: | |||||
| Float16CastFunc float16_to_float32_func_ = nullptr; | |||||
| Float16CastFunc float32_to_float16_func_ = nullptr; | |||||
| }; | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -108,8 +108,9 @@ int LiteSession::ConvertTensors(const lite::Model *model) { | |||||
| QuantArg quant_arg{}; | QuantArg quant_arg{}; | ||||
| quant_arg.scale = quant_params->Get(j)->scale(); | quant_arg.scale = quant_params->Get(j)->scale(); | ||||
| quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint(); | quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint(); | ||||
| quant_arg.var_corr = quant_params->Get(j)->var_corr(); | |||||
| quant_arg.mean_corr = quant_params->Get(j)->mean_corr(); | |||||
| quant_arg.var_corr = quant_params->Get(j)->varCorr(); | |||||
| quant_arg.mean_corr = quant_params->Get(j)->meanCorr(); | |||||
| quant_arg.inited = quant_params->Get(j)->inited(); | |||||
| auto quant_clusters = quant_params->Get(j)->clusters(); | auto quant_clusters = quant_params->Get(j)->clusters(); | ||||
| if (quant_clusters != nullptr) { | if (quant_clusters != nullptr) { | ||||
| for (size_t k = 0; k < quant_clusters->size(); k++) { | for (size_t k = 0; k < quant_clusters->size(); k++) { | ||||
| @@ -49,12 +49,7 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | |||||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | |||||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | |||||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||||
| SetOutputQuantParam(vecOutputQuantParam); | |||||
| } | |||||
| PopulaterQuantParam(prim, inputs); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -277,13 +277,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||||
| PopulaterConv2DSingleGroup(prim, this->primitive_, group); | PopulaterConv2DSingleGroup(prim, this->primitive_, group); | ||||
| } | } | ||||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | |||||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | |||||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | |||||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||||
| SetInputQuantParam(vecInputQuantParam); | |||||
| SetOutputQuantParam(vecOutputQuantParam); | |||||
| } | |||||
| PopulaterQuantParam(prim, inputs); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -254,14 +254,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i | |||||
| } else if (group > 1) { | } else if (group > 1) { | ||||
| PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); | PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); | ||||
| } | } | ||||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | |||||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | |||||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | |||||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||||
| SetInputQuantParam(vecInputQuantParam); | |||||
| SetOutputQuantParam(vecOutputQuantParam); | |||||
| } | |||||
| PopulaterQuantParam(prim, inputs); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| #else | #else | ||||
| @@ -146,14 +146,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode | |||||
| this->primitive_->value.type = schema::PrimitiveType_DepthwiseConv2D; | this->primitive_->value.type = schema::PrimitiveType_DepthwiseConv2D; | ||||
| this->primitive_->value.value = attr.release(); | this->primitive_->value.value = attr.release(); | ||||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | |||||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | |||||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | |||||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||||
| SetInputQuantParam(vecInputQuantParam); | |||||
| SetOutputQuantParam(vecOutputQuantParam); | |||||
| } | |||||
| PopulaterQuantParam(prim, inputs); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -61,13 +61,8 @@ int MatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | |||||
| std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam; | |||||
| std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam; | |||||
| PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); | |||||
| SetInputQuantParam(vecInputQuantParam); | |||||
| SetOutputQuantParam(vecOutputQuantParam); | |||||
| } | |||||
| PopulaterQuantParam(prim, inputs); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -164,32 +164,29 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| void PrimitiveC::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { | |||||
| void PrimitiveC::CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax) { | |||||
| const float qmin = 0; | const float qmin = 0; | ||||
| const float qmax = 255; | const float qmax = 255; | ||||
| *mMin = static_cast<float>((qmin - mean) / stdDev); | *mMin = static_cast<float>((qmin - mean) / stdDev); | ||||
| *mMax = static_cast<float>((qmax - mean) / stdDev); | *mMax = static_cast<float>((qmax - mean) / stdDev); | ||||
| } | } | ||||
| void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| auto narrow_range = prim.GetAttr("narrow_range"); | auto narrow_range = prim.GetAttr("narrow_range"); | ||||
| bool narrowRangeQuantParam = GetValue<bool>(narrow_range); | |||||
| bool narrowRangeQuantParam = narrow_range != nullptr ? GetValue<bool>(narrow_range) : false; | |||||
| auto num_bits = prim.GetAttr("num_bits"); | auto num_bits = prim.GetAttr("num_bits"); | ||||
| int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits); | |||||
| int32_t numbitsRangeQuantParam = num_bits != nullptr ? GetValue<int32_t>(num_bits) : 8; | |||||
| std::vector<schema::QuantParamT> quants; | std::vector<schema::QuantParamT> quants; | ||||
| schema::QuantParamT quantParam; | schema::QuantParamT quantParam; | ||||
| auto mean = prim.GetAttr("mean"); | auto mean = prim.GetAttr("mean"); | ||||
| auto std_dev = prim.GetAttr("std_dev"); | auto std_dev = prim.GetAttr("std_dev"); | ||||
| if (mean != nullptr && std_dev != nullptr) { | if (mean != nullptr && std_dev != nullptr) { | ||||
| auto meanQuantOaram = GetValue<double>(mean); | |||||
| double stddevQuantOaram = GetValue<double>(std_dev); | |||||
| auto meanValue = GetValue<double>(mean); | |||||
| auto stddevValue = GetValue<double>(std_dev); | |||||
| float mMin = 0.0; | float mMin = 0.0; | ||||
| float mMax = 0.0; | float mMax = 0.0; | ||||
| CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); | |||||
| CalFloatScopeByMeanAndStddev(meanValue, stddevValue, &mMin, &mMax); | |||||
| quantParam.min = mMin; | quantParam.min = mMin; | ||||
| quantParam.max = mMax; | quantParam.max = mMax; | ||||
| } else { | } else { | ||||
| @@ -198,8 +195,8 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | |||||
| if (inputMin != nullptr && inputMax != nullptr) { | if (inputMin != nullptr && inputMax != nullptr) { | ||||
| auto inputMinPtr = inputMin->cast<TensorPtr>(); | auto inputMinPtr = inputMin->cast<TensorPtr>(); | ||||
| auto inputMaxPtr = inputMax->cast<TensorPtr>(); | auto inputMaxPtr = inputMax->cast<TensorPtr>(); | ||||
| float *minBuf = static_cast<float *>(inputMinPtr->data_c()); | |||||
| float *maxBuf = static_cast<float *>(inputMaxPtr->data_c()); | |||||
| auto *minBuf = static_cast<float *>(inputMinPtr->data_c()); | |||||
| auto *maxBuf = static_cast<float *>(inputMaxPtr->data_c()); | |||||
| quantParam.min = *minBuf; | quantParam.min = *minBuf; | ||||
| quantParam.max = *maxBuf; | quantParam.max = *maxBuf; | ||||
| } | } | ||||
| @@ -207,7 +204,7 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | ||||
| numbitsRangeQuantParam); | numbitsRangeQuantParam); | ||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| vecInputQuantParam->emplace_back(quants); | |||||
| input_quant_param_.emplace_back(quants); | |||||
| quants.clear(); | quants.clear(); | ||||
| auto filterMin = prim.GetAttr("filter_minq"); | auto filterMin = prim.GetAttr("filter_minq"); | ||||
| @@ -227,17 +224,25 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | |||||
| } | } | ||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, true, numbitsRangeQuantParam); | quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, true, numbitsRangeQuantParam); | ||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| vecInputQuantParam->emplace_back(quants); | |||||
| input_quant_param_.emplace_back(quants); | |||||
| } | } | ||||
| if (vecInputQuantParam->size() == kDoubleNum) { | |||||
| if (input_quant_param_.size() == kDoubleNum) { | |||||
| quants.clear(); | quants.clear(); | ||||
| quantParam.min = 0.0; | quantParam.min = 0.0; | ||||
| quantParam.max = 0.0; | quantParam.max = 0.0; | ||||
| quantParam.zeroPoint = 0; | quantParam.zeroPoint = 0; | ||||
| quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(0).scale; | |||||
| quantParam.scale = input_quant_param_.at(0).at(0).scale * input_quant_param_.at(1).at(0).scale; | |||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| vecInputQuantParam->emplace_back(quants); | |||||
| input_quant_param_.emplace_back(quants); | |||||
| } | |||||
| // fill input_quant_param_ by not inited quant_parm | |||||
| if (input_quant_param_.size() < inputs.size()) { | |||||
| quants.clear(); | |||||
| schema::QuantParamT tmpQuantParam; | |||||
| quants.emplace_back(tmpQuantParam); | |||||
| input_quant_param_.insert(input_quant_param_.end(), inputs.size() - 1 - input_quant_param_.size(), quants); | |||||
| } | } | ||||
| quants.clear(); | quants.clear(); | ||||
| @@ -253,7 +258,11 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | ||||
| numbitsRangeQuantParam); | numbitsRangeQuantParam); | ||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| vecOutputQuantParam->emplace_back(quants); | |||||
| output_quant_param_.emplace_back(quants); | |||||
| } else { | |||||
| schema::QuantParamT tmpQuantParam; | |||||
| quants.emplace_back(tmpQuantParam); | |||||
| output_quant_param_.emplace_back(quants); | |||||
| } | } | ||||
| } | } | ||||
| @@ -279,14 +288,48 @@ schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_; | |||||
| void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; } | void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; } | ||||
| void PrimitiveC::SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) { | |||||
| void PrimitiveC::SetInputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) { | |||||
| this->input_quant_param_ = input_quant_param; | this->input_quant_param_ = input_quant_param; | ||||
| } | } | ||||
| void PrimitiveC::SetOutputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) { | |||||
| void PrimitiveC::SetInputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) { | |||||
| MS_ASSERT(index < this->input_quant_param_.size()); | |||||
| this->input_quant_param_[index] = input_quant_param; | |||||
| } | |||||
| void PrimitiveC::SetOutputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) { | |||||
| this->output_quant_param_ = output_quant_param; | this->output_quant_param_ = output_quant_param; | ||||
| } | } | ||||
| void PrimitiveC::SetOutputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param) { | |||||
| MS_ASSERT(index < this->output_quant_param_.size()); | |||||
| this->output_quant_param_[index] = output_quant_param; | |||||
| } | |||||
| bool PrimitiveC::IsInputQuantParamsInited() { | |||||
| if (this->input_quant_param_.empty()) { | |||||
| return false; | |||||
| } | |||||
| for (auto &quant_param : this->input_quant_param_) { | |||||
| if (!quant_param.front().inited) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool PrimitiveC::IsOutputQuantParamsInited() { | |||||
| if (this->output_quant_param_.empty()) { | |||||
| return false; | |||||
| } | |||||
| for (auto &quant_param : this->output_quant_param_) { | |||||
| if (!quant_param.front().inited) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void PrimitiveC::ClearInputOutputQuantParam() { | void PrimitiveC::ClearInputOutputQuantParam() { | ||||
| input_quant_param_.clear(); | input_quant_param_.clear(); | ||||
| output_quant_param_.clear(); | output_quant_param_.clear(); | ||||
| @@ -88,9 +88,17 @@ class PrimitiveC : public mindspore::Primitive { | |||||
| } | } | ||||
| } | } | ||||
| void SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param); | |||||
| void SetInputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param); | |||||
| void SetOutputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param); | |||||
| void SetInputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param); | |||||
| void SetOutputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param); | |||||
| void SetOutputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param); | |||||
| bool IsInputQuantParamsInited(); | |||||
| bool IsOutputQuantParamsInited(); | |||||
| void ClearInputOutputQuantParam(); | void ClearInputOutputQuantParam(); | ||||
| @@ -120,10 +128,8 @@ class PrimitiveC : public mindspore::Primitive { | |||||
| static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | ||||
| const schema::QuantType &quantType); | const schema::QuantType &quantType); | ||||
| void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam, | |||||
| const std::vector<AnfNodePtr> &inputs); | |||||
| void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); | |||||
| void PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs); | |||||
| void CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax); | |||||
| protected: | protected: | ||||
| virtual int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_ERROR; } | virtual int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_ERROR; } | ||||
| @@ -42,7 +42,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T | |||||
| auto *weight_tensor = inputs.at(kWeightIndex); | auto *weight_tensor = inputs.at(kWeightIndex); | ||||
| // data of second tensor of fc may be nullptr | // data of second tensor of fc may be nullptr | ||||
| auto *restore_data = weight_tensor->data_c(); | auto *restore_data = weight_tensor->data_c(); | ||||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||||
| restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | ||||
| if (dequant_weight == nullptr) { | if (dequant_weight == nullptr) { | ||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | MS_LOG(ERROR) << "dequant data is nullptr."; | ||||
| @@ -53,7 +55,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T | |||||
| auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive); | auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| if (!kernel) { | if (!kernel) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (!weight_tensor->GetQuantParams().empty()) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| @@ -65,13 +67,13 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T | |||||
| delete kernel; | delete kernel; | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| if (!weight_tensor->GetQuantParams().empty()) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| @@ -98,8 +98,9 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { | |||||
| MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found."; | MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto quant_arg = !out_tensors_.front()->GetQuantParams().empty() ? out_tensors_.front()->GetQuantParams().front() | |||||
| : in_tensors_.front()->GetQuantParams().front(); | |||||
| auto quant_arg = out_tensors_.front()->GetQuantParams().front().inited | |||||
| ? out_tensors_.front()->GetQuantParams().front() | |||||
| : in_tensors_.front()->GetQuantParams().front(); | |||||
| int ret = RET_OK; | int ret = RET_OK; | ||||
| if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeFloat32) { | if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeFloat32) { | ||||
| ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, | ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, | ||||
| @@ -140,8 +140,8 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *> | |||||
| auto *weight_tensor = inputs.at(kWeightIndex); | auto *weight_tensor = inputs.at(kWeightIndex); | ||||
| auto *restore_data = weight_tensor->MutableData(); | auto *restore_data = weight_tensor->MutableData(); | ||||
| auto dequant_flag = | |||||
| (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; | |||||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||||
| restore_data != nullptr; | |||||
| if (dequant_flag) { | if (dequant_flag) { | ||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | ||||
| if (dequant_weight == nullptr) { | if (dequant_weight == nullptr) { | ||||
| @@ -182,8 +182,8 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||||
| auto *weight_tensor = inputs.at(kWeightIndex); | auto *weight_tensor = inputs.at(kWeightIndex); | ||||
| auto *restore_data = weight_tensor->MutableData(); | auto *restore_data = weight_tensor->MutableData(); | ||||
| auto dequant_flag = | |||||
| (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; | |||||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||||
| restore_data != nullptr; | |||||
| if (dequant_flag) { | if (dequant_flag) { | ||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | ||||
| if (dequant_weight == nullptr) { | if (dequant_weight == nullptr) { | ||||
| @@ -204,8 +204,8 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor | |||||
| auto *weight_tensor = inputs.at(kWeightIndex); | auto *weight_tensor = inputs.at(kWeightIndex); | ||||
| auto *restore_data = weight_tensor->MutableData(); | auto *restore_data = weight_tensor->MutableData(); | ||||
| auto dequant_flag = | |||||
| (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; | |||||
| auto dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||||
| restore_data != nullptr; | |||||
| if (dequant_flag) { | if (dequant_flag) { | ||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | ||||
| if (dequant_weight == nullptr) { | if (dequant_weight == nullptr) { | ||||
| @@ -216,8 +216,8 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *> | |||||
| auto *weight_tensor = inputs.at(kWeightIndex); | auto *weight_tensor = inputs.at(kWeightIndex); | ||||
| auto *restore_data = weight_tensor->MutableData(); | auto *restore_data = weight_tensor->MutableData(); | ||||
| auto dequant_flag = | |||||
| (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; | |||||
| auto dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||||
| restore_data != nullptr; | |||||
| if (dequant_flag) { | if (dequant_flag) { | ||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | ||||
| if (dequant_weight == nullptr) { | if (dequant_weight == nullptr) { | ||||
| @@ -237,7 +237,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||||
| auto *weight_tensor = inputs.at(kWeightIndex); | auto *weight_tensor = inputs.at(kWeightIndex); | ||||
| // data of second tensor of fc may be nullptr | // data of second tensor of fc may be nullptr | ||||
| auto *restore_data = weight_tensor->data_c(); | auto *restore_data = weight_tensor->data_c(); | ||||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||||
| restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | ||||
| if (dequant_weight == nullptr) { | if (dequant_weight == nullptr) { | ||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | MS_LOG(ERROR) << "dequant data is nullptr."; | ||||
| @@ -250,7 +252,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||||
| auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (!weight_tensor->GetQuantParams().empty()) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| @@ -262,13 +264,13 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| delete kernel; | delete kernel; | ||||
| if (!weight_tensor->GetQuantParams().empty()) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (!weight_tensor->GetQuantParams().empty()) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| @@ -251,7 +251,9 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | auto *weight_tensor = inputs.at(kWeightIndex); | ||||
| auto *restore_data = weight_tensor->data_c(); | auto *restore_data = weight_tensor->data_c(); | ||||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||||
| restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | ||||
| if (dequant_weight == nullptr) { | if (dequant_weight == nullptr) { | ||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | MS_LOG(ERROR) << "dequant data is nullptr."; | ||||
| @@ -263,7 +265,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||||
| auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| @@ -275,13 +277,13 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| delete kernel; | delete kernel; | ||||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| @@ -284,7 +284,9 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||||
| auto *weight_tensor = inputs.at(kWeightIndex); | auto *weight_tensor = inputs.at(kWeightIndex); | ||||
| auto *restore_data = weight_tensor->MutableData(); | auto *restore_data = weight_tensor->MutableData(); | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||||
| restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | ||||
| if (dequant_weight == nullptr) { | if (dequant_weight == nullptr) { | ||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | MS_LOG(ERROR) << "dequant data is nullptr."; | ||||
| @@ -303,7 +305,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| @@ -315,14 +317,14 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||||
| delete kernel; | delete kernel; | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_)); | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| @@ -124,6 +124,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| auto *weight_tensor = inputs.at(kWeightIndex); | auto *weight_tensor = inputs.at(kWeightIndex); | ||||
| auto *restore_data = weight_tensor->MutableData(); | auto *restore_data = weight_tensor->MutableData(); | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | ||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | ||||
| if (dequant_weight == nullptr) { | if (dequant_weight == nullptr) { | ||||
| @@ -234,7 +234,9 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | auto *weight_tensor = inputs.at(kWeightIndex); | ||||
| auto *restore_data = weight_tensor->MutableData(); | auto *restore_data = weight_tensor->MutableData(); | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||||
| restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | ||||
| if (dequant_weight == nullptr) { | if (dequant_weight == nullptr) { | ||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | MS_LOG(ERROR) << "dequant data is nullptr."; | ||||
| @@ -255,7 +257,7 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| @@ -267,14 +269,14 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| delete kernel; | delete kernel; | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| @@ -196,7 +196,9 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | auto *weight_tensor = inputs.at(kWeightIndex); | ||||
| auto *restore_data = weight_tensor->MutableData(); | auto *restore_data = weight_tensor->MutableData(); | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited && | |||||
| restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | ||||
| if (dequant_weight == nullptr) { | if (dequant_weight == nullptr) { | ||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | MS_LOG(ERROR) << "dequant data is nullptr."; | ||||
| @@ -209,7 +211,7 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor | |||||
| new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); | new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| @@ -221,13 +223,13 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor | |||||
| delete kernel; | delete kernel; | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| @@ -208,9 +208,8 @@ kernel::LiteKernel *CpuMatmulInt8KernelCreator(const std::vector<lite::Tensor *> | |||||
| auto *weight_tensor = inputs.at(kWeightIndex); | auto *weight_tensor = inputs.at(kWeightIndex); | ||||
| auto *restore_data = weight_tensor->data_c(); | auto *restore_data = weight_tensor->data_c(); | ||||
| auto is_const_quant_weight = | |||||
| (restore_data != nullptr) && | |||||
| ((weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16)); | |||||
| bool is_const_quant_weight = !weight_tensor->GetQuantParams().empty() && | |||||
| weight_tensor->GetQuantParams().front().inited && restore_data != nullptr; | |||||
| if (is_const_quant_weight) { | if (is_const_quant_weight) { | ||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | ||||
| if (dequant_weight == nullptr) { | if (dequant_weight == nullptr) { | ||||
| @@ -25,36 +25,6 @@ using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_INFER_ERR; | using mindspore::lite::RET_INFER_ERR; | ||||
| using mindspore::lite::RET_INFER_INVALID; | using mindspore::lite::RET_INFER_INVALID; | ||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| using Float16CastFunc = void (*)(const void *, void *, int); | |||||
| class Float16CastUtil { | |||||
| public: | |||||
| static Float16CastUtil *GetInstance() { | |||||
| static Float16CastUtil float16_cast_util; | |||||
| return &float16_cast_util; | |||||
| } | |||||
| private: | |||||
| Float16CastUtil() { | |||||
| #ifdef ENABLE_ARM64 | |||||
| void *fp16_op_handler = Float16Module::GetInstance()->float16_op_handler_; | |||||
| if (fp16_op_handler != nullptr) { | |||||
| dlerror(); | |||||
| *(reinterpret_cast<void **>(&float16_to_float32_func_)) = dlsym(fp16_op_handler, "Float16ToFloat32_fp16_handler"); | |||||
| *(reinterpret_cast<void **>(&float32_to_float16_func_)) = dlsym(fp16_op_handler, "Float32ToFloat16_fp16_handler"); | |||||
| auto dlopen_error = dlerror(); | |||||
| if (dlopen_error != nullptr) { | |||||
| MS_LOG(ERROR) << "load float16 cast func failed! " << dlopen_error << "."; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } | |||||
| ~Float16CastUtil() = default; | |||||
| public: | |||||
| Float16CastFunc float16_to_float32_func_ = nullptr; | |||||
| Float16CastFunc float32_to_float16_func_ = nullptr; | |||||
| }; | |||||
| int SubGraphKernel::Prepare() { | int SubGraphKernel::Prepare() { | ||||
| for (auto node : this->nodes_) { | for (auto node : this->nodes_) { | ||||
| @@ -208,7 +178,7 @@ int CpuFp16SubGraph::PreProcess() { | |||||
| } | } | ||||
| int CpuFp16SubGraph::PostProcess() { | int CpuFp16SubGraph::PostProcess() { | ||||
| auto fp16_to_fp32_cast_func = Float16CastUtil::GetInstance()->float16_to_float32_func_; | |||||
| auto fp16_to_fp32_cast_func = lite::Float16CastUtil::GetInstance()->float16_to_float32_func_; | |||||
| if (fp16_to_fp32_cast_func == nullptr) { | if (fp16_to_fp32_cast_func == nullptr) { | ||||
| MS_LOG(ERROR) << "Can not find cast fp16 to fp32 func"; | MS_LOG(ERROR) << "Can not find cast fp16 to fp32 func"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -35,6 +35,7 @@ struct QuantArg { | |||||
| int32_t zeroPoint; | int32_t zeroPoint; | ||||
| double var_corr{1}; | double var_corr{1}; | ||||
| double mean_corr{0}; | double mean_corr{0}; | ||||
| bool inited; | |||||
| std::vector<float> clusters{}; | std::vector<float> clusters{}; | ||||
| }; | }; | ||||
| @@ -33,6 +33,6 @@ lite-model_on_device_vision_classifier_landmarks_classifier_oceania_antarctica_V | |||||
| lite-model_on_device_vision_classifier_landmarks_classifier_europe_V1_1.tflite | lite-model_on_device_vision_classifier_landmarks_classifier_europe_V1_1.tflite | ||||
| lite-model_on_device_vision_classifier_landmarks_classifier_south_america_V1_1.tflite | lite-model_on_device_vision_classifier_landmarks_classifier_south_america_V1_1.tflite | ||||
| vision_classifier_fungi_mobile_V1_1_default_1.tflite | vision_classifier_fungi_mobile_V1_1_default_1.tflite | ||||
| detect.tflite | |||||
| ssd_mobilenet_v1_1_default_1.tflite | |||||
| object_detection_mobile_object_localizer_v1_1_default_1.tflite | |||||
| #detect.tflite | |||||
| #ssd_mobilenet_v1_1_default_1.tflite | |||||
| #object_detection_mobile_object_localizer_v1_1_default_1.tflite | |||||
| @@ -121,8 +121,8 @@ function Run_Converter() { | |||||
| continue | continue | ||||
| fi | fi | ||||
| echo ${model_name} >> "${run_converter_log_file}" | echo ${model_name} >> "${run_converter_log_file}" | ||||
| echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}' --quantType=AwareTraining' >> "${run_converter_log_file}" | |||||
| ./converter_lite --fmk=TFLITE --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name} --quantType=AwareTraining | |||||
| echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}' --inputDataType=FLOAT --outputDataType=FLOAT' >> "${run_converter_log_file}" | |||||
| ./converter_lite --fmk=TFLITE --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name} --inputDataType=FLOAT --outputDataType=FLOAT | |||||
| if [ $? = 0 ]; then | if [ $? = 0 ]; then | ||||
| converter_result='converter aware_training '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} | converter_result='converter aware_training '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} | ||||
| else | else | ||||
| @@ -544,6 +544,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||||
| } else { | } else { | ||||
| auto ms_tensor = new schema::TensorT(); | auto ms_tensor = new schema::TensorT(); | ||||
| ms_tensor->nodeType = schema::NodeType_CNode; | ms_tensor->nodeType = schema::NodeType_CNode; | ||||
| ms_tensor->dataType = TypeId::kNumberTypeFloat32; | |||||
| fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | ||||
| node_id_map_[cnode_name] = meta_graphT->allTensors.size(); | node_id_map_[cnode_name] = meta_graphT->allTensors.size(); | ||||
| meta_graphT->allTensors.emplace_back(ms_tensor); | meta_graphT->allTensors.emplace_back(ms_tensor); | ||||
| @@ -73,30 +73,30 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s | |||||
| auto primitiveCValue = PrimitiveC::Create(cNode->primitive.release()); | auto primitiveCValue = PrimitiveC::Create(cNode->primitive.release()); | ||||
| cNode->primitive = nullptr; | cNode->primitive = nullptr; | ||||
| // add quant parameter | // add quant parameter | ||||
| if (cNode->quantType != schema::QuantType_PostTraining && cNode->quantType != schema::QuantType_WeightQuant) { | |||||
| primitiveCValue->SetQuantType(cNode->quantType); | |||||
| for (int index : cNode->inputIndex) { | |||||
| if (!meta_graph_->allTensors[index]->quantParams.empty()) { | |||||
| std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size()); | |||||
| std::transform( | |||||
| meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), | |||||
| quant_params.begin(), | |||||
| [](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; }); | |||||
| primitiveCValue->AddInputQuantParam(quant_params); | |||||
| } else { | |||||
| std::vector<schema::QuantParamT> empty_quant_params; | |||||
| primitiveCValue->AddInputQuantParam(empty_quant_params); | |||||
| } | |||||
| for (auto index : cNode->inputIndex) { | |||||
| if (!meta_graph_->allTensors[index]->quantParams.empty()) { | |||||
| std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size()); | |||||
| std::transform( | |||||
| meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), | |||||
| quant_params.begin(), | |||||
| [](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; }); | |||||
| primitiveCValue->AddInputQuantParam(quant_params); | |||||
| } else { | |||||
| std::vector<schema::QuantParamT> notinited_quant_params(1); | |||||
| primitiveCValue->AddInputQuantParam(notinited_quant_params); | |||||
| } | } | ||||
| for (int index : cNode->outputIndex) { | |||||
| if (!meta_graph_->allTensors[index]->quantParams.empty()) { | |||||
| std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size()); | |||||
| std::transform( | |||||
| meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), | |||||
| quant_params.begin(), | |||||
| [](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; }); | |||||
| primitiveCValue->AddOutputQuantParam(quant_params); | |||||
| } | |||||
| } | |||||
| for (auto index : cNode->outputIndex) { | |||||
| if (!meta_graph_->allTensors[index]->quantParams.empty()) { | |||||
| std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size()); | |||||
| std::transform( | |||||
| meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), | |||||
| quant_params.begin(), | |||||
| [](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; }); | |||||
| primitiveCValue->AddOutputQuantParam(quant_params); | |||||
| } else { | |||||
| std::vector<schema::QuantParamT> notinited_quant_params(1); | |||||
| primitiveCValue->AddOutputQuantParam(notinited_quant_params); | |||||
| } | } | ||||
| } | } | ||||
| auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitiveCValue)); | auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitiveCValue)); | ||||
| @@ -52,16 +52,13 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); | auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); | ||||
| // for now - trainning is not supporting fuse operations | // for now - trainning is not supporting fuse operations | ||||
| if (config != nullptr && config->trainModel == false) { | |||||
| if (config != nullptr && !config->trainModel) { | |||||
| // remove quantdtype when awaretraining | // remove quantdtype when awaretraining | ||||
| if (config->fmk == lite::converter::FmkType_ONNX) { | if (config->fmk == lite::converter::FmkType_ONNX) { | ||||
| auto remove_identity_pass = std::make_shared<opt::RemoveIdentityOpPass>(); | auto remove_identity_pass = std::make_shared<opt::RemoveIdentityOpPass>(); | ||||
| remove_identity_pass->SetFmkType(config->fmk); | remove_identity_pass->SetFmkType(config->fmk); | ||||
| pm->AddPass(remove_identity_pass); | pm->AddPass(remove_identity_pass); | ||||
| } | } | ||||
| if (config->quantType == QuantType_AwareTraining) { | |||||
| pm->AddPass(std::make_shared<opt::QuantDtypeCastFusion>()); | |||||
| } | |||||
| pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>()); | pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>()); | pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::ConvScaleFusion>()); | pm->AddPass(std::make_shared<opt::ConvScaleFusion>()); | ||||
| @@ -101,27 +98,25 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // quant | // quant | ||||
| if (config != nullptr) { | |||||
| if (config->quantType == schema::QuantType_PostTraining) { | |||||
| this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8); | |||||
| if (mQuantizer == nullptr) { | |||||
| MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||||
| return nullptr; | |||||
| } | |||||
| } else if (config->quantType == schema::QuantType_WeightQuant) { | |||||
| if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { | |||||
| MS_LOG(ERROR) << "weight quant input param error"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||||
| return nullptr; | |||||
| } | |||||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantWeightSize, | |||||
| config->quantWeightChannel, config->bitNum); | |||||
| if (mQuantizer == nullptr) { | |||||
| MS_LOG(ERROR) << "New WeightQuantizer failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||||
| return nullptr; | |||||
| } | |||||
| if (config->quantType == schema::QuantType_PostTraining) { | |||||
| this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8); | |||||
| if (mQuantizer == nullptr) { | |||||
| MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||||
| return nullptr; | |||||
| } | |||||
| } else if (config->quantType == schema::QuantType_WeightQuant) { | |||||
| if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { | |||||
| MS_LOG(ERROR) << "weight quant input param error"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||||
| return nullptr; | |||||
| } | |||||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantWeightSize, | |||||
| config->quantWeightChannel, config->bitNum); | |||||
| if (mQuantizer == nullptr) { | |||||
| MS_LOG(ERROR) << "New WeightQuantizer failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||||
| return nullptr; | |||||
| } | } | ||||
| } | } | ||||
| if (mQuantizer != nullptr) { | if (mQuantizer != nullptr) { | ||||
| @@ -93,7 +93,6 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||||
| } | } | ||||
| // transform | // transform | ||||
| transform->SetGraphDef(meta_graph); | transform->SetGraphDef(meta_graph); | ||||
| transform->CreateQuantizer(flag); | |||||
| auto status = transform->Transform(*flag); | auto status = transform->Transform(*flag); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Transform meta graph failed " << status; | MS_LOG(ERROR) << "Transform meta graph failed " << status; | ||||
| @@ -29,9 +29,14 @@ Flags::Flags() { | |||||
| AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); | AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); | ||||
| AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel", | AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel", | ||||
| ""); | ""); | ||||
| AddFlag(&Flags::inferenceTypeIn, "inferenceType", "Data type of input and output tensors. FLOAT | INT8 | UINT8", | |||||
| "FLOAT"); | |||||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", ""); | |||||
| AddFlag(&Flags::inputDataTypeIn, "inputDataType", | |||||
| "Data type of input tensors, default is same with the type defined in model. FLOAT | INT8 | UINT8 | DEFAULT", | |||||
| "DEFAULT"); | |||||
| AddFlag(&Flags::outputDataTypeIn, "outputDataType", | |||||
| "Data type of output and output tensors, default is same with the type defined in model. FLOAT | INT8 | " | |||||
| "UINT8 | DEFAULT", | |||||
| "DEFAULT"); | |||||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", ""); | |||||
| AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8"); | AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8"); | ||||
| AddFlag(&Flags::quantWeightSize, "quantWeightSize", "Weight quantization size threshold", "0"); | AddFlag(&Flags::quantWeightSize, "quantWeightSize", "Weight quantization size threshold", "0"); | ||||
| AddFlag(&Flags::quantWeightChannel, "quantWeightChannel", "Channel threshold for weight quantization", "16"); | AddFlag(&Flags::quantWeightChannel, "quantWeightChannel", "Channel threshold for weight quantization", "16"); | ||||
| @@ -78,15 +83,32 @@ int Flags::Init(int argc, const char **argv) { | |||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| if (this->inferenceTypeIn == "FLOAT") { | |||||
| this->inferenceType = TypeId::kNumberTypeFloat; | |||||
| } else if (this->inferenceTypeIn == "INT8") { | |||||
| this->inferenceType = TypeId::kNumberTypeInt8; | |||||
| } else if (this->inferenceTypeIn == "UINT8") { | |||||
| this->inferenceType = TypeId::kNumberTypeUInt8; | |||||
| if (this->inputDataTypeIn == "FLOAT") { | |||||
| this->inputDataType = TypeId::kNumberTypeFloat; | |||||
| } else if (this->inputDataTypeIn == "INT8") { | |||||
| this->inputDataType = TypeId::kNumberTypeInt8; | |||||
| } else if (this->inputDataTypeIn == "UINT8") { | |||||
| this->inputDataType = TypeId::kNumberTypeUInt8; | |||||
| } else if (this->inputDataTypeIn == "DEFAULT") { | |||||
| this->inputDataType = TypeId::kTypeUnknown; | |||||
| } else { | } else { | ||||
| std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | UINT8", | |||||
| this->inferenceTypeIn.c_str(); | |||||
| std::cerr << "INPUT INVALID: inputDataType is invalid: %s, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT", | |||||
| this->inputDataTypeIn.c_str(); | |||||
| return RET_INPUT_PARAM_INVALID; | |||||
| } | |||||
| if (this->outputDataTypeIn == "FLOAT") { | |||||
| this->outputDataType = TypeId::kNumberTypeFloat; | |||||
| } else if (this->outputDataTypeIn == "INT8") { | |||||
| this->outputDataType = TypeId::kNumberTypeInt8; | |||||
| } else if (this->outputDataTypeIn == "UINT8") { | |||||
| this->outputDataType = TypeId::kNumberTypeUInt8; | |||||
| } else if (this->outputDataTypeIn == "DEFAULT") { | |||||
| this->outputDataType = TypeId::kTypeUnknown; | |||||
| } else { | |||||
| std::cerr | |||||
| << "INPUT INVALID: outputDataType is invalid: %s, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT", | |||||
| this->outputDataTypeIn.c_str(); | |||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| @@ -107,9 +129,8 @@ int Flags::Init(int argc, const char **argv) { | |||||
| std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag"; | std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag"; | ||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| if (this->quantTypeIn == "AwareTraining") { | |||||
| this->quantType = QuantType_AwareTraining; | |||||
| } else if (this->quantTypeIn == "WeightQuant") { | |||||
| if (this->quantTypeIn == "WeightQuant") { | |||||
| this->quantType = QuantType_WeightQuant; | this->quantType = QuantType_WeightQuant; | ||||
| } else if (this->quantTypeIn == "PostTraining") { | } else if (this->quantTypeIn == "PostTraining") { | ||||
| this->quantType = QuantType_PostTraining; | this->quantType = QuantType_PostTraining; | ||||
| @@ -53,8 +53,11 @@ class Flags : public virtual mindspore::lite::FlagParser { | |||||
| std::string quantTypeIn; | std::string quantTypeIn; | ||||
| QuantType quantType; | QuantType quantType; | ||||
| std::string inferenceTypeIn; | std::string inferenceTypeIn; | ||||
| std::string inputDataTypeIn; | |||||
| std::string outputDataTypeIn; | |||||
| // used for parse aware trainning | // used for parse aware trainning | ||||
| TypeId inferenceType = TypeId::kNumberTypeFloat; | |||||
| TypeId inputDataType; | |||||
| TypeId outputDataType; | |||||
| // used for post-trainning-weight | // used for post-trainning-weight | ||||
| std::string quantWeightSize; | std::string quantWeightSize; | ||||
| std::string bitNum; | std::string bitNum; | ||||
| @@ -34,6 +34,8 @@ | |||||
| #include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h" | #include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h" | ||||
| #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" | #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" | ||||
| #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" | #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" | ||||
| #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" | |||||
| #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" | |||||
| #include "tools/converter/quantizer/aware_quantizer.h" | #include "tools/converter/quantizer/aware_quantizer.h" | ||||
| using std::string; | using std::string; | ||||
| @@ -44,20 +46,6 @@ GraphDefTransform::~GraphDefTransform() = default; | |||||
| void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; } | void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; } | ||||
| void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { | |||||
| auto type = flags->quantType; | |||||
| switch (type) { | |||||
| case QuantType::QuantType_AwareTraining: { | |||||
| MS_LOG(INFO) << "create AwareTrainingQuantizer!"; | |||||
| fbQuantizer = std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inferenceType); | |||||
| break; | |||||
| } | |||||
| default: | |||||
| MS_LOG(INFO) << "will support quantizer type " << flags->quantTypeIn << " in the future"; | |||||
| break; | |||||
| } | |||||
| } | |||||
| int GraphDefTransform::Transform(const converter::Flags &ctx) { | int GraphDefTransform::Transform(const converter::Flags &ctx) { | ||||
| STATUS status; | STATUS status; | ||||
| { | { | ||||
| @@ -84,26 +72,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| // generate and infer quant parameters | // generate and infer quant parameters | ||||
| { | { | ||||
| if (fbQuantizer != nullptr) { | |||||
| Optimizer topologicalOptimizer; | |||||
| topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| status = topologicalOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| if (ctx.quantType == QuantType_AwareTraining) { | |||||
| status = fbQuantizer->GenerateQuantParam(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "GenerateQuantParam failed"; | |||||
| return status; | |||||
| } | |||||
| status = fbQuantizer->DetermineNodeQuantType(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DetermineNodeQuant failed"; | |||||
| return status; | |||||
| } | |||||
| } | |||||
| Optimizer inferQuantParamPass; | |||||
| inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); | |||||
| status = inferQuantParamPass.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | } | ||||
| } | } | ||||
| @@ -146,12 +121,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| } | } | ||||
| } | } | ||||
| { | { | ||||
| Optimizer fusionOptimizer; | |||||
| fusionOptimizer.AddPass(new (std::nothrow) FormatTransPermuteFusionPass()); | |||||
| fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| status = fusionOptimizer.Run(graphDefT); | |||||
| Optimizer inferQuantParamOtimizer; | |||||
| inferQuantParamOtimizer.AddPass(new (std::nothrow) InferQuantParamPass()); | |||||
| status = inferQuantParamOtimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | if (status != RET_OK && status != RET_NO_CHANGE) { | ||||
| MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; | |||||
| MS_LOG(ERROR) << "Run tensorQuantOptimizer graphPasses Failed"; | |||||
| return status; | return status; | ||||
| } | } | ||||
| } | } | ||||
| @@ -168,8 +142,10 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| } | } | ||||
| // do quantization | // do quantization | ||||
| if (fbQuantizer != nullptr) { | |||||
| status = fbQuantizer->DoQuantize(); | |||||
| { | |||||
| Optimizer fusionOptimizer; | |||||
| fusionOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); | |||||
| status = fusionOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "DoQuantize failed!"; | MS_LOG(ERROR) << "DoQuantize failed!"; | ||||
| return status; | return status; | ||||
| @@ -177,11 +153,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| } | } | ||||
| // insert quantNode and deQuantNode | // insert quantNode and deQuantNode | ||||
| if (ctx.quantType == QuantType_AwareTraining) { | |||||
| { | |||||
| Optimizer quantNodeOptimizer; | Optimizer quantNodeOptimizer; | ||||
| auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | ||||
| dTypeTransPass->SetInputDataDType(ctx.inferenceType); | |||||
| dTypeTransPass->SetOutputDataDType(ctx.inferenceType); | |||||
| dTypeTransPass->SetInputDataDType(ctx.inputDataType); | |||||
| dTypeTransPass->SetOutputDataDType(ctx.outputDataType); | |||||
| quantNodeOptimizer.AddPass(dTypeTransPass); | quantNodeOptimizer.AddPass(dTypeTransPass); | ||||
| quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | ||||
| quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | ||||
| @@ -37,14 +37,10 @@ class GraphDefTransform { | |||||
| virtual int Transform(const converter::Flags &ctx); | virtual int Transform(const converter::Flags &ctx); | ||||
| void SetGraphDef(schema::MetaGraphT *dstDef); | void SetGraphDef(schema::MetaGraphT *dstDef); | ||||
| inline schema::MetaGraphT *GetOutput() { return graphDefT; } | inline schema::MetaGraphT *GetOutput() { return graphDefT; } | ||||
| void CreateQuantizer(const converter::Flags *flags); | |||||
| protected: | protected: | ||||
| schema::MetaGraphT *graphDefT = nullptr; | schema::MetaGraphT *graphDefT = nullptr; | ||||
| Optimizer *optimizer = nullptr; | Optimizer *optimizer = nullptr; | ||||
| std::unique_ptr<quant::Quantizer> mQuantizer; | |||||
| std::unique_ptr<quant::FbQuantizer> fbQuantizer; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -10,6 +10,8 @@ file(GLOB GRAPH_PASS | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc | ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/trans_format_remove_pass.cc | ${CMAKE_CURRENT_SOURCE_DIR}/trans_format_remove_pass.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc | ${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/tensor_quant_pass.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/infer_quant_param_pass.cc | |||||
| ) | ) | ||||
| set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | ||||
| add_library(graph_pass_mid OBJECT ${GRAPH_PASS}) | add_library(graph_pass_mid OBJECT ${GRAPH_PASS}) | ||||
| @@ -27,9 +27,6 @@ namespace lite { | |||||
| #define kMinInputNum 1 | #define kMinInputNum 1 | ||||
| #define kOutputNum 1 | #define kOutputNum 1 | ||||
| static const std::set<schema::PrimitiveType> NoNeedDtypeTransList = { | |||||
| PrimitiveType_QuantDTypeCast, PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; | |||||
| STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { | STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| @@ -44,12 +41,6 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { | |||||
| MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status; | MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status; | ||||
| return status; | return status; | ||||
| } | } | ||||
| status = DoNodeInoutDTypeTrans(graph); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status; | |||||
| return status; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -57,7 +48,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| auto &graphInIdxes = graph->inputIndex; | auto &graphInIdxes = graph->inputIndex; | ||||
| if (this->inputDataDType == TypeId::kNumberTypeInt8) { | |||||
| if (this->inputDataDType == TypeId::kNumberTypeInt8 || this->inputDataDType == TypeId::kTypeUnknown) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| if (this->inputDataDType != TypeId::kNumberTypeFloat && this->inputDataDType != TypeId::kNumberTypeUInt8) { | if (this->inputDataDType != TypeId::kNumberTypeFloat && this->inputDataDType != TypeId::kNumberTypeUInt8) { | ||||
| @@ -68,7 +59,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||||
| for (auto graphInIdx : graphInIdxes) { | for (auto graphInIdx : graphInIdxes) { | ||||
| MS_ASSERT(graphInIdx < graph->allTensors.size()); | MS_ASSERT(graphInIdx < graph->allTensors.size()); | ||||
| auto &tensor = graph->allTensors.at(graphInIdx); | auto &tensor = graph->allTensors.at(graphInIdx); | ||||
| if (tensor->dims.size() != kNHWCDimNumber || tensor->dataType != kNumberTypeInt8) { | |||||
| if (tensor->dataType != kNumberTypeInt8 || tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -98,7 +89,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||||
| STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| if (outputDataDType == TypeId::kNumberTypeInt8) { | |||||
| if (outputDataDType == TypeId::kNumberTypeInt8 || outputDataDType == TypeId::kTypeUnknown) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| if (this->outputDataDType != TypeId::kNumberTypeFloat && this->outputDataDType != TypeId::kNumberTypeUInt8) { | if (this->outputDataDType != TypeId::kNumberTypeFloat && this->outputDataDType != TypeId::kNumberTypeUInt8) { | ||||
| @@ -107,6 +98,11 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||||
| } | } | ||||
| auto &graphOutIdxes = graph->outputIndex; | auto &graphOutIdxes = graph->outputIndex; | ||||
| for (auto graphOutIdx : graphOutIdxes) { | for (auto graphOutIdx : graphOutIdxes) { | ||||
| MS_ASSERT(graphOutIdx < graph->allTensors.size()); | |||||
| auto &tensor = graph->allTensors.at(graphOutIdx); | |||||
| if (tensor->dataType != kNumberTypeInt8 || tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||||
| continue; | |||||
| } | |||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | ||||
| auto nodeName = (*iter)->name; | auto nodeName = (*iter)->name; | ||||
| MS_ASSERT(node != nullptr); | MS_ASSERT(node != nullptr); | ||||
| @@ -131,67 +127,6 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||||
| MS_ASSERT(graph != nullptr); | |||||
| // insert transNode before and after existNode | |||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||||
| if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { | |||||
| continue; | |||||
| } | |||||
| auto iterType = GetCNodeTType(**iter); | |||||
| if (NoNeedDtypeTransList.find(iterType) != NoNeedDtypeTransList.end()) { | |||||
| continue; | |||||
| } | |||||
| bool needInsertPost = true; | |||||
| if (GetCNodeTType(**iter) == PrimitiveType_Shape) { | |||||
| needInsertPost = false; | |||||
| } | |||||
| auto nodeName = (*iter)->name; | |||||
| if ((*iter)->inputIndex.size() < kMinInputNum) { | |||||
| MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| STATUS status; | |||||
| // insert pre | |||||
| for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) { | |||||
| MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i)); | |||||
| auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i)); | |||||
| if (preTensor->dataType == TypeId::kNumberTypeInt || preTensor->dataType == TypeId::kNumberTypeInt32) { | |||||
| continue; | |||||
| } | |||||
| auto &graphInIdxes = graph->inputIndex; | |||||
| if (!preTensor->data.empty() && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { | |||||
| continue; | |||||
| } | |||||
| if ((preTensor->dataType != TypeId::kNumberTypeInt8) && (IsContain(graphInIdxes, (*iter)->inputIndex.at(i)))) { | |||||
| continue; | |||||
| } | |||||
| iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| if (needInsertPost) { | |||||
| for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) { | |||||
| auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i)); | |||||
| if (postTensor->dataType == TypeId::kNumberTypeInt || postTensor->dataType == TypeId::kNumberTypeInt32) { | |||||
| continue; | |||||
| } | |||||
| iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| (*iter)->quantType = QuantType_QUANT_NONE; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, | NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, | ||||
| size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) { | size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) { | ||||
| MS_ASSERT((*existNodeIter) != nullptr); | MS_ASSERT((*existNodeIter) != nullptr); | ||||
| @@ -45,8 +45,6 @@ class DTypeTransPass : public GraphPass { | |||||
| STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph); | STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph); | ||||
| STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph); | |||||
| NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, | NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, | ||||
| DTypeTransNodeType nodeType, STATUS *errorCode); | DTypeTransNodeType nodeType, STATUS *errorCode); | ||||
| @@ -0,0 +1,87 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "src/common/utils.h" | |||||
| #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" | |||||
| #include "tools/converter/quantizer/calc_quant_param.h" | |||||
| #include "tools/common/node_util.h" | |||||
| #include "tools/common/converter_op_utils.h" | |||||
| namespace mindspore::lite { | |||||
| STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) { | |||||
| auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); | |||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||||
| auto &node = *iter; | |||||
| MS_ASSERT(node != nullptr); | |||||
| if (node->quantType == schema::QuantType_WeightQuant) { | |||||
| continue; | |||||
| } | |||||
| if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax || | |||||
| GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { | |||||
| MS_ASSERT(false); | |||||
| } | |||||
| auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); | |||||
| if (quantParamCalcer == nullptr) { | |||||
| MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str() | |||||
| << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; | |||||
| node->quantType = static_cast<schema::QuantType>(schema::QuantType_QUANT_NONE); | |||||
| } else { | |||||
| auto status = quantParamCalcer->Calc(graph, *node); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); | |||||
| node->quantType = schema::QuantType_QUANT_NONE; | |||||
| } else { | |||||
| DetermineNodeQuantType(*graph, node.get()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| void InferQuantParamPass::DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode) { | |||||
| MS_ASSERT(cnode != nullptr); | |||||
| bool canQuant = true; | |||||
| for (auto &inputTensorIdx : cnode->inputIndex) { | |||||
| MS_ASSERT(graph.allTensors.size() > inputTensorIdx); | |||||
| auto &inTensor = graph.allTensors.at(inputTensorIdx); | |||||
| MS_ASSERT(inTensor != nullptr); | |||||
| if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr || | |||||
| !inTensor->quantParams.front()->inited) { | |||||
| canQuant = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| for (auto &outTensorIdx : cnode->outputIndex) { | |||||
| MS_ASSERT(graph.allTensors.size() > outTensorIdx); | |||||
| auto &outTensor = graph.allTensors.at(outTensorIdx); | |||||
| MS_ASSERT(outTensor != nullptr); | |||||
| if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || | |||||
| !outTensor->quantParams.front()->inited) { | |||||
| canQuant = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*cnode))) { | |||||
| cnode->quantType = schema::QuantType_AwareTraining; | |||||
| } else { | |||||
| cnode->quantType = schema::QuantType_QUANT_NONE; | |||||
| } | |||||
| } | |||||
| } // namespace mindspore::lite | |||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_INFER_QUANT_PARAM_PASS_H | |||||
| #define LITE_INFER_QUANT_PARAM_PASS_H | |||||
| #include <memory> | |||||
| #include "tools/converter/optimizer.h" | |||||
| #include "tools/common/graph_util.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class InferQuantParamPass : public GraphPass { | |||||
| public: | |||||
| InferQuantParamPass() {} | |||||
| ~InferQuantParamPass() override = default; | |||||
| STATUS Run(schema::MetaGraphT *graph) override; | |||||
| private: | |||||
| void DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode); | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_INFER_QUANT_PARAM_PASS_H | |||||
| @@ -0,0 +1,87 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <cmath> | |||||
| #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" | |||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #include "tools/common/tensor_util.h" | |||||
| namespace mindspore::lite { | |||||
| STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | |||||
| for (auto &tensor : graph->allTensors) { | |||||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) { | |||||
| continue; | |||||
| } | |||||
| if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && | |||||
| tensor->dataType != TypeId::kNumberTypeUInt8) { | |||||
| continue; | |||||
| } | |||||
| // perlayer | |||||
| if (tensor->quantParams.size() == 1) { | |||||
| auto &quantParam = tensor->quantParams.front(); | |||||
| size_t wShapeSize = GetShapeSize(*(tensor.get())); | |||||
| void *oriWeightData = tensor->data.data(); | |||||
| if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { | |||||
| std::vector<int8_t> qDatas(wShapeSize); | |||||
| auto weightQauntParam = GetTensorQuantParam(tensor); | |||||
| if (tensor->dataType == TypeId::kNumberTypeFloat || | |||||
| tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant | |||||
| auto *weightData = static_cast<float *>(oriWeightData); | |||||
| for (size_t j = 0; j < wShapeSize; j++) { | |||||
| qDatas[j] = quant::QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); | |||||
| } | |||||
| } else { // tflite awareing quant | |||||
| auto *weightData = static_cast<uint8_t *>(oriWeightData); | |||||
| for (size_t j = 0; j < wShapeSize; j++) { | |||||
| qDatas[j] = (int32_t)weightData[j] - 128; | |||||
| } | |||||
| weightQauntParam->zeroPoint -= 128; | |||||
| tensor->quantParams.clear(); | |||||
| tensor->quantParams.emplace_back(weightQauntParam.release()); | |||||
| } | |||||
| tensor->dataType = TypeId::kNumberTypeInt8; | |||||
| ::memcpy(tensor->data.data(), qDatas.data(), wShapeSize); | |||||
| } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { | |||||
| // quant bias data | |||||
| auto bShapeSize = GetShapeSize(*(tensor.get())); | |||||
| std::unique_ptr<int32_t[]> qDatas(new (std::nothrow) int32_t[bShapeSize]); | |||||
| if (qDatas == nullptr) { | |||||
| MS_LOG(ERROR) << "new qDatas failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| void *biasData = tensor->data.data(); | |||||
| auto *rawDatas = static_cast<float *>(biasData); | |||||
| for (size_t i = 0; i < bShapeSize; ++i) { | |||||
| qDatas[i] = (int32_t)std::round(rawDatas[i] / quantParam->scale); | |||||
| } | |||||
| tensor->dataType = TypeId::kNumberTypeInt32; | |||||
| tensor->data.clear(); | |||||
| tensor->data.resize(bShapeSize * sizeof(int32_t)); | |||||
| auto ret = | |||||
| memcpy_s(tensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } else { // pertensor | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::lite | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_TENSOR_QUANT_PASS_H | |||||
| #define LITE_TENSOR_QUANT_PASS_H | |||||
| #include <memory> | |||||
| #include "tools/converter/optimizer.h" | |||||
| #include "tools/common/graph_util.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TensorQuantPass : public GraphPass { | |||||
| public: | |||||
| TensorQuantPass() {} | |||||
| ~TensorQuantPass() override = default; | |||||
| STATUS Run(schema::MetaGraphT *graph) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_TENSOR_QUANT_PASS_H | |||||
| @@ -52,7 +52,7 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||||
| attr->srcT = kNumberTypeInt8; | |||||
| attr->dstT = GetTfliteDataType(out_tensor->type); | attr->dstT = GetTfliteDataType(out_tensor->type); | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | ||||
| @@ -52,7 +52,7 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->srcT = GetTfliteDataType(in_tensor->type); | attr->srcT = GetTfliteDataType(in_tensor->type); | ||||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||||
| attr->dstT = kNumberTypeInt8; | |||||
| op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| } else { | } else { | ||||
| @@ -35,29 +35,10 @@ using std::string; | |||||
| using std::vector; | using std::vector; | ||||
| namespace mindspore::lite::quant { | namespace mindspore::lite::quant { | ||||
| const std::array<schema::PrimitiveType, 7> AwareQuantizer::propagatedOps = { | |||||
| {schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, schema::PrimitiveType_Reshape, | |||||
| schema::PrimitiveType_Squeeze, schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation, | |||||
| schema::PrimitiveType_DetectionPostProcess}}; | |||||
| AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType) : FbQuantizer(graph) {} | AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType) : FbQuantizer(graph) {} | ||||
| STATUS AwareQuantizer::RemoveFakeQuant() { return RET_OK; } | STATUS AwareQuantizer::RemoveFakeQuant() { return RET_OK; } | ||||
| STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph) { | |||||
| MS_ASSERT(subGraph != nullptr); | |||||
| for (const auto &tensor : subGraph->allTensors) { | |||||
| if (!tensor->quantParams.empty()) { | |||||
| continue; | |||||
| } | |||||
| std::unique_ptr<schema::QuantParamT> defaultQuantParam(new QuantParamT()); | |||||
| tensor->quantParams.emplace_back(std::move(defaultQuantParam)); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { return RET_OK; } | |||||
| STATUS AwareQuantizer::GenerateQuantParam() { | STATUS AwareQuantizer::GenerateQuantParam() { | ||||
| auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); | auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); | ||||
| @@ -70,13 +51,13 @@ STATUS AwareQuantizer::GenerateQuantParam() { | |||||
| } | } | ||||
| auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); | auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); | ||||
| if (quantParamCalcer == nullptr) { | if (quantParamCalcer == nullptr) { | ||||
| MS_LOG(INFO) << "Can not find QuantParamCalcer for " << node->name.c_str() | |||||
| << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; | |||||
| MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str() | |||||
| << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; | |||||
| node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE); | node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE); | ||||
| } else { | } else { | ||||
| auto status = quantParamCalcer->Calc(graph, *node); | auto status = quantParamCalcer->Calc(graph, *node); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(INFO) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); | |||||
| MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); | |||||
| node->quantType = schema::QuantType_QUANT_NONE; | node->quantType = schema::QuantType_QUANT_NONE; | ||||
| } else { | } else { | ||||
| node->quantType = schema::QuantType_AwareTraining; | node->quantType = schema::QuantType_AwareTraining; | ||||
| @@ -87,250 +68,65 @@ STATUS AwareQuantizer::GenerateQuantParam() { | |||||
| } | } | ||||
| STATUS AwareQuantizer::DoQuantize() { | STATUS AwareQuantizer::DoQuantize() { | ||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||||
| auto &node = *iter; | |||||
| if (!IsContain(GetInt8OpList(), GetCNodeTType(*node))) { | |||||
| for (auto &tensor : graph->allTensors) { | |||||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (node->quantType != schema::QuantType_AwareTraining) { | |||||
| if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && | |||||
| tensor->dataType != TypeId::kNumberTypeUInt8) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| STATUS status; | |||||
| if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D || | |||||
| GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D || | |||||
| GetCNodeTType(*node) == schema::PrimitiveType_DeConv2D || | |||||
| GetCNodeTType(*node) == schema::PrimitiveType_FullConnection || | |||||
| GetCNodeTType(*node) == schema::PrimitiveType_MatMul) { | |||||
| auto inputIndexes = node->inputIndex; | |||||
| if (inputIndexes.size() < 2) { | |||||
| MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // quant weight | |||||
| auto &weightTensor = graph->allTensors.at(node->inputIndex.at(1)); | |||||
| if (!weightTensor->quantParams.empty() && weightTensor->quantParams.at(0)->inited) { | |||||
| status = QuantConvWeight(graph, node.get()); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantConvWeight failed!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| // quant bias | |||||
| if (inputIndexes.size() == 3) { | |||||
| auto &biasTensor = graph->allTensors.at(node->inputIndex.at(2)); | |||||
| if (!biasTensor->quantParams.empty() && biasTensor->quantParams.at(0)->inited) { | |||||
| status = QuantConvBias(graph, node.get()); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantConvBias failed!"; | |||||
| return RET_ERROR; | |||||
| // perlayer | |||||
| if (tensor->quantParams.size() == 1) { | |||||
| auto &quantParam = tensor->quantParams.front(); | |||||
| size_t wShapeSize = GetShapeSize(*(tensor.get())); | |||||
| void *oriWeightData = tensor->data.data(); | |||||
| if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { | |||||
| vector<int8_t> qDatas(wShapeSize); | |||||
| auto weightQauntParam = GetTensorQuantParam(tensor); | |||||
| if (tensor->dataType == TypeId::kNumberTypeFloat || | |||||
| tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant | |||||
| auto *weightData = static_cast<float *>(oriWeightData); | |||||
| for (size_t j = 0; j < wShapeSize; j++) { | |||||
| qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); | |||||
| } | } | ||||
| } else { // tflite awareing quant | |||||
| auto *weightData = static_cast<uint8_t *>(oriWeightData); | |||||
| for (size_t j = 0; j < wShapeSize; j++) { | |||||
| qDatas[j] = (int32_t)weightData[j] - 128; | |||||
| } | |||||
| weightQauntParam->zeroPoint -= 128; | |||||
| tensor->quantParams.clear(); | |||||
| tensor->quantParams.emplace_back(weightQauntParam.release()); | |||||
| } | } | ||||
| } | |||||
| } else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { | |||||
| status = QuantDetectionPostProcessConstTensor(graph, node.get()); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else if (GetCNodeTType(*node) == schema::PrimitiveType_Add || | |||||
| GetCNodeTType(*node) == schema::PrimitiveType_Scale || | |||||
| GetCNodeTType(*node) == schema::PrimitiveType_Mul) { | |||||
| status = QuantArithmeticConstTensor(graph, node.get()); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantArithmeticConstTensor failed!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| const auto nodeType = GetCNodeTType(*node); | |||||
| auto find = std::find(propagatedOps.begin(), propagatedOps.end(), nodeType); | |||||
| if (find != propagatedOps.end()) { | |||||
| auto inputTensor = graph->allTensors.at(node->inputIndex[0]).get(); | |||||
| auto outputTensor = graph->allTensors.at(node->outputIndex[0]).get(); | |||||
| MS_ASSERT(inputTensor != nullptr); | |||||
| MS_ASSERT(outputTensor != nullptr); | |||||
| outputTensor->dataType = inputTensor->dataType; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS AwareQuantizer::QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) { | |||||
| MS_ASSERT(graph != nullptr); | |||||
| MS_ASSERT(node != nullptr); | |||||
| for (size_t i = 0; i < node->inputIndex.size(); i++) { | |||||
| auto inTensorIdx = node->inputIndex.at(i); | |||||
| MS_ASSERT(graph->allTensors.size() > inTensorIdx); | |||||
| auto &inTensor = graph->allTensors.at(inTensorIdx); | |||||
| MS_ASSERT(inTensor != nullptr); | |||||
| if (!inTensor->data.empty()) { | |||||
| if (inTensor->dataType == TypeId::kNumberTypeInt8) { | |||||
| continue; | |||||
| } | |||||
| if (inTensor->dataType != TypeId::kNumberTypeFloat32 && inTensor->dataType != TypeId::kNumberTypeFloat && | |||||
| inTensor->dataType != TypeId::kNumberTypeUInt8) { | |||||
| MS_LOG(ERROR) << node->name.c_str() << "'s weight data is not float or uint8"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto quantParam = GetTensorQuantParam(inTensor); | |||||
| MS_ASSERT(quantParam != nullptr); | |||||
| MS_ASSERT(quantParam->inited); | |||||
| auto constTensorShapeSize = GetShapeSize(*(inTensor.get())); | |||||
| vector<int8_t> qDatas(constTensorShapeSize); | |||||
| void *inData = inTensor->data.data(); | |||||
| if (inTensor->dataType == TypeId::kNumberTypeFloat || | |||||
| inTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant | |||||
| auto *weightData = static_cast<float *>(inData); | |||||
| for (size_t j = 0; j < constTensorShapeSize; j++) { | |||||
| qDatas[j] = QuantizeData<int8_t>(weightData[j], quantParam.get()); | |||||
| ::memcpy(tensor->data.data(), qDatas.data(), wShapeSize); | |||||
| } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { | |||||
| // quant bias data | |||||
| auto bShapeSize = GetShapeSize(*(tensor.get())); | |||||
| std::unique_ptr<int32_t[]> qDatas(new (std::nothrow) int32_t[bShapeSize]); | |||||
| if (qDatas == nullptr) { | |||||
| MS_LOG(ERROR) << "new qDatas failed"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| } else { // tflite awareing quant | |||||
| auto *weightData = static_cast<uint8_t *>(inData); | |||||
| for (size_t j = 0; j < constTensorShapeSize; j++) { | |||||
| qDatas[j] = (int32_t)weightData[j] - 128; | |||||
| void *biasData = tensor->data.data(); | |||||
| auto *rawDatas = static_cast<float *>(biasData); | |||||
| for (size_t i = 0; i < bShapeSize; ++i) { | |||||
| qDatas[i] = (int32_t)std::round(rawDatas[i] / quantParam->scale); | |||||
| } | |||||
| tensor->dataType = TypeId::kNumberTypeInt32; | |||||
| tensor->data.clear(); | |||||
| tensor->data.resize(bShapeSize * sizeof(int32_t)); | |||||
| auto ret = | |||||
| memcpy_s(tensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| quantParam->zeroPoint -= 128; | |||||
| inTensor->quantParams.clear(); | |||||
| inTensor->quantParams.emplace_back(quantParam.release()); | |||||
| } | } | ||||
| ::memcpy(inTensor->data.data(), qDatas.data(), constTensorShapeSize); | |||||
| inTensor->dataType = TypeId::kNumberTypeInt8; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { | |||||
| MS_ASSERT(subGraph != nullptr); | |||||
| MS_ASSERT(node != nullptr); | |||||
| auto &constTensor = subGraph->allTensors.at(node->inputIndex[2]); | |||||
| MS_ASSERT(constTensor != nullptr); | |||||
| const auto *constData = reinterpret_cast<const float *>(constTensor->data.data()); | |||||
| if (!constTensor->data.empty() && | |||||
| (constTensor->dataType == TypeId::kNumberTypeFloat || constTensor->dataType == TypeId::kNumberTypeFloat32)) { | |||||
| size_t constTensorShapeSize = GetShapeSize(*constTensor); | |||||
| std::unique_ptr<QuantParamT> quantParam = GetTensorQuantParam(constTensor); | |||||
| if (quantParam == nullptr) { | |||||
| MS_LOG(ERROR) << "new QuantParamT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| vector<int8_t> qDatas(constTensorShapeSize); | |||||
| for (size_t j = 0; j < constTensorShapeSize; j++) { | |||||
| float rawData = constData[j]; | |||||
| qDatas[j] = QuantizeData<int8_t>(rawData, quantParam.get()); | |||||
| } | |||||
| ::memcpy(constTensor->data.data(), qDatas.data(), constTensorShapeSize); | |||||
| constTensor->dataType = TypeId::kNumberTypeInt8; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, mindspore::schema::CNodeT *node) { | |||||
| MS_ASSERT(graph != nullptr); | |||||
| MS_ASSERT(node != nullptr); | |||||
| auto inputIndexes = node->inputIndex; | |||||
| MS_ASSERT(inputIndexes.size() >= 3); | |||||
| MS_ASSERT(graph->allTensors.size() > inputIndexes.at(0)); | |||||
| MS_ASSERT(graph->allTensors.size() > inputIndexes.at(1)); | |||||
| MS_ASSERT(graph->allTensors.size() > inputIndexes.at(2)); | |||||
| auto &biasTensor = graph->allTensors.at(inputIndexes.at(2)); | |||||
| MS_ASSERT(biasTensor != nullptr); | |||||
| if (biasTensor->dataType == TypeId::kNumberTypeInt32) { | |||||
| return RET_OK; | |||||
| } | |||||
| if (biasTensor->dataType != TypeId::kNumberTypeFloat && biasTensor->dataType != TypeId::kNumberTypeFloat32) { | |||||
| MS_LOG(ERROR) << "conv " << node->name << "'s bias data is not float"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto &inputTensor = graph->allTensors.at(inputIndexes.at(0)); | |||||
| auto &weightTensor = graph->allTensors.at(inputIndexes.at(1)); | |||||
| MS_ASSERT(inputTensor != nullptr); | |||||
| MS_ASSERT(weightTensor != nullptr); | |||||
| auto inputScale = inputTensor->quantParams.front()->scale; | |||||
| auto weightScale = weightTensor->quantParams.front()->scale; | |||||
| auto scale = inputScale * weightScale; | |||||
| // set bias quant param | |||||
| std::unique_ptr<QuantParamT> biasQuantParam = GetTensorQuantParam(biasTensor); | |||||
| if (biasQuantParam == nullptr) { | |||||
| MS_LOG(ERROR) << "new QuantParamT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| biasQuantParam->inited = true; | |||||
| biasQuantParam->scale = scale; | |||||
| biasQuantParam->zeroPoint = 0; | |||||
| biasQuantParam->numBits = 8; | |||||
| biasQuantParam->narrowRange = false; | |||||
| biasQuantParam->min = 0.0; | |||||
| biasQuantParam->max = 0.0; | |||||
| // quant bias data | |||||
| auto bShapeSize = GetShapeSize(*(biasTensor.get())); | |||||
| std::unique_ptr<int32_t[]> qDatas(new (std::nothrow) int32_t[bShapeSize]); | |||||
| if (qDatas == nullptr) { | |||||
| MS_LOG(ERROR) << "new qDatas failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| void *biasData = biasTensor->data.data(); | |||||
| auto *rawDatas = static_cast<float *>(biasData); | |||||
| for (size_t i = 0; i < bShapeSize; ++i) { | |||||
| qDatas[i] = (int32_t)std::round(rawDatas[i] / scale); | |||||
| } | |||||
| biasTensor->dataType = TypeId::kNumberTypeInt32; | |||||
| biasTensor->data.clear(); | |||||
| biasTensor->data.resize(bShapeSize * sizeof(int32_t)); | |||||
| auto ret = | |||||
| memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { | |||||
| MS_ASSERT(subGraph != nullptr); | |||||
| MS_ASSERT(node != nullptr); | |||||
| MS_ASSERT(node->quantParam.size() == node->inputIndex.size() + node->outputIndex.size()); | |||||
| auto inputIndexes = node->inputIndex; | |||||
| MS_ASSERT(inputIndexes.size() >= 2); | |||||
| MS_ASSERT(subGraph->allTensors.size() > inputIndexes.at(1)); | |||||
| auto &weightTensor = subGraph->allTensors.at(inputIndexes.at(1)); | |||||
| if (weightTensor->dataType == TypeId::kNumberTypeInt8) { | |||||
| return RET_OK; | |||||
| } | |||||
| if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && weightTensor->dataType != TypeId::kNumberTypeFloat && | |||||
| weightTensor->dataType != TypeId::kNumberTypeUInt8) { | |||||
| MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| size_t wShapeSize = GetShapeSize(*(weightTensor.get())); | |||||
| void *oriWeightData = weightTensor->data.data(); | |||||
| MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr); | |||||
| vector<int8_t> qDatas(wShapeSize); | |||||
| // todo support perchannel | |||||
| auto weightQauntParam = GetTensorQuantParam(weightTensor); | |||||
| if (weightTensor->dataType == TypeId::kNumberTypeFloat || | |||||
| weightTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant | |||||
| auto *weightData = static_cast<float *>(oriWeightData); | |||||
| for (size_t j = 0; j < wShapeSize; j++) { | |||||
| qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); | |||||
| } | |||||
| } else { // tflite awareing quant | |||||
| auto *weightData = static_cast<uint8_t *>(oriWeightData); | |||||
| for (size_t j = 0; j < wShapeSize; j++) { | |||||
| qDatas[j] = (int32_t)weightData[j] - 128; | |||||
| } else { // pertensor | |||||
| } | } | ||||
| weightQauntParam->zeroPoint -= 128; | |||||
| weightTensor->quantParams.clear(); | |||||
| weightTensor->quantParams.emplace_back(weightQauntParam.release()); | |||||
| } | } | ||||
| weightTensor->data.resize(wShapeSize * sizeof(uint8_t)); | |||||
| ::memcpy(weightTensor->data.data(), qDatas.data(), wShapeSize); | |||||
| weightTensor->dataType = TypeId::kNumberTypeInt8; | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS AwareQuantizer::DetermineNodeQuantType() { | STATUS AwareQuantizer::DetermineNodeQuantType() { | ||||
| @@ -39,24 +39,6 @@ class AwareQuantizer : public FbQuantizer { | |||||
| STATUS DetermineNodeQuantType() override; | STATUS DetermineNodeQuantType() override; | ||||
| STATUS DoQuantize() override; // override; | STATUS DoQuantize() override; // override; | ||||
| private: | |||||
| // RemoveFakeQuant | |||||
| STATUS SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node); | |||||
| STATUS GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph); | |||||
| STATUS QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node); | |||||
| STATUS QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node); | |||||
| STATUS QuantConvBias(const schema::MetaGraphT *graph, schema::CNodeT *node); | |||||
| STATUS QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node); | |||||
| float inputScale = 0.0f; | |||||
| static const std::array<schema::PrimitiveType, 7> propagatedOps; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite::quant | } // namespace mindspore::lite::quant | ||||
| #endif | #endif | ||||
| @@ -26,6 +26,9 @@ | |||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| static constexpr size_t BIAS_SIZE = 3; | |||||
| static constexpr size_t BIAS_ADD_SIZE = 2; | |||||
| STATUS QuantParamCalcer::ComputeConstQuantParam(const schema::TensorT &tensor, QuantParamT *quantParam) { | STATUS QuantParamCalcer::ComputeConstQuantParam(const schema::TensorT &tensor, QuantParamT *quantParam) { | ||||
| MS_ASSERT(quantParam != nullptr); | MS_ASSERT(quantParam != nullptr); | ||||
| // int32 weight no need to quant | // int32 weight no need to quant | ||||
| @@ -126,6 +129,36 @@ int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ConvCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { | |||||
| auto status = CommonCalcer::Calc(subGraph, node); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(WARNING) << "Call CommonCalcer::Calc failed: " << status; | |||||
| return status; | |||||
| } | |||||
| if (node.inputIndex.size() == BIAS_SIZE) { | |||||
| auto &biasTensor = subGraph->allTensors.at(node.inputIndex.at(BIAS_SIZE - 1)); | |||||
| for (auto &quantParam : biasTensor->quantParams) { | |||||
| quantParam->dstDtype = TypeId::kNumberTypeInt32; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int BiasAddCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { | |||||
| auto status = CommonCalcer::Calc(subGraph, node); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(WARNING) << "Call CommonCalcer::Calc failed: " << status; | |||||
| return status; | |||||
| } | |||||
| if (node.inputIndex.size() == BIAS_ADD_SIZE) { | |||||
| auto &biasTensor = subGraph->allTensors.at(node.inputIndex.at(BIAS_ADD_SIZE - 1)); | |||||
| for (auto &quantParam : biasTensor->quantParams) { | |||||
| quantParam->dstDtype = TypeId::kNumberTypeInt32; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | ||||
| auto status = QuantParamCalcer::Calc(graph, node); | auto status = QuantParamCalcer::Calc(graph, node); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -474,10 +507,10 @@ QuantParamCalcRegister::QuantParamCalcRegister() { | |||||
| _registerMap[schema::PrimitiveType_Activation] = std::make_shared<CalcActivation>(); | _registerMap[schema::PrimitiveType_Activation] = std::make_shared<CalcActivation>(); | ||||
| _registerMap[schema::PrimitiveType_Add] = std::make_shared<CalcAdd>(); | _registerMap[schema::PrimitiveType_Add] = std::make_shared<CalcAdd>(); | ||||
| _registerMap[schema::PrimitiveType_Mul] = commonCalcer; | _registerMap[schema::PrimitiveType_Mul] = commonCalcer; | ||||
| _registerMap[schema::PrimitiveType_Scale] = commonCalcer; | |||||
| _registerMap[schema::PrimitiveType_Conv2D] = commonCalcer; | |||||
| _registerMap[schema::PrimitiveType_DeConv2D] = commonCalcer; | |||||
| _registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer; | |||||
| _registerMap[schema::PrimitiveType_Scale] = std::make_shared<ConvCalcer>(); | |||||
| _registerMap[schema::PrimitiveType_Conv2D] = std::make_shared<ConvCalcer>(); | |||||
| _registerMap[schema::PrimitiveType_DeConv2D] = std::make_shared<ConvCalcer>(); | |||||
| _registerMap[schema::PrimitiveType_DepthwiseConv2D] = std::make_shared<ConvCalcer>(); | |||||
| _registerMap[schema::PrimitiveType_Pooling] = linearCalcer; | _registerMap[schema::PrimitiveType_Pooling] = linearCalcer; | ||||
| _registerMap[schema::PrimitiveType_Resize] = linearCalcer; | _registerMap[schema::PrimitiveType_Resize] = linearCalcer; | ||||
| _registerMap[schema::PrimitiveType_Reshape] = linearCalcer; | _registerMap[schema::PrimitiveType_Reshape] = linearCalcer; | ||||
| @@ -487,11 +520,11 @@ QuantParamCalcRegister::QuantParamCalcRegister() { | |||||
| _registerMap[schema::PrimitiveType_Squeeze] = linearCalcer; | _registerMap[schema::PrimitiveType_Squeeze] = linearCalcer; | ||||
| _registerMap[schema::PrimitiveType_RealDiv] = std::make_shared<CalcRealDiv>(); | _registerMap[schema::PrimitiveType_RealDiv] = std::make_shared<CalcRealDiv>(); | ||||
| _registerMap[schema::PrimitiveType_Reduce] = commonCalcer; | _registerMap[schema::PrimitiveType_Reduce] = commonCalcer; | ||||
| _registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer; | |||||
| _registerMap[schema::PrimitiveType_BiasAdd] = std::make_shared<BiasAddCalcer>(); | |||||
| _registerMap[schema::PrimitiveType_Mean] = linearCalcer; | _registerMap[schema::PrimitiveType_Mean] = linearCalcer; | ||||
| _registerMap[schema::PrimitiveType_Transpose] = linearCalcer; | _registerMap[schema::PrimitiveType_Transpose] = linearCalcer; | ||||
| _registerMap[schema::PrimitiveType_MatMul] = commonCalcer; | |||||
| _registerMap[schema::PrimitiveType_FullConnection] = commonCalcer; | |||||
| _registerMap[schema::PrimitiveType_MatMul] = std::make_shared<ConvCalcer>(); | |||||
| _registerMap[schema::PrimitiveType_FullConnection] = std::make_shared<ConvCalcer>(); | |||||
| _registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer; | _registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer; | ||||
| _registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer; | _registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer; | ||||
| // detection_postprocess op's quant param will not infer only fetch from preNode or postNode | // detection_postprocess op's quant param will not infer only fetch from preNode or postNode | ||||
| @@ -46,6 +46,20 @@ class CommonCalcer : public QuantParamCalcer { | |||||
| int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override; | int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override; | ||||
| }; | }; | ||||
| class ConvCalcer : public CommonCalcer { | |||||
| public: | |||||
| ConvCalcer() = default; | |||||
| ~ConvCalcer() override = default; | |||||
| int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override; | |||||
| }; | |||||
| class BiasAddCalcer : public CommonCalcer { | |||||
| public: | |||||
| BiasAddCalcer() = default; | |||||
| ~BiasAddCalcer() override = default; | |||||
| int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override; | |||||
| }; | |||||
| class LinearCalcer : public QuantParamCalcer { | class LinearCalcer : public QuantParamCalcer { | ||||
| public: | public: | ||||
| LinearCalcer() = default; | LinearCalcer() = default; | ||||
| @@ -564,11 +564,8 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in | |||||
| } | } | ||||
| } | } | ||||
| STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct MaxMin *max_min, | |||||
| std::shared_ptr<PrimitiveC> lite_primitive) { | |||||
| if (!lite_primitive->GetInputQuantParams().empty()) { | |||||
| MS_LOG(DEBUG) << "input quant params not empty"; // multi-input op: like concat | |||||
| } | |||||
| STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, | |||||
| std::shared_ptr<PrimitiveC> lite_primitive, const size_t &index) { | |||||
| schema::QuantParamT quant_param; | schema::QuantParamT quant_param; | ||||
| quant_param.scale = scale; | quant_param.scale = scale; | ||||
| quant_param.zeroPoint = zeropoint; | quant_param.zeroPoint = zeropoint; | ||||
| @@ -577,15 +574,12 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct M | |||||
| quant_param.numBits = bit_num; | quant_param.numBits = bit_num; | ||||
| quant_param.narrowRange = false; | quant_param.narrowRange = false; | ||||
| std::vector<schema::QuantParamT> quant_params = {quant_param}; | std::vector<schema::QuantParamT> quant_params = {quant_param}; | ||||
| lite_primitive->AddInputQuantParam(quant_params); | |||||
| lite_primitive->SetInputQuantParam(index, quant_params); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min, | STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min, | ||||
| std::shared_ptr<PrimitiveC> lite_primitive) { | std::shared_ptr<PrimitiveC> lite_primitive) { | ||||
| if (!lite_primitive->GetOutputQuantParams().empty()) { | |||||
| MS_LOG(DEBUG) << "output quant params not empty"; // multi-output op: like split | |||||
| } | |||||
| schema::QuantParamT quant_param; | schema::QuantParamT quant_param; | ||||
| quant_param.scale = scale; | quant_param.scale = scale; | ||||
| quant_param.zeroPoint = zeropoint; | quant_param.zeroPoint = zeropoint; | ||||
| @@ -593,8 +587,9 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct | |||||
| quant_param.min = max_min->min; | quant_param.min = max_min->min; | ||||
| quant_param.numBits = bit_num; | quant_param.numBits = bit_num; | ||||
| quant_param.narrowRange = false; | quant_param.narrowRange = false; | ||||
| quant_param.inited = true; | |||||
| std::vector<schema::QuantParamT> quant_params = {quant_param}; | std::vector<schema::QuantParamT> quant_params = {quant_param}; | ||||
| lite_primitive->AddOutputQuantParam(quant_params); | |||||
| lite_primitive->SetOutputQuantParam(0, quant_params); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -647,7 +642,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi | |||||
| auto bias_param = std::dynamic_pointer_cast<ParamValueLite>(bias_default_param); | auto bias_param = std::dynamic_pointer_cast<ParamValueLite>(bias_default_param); | ||||
| auto active_weight_quant_params = primitive_c->GetInputQuantParams(); | auto active_weight_quant_params = primitive_c->GetInputQuantParams(); | ||||
| if (active_weight_quant_params.size() != 2) { | |||||
| if (active_weight_quant_params.size() != 3) { | |||||
| MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size(); | MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -714,7 +709,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi | |||||
| double filter_scale = std::abs(raw_datas[i]) / (activate_scale * quanted_bias_abs_limit); | double filter_scale = std::abs(raw_datas[i]) / (activate_scale * quanted_bias_abs_limit); | ||||
| active_weight_quant_params[1][i].scale = filter_scale; | active_weight_quant_params[1][i].scale = filter_scale; | ||||
| active_weight_quant_params[1][i].zeroPoint = 0; | active_weight_quant_params[1][i].zeroPoint = 0; | ||||
| primitive_c->SetInputQuantParam(active_weight_quant_params); | |||||
| primitive_c->SetInputQuantParams(active_weight_quant_params); | |||||
| bias_scale_tmp = std::abs(raw_datas[i]) / quanted_bias_abs_limit; | bias_scale_tmp = std::abs(raw_datas[i]) / quanted_bias_abs_limit; | ||||
| quant_params[i].scale = bias_scale_tmp; | quant_params[i].scale = bias_scale_tmp; | ||||
| MS_LOG(DEBUG) << "new filter scale: " << filter_scale; | MS_LOG(DEBUG) << "new filter scale: " << filter_scale; | ||||
| @@ -726,7 +721,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi | |||||
| auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp); | auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp); | ||||
| quant_datas[i] = quant_data; | quant_datas[i] = quant_data; | ||||
| } | } | ||||
| primitive_c->AddInputQuantParam(quant_params); | |||||
| primitive_c->SetInputQuantParam(2, quant_params); | |||||
| auto ret = memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t)); | auto ret = memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(ERROR) << "memcpy_s failed."; | MS_LOG(ERROR) << "memcpy_s failed."; | ||||
| @@ -834,22 +829,21 @@ STATUS PostTrainingQuantizer::QuantNode() { | |||||
| << " PrimitiveC is null"; | << " PrimitiveC is null"; | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (!input_cnode_primitive_c->GetOutputQuantParams().empty()) { | |||||
| for (auto &quant_param : input_cnode_primitive_c->GetOutputQuantParams()) { | |||||
| primitive_c->AddInputQuantParam(quant_param); | |||||
| } | |||||
| if (input_cnode_primitive_c->IsOutputQuantParamsInited()) { | |||||
| auto quant_param = input_cnode_primitive_c->GetOutputQuantParams().front(); | |||||
| primitive_c->SetInputQuantParam(i - 1, quant_param); | |||||
| } else { | } else { | ||||
| // do input quant | // do input quant | ||||
| double scale = input_scale[cnode]; | double scale = input_scale[cnode]; | ||||
| int32_t zp = input_zero_point[cnode]; | int32_t zp = input_zero_point[cnode]; | ||||
| DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c); | |||||
| DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c, i - 1); | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| // do input quant | // do input quant | ||||
| double scale = input_scale[cnode]; | double scale = input_scale[cnode]; | ||||
| int32_t convInputzeropoint = input_zero_point[cnode]; | int32_t convInputzeropoint = input_zero_point[cnode]; | ||||
| DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c); | |||||
| DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c, 0); | |||||
| // do weight quant | // do weight quant | ||||
| auto weight = cnode->input(2); | auto weight = cnode->input(2); | ||||
| bool perchannel = per_channel_; | bool perchannel = per_channel_; | ||||
| @@ -106,7 +106,8 @@ class PostTrainingQuantizer : public Quantizer { | |||||
| STATUS QuantNode(); | STATUS QuantNode(); | ||||
| STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>); | |||||
| STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, | |||||
| std::shared_ptr<PrimitiveC> lite_primitive, const size_t &index); | |||||
| STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>); | STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>); | ||||
| STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel); | STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel); | ||||
| @@ -246,7 +246,7 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl | |||||
| MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; | MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| quantParam->inited = true; | |||||
| quantParam->inited = false; | |||||
| quantParam->min = mMin; | quantParam->min = mMin; | ||||
| quantParam->max = mMax; | quantParam->max = mMax; | ||||
| quantParam->scale = 0.0f; | quantParam->scale = 0.0f; | ||||
| @@ -39,6 +39,7 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| namespace quant { | namespace quant { | ||||
| static constexpr size_t UINT8_QUANTIZATION = 8; | static constexpr size_t UINT8_QUANTIZATION = 8; | ||||
| static constexpr size_t WEIGHT_INDEX = 1; | |||||
| /** | /** | ||||
| * 1. when op's weight size > mWeightSize just skip | * 1. when op's weight size > mWeightSize just skip | ||||
| @@ -225,16 +226,16 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||||
| } | } | ||||
| variance_dequant = std::sqrt(variance_dequant / one_filter_size); | variance_dequant = std::sqrt(variance_dequant / one_filter_size); | ||||
| variance_raw = std::sqrt(variance_raw / one_filter_size); | variance_raw = std::sqrt(variance_raw / one_filter_size); | ||||
| quant_param.var_corr = 1; | |||||
| quant_param.varCorr = 1; | |||||
| if (variance_raw != 0 && variance_dequant != 0) { | if (variance_raw != 0 && variance_dequant != 0) { | ||||
| auto temp_var_corr = variance_raw / variance_dequant; | auto temp_var_corr = variance_raw / variance_dequant; | ||||
| if (temp_var_corr > 0 && temp_var_corr < 10) { | if (temp_var_corr > 0 && temp_var_corr < 10) { | ||||
| quant_param.var_corr = temp_var_corr; | |||||
| quant_param.varCorr = temp_var_corr; | |||||
| } else { | } else { | ||||
| MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr; | MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr; | ||||
| } | } | ||||
| } | } | ||||
| quant_param.mean_corr = average_raw - average_dequant * quant_param.var_corr; | |||||
| quant_param.meanCorr = average_raw - average_dequant * quant_param.varCorr; | |||||
| } | } | ||||
| quant_params.emplace_back(quant_param); | quant_params.emplace_back(quant_param); | ||||
| } | } | ||||
| @@ -282,7 +283,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||||
| MS_LOG(ERROR) << "quant_params empty"; | MS_LOG(ERROR) << "quant_params empty"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| primitive_c->AddInputQuantParam(quant_params); | |||||
| primitive_c->SetInputQuantParam(WEIGHT_INDEX, quant_params); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -101,8 +101,6 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::vector<schema::QuantParamT> quant_params; | |||||
| primitive_c->AddInputQuantParam(quant_params); | |||||
| auto status = RET_ERROR; | auto status = RET_ERROR; | ||||
| if (type_id == kNumberTypeInt8) { | if (type_id == kNumberTypeInt8) { | ||||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | ||||
| @@ -143,9 +141,9 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||||
| ParameterPtr param_node = nullptr; | ParameterPtr param_node = nullptr; | ||||
| for (size_t i = 1; i < node->size(); i++) { | for (size_t i = 1; i < node->size(); i++) { | ||||
| auto inputNode = node->input(i); | auto inputNode = node->input(i); | ||||
| if (inputNode->isa<Parameter>() == true) { | |||||
| if (inputNode->isa<Parameter>()) { | |||||
| param_node = inputNode->cast<ParameterPtr>(); | param_node = inputNode->cast<ParameterPtr>(); | ||||
| if ((param_node != nullptr) && (param_node->has_default() == true)) { | |||||
| if ((param_node != nullptr) && param_node->has_default()) { | |||||
| param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | ||||
| if ((param_value == nullptr) || (param_value->tensor_size() == 0) || | if ((param_value == nullptr) || (param_value->tensor_size() == 0) || | ||||
| (param_value->tensor_addr() == nullptr) || | (param_value->tensor_addr() == nullptr) || | ||||
| @@ -169,8 +167,6 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::vector<schema::QuantParamT> quant_params; | |||||
| primitive_c->AddInputQuantParam(quant_params); | |||||
| auto status = RET_ERROR; | auto status = RET_ERROR; | ||||
| if (type_id == kNumberTypeInt8) { | if (type_id == kNumberTypeInt8) { | ||||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | ||||
| @@ -619,7 +619,7 @@ STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int3 | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| template<typename T> | |||||
| template <typename T> | |||||
| static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | ||||
| int32_t filterH, int32_t filterW) { | int32_t filterH, int32_t filterW) { | ||||
| MS_ASSERT(tensor != nullptr); | MS_ASSERT(tensor != nullptr); | ||||
| @@ -628,7 +628,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||||
| MS_LOG(ERROR) << "Dim size invalid"; | MS_LOG(ERROR) << "Dim size invalid"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::unique_ptr<T[]> buf(new(std::nothrow) T[count]); | |||||
| std::unique_ptr<T[]> buf(new (std::nothrow) T[count]); | |||||
| if (buf == nullptr) { | if (buf == nullptr) { | ||||
| MS_LOG(ERROR) << "new buf failed"; | MS_LOG(ERROR) << "new buf failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -653,18 +653,17 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||||
| p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); | p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); | ||||
| if (type == kCHWK2HWCK) { | if (type == kCHWK2HWCK) { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| } else if (type == kCHWK2KHWC) { | } else if (type == kCHWK2KHWC) { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| } | } | ||||
| *p2Buff = *p1Buff; | *p2Buff = *p1Buff; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| case kKHWC2HWCK: { | case kKHWC2HWCK: { | ||||
| for (int k = 0; k < filterK; ++k) { | for (int k = 0; k < filterK; ++k) { | ||||
| for (int h = 0; h < filterH; ++h) { | for (int h = 0; h < filterH; ++h) { | ||||
| @@ -677,8 +676,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| case kKCHW2HWCK: | case kKCHW2HWCK: | ||||
| case kKCHW2CKHW: | case kKCHW2CKHW: | ||||
| case kKCHW2KHWC: | case kKCHW2KHWC: | ||||
| @@ -690,24 +688,23 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||||
| p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | ||||
| if (type == kKCHW2HWCK) { | if (type == kKCHW2HWCK) { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| } else if (type == kKCHW2KHWC) { | } else if (type == kKCHW2KHWC) { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| } else if (type == kKCHW2CKHW) { | } else if (type == kKCHW2CKHW) { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| } else { | } else { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||||
| } | } | ||||
| *p2Buff = *p1Buff; | *p2Buff = *p1Buff; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| case kCKHW2HWCK: | case kCKHW2HWCK: | ||||
| case kCKHW2KHWC: | case kCKHW2KHWC: | ||||
| case kCKHW2HWKC: { | case kCKHW2HWKC: { | ||||
| @@ -718,21 +715,20 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||||
| p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | ||||
| if (type == kCKHW2HWCK) { | if (type == kCKHW2HWCK) { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| } else if (type == kCKHW2KHWC) { | } else if (type == kCKHW2KHWC) { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| } else { | } else { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||||
| } | } | ||||
| *p2Buff = *p1Buff; | *p2Buff = *p1Buff; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| case kHWCK2KCHW: | case kHWCK2KCHW: | ||||
| case kHWCK2CKHW: { | case kHWCK2CKHW: { | ||||
| for (int h = 0; h < filterH; ++h) { | for (int h = 0; h < filterH; ++h) { | ||||
| @@ -742,18 +738,17 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | ||||
| if (type == kHWCK2KCHW) { | if (type == kHWCK2KCHW) { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||||
| } else { | } else { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| } | } | ||||
| *p2Buff = *p1Buff; | *p2Buff = *p1Buff; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| case kHWKC2KCHW: | case kHWKC2KCHW: | ||||
| case kHWKC2CKHW: { | case kHWKC2CKHW: { | ||||
| for (int h = 0; h < filterH; ++h) { | for (int h = 0; h < filterH; ++h) { | ||||
| @@ -763,18 +758,17 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | ||||
| if (type == kHWKC2KCHW) { | if (type == kHWKC2KCHW) { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||||
| } else { | } else { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| } | } | ||||
| *p2Buff = *p1Buff; | *p2Buff = *p1Buff; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| case kNHWC2HWCK: | case kNHWC2HWCK: | ||||
| case kNHWC2KCHW: | case kNHWC2KCHW: | ||||
| case kNHWC2CKHW: { | case kNHWC2CKHW: { | ||||
| @@ -785,21 +779,20 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | ||||
| if (type == kNHWC2HWCK) { | if (type == kNHWC2HWCK) { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| } else if (type == kNHWC2CKHW) { | } else if (type == kNHWC2CKHW) { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| } else { | } else { | ||||
| p2Buff = | p2Buff = | ||||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||||
| } | } | ||||
| *p2Buff = *p1Buff; | *p2Buff = *p1Buff; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| case kKHWC2CHWK: { | case kKHWC2CHWK: { | ||||
| for (int k = 0; k < filterK; ++k) { | for (int k = 0; k < filterK; ++k) { | ||||
| for (int h = 0; h < filterH; ++h) { | for (int h = 0; h < filterH; ++h) { | ||||
| @@ -812,8 +805,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| default: { | default: { | ||||
| MS_LOG(ERROR) << "Unsupported transFilterType: " << type; | MS_LOG(ERROR) << "Unsupported transFilterType: " << type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -828,7 +820,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| template<typename T> | |||||
| template <typename T> | |||||
| static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type) { | static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type) { | ||||
| MS_ASSERT(tensor != nullptr); | MS_ASSERT(tensor != nullptr); | ||||
| auto oriDims = tensor->tensor_shape(); | auto oriDims = tensor->tensor_shape(); | ||||
| @@ -882,6 +874,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC); | status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC); | status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kKCHW2KHWC); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -894,6 +888,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC); | status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC); | status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kCKHW2KHWC); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -906,18 +902,20 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC); | status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC); | status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kCHWK2KHWC); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| break; | break; | ||||
| case schema::Format::Format_KHWC:return RET_OK; | |||||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||||
| << EnumNameFormat(dst_format); | |||||
| case schema::Format::Format_KHWC: | |||||
| return RET_OK; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| case schema::Format::Format_HWCK: { | case schema::Format::Format_HWCK: { | ||||
| switch (src_format) { | switch (src_format) { | ||||
| case schema::Format::Format_KCHW: | case schema::Format::Format_KCHW: | ||||
| @@ -927,6 +925,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK); | status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK); | status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kKCHW2HWCK); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -939,6 +939,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK); | status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK); | status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kKHWC2HWCK); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -951,6 +953,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK); | status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK); | status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kCKHW2HWCK); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -963,21 +967,24 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK); | status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK); | status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kCHWK2HWCK); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return lite::RET_ERROR; | return lite::RET_ERROR; | ||||
| } | } | ||||
| break; | break; | ||||
| case schema::Format::Format_HWCK:return RET_OK; | |||||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||||
| << EnumNameFormat(dst_format); | |||||
| case schema::Format::Format_HWCK: | |||||
| return RET_OK; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| case schema::Format::Format_KCHW: { | case schema::Format::Format_KCHW: { | ||||
| switch (src_format) { | switch (src_format) { | ||||
| case schema::Format::Format_KCHW:return RET_OK; | |||||
| case schema::Format::Format_KCHW: | |||||
| return RET_OK; | |||||
| case schema::Format::Format_HWCK: | case schema::Format::Format_HWCK: | ||||
| if (data_type == kNumberTypeFloat32) { | if (data_type == kNumberTypeFloat32) { | ||||
| status = TransFilterFormat<float>(tensor, kHWCK2KCHW); | status = TransFilterFormat<float>(tensor, kHWCK2KCHW); | ||||
| @@ -985,6 +992,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW); | status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW); | status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kHWCK2KCHW); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -997,6 +1006,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW); | status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW); | status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kHWCK2KCHW); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -1009,6 +1020,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW); | status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW); | status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kKHWC2KCHW); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -1021,6 +1034,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW); | status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW); | status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kCKHW2KCHW); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -1033,17 +1048,18 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW); | status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW); | status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kCKHW2KCHW); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| break; | break; | ||||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||||
| << EnumNameFormat(dst_format); | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| case schema::Format::Format_CKHW: { | case schema::Format::Format_CKHW: { | ||||
| switch (src_format) { | switch (src_format) { | ||||
| case schema::Format::Format_HWCK: | case schema::Format::Format_HWCK: | ||||
| @@ -1053,6 +1069,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW); | status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW); | status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kHWCK2CKHW); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -1065,6 +1083,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW); | status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW); | status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kHWKC2CKHW); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -1077,20 +1097,22 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for | |||||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW); | status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW); | ||||
| } else if (data_type == kNumberTypeInt8) { | } else if (data_type == kNumberTypeInt8) { | ||||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW); | status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW); | ||||
| } else if (data_type == kNumberTypeFloat16) { | |||||
| status = TransFilterFormat<float16>(tensor, kKCHW2CKHW); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| break; | break; | ||||
| case schema::Format::Format_CKHW:return RET_OK; | |||||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||||
| << EnumNameFormat(dst_format); | |||||
| case schema::Format::Format_CKHW: | |||||
| return RET_OK; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||||
| << EnumNameFormat(dst_format); | |||||
| } break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -155,8 +155,8 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons | |||||
| rmatmul_quant_params.pop_back(); | rmatmul_quant_params.pop_back(); | ||||
| // no bias quantParams | // no bias quantParams | ||||
| rmatmul_quant_params.emplace_back(jointed_quant_params); | rmatmul_quant_params.emplace_back(jointed_quant_params); | ||||
| matmul_cvalue->SetInputQuantParam(rmatmul_quant_params); | |||||
| matmul_cvalue->SetOutputQuantParam(fc_prim->GetOutputQuantParams()); | |||||
| matmul_cvalue->SetInputQuantParams(rmatmul_quant_params); | |||||
| matmul_cvalue->SetOutputQuantParams(fc_prim->GetOutputQuantParams()); | |||||
| auto matmul_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(matmul_cvalue)); | auto matmul_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(matmul_cvalue)); | ||||
| std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input}; | std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input}; | ||||