Browse Source

[MS][LITE] use set quant params instead of add quant params

pull/15596/head
cjh9368 4 years ago
parent
commit
8d60d097e0
19 changed files with 305 additions and 203 deletions
  1. +11
    -5
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_infer.c
  2. +1
    -1
      mindspore/lite/test/models_tflite_awaretraining.cfg
  3. +63
    -72
      mindspore/lite/tools/anf_exporter/anf_exporter.cc
  4. +3
    -3
      mindspore/lite/tools/anf_exporter/anf_exporter.h
  5. +26
    -10
      mindspore/lite/tools/common/graph_util.cc
  6. +15
    -0
      mindspore/lite/tools/common/node_util.cc
  7. +3
    -0
      mindspore/lite/tools/common/node_util.h
  8. +15
    -4
      mindspore/lite/tools/converter/converter_context.h
  9. +58
    -33
      mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc
  10. +14
    -37
      mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc
  11. +34
    -0
      mindspore/lite/tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.cc
  12. +27
    -0
      mindspore/lite/tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h
  13. +3
    -1
      mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_helper.cc
  14. +19
    -0
      mindspore/lite/tools/converter/quantizer/quantize_util.cc
  15. +4
    -0
      mindspore/lite/tools/converter/quantizer/quantize_util.h
  16. +4
    -0
      mindspore/lite/tools/optimizer/common/gllo_utils.cc
  17. +3
    -33
      mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc
  18. +0
    -4
      mindspore/lite/tools/optimizer/graph/transpose_strategy.cc
  19. +2
    -0
      mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc

+ 11
- 5
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_infer.c View File

@@ -17,6 +17,16 @@
#include "nnacl/infer/arithmetic_infer.h"
#include "nnacl/infer/infer_register.h"

void SetOutputDtypeFormat(const TensorC *input0, const TensorC *input1, TensorC *output) {
output->format_ = input0->format_;
output->data_type_ = input0->data_type_;
// when input0 is const, it is quanted before insert quant trans op, so use input1 data type instead
if (input0->data_ != NULL ||
((input0->data_type_ == kNumberTypeInt8) && (input1->data_type_ == kNumberTypeFloat32))) {
output->data_type_ = input1->data_type_;
}
}

void UpdateInputShape(const int input_shape0_size, const int input_shape1_size, int *ndim, const int *input_shape0,
const int *input_shape1, int *in_shape0, int *in_shape1) {
if (input_shape0_size < input_shape1_size) {
@@ -71,11 +81,7 @@ int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso
size_t input_shape0_size = input0->shape_size_;
const int *input_shape1 = input1->shape_;
size_t input_shape1_size = input1->shape_size_;
output->format_ = input0->format_;
output->data_type_ = input0->data_type_;
if ((input0->data_type_ == kNumberTypeInt8) && (input1->data_type_ == kNumberTypeFloat32)) {
output->data_type_ = input1->data_type_;
}
SetOutputDtypeFormat(input0, input1, output);

if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;


+ 1
- 1
mindspore/lite/test/models_tflite_awaretraining.cfg View File

@@ -1,4 +1,4 @@
video_infer.tflite
video_infer2.tflite
mobilenet_v1_0.25_128_quant.tflite
mobilenet_v1_0.25_160_quant.tflite
mobilenet_v1_0.25_192_quant.tflite


+ 63
- 72
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -38,6 +38,8 @@
#include "ops/op_utils.h"
#include "tools/common/graph_util.h"
#include "src/ops/ops_utils.h"
#include "tools/common/node_util.h"
#include "tools/converter/converter_context.h"

using mindspore::ops::PrimitiveC;

@@ -81,9 +83,9 @@ std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
}
} // namespace

