Browse Source

[MS][LITE] use set quant params instead of add quant params

pull/15226/head
cjh9368 4 years ago
parent
commit
8c2909b66f
15 changed files with 150 additions and 130 deletions
  1. +2
    -2
      mindspore/lite/tools/anf_exporter/anf_exporter.cc
  2. +2
    -13
      mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc
  3. +3
    -3
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
  4. +16
    -0
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc
  5. +3
    -0
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.h
  6. +7
    -3
      mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc
  7. +51
    -38
      mindspore/lite/tools/converter/quant_param_holder.h
  8. +18
    -18
      mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc
  9. +4
    -4
      mindspore/lite/tools/converter/quantizer/quant_cast.cc
  10. +2
    -2
      mindspore/lite/tools/converter/quantizer/quantize_util.cc
  11. +5
    -6
      mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc
  12. +2
    -2
      mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc
  13. +19
    -30
      mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc
  14. +15
    -6
      mindspore/lite/tools/optimizer/graph/tflite_inputs_adjust_pass.cc
  15. +1
    -3
      mindspore/lite/tools/optimizer/graph/transpose_strategy.cc

+ 2
- 2
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -274,8 +274,8 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
MS_LOG(ERROR) << "quant param is invalid.";
return RET_ERROR;
}
input_quant_params = quant_param_holder->input_quant_params();
output_quant_params = quant_param_holder->output_quant_params();
input_quant_params = quant_param_holder->get_input_quant_params();
output_quant_params = quant_param_holder->get_output_quant_params();
dst_node->quantType = quant_param_holder->quant_type();
}
// add quant param


+ 2
- 13
mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc View File

@@ -311,19 +311,8 @@ STATUS CaffeModelParser::ConvertLayerQuantParams(const caffe::LayerParameter &la
MS_LOG(ERROR) << "primitive_c is null, get quant params failed.";
return RET_NULL_PTR;
}
auto quant_params_holder = std::make_shared<QuantParamHolder>();
for (auto input_idx : layer.bottom()) {
std::vector<schema::QuantParamT> notinited_quant_params(1);
quant_params_holder->AddInputQuantParam(notinited_quant_params);
}
for (auto input_idx : weight.blobs()) {
std::vector<schema::QuantParamT> notinited_quant_params(1);
quant_params_holder->AddInputQuantParam(notinited_quant_params);
}
for (auto output_idx : layer.top()) {
std::vector<schema::QuantParamT> notinited_quant_params(1);
quant_params_holder->AddOutputQuantParam(notinited_quant_params);
}
auto quant_params_holder =
std::make_shared<QuantParamHolder>(layer.bottom_size() + weight.blobs_size(), layer.top_size());
primitive_c->AddAttr("quant_params", quant_params_holder);
return RET_OK;
}


+ 3
- 3
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc View File

@@ -535,7 +535,7 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o
return RET_ERROR;
}
// set input tensors
auto quant_params_holder = std::make_shared<QuantParamHolder>();
auto quant_params_holder = std::make_shared<QuantParamHolder>(onnx_node.input_size(), onnx_node.output_size());
for (int i = 0; i < onnx_node.input_size(); ++i) {
const auto &input_name = onnx_node.input(i);
std::vector<schema::QuantParamT> quant_params;
@@ -544,7 +544,7 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o
MS_LOG(ERROR) << "set input tensor quant param failed.";
return status;
}
quant_params_holder->AddInputQuantParam(quant_params);
quant_params_holder->set_input_quant_param(i, quant_params);
}
// set out tensors
for (int i = 0; i < onnx_node.output_size(); ++i) {
@@ -555,7 +555,7 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o
MS_LOG(ERROR) << "set output tensor quant param failed.";
return status;
}
quant_params_holder->AddOutputQuantParam(quant_params);
quant_params_holder->set_output_quant_param(i, quant_params);
}
primitive_c->AddAttr("quant_params", quant_params_holder);
return RET_OK;


+ 16
- 0
mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc View File

@@ -953,9 +953,25 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";
}

status = ConvertQuantParams(inputs.size() - 1, output_size, primitiveC);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed.";
}
return status;
}

