From 8f495b7aa8fc39ea364cb570417a08d2a70f15de Mon Sep 17 00:00:00 2001 From: xutianchun Date: Fri, 19 Mar 2021 14:54:47 +0800 Subject: [PATCH] fix LSTM quant bug and fix ToD: check input param --- mindspore/lite/src/train/transfer_session.cc | 12 ++++++++++++ .../tools/converter/quantizer/weight_quantizer.cc | 4 ++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/mindspore/lite/src/train/transfer_session.cc b/mindspore/lite/src/train/transfer_session.cc index 900bb1ff3b..71fbe8c809 100644 --- a/mindspore/lite/src/train/transfer_session.cc +++ b/mindspore/lite/src/train/transfer_session.cc @@ -159,6 +159,18 @@ session::TrainSession *session::TrainSession::CreateTransferSession(const char * size_t size_backbone, const char *model_buf_head, size_t size_head, lite::Context *context, bool train_mode) { + auto ValidModelSize = [](size_t size) -> bool { + constexpr size_t MaxModelSize = 1024 * 1024 * 1024ULL; // 1G B + return size < MaxModelSize && size > 0; + }; + if (!ValidModelSize(size_backbone)) { + MS_LOG(ERROR) << "size_backbone too large: " << size_backbone; + return nullptr; + } + if (!ValidModelSize(size_head)) { + MS_LOG(ERROR) << "size_head too large: " << size_head; + return nullptr; + } auto session = new (std::nothrow) lite::TransferSession(model_buf_backbone, size_backbone, context); if (session == nullptr) { MS_LOG(ERROR) << "create transfer session failed"; diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index 6eaf64a77f..68706d6df5 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -275,8 +275,8 @@ STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const Pr ParamValueLitePtr param_value; GetLiteParameter(weight_i, ¶m_node, ¶m_value); if (param_node == nullptr || param_value == nullptr) { - MS_LOG(ERROR) << "GetLiteParameter error"; - return RET_ERROR; + MS_LOG(INFO) << "LSTM input index " << index << " is not weight"; + return RET_OK; } if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";