int AnfExporter::SetQuantOutputTensorType(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<mindspore::Primitive> &primitive,
const std::unique_ptr<schema::CNodeT> &dst_node) {
int AnfExporter::SetPostTrainOutputTensorType(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<mindspore::Primitive> &primitive,
const std::unique_ptr<schema::CNodeT> &dst_node) {
auto first_output_index = dst_node->outputIndex[0];
auto first_tensor_output = meta_graph->allTensors[first_output_index].get();
if (dst_node->quantType == schema::QuantType_PostTraining) {
@@ -116,82 +118,63 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
QuantParamsVector output_quant_params;
dst_node->quantType = schema::QuantType_QUANT_NONE;
auto quant_tensor_info_ptr = primitive->GetAttr("quant_params");
if (quant_tensor_info_ptr != nullptr) {
auto quant_param_holder = quant_tensor_info_ptr->cast<QuantParamHolderPtr>();
if (quant_param_holder == nullptr) {
MS_LOG(ERROR) << "quant param is invalid.";
return RET_ERROR;
}
input_quant_params = quant_param_holder->get_input_quant_params();
output_quant_params = quant_param_holder->get_output_quant_params();
dst_node->quantType = quant_param_holder->quant_type();
QuantParamHolderPtr quant_param_holder = nullptr;
if (quant_tensor_info_ptr == nullptr ||
(quant_param_holder = quant_tensor_info_ptr->cast<QuantParamHolderPtr>()) == nullptr) {
quant_param_holder = std::make_shared<QuantParamHolder>(dst_node->inputIndex.size(), dst_node->outputIndex.size());
}
// add quant param
if (!input_quant_params.empty()) {
for (size_t i = 0; i < input_quant_params.size(); i++) {
if (i >= dst_node->inputIndex.size()) {
MS_LOG(INFO) << "node: " << dst_node->name << " input has " << input_quant_params.size()
<< " quant_params; but only " << dst_node->inputIndex.size() << " input";
break;
}
auto activate_index = dst_node->inputIndex[i];
auto tensor_input = meta_graph->allTensors[activate_index].get();
if (tensor_input->quantParams.empty()) {
for (auto input_quant_param : input_quant_params[i]) {
auto input_quant_param_ptr = std::make_unique<schema::QuantParamT>(input_quant_param);
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
<< " zp: " << input_quant_param_ptr->zeroPoint;
input_quant_param_ptr->dstDtype = tensor_input->dataType;
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
}

input_quant_params = quant_param_holder->get_input_quant_params();
output_quant_params = quant_param_holder->get_output_quant_params();
dst_node->quantType = quant_param_holder->quant_type();

// convert input quant param
for (size_t i = 0; i < dst_node->inputIndex.size(); i++) {
if (i >= input_quant_params.size()) {
MS_LOG(INFO) << "node: " << dst_node->name << " has " << dst_node->inputIndex.size() << ", but only has"
<< input_quant_params.size() << " quant params";
break;
}
auto activate_index = dst_node->inputIndex[i];
auto tensor_input = meta_graph->allTensors[activate_index].get();
if (tensor_input->quantParams.empty()) {
for (auto input_quant_param : input_quant_params[i]) {
auto input_quant_param_ptr = std::make_unique<schema::QuantParamT>(input_quant_param);
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
<< " zp: " << input_quant_param_ptr->zeroPoint;
input_quant_param_ptr->dstDtype = tensor_input->dataType;
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
}
if (!tensor_input->quantParams.empty()) {
int bit_num = tensor_input->quantParams.at(0)->numBits;
if (bit_num != 8 && bit_num != 16) {
auto status = DoBitPack(bit_num, tensor_input);
if (status != RET_OK) {
MS_LOG(ERROR) << "do bit pack failed. " << status;
return RET_ERROR;
}
}

if (!tensor_input->quantParams.empty()) {
int bit_num = tensor_input->quantParams.at(0)->numBits;
if (bit_num != 8 && bit_num != 16) {
auto status = DoBitPack(bit_num, tensor_input);
if (status != RET_OK) {
MS_LOG(ERROR) << "do bit pack failed. " << status;
return RET_ERROR;
}
}
}
} else {
MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty";
}

// output
if (output_quant_params.empty()) {
if (primitive->name() != mindspore::ops::kNameQuantDTypeCast) {
MS_LOG(DEBUG) << "node: " << dst_node->name << " output quant params is empty";
}
} else {
if (dst_node->outputIndex.size() != output_quant_params.size()) {
MS_LOG(INFO) << "node: " << dst_node->name << " output has " << output_quant_params.size()
<< " quant_params; but only " << dst_node->outputIndex.size() << " output";
return RET_ERROR;
}
int output_idx = 0;
for (const auto &output_quant_param : output_quant_params) {
auto output_tensor = meta_graph->allTensors[dst_node->outputIndex[output_idx]].get();
output_idx++;
for (const auto &channel_quant_param : output_quant_param) {
if (output_tensor->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) {
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
std::make_unique<schema::QuantParamT>(channel_quant_param);
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
<< " zp: " << output_quant_param_ptr->zeroPoint;
output_quant_param_ptr->dstDtype = output_tensor->dataType;
output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr));
}
int output_idx = 0;
for (const auto &output_quant_param : output_quant_params) {
auto output_tensor = meta_graph->allTensors[dst_node->outputIndex[output_idx]].get();
output_idx++;
for (const auto &channel_quant_param : output_quant_param) {
if (output_tensor->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) {
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
std::make_unique<schema::QuantParamT>(channel_quant_param);
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
<< " zp: " << output_quant_param_ptr->zeroPoint;
output_quant_param_ptr->dstDtype = output_tensor->dataType;
output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr));
}
}
}

auto status = SetQuantOutputTensorType(meta_graph, primitive, dst_node);
if (status != RET_OK) {
MS_LOG(ERROR) << "set quant output tensor data type failed.";
return RET_ERROR;
}
return RET_OK;
}

@@ -224,6 +207,7 @@ int AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &m
tensor->format = schema::Format_NHWC;
if (!IsContain(subgraph->inputIndices, input)) {
if (subgraph_index == kMainGraphIndex) {
TensorDataType::GetInstance()->UpdateGraphInputDType(meta_graphT->inputIndex.size(), tensor->dataType);
meta_graphT->inputIndex.push_back(input);
}
subgraph->inputIndices.push_back(input);
@@ -262,6 +246,8 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgrap
}
for (unsigned int &i : return_node->inputIndex) {
if (subgraph_index == kMainGraphIndex) {
auto &tensor = meta_graphT->allTensors.at(i);
TensorDataType::GetInstance()->UpdateGraphOutputDType(meta_graphT->outputIndex.size(), tensor->dataType);
meta_graphT->outputIndex.push_back(i);
}
meta_graphT->subGraph.at(subgraph_index)->outputIndices.push_back(i);
@@ -354,6 +340,13 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
MS_LOG(ERROR) << "ConvertQuantParam failed";
break;
}

auto status = SetPostTrainOutputTensorType(meta_graphT, prim, node);
if (status != RET_OK) {
MS_LOG(ERROR) << "set quant output tensor data type failed.";
break;
}

meta_graphT->nodes.push_back(std::move(node));
meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx_++);
}
@@ -615,7 +608,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
return;
}
auto elements = tuple->elements();
for (size_t i = 0; i < elements.size(); i++) {
for (size_t i = 0; i < lite::GetCNodeOutputsSize(cnode, train_flag_); i++) {
auto msTensor = new (std::nothrow) schema::TensorT();
if (msTensor == nullptr) {
MS_LOG(ERROR) << "new msTensor failed";
@@ -627,8 +620,6 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
std::string name = cnode_name + "_o:" + std::to_string(i);
node_id_map_[name] = meta_graphT->allTensors.size();
meta_graphT->allTensors.emplace_back(msTensor);
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimAdam))
break;
} else {
if (elements.size() == 1) {
node_id_map_[cnode_name] = meta_graphT->allTensors.size();


+ 3
- 3
mindspore/lite/tools/anf_exporter/anf_exporter.h View File

@@ -57,9 +57,9 @@ class AnfExporter {
int SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const size_t &subgraph_index);
int SetGraphoutputIndex(const CNodePtr &cnode, size_t subgraph_index,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *return_node);
static int SetQuantOutputTensorType(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<mindspore::Primitive> &primitive,
const std::unique_ptr<schema::CNodeT> &dst_node);
static int SetPostTrainOutputTensorType(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<mindspore::Primitive> &primitive,
const std::unique_ptr<schema::CNodeT> &dst_node);
static int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<mindspore::Primitive> &primitive,
const std::unique_ptr<schema::CNodeT> &dst_node);


+ 26
- 10
mindspore/lite/tools/common/graph_util.cc View File

@@ -495,13 +495,21 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
MS_ASSERT(prim != nullptr);
if (prim->src_t == TypeId::kNumberTypeUInt8) {
if (preTensor->dataType == TypeId::kNumberTypeUInt8) {
toAddTensor->quantParams.front()->zeroPoint -= 128;
} else {
preTensor->quantParams.front()->zeroPoint += 128;
}
} else if (prim->dst_t == TypeId::kNumberTypeUInt8) {
if (preTensor->dataType == TypeId::kNumberTypeInt8) {
toAddTensor->quantParams.front()->zeroPoint += 128;
} else {
preTensor->quantParams.front()->zeroPoint -= 128;
}
}
preTensor->dataType = prim->src_t;
toAddTensor->dataType = prim->dst_t;
if (prim->src_t == TypeId::kNumberTypeUInt8 && prim->dst_t == TypeId::kNumberTypeInt8) {
preTensor->quantParams.front()->zeroPoint += 128;
} else if (prim->src_t == TypeId::kNumberTypeInt8 && prim->dst_t == TypeId::kNumberTypeUInt8) {
toAddTensor->quantParams.front()->zeroPoint += 128;
}
}
graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
@@ -565,13 +573,21 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
MS_ASSERT(prim != nullptr);
if (prim->dst_t == TypeId::kNumberTypeUInt8) {
if (postTensor->dataType == TypeId::kNumberTypeUInt8) {
postTensor->quantParams.front()->zeroPoint -= 128;
} else {
toAddTensor->quantParams.front()->zeroPoint += 128;
}
} else if (prim->src_t == TypeId::kNumberTypeUInt8) {
if (postTensor->dataType == TypeId::kNumberTypeUInt8) {
toAddTensor->quantParams.front()->zeroPoint -= 128;
} else {
postTensor->quantParams.front()->zeroPoint += 128;
}
}
postTensor->dataType = prim->src_t;
toAddTensor->dataType = prim->dst_t;
if (prim->src_t == TypeId::kNumberTypeInt8 && prim->dst_t == TypeId::kNumberTypeUInt8) {
toAddTensor->quantParams.front()->zeroPoint += 128;
} else if (prim->src_t == TypeId::kNumberTypeUInt8 && prim->dst_t == TypeId::kNumberTypeInt8) {
postTensor->quantParams.front()->zeroPoint += 128;
}
}
graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1;


+ 15
- 0
mindspore/lite/tools/common/node_util.cc View File

@@ -446,5 +446,20 @@ STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
}
return RET_OK;
}

size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag) {
auto cnode = anf_node->cast<CNodePtr>();
if (train_flag &&
(opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimAdam))) {
return 1;
}
if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
return tuple->elements().size();
} else {
return 1;
}
}

} // namespace lite
} // namespace mindspore

