Browse Source

!7196 tflite support perchannel

Merge pull request !7196 from cjh9368/tflite_support_perchannel
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
93bf293501
2 changed files with 35 additions and 24 deletions
  1. +13
    -4
      mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc
  2. +22
    -20
      mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc

+ 13
- 4
mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc View File

@@ -22,6 +22,7 @@
#include "tools/anf_importer/import_from_meta_graphT.h"
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
#include "tools/common/tensor_util.h"

namespace mindspore::lite {
int AnfImporterFromMetaGraphT::ConverterConstTensor() {
@@ -75,8 +76,12 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s
if (cNode->quantType != schema::QuantType_PostTraining && cNode->quantType != schema::QuantType_WeightQuant) {
primitiveCValue->SetQuantType(cNode->quantType);
for (int index : cNode->inputIndex) {
if (meta_graph_->allTensors[index]->quantParams.size() > 0) {
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])};
if (!meta_graph_->allTensors[index]->quantParams.empty()) {
std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size());
std::transform(
meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(),
quant_params.begin(),
[](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; });
primitiveCValue->AddInputQuantParam(quant_params);
} else {
std::vector<schema::QuantParamT> empty_quant_params;
@@ -84,8 +89,12 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s
}
}
for (int index : cNode->outputIndex) {
if (meta_graph_->allTensors[index]->quantParams.size() > 0) {
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])};
if (!meta_graph_->allTensors[index]->quantParams.empty()) {
std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size());
std::transform(
meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(),
quant_params.begin(),
[](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; });
primitiveCValue->AddOutputQuantParam(quant_params);
}
}


+ 22
- 20
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc View File

@@ -64,31 +64,33 @@ STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr<

void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::TensorT> &tflite_tensor,
schema::TensorT *tensor) {
std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<QuantParamT>();
if (!tflite_tensor->quantization->scale.empty()) {
quant_param->scale = tflite_tensor->quantization->scale[0];
}
tensor->quantParams.clear();
for (size_t i = 0; i < tflite_tensor->quantization->scale.size(); i++) {
std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<QuantParamT>();
if (!tflite_tensor->quantization->scale.empty()) {
quant_param->scale = tflite_tensor->quantization->scale[i];
}

if (!tflite_tensor->quantization->zero_point.empty()) {
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0];
}
if (!tflite_tensor->quantization->zero_point.empty()) {
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[i];
}

// change quant param min to 0 to fit ms-lite ops
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) {
quant_param->zeroPoint = quant_param->zeroPoint - 128;
tensor->dataType = TypeId::kNumberTypeInt8;
}
// change quant param min to 0 to fit ms-lite ops
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) {
quant_param->zeroPoint = quant_param->zeroPoint - 128;
tensor->dataType = TypeId::kNumberTypeInt8;
}

if (!tflite_tensor->quantization->min.empty()) {
quant_param->min = tflite_tensor->quantization->min[0];
}
if (!tflite_tensor->quantization->min.empty()) {
quant_param->min = tflite_tensor->quantization->min[i];
}

if (!tflite_tensor->quantization->max.empty()) {
quant_param->max = tflite_tensor->quantization->max[0];
if (!tflite_tensor->quantization->max.empty()) {
quant_param->max = tflite_tensor->quantization->max[i];
}
quant_param->inited = true;
tensor->quantParams.emplace_back(std::move(quant_param));
}
quant_param->inited = true;
tensor->quantParams.clear();
tensor->quantParams.emplace_back(std::move(quant_param));
}

STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflite_model,


Loading…
Cancel
Save