STATUS TFModelParser::ConvertQuantParams(const size_t &input_size, const size_t &output_size,
ops::PrimitiveC *primitive_c) {
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is null, get quant params failed.";
return RET_NULL_PTR;
}
auto quant_params_holder = std::make_shared<QuantParamHolder>(input_size, output_size);
primitive_c->AddAttr("quant_params", quant_params_holder);
return RET_OK;
}

STATUS TFModelParser::ConvertRootGraphOutputs() {
// because output of intermediate node in anf graph may also be output tensors, we search output tensors in
// tf_root_graph_nodes_ but not anf_root_node_map_


+ 3
- 0
mindspore/lite/tools/converter/parser/tf/tf_model_parser.h View File

@@ -29,6 +29,7 @@
#include "securec/include/securec.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/model_parser.h"
#include "ops/primitive_c.h"

namespace mindspore {
namespace lite {
@@ -75,6 +76,8 @@ class TFModelParser : public ModelParser {
STATUS ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGraphPtr> &first_func_map,
const std::map<CNodePtr, FuncGraphPtr> &second_func_map);

STATUS ConvertQuantParams(const size_t &input_size, const size_t &output_size, ops::PrimitiveC *primitive_c);

static STATUS MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph);

STATUS RecordNullInput(const CNodePtr &node, const std::vector<std::string> &input_name_not_found);


+ 7
- 3
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc View File

@@ -243,7 +243,8 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, ops:
round_type = 2;
}
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
auto quant_params_holder = std::make_shared<QuantParamHolder>();
auto quant_params_holder = std::make_shared<QuantParamHolder>(op->inputs.size(), op->outputs.size());
size_t idx = 0;
for (auto input_idx : op->inputs) {
if (input_idx < 0) {
input_idx += tflite_subgraph->tensors.size();
@@ -255,8 +256,10 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, ops:
MS_LOG(ERROR) << "set input tensor quant param failed.";
return status;
}
quant_params_holder->AddInputQuantParam(quant_params);
quant_params_holder->set_input_quant_param(idx, quant_params);
idx++;
}
idx = 0;
for (auto output_idx : op->outputs) {
if (output_idx < 0) {
output_idx += tflite_subgraph->tensors.size();
@@ -268,7 +271,8 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, ops:
MS_LOG(ERROR) << "set output tensor quant param failed.";
return status;
}
quant_params_holder->AddOutputQuantParam(quant_params);
quant_params_holder->set_output_quant_param(idx, quant_params);
idx++;
}
primitive_c->AddAttr("quant_params", quant_params_holder);
return RET_OK;


+ 51
- 38
mindspore/lite/tools/converter/quant_param_holder.h View File

@@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANT_PARAM_CONTEXT_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANT_PARAM_CONTEXT_H

#include <utility>
#include <vector>
#include <memory>
#include "ir/anf.h"
@@ -27,7 +28,24 @@ namespace lite {
using QuantParamsVector = std::vector<std::vector<schema::QuantParamT>>;
class QuantParamHolder : public Value {
public:
QuantParamHolder() = default;
QuantParamHolder(size_t input_size, size_t output_size) {
input_quant_params_.resize(input_size);
output_quant_params_.resize(output_size);
for (size_t i = 0; i < input_size; i++) {
std::vector<schema::QuantParamT> notinited_quant_params(1);
set_input_quant_param(i, notinited_quant_params);
}

for (size_t i = 0; i < output_size; i++) {
std::vector<schema::QuantParamT> notinited_quant_params(1);
set_output_quant_param(i, notinited_quant_params);
}
}

QuantParamHolder(const QuantParamsVector &input_quant_params, const QuantParamsVector &output_quant_params) {
input_quant_params_ = input_quant_params;
output_quant_params_ = output_quant_params;
}

~QuantParamHolder() override = default;

@@ -36,17 +54,17 @@ class QuantParamHolder : public Value {
bool operator==(const Value &rhs) const override { // unused
if (rhs.isa<QuantParamHolder>()) {
auto other_holder = dynamic_cast<const QuantParamHolder &>(rhs);
auto input_quant_params_rhs = other_holder.input_quant_params();
auto output_quant_params_rhs = other_holder.output_quant_params();
if (input_quant_params_rhs.size() != this->input_quant_param_.size() ||
output_quant_params_rhs.size() != this->output_quant_param_.size()) {
auto input_quant_params_rhs = other_holder.get_input_quant_params();
auto output_quant_params_rhs = other_holder.get_output_quant_params();
if (input_quant_params_rhs.size() != this->input_quant_params_.size() ||
output_quant_params_rhs.size() != this->output_quant_params_.size()) {
return false;
}
for (size_t i = 0; i < input_quant_params_rhs.size(); ++i) {
if (input_quant_params_rhs.at(i).size() != this->input_quant_param_.at(i).size()) {
if (input_quant_params_rhs.at(i).size() != this->input_quant_params_.at(i).size()) {
return false;
}
auto *params = reinterpret_cast<const char *>(this->input_quant_param_.at(i).data());
auto *params = reinterpret_cast<const char *>(this->input_quant_params_.at(i).data());
auto *params_rhs = reinterpret_cast<const char *>(input_quant_params_rhs.at(i).data());
for (size_t j = 0; j < input_quant_params_rhs.at(i).size() * sizeof(schema::QuantParamT); ++j) {
if (params[j] != params_rhs[j]) {
@@ -55,10 +73,10 @@ class QuantParamHolder : public Value {
}
}
for (size_t i = 0; i < output_quant_params_rhs.size(); ++i) {
if (output_quant_params_rhs.at(i).size() != this->output_quant_param_.at(i).size()) {
if (output_quant_params_rhs.at(i).size() != this->output_quant_params_.at(i).size()) {
return false;
}
auto *params = reinterpret_cast<const char *>(this->output_quant_param_.at(i).data());
auto *params = reinterpret_cast<const char *>(this->output_quant_params_.at(i).data());
auto *params_rhs = reinterpret_cast<const char *>(output_quant_params_rhs.at(i).data());
for (size_t j = 0; j < output_quant_params_rhs.at(i).size() * sizeof(schema::QuantParamT); ++j) {
if (params[j] != params_rhs[j]) {
@@ -76,58 +94,53 @@ class QuantParamHolder : public Value {

schema::QuantType quant_type() const { return quant_type_; }

void set_input_quant_params(const QuantParamsVector &input_quant_param) {
this->input_quant_param_ = input_quant_param;
}

void set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) {
if (index >= this->input_quant_param_.size()) {
if (index >= this->input_quant_params_.size()) {
std::vector<schema::QuantParamT> place_quant(1);
this->input_quant_param_.insert(this->input_quant_param_.end(), index + 1 - input_quant_param_.size(),
place_quant);
this->input_quant_params_.insert(this->input_quant_params_.end(), index + 1 - input_quant_params_.size(),
place_quant);
}
this->input_quant_param_.at(index) = input_quant_param;
}

void set_output_quant_params(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) {
this->output_quant_param_ = output_quant_param;
this->input_quant_params_.at(index) = input_quant_param;
}

void set_output_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param) {
if (index >= this->output_quant_param_.size()) {
if (index >= this->output_quant_params_.size()) {
std::vector<schema::QuantParamT> place_quant(1);
this->output_quant_param_.insert(this->output_quant_param_.end(), index + 1 - output_quant_param_.size(),
place_quant);
this->output_quant_params_.insert(this->output_quant_params_.end(), index + 1 - output_quant_params_.size(),
place_quant);
}
this->output_quant_param_.at(index) = output_quant_param;
this->output_quant_params_.at(index) = output_quant_param;
}

void set_enable_huffman_code(bool enable_huffman_code) { enable_huffman_code_ = enable_huffman_code; }

bool enable_huffman_code() const { return enable_huffman_code_; }

// deprecated
void AddInputQuantParam(const std::vector<schema::QuantParamT> &quant_param) {
this->input_quant_param_.emplace_back(quant_param);
this->input_quant_params_.emplace_back(quant_param);
}

std::vector<std::vector<schema::QuantParamT>> input_quant_params() const { return this->input_quant_param_; }

// deprecated
void AddOutputQuantParam(const std::vector<schema::QuantParamT> &quant_param) {
this->output_quant_param_.emplace_back(quant_param);
this->output_quant_params_.emplace_back(quant_param);
}

std::vector<std::vector<schema::QuantParamT>> output_quant_params() const { return this->output_quant_param_; }
std::vector<std::vector<schema::QuantParamT>> get_input_quant_params() const { return this->input_quant_params_; }

std::vector<std::vector<schema::QuantParamT>> get_output_quant_params() const { return this->output_quant_params_; }

// deprecated
void ClearInputOutputQuantParam() {
input_quant_param_.clear();
output_quant_param_.clear();
input_quant_params_.clear();
output_quant_params_.clear();
}

bool IsInputQuantParamsInited() {
if (this->input_quant_param_.empty()) {
if (this->input_quant_params_.empty()) {
return false;
}
for (auto &quant_param : this->input_quant_param_) {
for (auto &quant_param : this->input_quant_params_) {
if (!quant_param.front().inited) {
return false;
}
@@ -136,10 +149,10 @@ class QuantParamHolder : public Value {
}

bool IsOutputQuantParamsInited() {
if (this->output_quant_param_.empty()) {
if (this->output_quant_params_.empty()) {
return false;
}
for (auto &quant_param : this->output_quant_param_) {
for (auto &quant_param : this->output_quant_params_) {
if (!quant_param.front().inited) {
return false;
}
@@ -149,8 +162,8 @@ class QuantParamHolder : public Value {

private:
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
QuantParamsVector input_quant_param_;
QuantParamsVector output_quant_param_;
QuantParamsVector input_quant_params_;
QuantParamsVector output_quant_params_;
bool enable_huffman_code_ = false;
};
using QuantParamHolderPtr = std::shared_ptr<QuantParamHolder>;


+ 18
- 18
mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc View File

@@ -59,7 +59,7 @@ STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, cons
MS_ASSERT(quant_params != nullptr && quant_datas != nullptr);
double bias_scale_tmp;
const constexpr int32_t quanted_bias_abs_limit = 0.5 * INT32_MAX;
auto active_weight_quant_params = quant_param_holder->input_quant_params();
auto weight_quant_params = quant_param_holder->get_input_quant_params().at(1);
auto shape_size = quant_datas->size();
if (bias_scales.size() == shape_size) {
for (size_t i = 0; i < shape_size; i++) {
@@ -69,14 +69,14 @@ STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, cons
return RET_ERROR;
}
if (std::abs(raw_datas[i] / bias_scale_tmp) >= quanted_bias_abs_limit) {
MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << active_weight_quant_params[1][i].scale
MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << weight_quant_params[i].scale
<< " is too small, need to update";
// update filter scale and zp
double activate_scale = input_scales[0];
double filter_scale = std::abs(raw_datas[i]) / (activate_scale * quanted_bias_abs_limit);
active_weight_quant_params[1][i].scale = filter_scale;
active_weight_quant_params[1][i].zeroPoint = 0;
quant_param_holder->set_input_quant_params(active_weight_quant_params);
weight_quant_params[i].scale = filter_scale;
weight_quant_params[i].zeroPoint = 0;
quant_param_holder->set_input_quant_param(1, weight_quant_params);
bias_scale_tmp = std::abs(raw_datas[i]) / quanted_bias_abs_limit;
quant_params->at(i).scale = bias_scale_tmp;
MS_LOG(DEBUG) << "new filter scale: " << filter_scale;
@@ -99,13 +99,13 @@ STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, cons
return RET_ERROR;
}
if (std::abs(max_raw_data / bias_scale_tmp) >= quanted_bias_abs_limit) {
MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << active_weight_quant_params[1][0].scale
MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << weight_quant_params[0].scale
<< " is too small, need to update";
double activate_scale = input_scales[0];
double filter_scale = std::abs(max_raw_data) / (activate_scale * quanted_bias_abs_limit);
active_weight_quant_params[1][0].scale = filter_scale;
active_weight_quant_params[1][0].zeroPoint = 0;
quant_param_holder->set_input_quant_params(active_weight_quant_params);
weight_quant_params[0].scale = filter_scale;
weight_quant_params[0].zeroPoint = 0;
quant_param_holder->set_input_quant_param(1, weight_quant_params);
bias_scale_tmp = max_raw_data / quanted_bias_abs_limit;
quant_params->front().scale = bias_scale_tmp;
MS_LOG(DEBUG) << "new filter scale: " << filter_scale;
@@ -117,7 +117,7 @@ STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, cons
return RET_OK;
}
MS_LOG(ERROR) << "unexpected input_scales size: " << input_scales.size()
<< " weight_scales size: " << active_weight_quant_params[1].size();
<< " weight_scales size: " << weight_quant_params.size();
return RET_ERROR;
}
} // namespace
@@ -620,7 +620,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(const AnfNodePtr &bias, const Primitiv
MS_ASSERT(bias_parameter != nullptr);
auto quant_param_holder = GetCNodeQuantHolder(primitive);
MS_ASSERT(quant_param_holder != nullptr);
auto active_weight_quant_params = quant_param_holder->input_quant_params();
auto active_weight_quant_params = quant_param_holder->get_input_quant_params();
if (active_weight_quant_params.size() != 2) {
MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size();
return RET_ERROR;
@@ -731,7 +731,7 @@ STATUS PostTrainingQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
auto input_primitive_quant_holder = GetCNodeQuantHolder(input_cnode_primitive);
MS_ASSERT(input_primitive_quant_holder != nullptr);
if (input_primitive_quant_holder->IsOutputQuantParamsInited()) {
auto quant_param = input_primitive_quant_holder->output_quant_params().front();
auto quant_param = input_primitive_quant_holder->get_output_quant_params().front();
primitive_quant_holder->AddInputQuantParam(quant_param);
} else {
// do input quant
@@ -820,14 +820,14 @@ STATUS PostTrainingQuantizer::QuantNode() {
}
auto input_primitive_quant_holder = GetCNodeQuantHolder(input_cnode_primitive);
MS_ASSERT(input_primitive_quant_holder != nullptr);
if (input_primitive_quant_holder->output_quant_params().size() > index) {
auto quant_param = input_primitive_quant_holder->output_quant_params()[index];
if (input_primitive_quant_holder->get_output_quant_params().size() > index) {
auto quant_param = input_primitive_quant_holder->get_output_quant_params()[index];
primitive_quant_holder->AddInputQuantParam(quant_param);
primitive_quant_holder->AddOutputQuantParam(quant_param);
} else {
MS_LOG(WARNING) << "this TupleGetItem node's input node: " << input_cnode->fullname_with_scope()
<< "'s output quant_params size: " << input_primitive_quant_holder->output_quant_params().size()
<< ", but index: " << index;
<< "'s output quant_params size: "
<< input_primitive_quant_holder->get_output_quant_params().size() << ", but index: " << index;
}
primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
continue;
@@ -1125,7 +1125,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, con
}
auto quant_param_holder = GetCNodeQuantHolder(primitive);
MS_ASSERT(quant_param_holder != nullptr);
auto input_quant_params = quant_param_holder->input_quant_params();
auto input_quant_params = quant_param_holder->get_input_quant_params();
if (input_quant_params.size() == 3) {
// compensate the existed
auto bias_quant_params = input_quant_params[2];
@@ -1191,7 +1191,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, con
cnode->add_input(parameter);
DoBiasQuant(parameter, primitive);
} else {
MS_LOG(ERROR) << "unexpected input_quant_params size: " << input_quant_params.size();
MS_LOG(ERROR) << "unexpected get_input_quant_params size: " << input_quant_params.size();
}
return RET_OK;
}


+ 4
- 4
mindspore/lite/tools/converter/quantizer/quant_cast.cc View File

@@ -25,7 +25,7 @@ namespace mindspore::lite::quant {
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params) {
auto prim_c = std::make_shared<ops::QuantDTypeCast>();
prim_c->Init(src_type, dst_type);
auto quant_params_holder = std::make_shared<QuantParamHolder>();
auto quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL);
for (auto &quant_param : quant_params) {
std::vector<schema::QuantParamT> quant_params_in = {quant_param};
@@ -82,17 +82,17 @@ STATUS QuantCast::Run(const FuncGraphPtr &graph) {
ValueNodePtr value_node = nullptr;
if (curnode_quant_type == schema::QuantType_QUANT_ALL &&
input_cnode_quant_type == schema::QuantType_QUANT_NONE) {
if (primitive_quant_param_holder->input_quant_params().size() < i) {
if (primitive_quant_param_holder->get_input_quant_params().size() < i) {
MS_LOG(ERROR) << "quant param is invalid.";
return RET_ERROR;
}
value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8,
primitive_quant_param_holder->input_quant_params()[i - 1]);
primitive_quant_param_holder->get_input_quant_params()[i - 1]);
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
input_cnode_quant_type == schema::QuantType_QUANT_ALL) {
auto input_primitive_quant_param_holder = GetCNodeQuantHolder(input_cnode_primitive_c);
value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32,
input_primitive_quant_param_holder->output_quant_params().front());
input_primitive_quant_param_holder->get_output_quant_params().front());
}
if (value_node == nullptr) {
MS_LOG(WARNING) << "value_node is null! "


+ 2
- 2
mindspore/lite/tools/converter/quantizer/quantize_util.cc View File

@@ -186,12 +186,12 @@ QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) {
QuantParamHolderPtr quant_params_holder = nullptr;
auto quant_params_valueptr = primitive->GetAttr("quant_params");
if (quant_params_valueptr == nullptr) {
quant_params_holder = std::make_shared<QuantParamHolder>();
quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
primitive->AddAttr("quant_params", quant_params_holder);
} else {
quant_params_holder = quant_params_valueptr->cast<QuantParamHolderPtr>();
if (quant_params_holder == nullptr) {
quant_params_holder = std::make_shared<QuantParamHolder>();
quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
primitive->AddAttr("quant_params", quant_params_holder);
}
}


+ 5
- 6
mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc View File

@@ -133,7 +133,6 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
}
auto left_matmul_input = left_slice_cnode->input(1);
auto right_reshape_node = fullconnect_cnode->input(2);

auto matmul_cvalue = new (std::nothrow) mindspore::ops::MatMul();
if (matmul_cvalue == nullptr) {
MS_LOG(ERROR) << "new MatMul failed";
@@ -153,29 +152,29 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
MS_LOG(ERROR) << "quant param is invalid.";
return nullptr;
}
auto fc_input_quantParams = fc_input_quantParams_holder->input_quant_params();
auto fc_input_quantParams = fc_input_quantParams_holder->get_input_quant_params();
if (fc_input_quantParams.size() > 1 && !fc_input_quantParams[1].empty()) {
jointed_quant_params.push_back(fc_input_quantParams[1][0]);
}
}
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>();
auto fc_prim = GetValueNode<PrimitiveCPtr>(fullconnect_cnode->input(0));
lite::QuantParamsVector rmatmul_quant_params;
auto rmatmul_quant_params_valueptr = fc_prim->GetAttr("quant_params");
lite::QuantParamsVector output_quant_params;
if (rmatmul_quant_params_valueptr != nullptr) {
auto rmatmul_quant_params_holder = rmatmul_quant_params_valueptr->cast<lite::QuantParamHolderPtr>();
if (rmatmul_quant_params_holder == nullptr) {
MS_LOG(ERROR) << "quant param is invalid.";
return nullptr;
}
rmatmul_quant_params = rmatmul_quant_params_holder->input_quant_params();
quant_params_holder->set_output_quant_params(rmatmul_quant_params_holder->output_quant_params());
rmatmul_quant_params = rmatmul_quant_params_holder->get_input_quant_params();
output_quant_params = rmatmul_quant_params_holder->get_output_quant_params();
}
rmatmul_quant_params.pop_back();
rmatmul_quant_params.pop_back();
// no bias quantParams
rmatmul_quant_params.emplace_back(jointed_quant_params);
quant_params_holder->set_input_quant_params(rmatmul_quant_params);
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(rmatmul_quant_params, output_quant_params);
matmul_cvalue->AddAttr("quant_params", quant_params_holder);
auto matmul_value_node = NewValueNode(std::shared_ptr<ops::PrimitiveC>(matmul_cvalue));
std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input};


+ 2
- 2
mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc View File

@@ -205,7 +205,7 @@ lite::STATUS CopyQuantParams(const CNodePtr &cnode, const std::vector<Tensor *>
MS_LOG(ERROR) << "quant param is invalid.";
return lite::RET_ERROR;
}
auto input_quant_params = quant_param_holder->input_quant_params();
auto input_quant_params = quant_param_holder->get_input_quant_params();
for (size_t m = 0; m < input_quant_params.size(); m++) {
for (auto inputQuantParam : input_quant_params[m]) {
lite::QuantArg quant_arg{};
@@ -216,7 +216,7 @@ lite::STATUS CopyQuantParams(const CNodePtr &cnode, const std::vector<Tensor *>
inputs[m]->AddQuantParam(quant_arg);
}
}
auto output_quant_params = quant_param_holder->output_quant_params();
auto output_quant_params = quant_param_holder->get_output_quant_params();
for (size_t m = 0; m < output_quant_params.size(); m++) {
for (auto outputQuantParam : output_quant_params[m]) {
lite::QuantArg quant_arg{};


+ 19
- 30
mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc View File

@@ -25,12 +25,17 @@
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kDoubleNum = 2;
int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) {
auto quant_tensor_info_ptr = prim->GetAttr("quant_params");
if (quant_tensor_info_ptr == nullptr) {
prim->AddAttr("quant_params", std::make_shared<lite::QuantParamHolder>());
size_t GetCNodeOutputsSize(std::shared_ptr<AnfNode> anf_node) {
auto cnode = anf_node->cast<CNodePtr>();
if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
return tuple->elements().size();
} else {
return 1;
}
}

int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) {
auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>();
std::vector<schema::QuantParamT> quants;
schema::QuantParamT quant_param;
@@ -50,10 +55,7 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t
return ret;
}
quants.emplace_back(quant_param);
quant_param_holder->AddInputQuantParam(quants);
} else {
std::vector<schema::QuantParamT> notinited_quant_params(1);
quant_param_holder->AddInputQuantParam(notinited_quant_params);
quant_param_holder->set_input_quant_param(0, quants);
}

quants.clear();
@@ -78,19 +80,12 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t
return ret;
}
quants.emplace_back(quant_param);
quant_param_holder->AddInputQuantParam(quants);
} else {
std::vector<schema::QuantParamT> notinited_quant_params(1);
quant_param_holder->AddInputQuantParam(notinited_quant_params);
quant_param_holder->set_input_quant_param(1, quants);
}
return lite::RET_OK;
}

int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) {
auto quant_tensor_info_ptr = prim->GetAttr("quant_params");
if (quant_tensor_info_ptr == nullptr) {
prim->AddAttr("quant_params", std::make_shared<lite::QuantParamHolder>());
}
auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>();
std::vector<schema::QuantParamT> quants;
schema::QuantParamT quant_param;
@@ -110,22 +105,14 @@ int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t
return ret;
}
quants.emplace_back(quant_param);
quant_param_holder->AddOutputQuantParam(quants);
} else {
schema::QuantParamT tmpQuantParam;
quants.emplace_back(tmpQuantParam);
quant_param_holder->AddOutputQuantParam(quants);
quant_param_holder->set_output_quant_param(0, quants);
}
return lite::RET_OK;
}

void CheckQuantParams(const PrimitivePtr &prim) {
auto quant_tensor_info_ptr = prim->GetAttr("quant_params");
if (quant_tensor_info_ptr == nullptr) {
prim->AddAttr("quant_params", std::make_shared<lite::QuantParamHolder>());
}
auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>();
auto input_quant_params = quant_param_holder->input_quant_params();
auto input_quant_params = quant_param_holder->get_input_quant_params();
bool is_quant = false;
for (size_t i = 0; i < input_quant_params.size(); ++i) {
if (!input_quant_params.at(i).empty() && input_quant_params.at(i).at(0).inited) {
@@ -133,7 +120,7 @@ void CheckQuantParams(const PrimitivePtr &prim) {
break;
}
}
auto output_quant_params = quant_param_holder->output_quant_params();
auto output_quant_params = quant_param_holder->get_output_quant_params();
for (size_t i = 0; i < output_quant_params.size(); ++i) {
if (!output_quant_params.at(i).empty() && output_quant_params.at(i).at(0).inited) {
is_quant = true;
@@ -145,8 +132,6 @@ void CheckQuantParams(const PrimitivePtr &prim) {
}

int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs) {
auto quant_param_holder = std::make_shared<lite::QuantParamHolder>();
prim->AddAttr("quant_params", quant_param_holder);
auto narrow_range = prim->GetAttr("narrow_range");
bool narrow_range_param = false;
if (narrow_range != nullptr) {
@@ -248,6 +233,10 @@ int MindirAdjustPass::ComputeQuantParams(std::shared_ptr<AnfNode> anf_node) {
}
auto inputs = cnode->inputs();
inputs.erase(inputs.begin());

auto quant_param_holder = std::make_shared<lite::QuantParamHolder>(inputs.size(), GetCNodeOutputsSize(anf_node));
primitive->AddAttr("quant_params", quant_param_holder);

if (ConvertQuantParam(primitive, inputs) != lite::RET_OK) {
MS_LOG(ERROR) << "compute quant param failed.";
return lite::RET_ERROR;


+ 15
- 6
mindspore/lite/tools/optimizer/graph/tflite_inputs_adjust_pass.cc View File

@@ -56,12 +56,16 @@ lite::STATUS ReorderCnodeInputs(CNode *cnode, const std::vector<size_t> &perm) {
std::vector<AnfNodePtr> new_inputs = {cnode->input(0)};
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
auto input_quant_params = primitive->GetAttr("quant_params");
auto input_quant_params_holder = input_quant_params == nullptr
? std::make_shared<lite::QuantParamHolder>()
: input_quant_params->cast<lite::QuantParamHolderPtr>();
auto old_quant_params = input_quant_params_holder->input_quant_params();
auto new_input_quant_holder = std::make_shared<lite::QuantParamHolder>();
if (input_quant_params == nullptr) {
MS_LOG(ERROR) << "quant params holder is null";
return RET_ERROR;
}
auto input_quant_params_holder = input_quant_params->cast<lite::QuantParamHolderPtr>();
auto old_quant_params = input_quant_params_holder->get_input_quant_params();
auto new_input_quant_holder =
std::make_shared<lite::QuantParamHolder>(perm.size(), input_quant_params_holder->get_output_quant_params().size());
// add inputs as perm order
size_t new_idx = 0;
for (size_t idx : perm) {
if (idx > cnode->inputs().size() - 1) {
MS_LOG(ERROR) << "Idx " << idx << " is larger than inputs size: " << cnode->inputs().size() - 1;
@@ -69,7 +73,12 @@ lite::STATUS ReorderCnodeInputs(CNode *cnode, const std::vector<size_t> &perm) {
}
new_inputs.emplace_back(cnode->input(idx));
auto quant_param = idx < old_quant_params.size() ? old_quant_params.at(idx) : std::vector<schema::QuantParamT>();
new_input_quant_holder->AddInputQuantParam(quant_param);
new_input_quant_holder->set_input_quant_param(new_idx, quant_param);
new_idx++;
}

for (size_t i = 0; i < input_quant_params_holder->get_output_quant_params().size(); i++) {
new_input_quant_holder->set_output_quant_param(i, input_quant_params_holder->get_output_quant_params().at(i));
}
cnode->set_inputs(new_inputs);
primitive->set_attr("quant_params", new_input_quant_holder);


+ 1
- 3
mindspore/lite/tools/optimizer/graph/transpose_strategy.cc View File

@@ -93,9 +93,7 @@ AnfNodePtr TransposeStrategy::TransposeDependOnShape(const FuncGraphPtr &func_gr
std::string trans_name =
before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post";
auto trans_insert_node = GenTransposeNode(func_graph, trans_input_node, perm, trans_name);
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>();
quant_params_holder->AddInputQuantParam(std::vector<schema::QuantParamT>(1));
quant_params_holder->AddOutputQuantParam(std::vector<schema::QuantParamT>(1));
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(1, 1);
auto trans_insert_prim = GetValueNode<PrimitivePtr>(trans_insert_node->input(0));
trans_insert_prim->AddAttr("quant_params", quant_params_holder);
return trans_insert_node;


Loading…
Cancel
Save