+ 3
- 0
mindspore/lite/tools/common/node_util.h View File

@@ -26,6 +26,7 @@
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
#include "securec/include/securec.h"
#include "tools/optimizer/common/gllo_utils.h"

namespace mindspore {
namespace lite {
@@ -401,6 +402,8 @@ static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type)
}

STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat);

size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag = false);
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H

+ 15
- 4
mindspore/lite/tools/converter/converter_context.h View File

@@ -77,18 +77,29 @@ class TensorDataType {
static TensorDataType tensor_data_type;
return &tensor_data_type;
}
void UpdateTensorType(int32_t index, int32_t type) { tensor_data_type_map_[index] = type; }
int32_t GetTensorType(int32_t index) const {
if (tensor_data_type_map_.find(index) == tensor_data_type_map_.end()) {

void UpdateGraphInputDType(int32_t index, int32_t dtype) { graph_input_data_type_map_[index] = dtype; }
int32_t GetGraphInputDType(int32_t index) const {
if (graph_input_data_type_map_.find(index) == graph_input_data_type_map_.end()) {
return TypeId::kTypeUnknown;
}
return graph_input_data_type_map_.at(index);
}

void UpdateGraphOutputDType(int32_t index, int32_t dtype) { graph_output_data_type_map_[index] = dtype; }
int32_t GetGraphOutputDType(int32_t index) const {
if (graph_output_data_type_map_.find(index) == graph_output_data_type_map_.end()) {
return TypeId::kTypeUnknown;
}
return tensor_data_type_map_.at(index);
return graph_output_data_type_map_.at(index);
}

private:
TensorDataType() {}
virtual ~TensorDataType() = default;
std::map<int32_t, int32_t> tensor_data_type_map_;
std::map<int32_t, int32_t> graph_input_data_type_map_;
std::map<int32_t, int32_t> graph_output_data_type_map_;
};
} // namespace lite
} // namespace mindspore


