| @@ -53,9 +53,6 @@ int AddnInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o | |||||
| if ((inputs[i]->shape_size_ != max_dims) && (GetElementNum(inputs[i]) != GetElementNum(inputs[max_dims_idx]))) { | if ((inputs[i]->shape_size_ != max_dims) && (GetElementNum(inputs[i]) != GetElementNum(inputs[max_dims_idx]))) { | ||||
| return NNACL_ERR; | return NNACL_ERR; | ||||
| } | } | ||||
| if (inputs[i]->data_type_ != inputs[0]->data_type_) { | |||||
| return NNACL_ERR; | |||||
| } | |||||
| } | } | ||||
| for (size_t d = 0; d < input->shape_size_; ++d) { | for (size_t d = 0; d < input->shape_size_; ++d) { | ||||
| @@ -39,6 +39,10 @@ int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso | |||||
| size_t input_shape1_size = input1->shape_size_; | size_t input_shape1_size = input1->shape_size_; | ||||
| output->format_ = input0->format_; | output->format_ = input0->format_; | ||||
| output->data_type_ = input0->data_type_; | output->data_type_ = input0->data_type_; | ||||
| if ((input0->data_type_ == kNumberTypeInt8) && (input1->data_type_ == kNumberTypeFloat32)) { | |||||
| output->data_type_ = input1->data_type_; | |||||
| } | |||||
| if (!parameter->infer_flag_) { | if (!parameter->infer_flag_) { | ||||
| return NNACL_INFER_INVALID; | return NNACL_INFER_INVALID; | ||||
| } | } | ||||
| @@ -28,5 +28,10 @@ if [ ! -f "$CONVERTER" ]; then | |||||
| fi | fi | ||||
| echo "============Converting=========" | echo "============Converting=========" | ||||
| LD_LIBRARY_PATH=./ $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod | |||||
| QUANT_OPTIONS="" | |||||
| if [[ ! -z ${QUANTIZE} ]]; then | |||||
| echo "Quantizing weights" | |||||
| QUANT_OPTIONS="--quantType=WeightQuant --bitNum=8 --quantWeightSize=100 --quantWeightChannel=15" | |||||
| fi | |||||
| LD_LIBRARY_PATH=./ $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod $QUANT_OPTIONS | |||||
| @@ -2,7 +2,7 @@ | |||||
| display_usage() | display_usage() | ||||
| { | { | ||||
| echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz] [-t arm64|x86]\n" | |||||
| echo -e "\nUsage: prepare_and_run.sh -D dataset_path [-d mindspore_docker] [-r release.tar.gz] [-t arm64|x86] [-q]\n" | |||||
| } | } | ||||
| checkopts() | checkopts() | ||||
| @@ -10,7 +10,8 @@ checkopts() | |||||
| TARGET="arm64" | TARGET="arm64" | ||||
| DOCKER="" | DOCKER="" | ||||
| MNIST_DATA_PATH="" | MNIST_DATA_PATH="" | ||||
| while getopts 'D:d:r:t:' opt | |||||
| QUANTIZE="" | |||||
| while getopts 'D:d:r:t:q' opt | |||||
| do | do | ||||
| case "${opt}" in | case "${opt}" in | ||||
| D) | D) | ||||
| @@ -31,6 +32,9 @@ checkopts() | |||||
| r) | r) | ||||
| TARBALL=$OPTARG | TARBALL=$OPTARG | ||||
| ;; | ;; | ||||
| q) | |||||
| QUANTIZE="QUANTIZE" | |||||
| ;; | |||||
| *) | *) | ||||
| echo "Unknown option ${opt}!" | echo "Unknown option ${opt}!" | ||||
| display_usage | display_usage | ||||
| @@ -64,7 +68,7 @@ fi | |||||
| # Prepare the model | # Prepare the model | ||||
| cd model/ || exit 1 | cd model/ || exit 1 | ||||
| rm -f *.ms | rm -f *.ms | ||||
| ./prepare_model.sh $DOCKER || exit 1 | |||||
| QUANTIZE=${QUANTIZE} ./prepare_model.sh $DOCKER || exit 1 | |||||
| cd ../ | cd ../ | ||||
| # Copy the .ms model to the package folder | # Copy the .ms model to the package folder | ||||
| @@ -110,6 +110,9 @@ void NetRunner::InitAndFigureInputs() { | |||||
| MS_ASSERT(nullptr != session_); | MS_ASSERT(nullptr != session_); | ||||
| loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session_); | loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session_); | ||||
| if (verbose_) { | |||||
| loop_->SetKernelCallBack(nullptr, after_callback); | |||||
| } | |||||
| acc_metrics_ = std::shared_ptr<AccuracyMetrics>(new AccuracyMetrics); | acc_metrics_ = std::shared_ptr<AccuracyMetrics>(new AccuracyMetrics); | ||||
| loop_->Init({acc_metrics_.get()}); | loop_->Init({acc_metrics_.get()}); | ||||
| @@ -125,11 +128,11 @@ void NetRunner::InitAndFigureInputs() { | |||||
| float NetRunner::CalculateAccuracy(int max_tests) { | float NetRunner::CalculateAccuracy(int max_tests) { | ||||
| test_ds_ = Mnist(data_dir_ + "/test", "all"); | test_ds_ = Mnist(data_dir_ + "/test", "all"); | ||||
| TypeCast typecast_f("float32"); | |||||
| TypeCast typecast_f(mindspore::DataType::kNumberTypeFloat32); | |||||
| Resize resize({h_, w_}); | Resize resize({h_, w_}); | ||||
| test_ds_ = test_ds_->Map({&resize, &typecast_f}, {"image"}); | test_ds_ = test_ds_->Map({&resize, &typecast_f}, {"image"}); | ||||
| TypeCast typecast("int32"); | |||||
| TypeCast typecast(mindspore::DataType::kNumberTypeInt32); | |||||
| test_ds_ = test_ds_->Map({&typecast}, {"label"}); | test_ds_ = test_ds_->Map({&typecast}, {"label"}); | ||||
| test_ds_ = test_ds_->Batch(batch_size_, true); | test_ds_ = test_ds_->Batch(batch_size_, true); | ||||
| @@ -144,14 +147,14 @@ float NetRunner::CalculateAccuracy(int max_tests) { | |||||
| int NetRunner::InitDB() { | int NetRunner::InitDB() { | ||||
| train_ds_ = Mnist(data_dir_ + "/train", "all"); | train_ds_ = Mnist(data_dir_ + "/train", "all"); | ||||
| TypeCast typecast_f("float32"); | |||||
| TypeCast typecast_f(mindspore::DataType::kNumberTypeFloat32); | |||||
| Resize resize({h_, w_}); | Resize resize({h_, w_}); | ||||
| train_ds_ = train_ds_->Map({&resize, &typecast_f}, {"image"}); | train_ds_ = train_ds_->Map({&resize, &typecast_f}, {"image"}); | ||||
| TypeCast typecast("int32"); | |||||
| TypeCast typecast(mindspore::DataType::kNumberTypeInt32); | |||||
| train_ds_ = train_ds_->Map({&typecast}, {"label"}); | train_ds_ = train_ds_->Map({&typecast}, {"label"}); | ||||
| train_ds_ = train_ds_->Shuffle(2); | |||||
| // train_ds_ = train_ds_->Shuffle(2); | |||||
| train_ds_ = train_ds_->Batch(batch_size_, true); | train_ds_ = train_ds_->Batch(batch_size_, true); | ||||
| if (verbose_) { | if (verbose_) { | ||||
| @@ -187,7 +187,7 @@ int NetRunner::TrainLoop() { | |||||
| if (save_checkpoint_ != 0 && (i + 1) % save_checkpoint_ == 0) { | if (save_checkpoint_ != 0 && (i + 1) % save_checkpoint_ == 0) { | ||||
| auto cpkt_fn = | auto cpkt_fn = | ||||
| ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained_" + std::to_string(i + 1) + ".ms"; | ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained_" + std::to_string(i + 1) + ".ms"; | ||||
| session_->SaveToFile(cpkt_fn); | |||||
| mindspore::lite::Model::Export(head_model_, cpkt_fn.c_str()); | |||||
| } | } | ||||
| std::cout << i + 1 << ": Loss is " << loss << " [min=" << min_loss << "]" << std::endl; | std::cout << i + 1 << ": Loss is " << loss << " [min=" << min_loss << "]" << std::endl; | ||||
| @@ -213,7 +213,7 @@ int NetRunner::Main() { | |||||
| if (cycles_ > 0) { | if (cycles_ > 0) { | ||||
| auto trained_fn = ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained.ms"; | auto trained_fn = ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained.ms"; | ||||
| session_->SaveToFile(trained_fn); | |||||
| mindspore::lite::Model::Export(head_model_, trained_fn.c_str()); | |||||
| } | } | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| @@ -44,6 +44,8 @@ class NetRunner { | |||||
| DataSet ds_; | DataSet ds_; | ||||
| mindspore::session::TrainSession *session_ = nullptr; | mindspore::session::TrainSession *session_ = nullptr; | ||||
| mindspore::lite::Model *backbone_model_ = nullptr; | |||||
| mindspore::lite::Model *head_model_ = nullptr; | |||||
| std::string ms_backbone_file_ = ""; | std::string ms_backbone_file_ = ""; | ||||
| std::string ms_head_file_ = ""; | std::string ms_head_file_ = ""; | ||||
| @@ -176,10 +176,6 @@ int Flags::InitTrainModel() { | |||||
| std::cerr << "INPUT ILLEGAL: train model converter supporting only FP32 output tensors"; | std::cerr << "INPUT ILLEGAL: train model converter supporting only FP32 output tensors"; | ||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| if (this->quantType != QuantType_QUANT_NONE) { | |||||
| std::cerr << "INPUT ILLEGAL: train model converter is not supporting quantization"; | |||||
| return RET_INPUT_PARAM_INVALID; | |||||
| } | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -181,6 +181,57 @@ bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &inputNode) const { | |||||
| if (inputNode == nullptr) { | |||||
| MS_LOG(INFO) << "CanTensorQuantized input is nullptr!"; | |||||
| return false; | |||||
| } | |||||
| ParameterPtr paramNode = nullptr; | |||||
| if (inputNode->isa<Parameter>()) { | |||||
| paramNode = inputNode->cast<ParameterPtr>(); | |||||
| } | |||||
| if (paramNode == nullptr) { | |||||
| MS_LOG(INFO) << "CanTensorQuantized invalid paramNode!"; | |||||
| return false; | |||||
| } | |||||
| auto abstract_base = paramNode->abstract(); | |||||
| if (abstract_base == nullptr) { | |||||
| MS_LOG(INFO) << "abstract is nullptr"; | |||||
| return false; | |||||
| } | |||||
| if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) { | |||||
| MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); | |||||
| return false; | |||||
| } | |||||
| auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape(); | |||||
| if (weight_shape.size() < 2) { // do not quant single dim tensors | |||||
| return false; | |||||
| } | |||||
| size_t shapeSize = 1; | |||||
| for (auto dim : weight_shape) { | |||||
| shapeSize = shapeSize * dim; | |||||
| } | |||||
| if (shapeSize < m_weight_size_) { | |||||
| MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; | |||||
| return false; | |||||
| } | |||||
| if (weight_shape.size() == 4) { // assume Convolution | |||||
| if (weight_shape[0] <= static_cast<int>(m_conv_weight_quant_channel_threshold_)) { | |||||
| MS_LOG(INFO) << "channel less m_conv_weight_quant_channel_threshold_!" << weight_shape[0]; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) { | QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) { | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| QuantParamHolderPtr quant_params_holder = nullptr; | QuantParamHolderPtr quant_params_holder = nullptr; | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZE_UTIL_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZE_UTIL_H_ | |||||
| #include <dirent.h> | #include <dirent.h> | ||||
| #include <sys/stat.h> | #include <sys/stat.h> | ||||
| @@ -83,6 +83,7 @@ class QuantStrategy { | |||||
| bool CanConvOpQuantized(const CNodePtr &node) const; | bool CanConvOpQuantized(const CNodePtr &node) const; | ||||
| bool CanMulOpQuantized(const CNodePtr &node) const; | bool CanMulOpQuantized(const CNodePtr &node) const; | ||||
| bool CanOpPostQuantized(AnfNodePtr &node) const; | bool CanOpPostQuantized(AnfNodePtr &node) const; | ||||
| bool CanTensorQuantized(const AnfNodePtr &inputNode) const; | |||||
| size_t m_weight_size_; | size_t m_weight_size_; | ||||
| size_t m_conv_weight_quant_channel_threshold_; | size_t m_conv_weight_quant_channel_threshold_; | ||||
| @@ -417,4 +418,4 @@ FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &); | |||||
| void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info); | void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info); | ||||
| } // namespace mindspore::lite::quant | } // namespace mindspore::lite::quant | ||||
| #endif | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZE_UTIL_H_ | |||||
| @@ -75,6 +75,7 @@ STATUS WeightQuantizer::SetAbstract(const tensor::TensorPtr &tensor_info, const | |||||
| auto quant_param_holder = GetCNodeQuantHolder(primitive); | auto quant_param_holder = GetCNodeQuantHolder(primitive); | ||||
| quant_param_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT); | quant_param_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT); | ||||
| weight_quantized_tensors.insert({tensor_info, param_node}); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -244,6 +245,82 @@ STATUS WeightQuantizer::DoGatherQuantize(const CNodePtr &cnode) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS WeightQuantizer::DoOptimizerQuantize(const CNodePtr &cnode) { | |||||
| auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| std::vector<int> weight_indices = {2}; | |||||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam)) { | |||||
| weight_indices = {2, 3}; | |||||
| } | |||||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) { | |||||
| weight_indices = {4, 6}; | |||||
| } | |||||
| for (int idx : weight_indices) { | |||||
| auto input = cnode->input(idx); | |||||
| if (!quant_strategy_->CanTensorQuantized(input)) { | |||||
| MS_LOG(INFO) << "Input " << idx << "of Optimizer is not quantizable"; | |||||
| continue; | |||||
| } | |||||
| ParameterPtr param_node; | |||||
| tensor::TensorPtr tensor_info; | |||||
| GetLiteParameter(input, ¶m_node, &tensor_info); | |||||
| if (param_node == nullptr || tensor_info == nullptr || tensor_info->data_type() != TypeId::kNumberTypeFloat32) { | |||||
| MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight"; | |||||
| return RET_OK; | |||||
| } | |||||
| auto status = RET_ERROR; | |||||
| if (type_id_ == kNumberTypeInt8) { | |||||
| status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||||
| false, type_id_, idx - 1); | |||||
| } else if (type_id_ == kNumberTypeInt16) { | |||||
| status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||||
| false, type_id_, idx - 1); | |||||
| } | |||||
| if (status != RET_OK && status != RET_CONTINUE) { | |||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||||
| return status; | |||||
| } | |||||
| status = SetAbstract(tensor_info, param_node, primitive); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS WeightQuantizer::DoMarkWeightQuantizeIfQuantized(const CNodePtr &cnode) { | |||||
| auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto quant_param_holder = GetCNodeQuantHolder(primitive); | |||||
| if (quant_param_holder->quant_type() == schema::QuantType_QUANT_WEIGHT) { | |||||
| // already marked with QUANT_WEIGHT | |||||
| return RET_OK; | |||||
| } | |||||
| for (size_t i = 1; i < cnode->size(); i++) { | |||||
| auto inputNode = cnode->input(i); | |||||
| if (inputNode->isa<Parameter>()) { | |||||
| ParameterPtr param_node; | |||||
| tensor::TensorPtr tensor_info; | |||||
| GetLiteParameter(inputNode, ¶m_node, &tensor_info); | |||||
| auto param = weight_quantized_tensors.find(tensor_info); | |||||
| if (param != weight_quantized_tensors.end()) { | |||||
| quant_param_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const PrimitivePtr &primitive, | STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const PrimitivePtr &primitive, | ||||
| const int &index) { | const int &index) { | ||||
| auto op_name = cnode->fullname_with_scope(); | auto op_name = cnode->fullname_with_scope(); | ||||
| @@ -649,6 +726,8 @@ STATUS WeightQuantizer::DoMixedQuant(const FuncGraphPtr &func_graph) { | |||||
| STATUS WeightQuantizer::DoFixedQuant(const FuncGraphPtr &func_graph) { | STATUS WeightQuantizer::DoFixedQuant(const FuncGraphPtr &func_graph) { | ||||
| MS_ASSERT(func_graph != nullptr); | MS_ASSERT(func_graph != nullptr); | ||||
| weight_quantized_tensors.clear(); | |||||
| for (auto &cnode : func_graph->GetOrderedCnodes()) { | for (auto &cnode : func_graph->GetOrderedCnodes()) { | ||||
| auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0)); | auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0)); | ||||
| if (primitive == nullptr) { | if (primitive == nullptr) { | ||||
| @@ -681,10 +760,34 @@ STATUS WeightQuantizer::DoFixedQuant(const FuncGraphPtr &func_graph) { | |||||
| MS_LOG(ERROR) << "DoGatherQuantize error"; | MS_LOG(ERROR) << "DoGatherQuantize error"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } else if ((opt::CheckPrimitiveType(cnode, prim::kPrimAdam)) || (opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) || | |||||
| (opt::CheckPrimitiveType(cnode, prim::kPrimApplyMomentum))) { | |||||
| auto status = DoOptimizerQuantize(cnode); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoOptimizerQuantize error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << op_name << " of type: " << primitive->name() << " no need quant"; | MS_LOG(DEBUG) << op_name << " of type: " << primitive->name() << " no need quant"; | ||||
| } | } | ||||
| } | } | ||||
| return MarkWeightQuantizationInNodes(func_graph); | |||||
| } | |||||
| STATUS WeightQuantizer::MarkWeightQuantizationInNodes(const FuncGraphPtr &func_graph) { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| for (auto &cnode : func_graph->GetOrderedCnodes()) { | |||||
| auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0)); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr"; | |||||
| continue; | |||||
| } | |||||
| auto status = DoMarkWeightQuantizeIfQuantized(cnode); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "MarkWeightQuantizationInNodes error marking " << cnode->fullname_with_scope(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_ | |||||
| #include <future> | #include <future> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -43,6 +43,7 @@ class WeightQuantizer : public Quantizer { | |||||
| STATUS DoQuantize(FuncGraphPtr func_graph) override; | STATUS DoQuantize(FuncGraphPtr func_graph) override; | ||||
| STATUS DoConvQuantize(const CNodePtr &); | STATUS DoConvQuantize(const CNodePtr &); | ||||
| STATUS DoMulQuantize(const CNodePtr &); | STATUS DoMulQuantize(const CNodePtr &); | ||||
| STATUS DoOptimizerQuantize(const CNodePtr &); | |||||
| STATUS DoLstmQuantize(const CNodePtr &cnode); | STATUS DoLstmQuantize(const CNodePtr &cnode); | ||||
| STATUS DoGatherQuantize(const CNodePtr &cnode); | STATUS DoGatherQuantize(const CNodePtr &cnode); | ||||
| @@ -57,6 +58,7 @@ class WeightQuantizer : public Quantizer { | |||||
| std::unique_ptr<QuantStrategy> quant_strategy_; | std::unique_ptr<QuantStrategy> quant_strategy_; | ||||
| size_t bit_num_{8}; | size_t bit_num_{8}; | ||||
| std::string config_file_; | std::string config_file_; | ||||
| std::map<tensor::TensorPtr, ParameterPtr> weight_quantized_tensors; | |||||
| PostQuantConfig config_param_; | PostQuantConfig config_param_; | ||||
| std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...] | std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...] | ||||
| std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_; | std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_; | ||||
| @@ -65,6 +67,8 @@ class WeightQuantizer : public Quantizer { | |||||
| STATUS SetAbstract(const tensor::TensorPtr &tensor_info, const ParameterPtr ¶m_node, | STATUS SetAbstract(const tensor::TensorPtr &tensor_info, const ParameterPtr ¶m_node, | ||||
| const PrimitivePtr &primitive); | const PrimitivePtr &primitive); | ||||
| STATUS DoFixedQuant(const FuncGraphPtr &); | STATUS DoFixedQuant(const FuncGraphPtr &); | ||||
| STATUS MarkWeightQuantizationInNodes(const FuncGraphPtr &); | |||||
| STATUS DoMarkWeightQuantizeIfQuantized(const CNodePtr &); | |||||
| STATUS RunFp32Graph(const FuncGraphPtr &); | STATUS RunFp32Graph(const FuncGraphPtr &); | ||||
| STATUS DoMixedQuantize(const FuncGraphPtr &func_graph); | STATUS DoMixedQuantize(const FuncGraphPtr &func_graph); | ||||
| @@ -74,6 +78,7 @@ class WeightQuantizer : public Quantizer { | |||||
| STATUS TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, const tensor::TensorPtr &tensor_info, | STATUS TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, const tensor::TensorPtr &tensor_info, | ||||
| const PrimitivePtr &primitive); | const PrimitivePtr &primitive); | ||||
| STATUS DoQuantSearch(const FuncGraphPtr &func_graph); | STATUS DoQuantSearch(const FuncGraphPtr &func_graph); | ||||
| STATUS DoTensorQuantize(const CNodePtr &); | |||||
| }; | }; | ||||
| } // namespace mindspore::lite::quant | } // namespace mindspore::lite::quant | ||||
| #endif | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_ | |||||