| @@ -51,7 +51,7 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i | |||||
| } | } | ||||
| std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors, | std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors, | ||||
| TypeId data_type) { | |||||
| TypeId data_type, bool need_restore) { | |||||
| std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data; | std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data; | ||||
| if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) { | if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) { | ||||
| for (auto weight_tensor : in_tensors) { | for (auto weight_tensor : in_tensors) { | ||||
| @@ -59,16 +59,21 @@ std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const s | |||||
| auto *restore_data = weight_tensor->data_c(); | auto *restore_data = weight_tensor->data_c(); | ||||
| auto restore_type = weight_tensor->data_type(); | auto restore_type = weight_tensor->data_type(); | ||||
| bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && | bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && | ||||
| restore_data != nullptr; | |||||
| restore_data != nullptr && | |||||
| (restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16); | |||||
| if (dequant_flag) { | if (dequant_flag) { | ||||
| auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor); | auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor); | ||||
| if (dequant_weight == nullptr) { | if (dequant_weight == nullptr) { | ||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | MS_LOG(ERROR) << "dequant data is nullptr."; | ||||
| return tensor_origin_data; | return tensor_origin_data; | ||||
| } | } | ||||
| if (need_restore) { | |||||
| tensor_origin_data[weight_tensor] = {restore_type, restore_data}; | |||||
| } else { | |||||
| weight_tensor->FreeData(); | |||||
| } | |||||
| weight_tensor->set_data(dequant_weight); | weight_tensor->set_data(dequant_weight); | ||||
| weight_tensor->set_data_type(kNumberTypeFloat32); | weight_tensor->set_data_type(kNumberTypeFloat32); | ||||
| tensor_origin_data[weight_tensor] = {restore_type, restore_data}; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -34,7 +34,7 @@ class DequantUtil { | |||||
| static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); | static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); | ||||
| static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors, | static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors, | ||||
| TypeId data_type); | |||||
| TypeId data_type, bool need_restore = true); | |||||
| static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map); | static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map); | ||||
| @@ -79,7 +79,7 @@ class DequantUtil { | |||||
| auto var_corr = param.var_corr; | auto var_corr = param.var_corr; | ||||
| auto mean_corr = param.mean_corr; | auto mean_corr = param.mean_corr; | ||||
| if (var_corr < 0 || var_corr > 10) { | if (var_corr < 0 || var_corr > 10) { | ||||
| MS_LOG(WARNING) << "unexpeted var_corr: " << var_corr; | |||||
| MS_LOG(WARNING) << "unexpected var_corr: " << var_corr; | |||||
| var_corr = 1; | var_corr = 1; | ||||
| } | } | ||||
| for (size_t j = 0; j < per_channel_size; j++) { | for (size_t j = 0; j < per_channel_size; j++) { | ||||
| @@ -38,10 +38,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| static std::vector<schema::PrimitiveType> packed_op = { | |||||
| schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DepthwiseConv2D, | |||||
| schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_MatMul}; | |||||
| // this method will not check whether tensor_idx is a weight tensor index, caller should ensure this. | // this method will not check whether tensor_idx is a weight tensor index, caller should ensure this. | ||||
| static bool WeightTensorNeedCopy(const lite::Model *model, const uint32_t tensor_idx) { | static bool WeightTensorNeedCopy(const lite::Model *model, const uint32_t tensor_idx) { | ||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| @@ -92,8 +88,13 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde | |||||
| lite::Tensor *dst_tensor) { | lite::Tensor *dst_tensor) { | ||||
| MS_ASSERT(src_tensor != nullptr); | MS_ASSERT(src_tensor != nullptr); | ||||
| MS_ASSERT(dst_tensor != nullptr); | MS_ASSERT(dst_tensor != nullptr); | ||||
| auto NeedUnPack = [&src_tensor, &dst_tensor]() -> bool { | |||||
| auto data_type = src_tensor->dataType(); | |||||
| int pack_size = src_tensor->data()->size(); | |||||
| int org_size = dst_tensor->Size(); | |||||
| return (pack_size != org_size) && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16); | |||||
| }; | |||||
| auto src_category = TensorCategory(src_tensor); | auto src_category = TensorCategory(src_tensor); | ||||
| auto data_type = src_tensor->dataType(); | |||||
| if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) && | if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) && | ||||
| src_tensor->data() != nullptr && src_tensor->data()->size() > 0) { | src_tensor->data() != nullptr && src_tensor->data()->size() > 0) { | ||||
| if (src_tensor->dataType() == kObjectTypeTensorType) { | if (src_tensor->dataType() == kObjectTypeTensorType) { | ||||
| @@ -112,18 +113,20 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde | |||||
| MS_LOG(ERROR) << "Data from tensor is nullptr"; | MS_LOG(ERROR) << "Data from tensor is nullptr"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size()); | |||||
| if (NeedUnPack()) { | |||||
| DequantUtil::UnPackToInt(src_tensor, dst_data); | |||||
| } else { | |||||
| memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size()); | |||||
| } | |||||
| copyed_tensor_idxes_.emplace_back(tensor_index); | copyed_tensor_idxes_.emplace_back(tensor_index); | ||||
| } else { | } else { | ||||
| int pack_size = src_tensor->data()->size(); | |||||
| int org_size = dst_tensor->Size(); | |||||
| if (pack_size != org_size && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16)) { | |||||
| auto ret = dst_tensor->MallocData(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Malloc data for tensor failed "; | |||||
| return RET_ERROR; | |||||
| if (NeedUnPack()) { | |||||
| auto dst_data = dst_tensor->MutableData(); | |||||
| if (dst_data == nullptr) { | |||||
| MS_LOG(ERROR) << "Data from tensor is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | } | ||||
| DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData()); | |||||
| DequantUtil::UnPackToInt(src_tensor, dst_data); | |||||
| copyed_tensor_idxes_.emplace_back(tensor_index); | copyed_tensor_idxes_.emplace_back(tensor_index); | ||||
| } else { | } else { | ||||
| dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data())); | dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data())); | ||||
| @@ -713,12 +716,12 @@ int LiteSession::InitGPURuntime() { | |||||
| session::LiteSession *session::LiteSession::CreateSession(const lite::Context *context) { | session::LiteSession *session::LiteSession::CreateSession(const lite::Context *context) { | ||||
| auto session = new (std::nothrow) lite::LiteSession(); | auto session = new (std::nothrow) lite::LiteSession(); | ||||
| if (session == nullptr) { | if (session == nullptr) { | ||||
| MS_LOG(ERROR) << "create sesssion failed"; | |||||
| MS_LOG(ERROR) << "create session failed"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto ret = session->Init(context); | auto ret = session->Init(context); | ||||
| if (ret != mindspore::lite::RET_OK) { | if (ret != mindspore::lite::RET_OK) { | ||||
| MS_LOG(ERROR) << "init sesssion failed"; | |||||
| MS_LOG(ERROR) << "init session failed"; | |||||
| delete session; | delete session; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -729,7 +732,7 @@ session::LiteSession *session::LiteSession::CreateSession(const char *model_buf, | |||||
| const lite::Context *context) { | const lite::Context *context) { | ||||
| auto *session = LiteSession::CreateSession(context); | auto *session = LiteSession::CreateSession(context); | ||||
| if (session == nullptr) { | if (session == nullptr) { | ||||
| MS_LOG(ERROR) << "Create sesssion failed"; | |||||
| MS_LOG(ERROR) << "Create session failed"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto *model = lite::ImportFromBuffer(model_buf, size, true); | auto *model = lite::ImportFromBuffer(model_buf, size, true); | ||||
| @@ -107,8 +107,10 @@ int LstmCPUKernel::InitWeightBias() { | |||||
| } | } | ||||
| memcpy(weight_h_ptr_, weight_h->MutableData(), weight_h->ElementsNum() * sizeof(float)); | memcpy(weight_h_ptr_, weight_h->MutableData(), weight_h->ElementsNum() * sizeof(float)); | ||||
| std::vector<int> w_shape = weight_i->shape(); | |||||
| auto hidden_size = w_shape.at(1) / 4; | |||||
| // init bias | // init bias | ||||
| int bias_num = lstm_parm_->bidirectional_ ? 2 * 4 * lstm_parm_->hidden_size_ : 4 * lstm_parm_->hidden_size_; | |||||
| int bias_num = lstm_parm_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size; | |||||
| bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float))); | bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float))); | ||||
| if (bias_ptr_ == nullptr) { | if (bias_ptr_ == nullptr) { | ||||
| MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error."; | MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error."; | ||||
| @@ -116,13 +118,13 @@ int LstmCPUKernel::InitWeightBias() { | |||||
| } | } | ||||
| auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->MutableData()); | auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->MutableData()); | ||||
| const int state_bias_offset = 4 * lstm_parm_->hidden_size_; | |||||
| const int state_bias_offset = 4 * hidden_size; | |||||
| for (int i = 0; i < state_bias_offset; i++) { | for (int i = 0; i < state_bias_offset; i++) { | ||||
| bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset]; | bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset]; | ||||
| } | } | ||||
| if (lstm_parm_->bidirectional_) { | if (lstm_parm_->bidirectional_) { | ||||
| bias_data += 4 * lstm_parm_->hidden_size_ * 2; | |||||
| auto backward_bias = bias_ptr_ + 4 * lstm_parm_->hidden_size_; | |||||
| bias_data += 4 * hidden_size * 2; | |||||
| auto backward_bias = bias_ptr_ + 4 * hidden_size; | |||||
| for (int i = 0; i < state_bias_offset; i++) { | for (int i = 0; i < state_bias_offset; i++) { | ||||
| backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset]; | backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset]; | ||||
| } | } | ||||
| @@ -131,6 +133,14 @@ int LstmCPUKernel::InitWeightBias() { | |||||
| } | } | ||||
| int LstmCPUKernel::Init() { | int LstmCPUKernel::Init() { | ||||
| FreeTmpBuffer(); | |||||
| auto ret = InitWeightBias(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error."; | |||||
| FreeTmpBuffer(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!InferShapeDone()) { | if (!InferShapeDone()) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -138,20 +148,12 @@ int LstmCPUKernel::Init() { | |||||
| } | } | ||||
| int LstmCPUKernel::ReSize() { | int LstmCPUKernel::ReSize() { | ||||
| FreeTmpBuffer(); | |||||
| auto ret = InitParam(); | auto ret = InitParam(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "LstmCPUKernel InitParam error."; | MS_LOG(ERROR) << "LstmCPUKernel InitParam error."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| ret = InitWeightBias(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error."; | |||||
| FreeTmpBuffer(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| ret = InitBuffer(); | ret = InitBuffer(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "LstmCPUKernel InitBuffer error."; | MS_LOG(ERROR) << "LstmCPUKernel InitBuffer error."; | ||||
| @@ -184,6 +184,13 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||||
| const Model::Node *node) { | const Model::Node *node) { | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors); | TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors); | ||||
| bool need_restore = true; | |||||
| if (primitive->quant_type() == schema::QuantType_WeightQuant) { | |||||
| data_type = kNumberTypeFloat32; | |||||
| } | |||||
| if (!IsContain(packed_op, (schema::PrimitiveType)primitive->Type())) { | |||||
| need_restore = false; | |||||
| } | |||||
| kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())}; | kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())}; | ||||
| #if SUPPORT_GPU | #if SUPPORT_GPU | ||||
| if (context_->IsGpuEnabled()) { | if (context_->IsGpuEnabled()) { | ||||
| @@ -216,7 +223,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||||
| if (mindspore::lite::IsSupportFloat16() && | if (mindspore::lite::IsSupportFloat16() && | ||||
| ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { | ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { | ||||
| kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; | kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; | ||||
| auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type); | |||||
| auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type, need_restore); | |||||
| auto *kernel = | auto *kernel = | ||||
| KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); | KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); | ||||
| DequantUtil::RestoreTensorData(tensor_origin_data_map); | DequantUtil::RestoreTensorData(tensor_origin_data_map); | ||||
| @@ -230,7 +237,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||||
| MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; | MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; | ||||
| desc.data_type = kNumberTypeFloat32; | desc.data_type = kNumberTypeFloat32; | ||||
| } | } | ||||
| auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type); | |||||
| auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type, need_restore); | |||||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | ||||
| DequantUtil::RestoreTensorData(tensor_origin_data_map); | DequantUtil::RestoreTensorData(tensor_origin_data_map); | ||||
| if (kernel != nullptr) { | if (kernel != nullptr) { | ||||
| @@ -26,6 +26,12 @@ | |||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| static std::vector<schema::PrimitiveType> packed_op = { | |||||
| schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, | |||||
| schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, | |||||
| schema::PrimitiveType_MatMul, schema::PrimitiveType_Lstm}; | |||||
| class Scheduler { | class Scheduler { | ||||
| public: | public: | ||||
| Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors) | Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors) | ||||
| @@ -253,11 +253,11 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) { | |||||
| } | } | ||||
| auto status = RET_ERROR; | auto status = RET_ERROR; | ||||
| if (type_id_ == kNumberTypeInt8) { | if (type_id_ == kNumberTypeInt8) { | ||||
| status = | |||||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||||
| false, 1); | |||||
| } else if (type_id_ == kNumberTypeInt16) { | } else if (type_id_ == kNumberTypeInt16) { | ||||
| status = | |||||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||||
| false, 1); | |||||
| } | } | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | MS_LOG(ERROR) << "QuantFilter failed : " << status; | ||||
| @@ -316,11 +316,11 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) { | |||||
| } | } | ||||
| auto status = RET_ERROR; | auto status = RET_ERROR; | ||||
| if (type_id_ == kNumberTypeInt8) { | if (type_id_ == kNumberTypeInt8) { | ||||
| status = | |||||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||||
| false, 3); | |||||
| } else if (type_id_ == kNumberTypeInt16) { | } else if (type_id_ == kNumberTypeInt16) { | ||||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | ||||
| false); | |||||
| false, 3); | |||||
| } | } | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | MS_LOG(ERROR) << "QuantFilter failed : " << status; | ||||
| @@ -340,10 +340,10 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) { | |||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | ||||
| MS_ASSERT(primitive_c != nullptr); | MS_ASSERT(primitive_c != nullptr); | ||||
| auto weight_h = cnode->input(1); | |||||
| auto first_input = cnode->input(1); | |||||
| ParameterPtr param_node; | ParameterPtr param_node; | ||||
| ParamValueLitePtr param_value; | ParamValueLitePtr param_value; | ||||
| GetLiteParameter(weight_h, ¶m_node, ¶m_value); | |||||
| GetLiteParameter(first_input, ¶m_node, ¶m_value); | |||||
| if (param_node == nullptr || param_value == nullptr || param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | if (param_node == nullptr || param_value == nullptr || param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | ||||
| MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight"; | MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight"; | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -358,10 +358,10 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) { | |||||
| auto status = RET_ERROR; | auto status = RET_ERROR; | ||||
| if (type_id_ == kNumberTypeInt8) { | if (type_id_ == kNumberTypeInt8) { | ||||
| status = | status = | ||||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, 0); | |||||
| } else if (type_id_ == kNumberTypeInt16) { | } else if (type_id_ == kNumberTypeInt16) { | ||||
| status = | status = | ||||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, 0); | |||||
| } | } | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | MS_LOG(ERROR) << "QuantFilter failed : " << status; | ||||
| @@ -510,7 +510,7 @@ STATUS WeightQuantizer::RunFp32Graph(FuncGraphPtr func_graph) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { | |||||
| STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { | |||||
| // 0.2 Parse input calib files | // 0.2 Parse input calib files | ||||
| auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); | auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -652,7 +652,7 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { | |||||
| delete quant_sm.model; | delete quant_sm.model; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // 3. compare betwen quant and fp32 | |||||
| // 3. compare between quant and fp32 | |||||
| auto quant_outputs = quant_session->GetOutputs(); | auto quant_outputs = quant_session->GetOutputs(); | ||||
| mean_error += CompareOutputData<float>(fp32_output_tensors_[i], quant_outputs); | mean_error += CompareOutputData<float>(fp32_output_tensors_[i], quant_outputs); | ||||
| } // end_for: calib data loop | } // end_for: calib data loop | ||||
| @@ -690,8 +690,8 @@ STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) { | |||||
| for (auto &cnode : func_graph->GetOrderedCnodes()) { | for (auto &cnode : func_graph->GetOrderedCnodes()) { | ||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | ||||
| if (primitive_c == nullptr) { | if (primitive_c == nullptr) { | ||||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||||
| return RET_ERROR; | |||||
| MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive_c is nullptr"; | |||||
| continue; | |||||
| } | } | ||||
| auto op_name = cnode->fullname_with_scope(); | auto op_name = cnode->fullname_with_scope(); | ||||
| auto op_type = (schema::PrimitiveType)primitive_c->Type(); | auto op_type = (schema::PrimitiveType)primitive_c->Type(); | ||||
| @@ -744,7 +744,7 @@ STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { | |||||
| quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); | quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); | ||||
| type_id_ = kNumberTypeInt8; | type_id_ = kNumberTypeInt8; | ||||
| MS_LOG(INFO) << "Do mixed bit quantization"; | MS_LOG(INFO) << "Do mixed bit quantization"; | ||||
| return DoMiexedQuant(func_graph); | |||||
| return DoMixedQuant(func_graph); | |||||
| } | } | ||||
| return DoFixedQuant(func_graph); | return DoFixedQuant(func_graph); | ||||
| @@ -62,7 +62,7 @@ class WeightQuantizer : public Quantizer { | |||||
| std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...] | std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...] | ||||
| std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_; | std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_; | ||||
| STATUS DoMiexedQuant(FuncGraphPtr); | |||||
| STATUS DoMixedQuant(FuncGraphPtr); | |||||
| STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c); | STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c); | ||||
| STATUS DoFixedQuant(FuncGraphPtr); | STATUS DoFixedQuant(FuncGraphPtr); | ||||
| STATUS RunFp32Graph(FuncGraphPtr); | STATUS RunFp32Graph(FuncGraphPtr); | ||||