+ 58
- 33
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc View File

@@ -23,6 +23,7 @@
#include "tools/converter/converter_context.h"
#include "src/common/common.h"
#include "src/common/utils.h"
#include "tools/converter/quantizer/quantize_util.h"

namespace mindspore {
namespace lite {
@@ -38,18 +39,17 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
return status;
}

status = DoNodeInoutDTypeTrans(graph);
status = DoModelOutputDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status;
MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status;
return status;
}

status = DoModelOutputDTypeTrans(graph);
status = DoNodeInoutDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status;
MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status;
return status;
}

return RET_OK;
}

@@ -61,15 +61,23 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
MS_LOG(ERROR) << "Invalid inputDataType: " << this->input_data_dtype;
return RET_ERROR;
}
for (auto graph_in_idx : graph_in_idxes) {
for (size_t i = 0; i < graph_in_idxes.size(); i++) {
size_t graph_in_idx = graph_in_idxes.at(i);
MS_ASSERT(graph_in_idx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graph_in_idx);
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
if (!quant::TensorQuantParamsInited(*tensor)) {
continue;
}

if (this->input_data_dtype == TypeId::kTypeUnknown) {
if (tensor->dataType != TensorDataType::GetInstance()->GetGraphInputDType(i)) {
MS_LOG(ERROR) << "Change graph input dtype is not allowed.";
return RET_ERROR;
}
continue;
}
int32_t tensor_data_type = this->input_data_dtype != TypeId::kTypeUnknown
? this->input_data_dtype
: TensorDataType::GetInstance()->GetTensorType(graph_in_idx);

int32_t tensor_data_type = this->input_data_dtype;
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto node_name = (*iter)->name;
for (size_t input_indexidx = 0; input_indexidx < (*iter)->inputIndex.size(); input_indexidx++) {
@@ -77,7 +85,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
STATUS status = RET_OK;

// insert dtype cast node between input tensor and input node
if (tensor_data_type != tensor->dataType && tensor_data_type != kTypeUnknown) {
if (tensor_data_type != tensor->dataType) {
iter =
InsertDTypeTransNode(graph, iter, kBefore, input_indexidx, tensor_data_type, tensor->dataType, &status);
}
@@ -101,15 +109,23 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
return RET_ERROR;
}
auto &graph_out_idxes = graph->outputIndex;
for (auto graph_out_idx : graph_out_idxes) {
for (size_t i = 0; i < graph_out_idxes.size(); i++) {
size_t graph_out_idx = graph_out_idxes.at(i);
MS_ASSERT(graph_out_idx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graph_out_idx);
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
if (!quant::TensorQuantParamsInited(*tensor)) {
continue;
}
int32_t tensor_data_type = this->output_data_dtype != TypeId::kTypeUnknown
? this->output_data_dtype
: TensorDataType::GetInstance()->GetTensorType(graph_out_idx);

if (this->output_data_dtype == TypeId::kTypeUnknown) {
if (tensor->dataType != TensorDataType::GetInstance()->GetGraphOutputDType(i)) {
MS_LOG(ERROR) << "Change graph output dtype is not allowed.";
return RET_ERROR;
}
continue;
}

int32_t tensor_data_type = this->output_data_dtype;
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto node_name = (*iter)->name;
MS_ASSERT(node != nullptr);
@@ -117,7 +133,7 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
if ((*iter)->outputIndex.at(outputIndexIdx) == graph_out_idx) {
// insert transNode
STATUS status = RET_OK;
if (tensor_data_type != tensor->dataType && tensor_data_type != kTypeUnknown) {
if (tensor_data_type != tensor->dataType) {
iter =
InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensor_data_type, &status);
}
@@ -136,27 +152,34 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
STATUS DTypeTransPass::InsetDTypeTransNodeForWrongDtypeQuantOp(schema::MetaGraphT *graph, NodeIter *iter) {
auto node_name = (**iter)->name;
auto status = RET_OK;
// insert fp32 to int8 before
// insert fp32/uint8 to int8 before
for (size_t i = 0; i < (**iter)->inputIndex.size(); i++) {
auto &pre_tensor = graph->allTensors.at((**iter)->inputIndex.at(i));
if (pre_tensor->dataType == kNumberTypeFloat32 && !pre_tensor->quantParams.empty() &&
pre_tensor->quantParams.front()->inited) {
*iter = InsertDTypeTransNode(graph, *iter, kBefore, i, kNumberTypeFloat32, kNumberTypeInt8, &status);
// insert quant cast op for tensor which should be int8
if ((pre_tensor->dataType == kNumberTypeFloat32 || pre_tensor->dataType == kNumberTypeUInt8) &&
quant::TensorQuantParamsInited(*pre_tensor)) {
if (!pre_tensor->data.empty()) {
MS_LOG(ERROR) << "tensor with float data should be quantized at tensor_quant_pass.";
return RET_ERROR;
}
*iter = InsertDTypeTransNode(graph, *iter, kBefore, i, pre_tensor->dataType, kNumberTypeInt8, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed";
MS_LOG(ERROR) << "Insert float32 or uint8 to int8 node after before " << node_name.c_str() << " failed";
return RET_ERROR;
}
}
}

// insert int8 to fp32 after
// insert int8 to fp32/uint8 after
for (size_t i = 0; i < (**iter)->outputIndex.size(); i++) {
auto &post_tensor = graph->allTensors.at((**iter)->outputIndex.at(i));
if (post_tensor->dataType == kNumberTypeFloat32 && !post_tensor->quantParams.empty() &&
post_tensor->quantParams.front()->inited) {
*iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, kNumberTypeFloat32, &status);
// insert quant cast op for tensor which should be int8
// e.g: reshape's shape tensor don't need insert quant op so its quant param isn't inited
if ((post_tensor->dataType == kNumberTypeFloat32 || post_tensor->dataType == kNumberTypeUInt8) &&
quant::TensorQuantParamsInited(*post_tensor)) {
*iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, post_tensor->dataType, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << node_name.c_str() << " failed";
MS_LOG(ERROR) << "Insert int8 to float32 or uint8 node after " << node_name.c_str() << " failed";
return RET_ERROR;
}
}
@@ -170,8 +193,7 @@ STATUS DTypeTransPass::InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraph
// insert int8 to fp32 before
for (size_t i = 0; i < (**iter)->inputIndex.size(); i++) {
auto &pre_tensor = graph->allTensors.at((**iter)->inputIndex.at(i));
if (pre_tensor->dataType == kNumberTypeInt8 && !pre_tensor->quantParams.empty() &&
pre_tensor->quantParams.front()->inited) {
if (pre_tensor->dataType == kNumberTypeInt8 && quant::TensorQuantParamsInited(*pre_tensor)) {
*iter = InsertDTypeTransNode(graph, *iter, kBefore, i, kNumberTypeInt8, kNumberTypeFloat32, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << node_name.c_str() << " failed";
@@ -183,8 +205,7 @@ STATUS DTypeTransPass::InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraph
// insert fp32 to int8 after
for (size_t i = 0; i < (**iter)->outputIndex.size(); i++) {
auto &post_tensor = graph->allTensors.at((**iter)->outputIndex.at(i));
if (post_tensor->dataType == kNumberTypeInt8 && !post_tensor->quantParams.empty() &&
post_tensor->quantParams.front()->inited) {
if (post_tensor->dataType == kNumberTypeInt8 && quant::TensorQuantParamsInited(*post_tensor)) {
*iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, kNumberTypeFloat32, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed";
@@ -200,8 +221,8 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto node_name = (*iter)->name;
if ((*iter)->inputIndex.empty()) {
MS_LOG(ERROR) << "Op " << node_name.c_str() << " should have " << kMinInputNum << " input tensor at least";
return RET_ERROR;
MS_LOG(WARNING) << "Op " << node_name.c_str() << " should have " << kMinInputNum << " input tensor at least";
continue;
}

if ((*iter)->primitive->value.type == schema::PrimitiveType_QuantDTypeCast ||
@@ -270,6 +291,10 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte
trans_node->name = "uint8toint8_" + tile_name + std::to_string(id_++);
} else if (input_data_type == TypeId::kNumberTypeInt8 && output_data_type == TypeId::kNumberTypeUInt8) {
trans_node->name = "int8touint8_" + tile_name + std::to_string(id_++);
} else if (input_data_type == TypeId::kNumberTypeUInt8 && output_data_type == TypeId::kNumberTypeFloat32) {
trans_node->name = "uint8toft32_" + tile_name + std::to_string(id_++);
} else if (input_data_type == TypeId::kNumberTypeFloat32 && output_data_type == TypeId::kNumberTypeUInt8) {
trans_node->name = "ft32touint8_" + tile_name + std::to_string(id_++);
}
trans_node->primitive->value.value = quant_dtype_cast_param;
int insert_num = 0;


+ 14
- 37
mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc View File

@@ -25,29 +25,15 @@

namespace mindspore::lite {
namespace {
STATUS PreHandleQuantDtypeCast(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
for (auto &node : graph->nodes) {
if (node == nullptr || node->primitive == nullptr) {
MS_LOG(ERROR) << " node or node->primitive is nullptr";
return RET_ERROR;
}
if (node->primitive->value.type == PrimitiveType_QuantDTypeCast) {
auto attr = node->primitive->value.AsQuantDTypeCast();
auto &inputTensor = graph->allTensors.at(node->inputIndex.front());
inputTensor->dataType = attr->src_t;
auto &outputTensor = graph->allTensors.at(node->outputIndex.front());
outputTensor->dataType = attr->dst_t;

if (attr->src_t == TypeId::kNumberTypeUInt8) {
attr->src_t = TypeId::kNumberTypeInt8;
}
if (attr->dst_t == TypeId::kNumberTypeUInt8) {
attr->dst_t = TypeId::kNumberTypeInt8;
}
}
bool TensorNeedQuant(const std::unique_ptr<TensorT> &tensor) {
if (!quant::TensorQuantParamsInited(*tensor)) {
return false;
}
return RET_OK;
if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat &&
tensor->dataType != TypeId::kNumberTypeUInt8 && tensor->dataType != TypeId::kTypeUnknown) {
return false;
}
return !tensor->data.empty();
}

STATUS ComputeDataToInt8(const std::unique_ptr<TensorT> &tensor, int32_t index) {
@@ -73,7 +59,6 @@ STATUS ComputeDataToInt8(const std::unique_ptr<TensorT> &tensor, int32_t index)
weightQauntParam->zeroPoint -= 128;
tensor->quantParams.clear();
tensor->quantParams.emplace_back(weightQauntParam.release());
TensorDataType::GetInstance()->UpdateTensorType(index, TypeId::kNumberTypeUInt8);
}
tensor->dataType = TypeId::kNumberTypeInt8;
if (tensor->data.empty()) {
@@ -174,23 +159,15 @@ STATUS ComputeQuantTensorPerChannel(TensorT *tensor, const int &tensor_index, co

STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
auto status = PreHandleQuantDtypeCast(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "pre adjust failed.";
return status;
}
int32_t index = 0;
auto status = RET_OK;
for (auto &tensor : graph->allTensors) {
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
index++;
continue;
}
if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat &&
tensor->dataType != TypeId::kNumberTypeUInt8 && tensor->dataType != TypeId::kTypeUnknown) {
if (!TensorNeedQuant(tensor)) {
index++;
continue;
}
if (tensor->quantParams.size() != 1) { // perchannel

if (tensor->quantParams.size() > 1) { // perchannel
status = ComputeQuantTensorPerChannel(tensor.get(), index, *graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "compute tensor to int8 prechannel failed.";
@@ -201,8 +178,8 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
}
// perlayer
auto &quantParam = tensor->quantParams.front();
if (quantParam->dstDtype == TypeId::kNumberTypeUInt8 || quantParam->dstDtype == TypeId::kNumberTypeFloat32 ||
quantParam->dstDtype == TypeId::kNumberTypeFloat) {
if (quantParam->dstDtype == TypeId::kNumberTypeInt8 || quantParam->dstDtype == TypeId::kNumberTypeUInt8 ||
quantParam->dstDtype == TypeId::kNumberTypeFloat32 || quantParam->dstDtype == TypeId::kNumberTypeFloat) {
status = ComputeDataToInt8(tensor, index);
} else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) {
// quant bias data


+ 34
- 0
mindspore/lite/tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.cc View File

@@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h"
#include "mindspore/core/utils/log_adapter.h"
#include "mindspore/core/ir/dtype/type_id.h"
namespace mindspore::lite {
static constexpr size_t kInputIndex = 0;
static constexpr size_t kWeightIndex = 1;

STATUS QuantDtypeCastQuantParamPropogator::PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) {
auto &input_tensor = graph->allTensors.at(node.inputIndex.at(0));
if (!input_tensor->quantParams.empty() && input_tensor->quantParams.front()->inited) {
input_tensor->quantParams.front()->dstDtype = input_tensor->dataType;
}
auto &output_tensor = graph->allTensors.at(node.outputIndex.at(0));
if (!output_tensor->quantParams.empty() && output_tensor->quantParams.front()->inited) {
output_tensor->quantParams.front()->dstDtype = output_tensor->dataType;
}
return RET_OK;
}
} // namespace mindspore::lite

+ 27
- 0
mindspore/lite/tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h View File

@@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H

#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class QuantDtypeCastQuantParamPropogator : public QuantParamPropogator {
public:
STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
};
} // namespace mindspore::lite

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H

+ 3
- 1
mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_helper.cc View File

@@ -25,6 +25,7 @@
#include "tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h"
#include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h"
#include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h"
#include "tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h"

namespace mindspore::lite {
void QuantNodeBase::UpdateQuantParamsNum(const schema::MetaGraphT &graph, const schema::CNodeT &node) {
@@ -100,6 +101,7 @@ QuantNodeHelper *QuantHelperRegister::GetQuantHelper(schema::PrimitiveType op_ty
QuantHelperRegister::QuantHelperRegister() {
auto base_propogator = std::make_shared<QuantParamPropogator>();
auto base_determiner = std::make_shared<QuantTypeDeterminer>();
auto quant_dtype_cast_propogator = std::make_shared<QuantDtypeCastQuantParamPropogator>();
auto bias_add_propogator = std::make_shared<BiasAddQuantParamPropogator>();
auto carry_data_propogator = std::make_shared<CarryDataQuantParamPropogator>();
auto carry_data_determiner = std::make_shared<CarryDataQuantTypeDeterminer>();
@@ -127,7 +129,7 @@ QuantHelperRegister::QuantHelperRegister() {
register_map_[schema::PrimitiveType_MatMul] = new QuantNodeHelper(conv_propogator, conv_determiner);

register_map_[schema::PrimitiveType_QuantDTypeCast] =
new QuantNodeHelper(base_propogator, default_quant_all_determiner);
new QuantNodeHelper(quant_dtype_cast_propogator, default_quant_all_determiner);

register_map_[schema::PrimitiveType_DetectionPostProcess] =
new QuantNodeHelper(base_propogator, only_need_inputs_determiner);


+ 19
- 0
mindspore/lite/tools/converter/quantizer/quantize_util.cc View File

@@ -248,6 +248,25 @@ QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) {
}
return quant_params_holder;
}
bool QuantParamEqual(const schema::QuantParamT &quant_param1, const schema::QuantParamT &quant_param2) {
return quant_param1.inited == quant_param2.inited && quant_param1.scale == quant_param2.scale &&
quant_param1.zeroPoint == quant_param2.zeroPoint && quant_param1.min == quant_param2.min &&
quant_param1.max == quant_param2.max && quant_param1.narrowRange == quant_param2.narrowRange &&
quant_param1.numBits == quant_param2.numBits && quant_param1.inited == quant_param2.inited &&
quant_param1.varCorr == quant_param2.varCorr && quant_param1.meanCorr == quant_param2.meanCorr;
}
bool TensorQuantParamsInited(const schema::TensorT &tensor) {
if (tensor.quantParams.empty()) {
return false;
}

for (auto &quant_param : tensor.quantParams) {
if (!quant_param->inited) {
return false;
}
}
return true;
}

STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max,
int quant_min, int num_bits) {


+ 4
- 0
mindspore/lite/tools/converter/quantizer/quantize_util.h View File

@@ -175,6 +175,10 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan
}();
}

bool QuantParamEqual(const schema::QuantParamT &quant_param1, const schema::QuantParamT &quant_param2);

bool TensorQuantParamsInited(const schema::TensorT &tensor);

template <typename T>
STATUS DoPerChannelQuant(const tensor::TensorPtr &weight, const QuantType &quant_type,
std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min,


+ 4
- 0
mindspore/lite/tools/optimizer/common/gllo_utils.cc View File

@@ -28,6 +28,7 @@
#include "tools/common/tensor_util.h"
#include "frontend/operator/ops.h"
#include "backend/optimizer/common/helper.h"
#include "tools/converter/quant_param_holder.h"

using float16 = Eigen::half;

@@ -1416,6 +1417,9 @@ CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu
auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm_node});
MS_ASSERT(cnode != nullptr);
cnode->set_fullname_with_scope(cnode_name);
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(2, 1);
auto trans_insert_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
trans_insert_prim->AddAttr("quant_params", quant_params_holder);
return cnode;
}



+ 3
- 33
mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc View File

@@ -21,20 +21,11 @@
#include "tools/converter/quant_param_holder.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "src/common/log_adapter.h"
#include "tools/common/node_util.h"

namespace mindspore {
namespace opt {
namespace {
size_t GetCNodeOutputsSize(std::shared_ptr<AnfNode> anf_node) {
auto cnode = anf_node->cast<CNodePtr>();
if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
return tuple->elements().size();
} else {
return 1;
}
}

int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) {
auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>();
std::vector<schema::QuantParamT> quants;
@@ -112,27 +103,6 @@ int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t
return lite::RET_OK;
}

void CheckQuantParams(const PrimitivePtr &prim) {
auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>();
auto input_quant_params = quant_param_holder->get_input_quant_params();
bool is_quant = false;
for (size_t i = 0; i < input_quant_params.size(); ++i) {
if (!input_quant_params.at(i).empty() && input_quant_params.at(i).at(0).inited) {
is_quant = true;
break;
}
}
auto output_quant_params = quant_param_holder->get_output_quant_params();
for (size_t i = 0; i < output_quant_params.size(); ++i) {
if (!output_quant_params.at(i).empty() && output_quant_params.at(i).at(0).inited) {
is_quant = true;
}
}
if (!is_quant) {
prim->EraseAttr("quant_params");
}
}

int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs) {
auto narrow_range = prim->GetAttr("narrow_range");
bool narrow_range_param = false;
@@ -170,7 +140,6 @@ int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &i
MS_LOG(ERROR) << "compute output quant param failed.";
return status;
}
CheckQuantParams(prim);
return lite::RET_OK;
}
} // namespace
@@ -236,7 +205,8 @@ int MindirAdjustPass::ComputeQuantParams(std::shared_ptr<AnfNode> anf_node) {
auto inputs = cnode->inputs();
inputs.erase(inputs.begin());

auto quant_param_holder = std::make_shared<lite::QuantParamHolder>(inputs.size(), GetCNodeOutputsSize(anf_node));
auto quant_param_holder =
std::make_shared<lite::QuantParamHolder>(inputs.size(), lite::GetCNodeOutputsSize(anf_node, train_flag_));
primitive->AddAttr("quant_params", quant_param_holder);

if (ConvertQuantParam(primitive, inputs) != lite::RET_OK) {


+ 0
- 4
mindspore/lite/tools/optimizer/graph/transpose_strategy.cc View File

@@ -25,7 +25,6 @@
#include "ops/fusion/slice_fusion.h"
#include "ops/op_utils.h"
#include "ops/strided_slice.h"
#include "tools/converter/quant_param_holder.h"

namespace mindspore {
namespace opt {
@@ -93,9 +92,6 @@ AnfNodePtr TransposeStrategy::TransposeDependOnShape(const FuncGraphPtr &func_gr
std::string trans_name =
before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post";
auto trans_insert_node = GenTransposeNode(func_graph, trans_input_node, perm, trans_name);
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(1, 1);
auto trans_insert_prim = GetValueNode<PrimitivePtr>(trans_insert_node->input(0));
trans_insert_prim->AddAttr("quant_params", quant_params_holder);
return trans_insert_node;
}



+ 2
- 0
mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc View File

@@ -20,6 +20,7 @@
#include "ops/transpose.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/quant_param_holder.h"

using mindspore::lite::converter::FmkType_CAFFE;
using mindspore::lite::converter::FmkType_MS;
@@ -92,6 +93,7 @@ lite::STATUS WeightFormatTransformPass::TransposeInsertForWeightSharing(const Fu
}
auto perm_node = BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_perm");
auto prim = std::make_shared<ops::Transpose>();
prim->AddAttr("quant_params", std::make_shared<lite::QuantParamHolder>(1, 1));
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
if (!weight_node->has_default()) {
MS_LOG(DEBUG) << "Weight parameter should has default parameter.";


Loading…
Cancel
Save