Browse Source

add quant aware compile success

tags/v0.7.0-beta
cjh9368 5 years ago
parent
commit
18c6ac9988
32 changed files with 2211 additions and 357 deletions
  1. +1
    -1
      mindspore/lite/src/common/anf_exporter/anf_exporter.cc
  2. +21
    -0
      mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc
  3. +21
    -5
      mindspore/lite/tools/common/graph_util.cc
  4. +1
    -1
      mindspore/lite/tools/common/graph_util.h
  5. +27
    -5
      mindspore/lite/tools/common/tensor_util.cc
  6. +5
    -0
      mindspore/lite/tools/common/tensor_util.h
  7. +1
    -0
      mindspore/lite/tools/converter/CMakeLists.txt
  8. +15
    -12
      mindspore/lite/tools/converter/converter.cc
  9. +9
    -3
      mindspore/lite/tools/converter/converter_flags.cc
  10. +2
    -1
      mindspore/lite/tools/converter/converter_flags.h
  11. +104
    -13
      mindspore/lite/tools/converter/graphdef_transform.cc
  12. +4
    -2
      mindspore/lite/tools/converter/graphdef_transform.h
  13. +1
    -1
      mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h
  14. +1
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt
  15. +235
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc
  16. +81
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h
  17. +38
    -43
      mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc
  18. +2
    -1
      mindspore/lite/tools/converter/model_parser.h
  19. +3
    -2
      mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc
  20. +2
    -1
      mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h
  21. +2
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h
  22. +81
    -56
      mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc
  23. +11
    -8
      mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h
  24. +2
    -0
      mindspore/lite/tools/converter/quantizer/CMakeLists.txt
  25. +594
    -0
      mindspore/lite/tools/converter/quantizer/aware_quantizer.cc
  26. +65
    -0
      mindspore/lite/tools/converter/quantizer/aware_quantizer.h
  27. +504
    -0
      mindspore/lite/tools/converter/quantizer/calc_quant_param.cc
  28. +69
    -0
      mindspore/lite/tools/converter/quantizer/calc_quant_param.h
  29. +230
    -169
      mindspore/lite/tools/converter/quantizer/quantize_util.cc
  30. +35
    -0
      mindspore/lite/tools/converter/quantizer/quantize_util.h
  31. +8
    -11
      mindspore/lite/tools/converter/quantizer/quantizer.cc
  32. +36
    -21
      mindspore/lite/tools/converter/quantizer/quantizer.h

+ 1
- 1
mindspore/lite/src/common/anf_exporter/anf_exporter.cc View File

@@ -188,7 +188,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {


// add quant param // add quant param
node->quantType = primitiveT_value->GetQuantType(); node->quantType = primitiveT_value->GetQuantType();
if (node->quantType == schema::QuantType_PostTraining) {
if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTrainning) {
MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; MS_LOG(INFO) << "node: " << node->name << " add QuantParam";
// activation // activation
auto input_quant_params = primitiveT_value->GetInputQuantParams(); auto input_quant_params = primitiveT_value->GetInputQuantParams();


+ 21
- 0
mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc View File

@@ -60,6 +60,17 @@ void AnfImporterFromMetaGraphT::ConverterConstTensor() {
param_value->set_tensor_addr(tensor_data); param_value->set_tensor_addr(tensor_data);
param_value->set_tensor_size(size); param_value->set_tensor_size(size);
} }
if (tensor->quantParams.size() > 0) {
std::unique_ptr<AnfQuantParam> quantParam = std::make_unique<AnfQuantParam>();
quantParam->scale = tensor->quantParams[0]->scale;
quantParam->zeroPoint = tensor->quantParams[0]->zeroPoint;
quantParam->min = tensor->quantParams[0]->min;
quantParam->max = tensor->quantParams[0]->max;
quantParam->narrowRange = tensor->quantParams[0]->narrowRange;
quantParam->numBits = tensor->quantParams[0]->numBits;
quantParam->inited = tensor->quantParams[0]->inited;
param_value->set_quant_param(quantParam);
}
parameter->set_default_param(param_value); parameter->set_default_param(param_value);
AddNode(i, parameter); AddNode(i, parameter);
} }
@@ -77,6 +88,16 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
flag = true; flag = true;
} }
auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release()); auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release());
// add quant parameter
if (cNode->quantType == schema::QuantType_AwareTrainning || cNode->quantType == schema::QuantType_PostTraining) {
primTValue->SetQuantType(cNode->quantType);
for (int index : cNode->inputIndex) {
primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0]));
}
for (int index : cNode->outputIndex) {
primTValue->AddOutputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0]));
}
}
cNode->primitive = nullptr; cNode->primitive = nullptr;
auto value_node = NewValueNode(primTValue); auto value_node = NewValueNode(primTValue);




+ 21
- 5
mindspore/lite/tools/common/graph_util.cc View File

@@ -28,7 +28,7 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpDefCopyer GetSimpleOpCopyer() { OpDefCopyer GetSimpleOpCopyer() {
return [](std::unique_ptr<CNodeT> &inCNode) -> std::unique_ptr<CNodeT> {
return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> {
std::unique_ptr<CNodeT> newCNode(new CNodeT); std::unique_ptr<CNodeT> newCNode(new CNodeT);


newCNode->name = inCNode->name; newCNode->name = inCNode->name;
@@ -421,9 +421,13 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
} }
preTensor->refCount = 0; preTensor->refCount = 0;
preTensor->data.clear(); preTensor->data.clear();
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
}
graphT->allTensors.emplace_back(std::move(toAddTensor)); graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1; size_t toAddTensorIdx = graphT->allTensors.size() - 1;
auto toAddNode = opDefCopyer(toAddNodeIn);
auto toAddNode = opDefCopyer(toAddNodeIn.get());
if (toAddNode == nullptr) { if (toAddNode == nullptr) {
MS_LOG(ERROR) << "copy toAddNodeIn failed"; MS_LOG(ERROR) << "copy toAddNodeIn failed";
*errorCode = RET_NULL_PTR; *errorCode = RET_NULL_PTR;
@@ -456,9 +460,13 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
// MS_LOG(ERROR)("Copy TensorT failed"); // MS_LOG(ERROR)("Copy TensorT failed");
return graphT->nodes.end(); return graphT->nodes.end();
} }
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
}
graphT->allTensors.emplace_back(std::move(toAddTensor)); graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1; size_t toAddTensorIdx = graphT->allTensors.size() - 1;
auto toAddNode = opDefCopyer(toAddNodeIn);
auto toAddNode = opDefCopyer(toAddNodeIn.get());
if (toAddNode == nullptr) { if (toAddNode == nullptr) {
// MS_LOG(ERROR)("copy toAddNodeIn failed"); // MS_LOG(ERROR)("copy toAddNodeIn failed");
*errorCode = RET_NULL_PTR; *errorCode = RET_NULL_PTR;
@@ -505,9 +513,13 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
*errorCode = RET_NULL_PTR; *errorCode = RET_NULL_PTR;
return graphT->nodes.end(); return graphT->nodes.end();
} }
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
}
graphT->allTensors.emplace_back(std::move(toAddTensor)); graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1; size_t toAddTensorIdx = graphT->allTensors.size() - 1;
auto toAddNode = opDefCopyer(toAddNodeIn);
auto toAddNode = opDefCopyer(toAddNodeIn.get());
if (toAddNode == nullptr) { if (toAddNode == nullptr) {
// MS_LOG(ERROR)("copy toAddNodeIn failed"); // MS_LOG(ERROR)("copy toAddNodeIn failed");
*errorCode = RET_NULL_PTR; *errorCode = RET_NULL_PTR;
@@ -540,9 +552,13 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
*errorCode = RET_NULL_PTR; *errorCode = RET_NULL_PTR;
return graphT->nodes.end(); return graphT->nodes.end();
} }
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
}
graphT->allTensors.emplace_back(std::move(toAddTensor)); graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1; size_t toAddTensorIdx = graphT->allTensors.size() - 1;
auto toAddNode = opDefCopyer(toAddNodeIn);
auto toAddNode = opDefCopyer(toAddNodeIn.get());
if (toAddNode == nullptr) { if (toAddNode == nullptr) {
// MS_LOG(ERROR)("copy toAddNodeIn failed"); // MS_LOG(ERROR)("copy toAddNodeIn failed");
*errorCode = RET_NULL_PTR; *errorCode = RET_NULL_PTR;


+ 1
- 1
mindspore/lite/tools/common/graph_util.h View File

@@ -36,7 +36,7 @@ enum InsertPlace { kBefore, kAfter };


using NodeIter = std::vector<std::unique_ptr<schema::CNodeT>>::iterator; using NodeIter = std::vector<std::unique_ptr<schema::CNodeT>>::iterator;


using OpDefCopyer = std::function<std::unique_ptr<schema::CNodeT>(std::unique_ptr<schema::CNodeT> &)>;
using OpDefCopyer = std::function<std::unique_ptr<schema::CNodeT> (schema::CNodeT *)>;


OpDefCopyer GetSimpleOpCopyer(); OpDefCopyer GetSimpleOpCopyer();




+ 27
- 5
mindspore/lite/tools/common/tensor_util.cc View File

@@ -19,8 +19,29 @@
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"


namespace mindspore {
namespace lite {
namespace mindspore::lite {
std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT> &tensor) {
MS_ASSERT(tensor != nullptr);
auto &quantParams = tensor->quantParams;
if (!quantParams.empty()) {
return std::move(CopyQuantParamT(quantParams.front()));
} else {
return nullptr;
}
}
std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schema::QuantParamT> &srcQuantParam) {
MS_ASSERT(srcQuantParam != nullptr);
std::unique_ptr<schema::QuantParamT> dstQuantParam = std::make_unique<schema::QuantParamT>();
dstQuantParam->inited = srcQuantParam->inited;
dstQuantParam->scale = srcQuantParam->scale;
dstQuantParam->zeroPoint = srcQuantParam->zeroPoint;
dstQuantParam->min = srcQuantParam->min;
dstQuantParam->max = srcQuantParam->max;
dstQuantParam->narrowRange = srcQuantParam->narrowRange;
dstQuantParam->numBits = srcQuantParam->numBits;
return std::move(dstQuantParam);
}

std::unique_ptr<QuantParamT> CopyQuantParamArrayT(const std::unique_ptr<QuantParamT> &srcQuantParamArray) { std::unique_ptr<QuantParamT> CopyQuantParamArrayT(const std::unique_ptr<QuantParamT> &srcQuantParamArray) {
MS_ASSERT(srcQuantParamArray != nullptr); MS_ASSERT(srcQuantParamArray != nullptr);
auto dstQuantParamArrayT = std::unique_ptr<QuantParamT>(new (std::nothrow) QuantParamT()); auto dstQuantParamArrayT = std::unique_ptr<QuantParamT>(new (std::nothrow) QuantParamT());
@@ -164,6 +185,9 @@ std::unique_ptr<TensorT> CopyTensorDefT(const std::unique_ptr<TensorT> &oldTenso
newTensor->refCount = oldTensor->refCount; newTensor->refCount = oldTensor->refCount;
newTensor->nodeType = oldTensor->nodeType; newTensor->nodeType = oldTensor->nodeType;
newTensor->data = oldTensor->data; newTensor->data = oldTensor->data;
if (!oldTensor->quantParams.empty()) {
newTensor->quantParams.emplace_back(std::move(GetTensorQuantParam(oldTensor)));
}
return std::move(newTensor); return std::move(newTensor);
} }


@@ -186,6 +210,4 @@ size_t GetShapeSize(const std::vector<int32_t> &shape) {
} }
return shapeSize; return shapeSize;
} }
} // namespace lite
} // namespace mindspore

} // namespace mindspore::lite

+ 5
- 0
mindspore/lite/tools/common/tensor_util.h View File

@@ -38,6 +38,9 @@ using schema::FusedBatchNormT;
using schema::Format_NCHW; using schema::Format_NCHW;
using schema::Format_NHWC; using schema::Format_NHWC;
using STATUS = int; using STATUS = int;

std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT> &tensor);

size_t GetElementSize(const TensorT &tensor); size_t GetElementSize(const TensorT &tensor);


size_t GetElementSize(const TypeId &dataType); size_t GetElementSize(const TypeId &dataType);
@@ -50,6 +53,8 @@ std::unique_ptr<TensorT> CopyTensorDefT(const std::unique_ptr<TensorT> &);


size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx); size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx);


std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schema::QuantParamT> &srcQuantParam);

std::unique_ptr<schema::QuantParamT> \ std::unique_ptr<schema::QuantParamT> \
CopyQuantParamArrayT(const std::unique_ptr<schema::QuantParamT> &srcQuantParamArray); CopyQuantParamArrayT(const std::unique_ptr<schema::QuantParamT> &srcQuantParamArray);




+ 1
- 0
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -101,6 +101,7 @@ target_link_libraries(converter_lite PRIVATE
node_mid node_mid
graph_pass_mid graph_pass_mid
fusion_mid fusion_mid
quantizer_mid
protobuf protobuf
quantizer_mid quantizer_mid
pthread pthread


+ 15
- 12
mindspore/lite/tools/converter/converter.cc View File

@@ -77,7 +77,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
MS_ASSERT(nullptr != modelParser); MS_ASSERT(nullptr != modelParser);
const std::string modelFile = flag->modelFile; const std::string modelFile = flag->modelFile;
const std::string weightFile = flag->weightFile; const std::string weightFile = flag->weightFile;
auto meta_graph = modelParser->Parse(modelFile, weightFile);
auto meta_graph = modelParser->Parse(modelFile, weightFile, flag->quantType);
if (meta_graph == nullptr) { if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Parse to metaGraph return nullptr"; MS_LOG(ERROR) << "Parse to metaGraph return nullptr";
return nullptr; return nullptr;
@@ -118,6 +118,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {


// transform // transform
transform->SetGraphDef(meta_graph); transform->SetGraphDef(meta_graph);
transform->CreateQuantizer(flag);
auto status = transform->Transform(*flag); auto status = transform->Transform(*flag);
if (status != 0) { if (status != 0) {
MS_LOG(ERROR) << "FBTransform model failed " << status; MS_LOG(ERROR) << "FBTransform model failed " << status;
@@ -125,6 +126,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
} }
return meta_graph; return meta_graph;
} }

void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) { void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) {
auto type = flags->quantType; auto type = flags->quantType;
switch (type) { switch (type) {
@@ -132,17 +134,18 @@ void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *
// mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean)); // mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean));
break; break;
} }
case mindspore::schema::QuantType_WeightQuant: {
MS_LOG(INFO) << "create WeightQuantizer!";
mQuantizer.reset(
new quant::WeightQuantizer(funcGraph, flags->quantSize, flags->convWeightQuantChannelThreshold, flags->bitNum));
break;
}
case mindspore::schema::QuantType_PostTraining: {
MS_LOG(INFO) << "create PostTrainningQuantizer!";
mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8));
break;
}
// case mindspore::schema::QuantType_WeightQuant: {
// MS_LOG(INFO) << "create WeightQuantizer!";
// mQuantizer.reset(
// new quant::WeightQuantizer(funcGraph, flags->quantSize, flags->convWeightQuantChannelThreshold,
// flags->bitNum));
// break;
// }
// case mindspore::schema::QuantType_PostTraining: {
// MS_LOG(INFO) << "create PostTrainningQuantizer!";
// mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8));
// break;
// }
case mindspore::schema::QuantType_QUANT_NONE: case mindspore::schema::QuantType_QUANT_NONE:
MS_LOG(INFO) << "Not do quantization for model!"; MS_LOG(INFO) << "Not do quantization for model!";
break; break;


+ 9
- 3
mindspore/lite/tools/converter/converter_flags.cc View File

@@ -14,8 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#include <string>
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"
#include <regex>
#include <string>
#include "ir/dtype/type_id.h"



namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@@ -70,9 +74,11 @@ int Flags::Init(int argc, const char **argv) {
return 1; return 1;
} }
if (this->inputInferenceTypeIn == "FLOAT") { if (this->inputInferenceTypeIn == "FLOAT") {
this->inputInferenceType = 0;
this->inputInferenceType = TypeId::kNumberTypeFloat;
} else if (this->inputInferenceTypeIn == "UINT8") { } else if (this->inputInferenceTypeIn == "UINT8") {
this->inputInferenceType = 1;
this->inputInferenceType = TypeId::kNumberTypeUInt8;
} else if (this->inputInferenceTypeIn == "INT8") {
this->inputInferenceType = TypeId::kNumberTypeInt8;
} else { } else {
std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str(); std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str();
return 1; return 1;


+ 2
- 1
mindspore/lite/tools/converter/converter_flags.h View File

@@ -19,6 +19,7 @@


#include <string> #include <string>
#include "tools/common/flag_parser.h" #include "tools/common/flag_parser.h"
#include "ir/dtype/type_id.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"


namespace mindspore { namespace mindspore {
@@ -66,7 +67,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
// used for parse aware trainning // used for parse aware trainning
std::string inputInferenceTypeIn; std::string inputInferenceTypeIn;
// mindspore::predict::DataType inputInferenceType = DataType_DT_FLOAT; // mindspore::predict::DataType inputInferenceType = DataType_DT_FLOAT;
int inputInferenceType = 0;
TypeId inputInferenceType = TypeId::kNumberTypeFloat;
std::string stdDev; std::string stdDev;
std::string mean; std::string mean;
// used for post-trainning-weight // used for post-trainning-weight


+ 104
- 13
mindspore/lite/tools/converter/graphdef_transform.cc View File

@@ -16,11 +16,13 @@


#include "tools/converter/graphdef_transform.h" #include "tools/converter/graphdef_transform.h"
#include <iostream> #include <iostream>
#include <memory>
#include <string> #include <string>
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "src/common/op_utils.h" #include "src/common/op_utils.h"
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h"
@@ -28,7 +30,7 @@
#include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h" // #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h" // #include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h"
// //
// #include "tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.h" // #include "tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.h"
@@ -52,18 +54,45 @@
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
#include "tools/converter/quantizer/aware_quantizer.h"
#include "tools/converter/converter.h" #include "tools/converter/converter.h"


using std::string; using std::string;
namespace mindspore {
namespace lite {
namespace mindspore::lite {
GraphDefTransform::GraphDefTransform() = default; GraphDefTransform::GraphDefTransform() = default;


GraphDefTransform::~GraphDefTransform() = default; GraphDefTransform::~GraphDefTransform() = default;


void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; } void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; }


void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) {
auto type = flags->quantType;
switch (type) {
case QuantType::QuantType_AwareTrainning: {
MS_LOG(INFO) << "create AwareTrainningQuantizer!";
fbQuantizer =
std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean);
break;
}
// case QuantType::QuantType_WeightQuant: {
// MS_LOGI("create WeightQuantizer!");
// mQuantizer.reset(new WeightQuantizer(graphDefT, flags->quantSize));
// break;
// }
// case QuantType_PostTraining: {
// MS_LOGI("create PostTrainningQuantizer!");
// mQuantizer.reset(new PostTrainingQuantizer(graphDefT, flags->configFile));
// break;
// }
// case QuantType::QuantType_QUANT_NONE:
// MS_LOGD("Not do quantization for model!");
// break;
default:
// MS_LOGI("will support quantizer type %s in the future!", flags->quantTypeIn.c_str());
break;
}
}

int GraphDefTransform::Transform(const converter::Flags &ctx) { int GraphDefTransform::Transform(const converter::Flags &ctx) {
STATUS status; STATUS status;
// // constant folding // // constant folding
@@ -133,6 +162,53 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
} }



{
Optimizer unusedOpRemoveOptimizer;
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
status = unusedOpRemoveOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed";
return status;
}
}
// topological sorting
{
Optimizer topologicalOptimizer;
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
status = topologicalOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
}

// generate and infer quant parameters
{
if (mQuantizer != nullptr) {
Optimizer topologicalOptimizer;
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
status = topologicalOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
if (!(this->graphDefT->fmkType == converter::FmkType_TF &&
this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTrainning)) {
status = mQuantizer->GenerateQuantParam();
if (status != RET_OK) {
MS_LOG(ERROR) << "GenerateQuantParam failed";
return status;
}
status = mQuantizer->DetermineNodeQuantType();
if (status != RET_OK) {
MS_LOG(ERROR) << "DetermineNodeQuant failed";
}
}
}
}

// format transform // format transform
if (ctx.formatTrans) { if (ctx.formatTrans) {
Optimizer formatTransOptimizer; Optimizer formatTransOptimizer;
@@ -156,13 +232,30 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
} }


{
Optimizer unusedOpRemoveOptimizer;
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
status = unusedOpRemoveOptimizer.Run(graphDefT);
// do quantization
if (fbQuantizer != nullptr) {
status = fbQuantizer->DoQuantize();
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantize failed!";
return status;
}
}

// insert quantNode and deQuantNode
if (ctx.quantType == QuantType_AwareTrainning) {
Optimizer quantNodeOptimizer;
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
if (dTypeTransPass == nullptr) {
MS_LOG(ERROR) << "new dTypeTransPass failed";
return RET_ERROR;
}
dTypeTransPass->SetInputDataDType(ctx.inputInferenceType);
quantNodeOptimizer.AddPass(dTypeTransPass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
status = quantNodeOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
return status; return status;
} }
} }
@@ -178,6 +271,4 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
return RET_OK; return RET_OK;
} }
} // namespace lite
} // namespace mindspore

} // namespace mindspore::lite

+ 4
- 2
mindspore/lite/tools/converter/graphdef_transform.h View File

@@ -17,8 +17,9 @@
#ifndef MS_GRAPHDEF_TRANSFORM_H #ifndef MS_GRAPHDEF_TRANSFORM_H
#define MS_GRAPHDEF_TRANSFORM_H #define MS_GRAPHDEF_TRANSFORM_H


#include <memory>
#include "tools/converter/optimizer.h" #include "tools/converter/optimizer.h"
// #include "quantizer/quantizer.h"
#include "tools/converter/quantizer/quantizer.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "tools/common/storage.h" #include "tools/common/storage.h"
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"
@@ -42,7 +43,8 @@ class GraphDefTransform {
schema::MetaGraphT *graphDefT = nullptr; schema::MetaGraphT *graphDefT = nullptr;
Optimizer *optimizer = nullptr; Optimizer *optimizer = nullptr;


// std::unique_ptr<Quantizer> mQuantizer;
std::unique_ptr<quant::Quantizer> mQuantizer;
std::unique_ptr<quant::FbQuantizer> fbQuantizer;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 1
- 1
mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h View File

@@ -53,7 +53,7 @@ class MatMulBiasAddFusionPass : public FusionPass {
bool transB = false; bool transB = false;
size_t id = 0; size_t id = 0;


OpDefCopyer TransposeOpCopyer = [](const std::unique_ptr<CNodeT> &inOpDef) -> std::unique_ptr<CNodeT> {
OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> {
std::unique_ptr<CNodeT> newOpDef(new (std::nothrow) CNodeT); std::unique_ptr<CNodeT> newOpDef(new (std::nothrow) CNodeT);
if (newOpDef == nullptr) { if (newOpDef == nullptr) {
MS_LOG(ERROR) << "new OpDefT failed"; MS_LOG(ERROR) << "new OpDefT failed";


+ 1
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt View File

@@ -1,5 +1,6 @@
add_library(graph_pass_mid OBJECT add_library(graph_pass_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/format_trans_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc


+ 235
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc View File

@@ -0,0 +1,235 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
#include <string>
#include "tools/common/converter_op_utils.h"
#include "tools/common/node_util.h"
#include "src/common/common.h"
#include "src/common/utils.h"

namespace mindspore {
namespace lite {
#define kMinInputNum 1
#define kOutputNum 1

STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);

auto status = DoModelInputDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoModelInputDTypeTrans error: " << status;
return status;
}

status = DoModelOutputDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status;
return status;
}

status = DoNodeInoutDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status;
return status;
}
return RET_OK;
}

STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// modify inputTensor first
auto &graphInIdxes = graph->inputIndex;
for (auto graphInIdx : graphInIdxes) {
MS_ASSERT(graph->allTensors.size() > graphInIdx);
auto &graphInTensor = graph->allTensors.at(graphInIdx);
graphInTensor->dataType = TypeId::kNumberTypeUInt8;
}

if (this->inputDataDType == TypeId::kNumberTypeInt8) {
return RET_OK;
}
if (this->inputDataDType != TypeId::kNumberTypeFloat && this->inputDataDType != TypeId::kNumberTypeUInt8) {
MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType;
return RET_ERROR;
}
// insert fp2int8 node
for (auto graphInIdx : graphInIdxes) {
MS_ASSERT(graphInIdx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graphInIdx);
if (tensor->dims.size() != kNHWCDimNumber) {
continue;
}

for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
auto nodeName = node->name;
for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) {
if (node->inputIndex.at(inputIndexIdx) == graphInIdx) {
STATUS status = RET_OK;

// insert dtype cast node between input tensor and input node
if (inputDataDType == TypeId::kNumberTypeFloat) {
iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kFP32ToInt8, &status);
} else {
iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kUInt8ToInt8, &status);
}

if (status != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeTransNode before " << nodeName.c_str() << " failed";
return status;
}
}
}
}
}
return RET_OK;
}

STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
if (inputDataDType == TypeId::kNumberTypeInt8) {
return RET_OK;
}
MS_ASSERT(inputDataDType == TypeId::kNumberTypeFloat);
auto &graphOutIdxes = graph->outputIndex;
for (auto graphOutIdx : graphOutIdxes) {
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
auto nodeName = node->name;
MS_ASSERT(node != nullptr);
for (size_t outputIndexIdx = 0; outputIndexIdx < node->outputIndex.size(); outputIndexIdx++) {
if (node->outputIndex.at(outputIndexIdx) == graphOutIdx) {
// insert transNode
STATUS status = RET_OK;
if (inputDataDType == TypeId::kNumberTypeFloat) {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status);
} else {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToUInt8, &status);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed";
return status;
}
break;
}
}
}
}
return RET_OK;
}

STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// insert transNode before and after existNode
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTrainning) {
continue;
}
auto &node = *iter;
if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) {
continue;
}
bool needInsertPost = true;
if (GetCNodeTType(**iter) == PrimitiveType_Shape) {
needInsertPost = false;
}
auto nodeName = node->name;
if (node->inputIndex.size() < kMinInputNum) {
MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least";
return RET_ERROR;
}
STATUS status;
// insert pre
for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) {
MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i));
auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i));
auto &graphInIdxes = graph->inputIndex;
if (preTensor->nodeType == NodeType_ValueNode && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed";
return RET_ERROR;
}
}

if (needInsertPost) {
for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) {
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed";
return RET_ERROR;
}
}
}
(*iter)->quantType = QuantType_QUANT_NONE;
}

return RET_OK;
}

NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place,
size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) {
MS_ASSERT((*existNodeIter) != nullptr);
auto existNodeName = (*existNodeIter)->name;
std::string tileName;
if (place == kBefore) {
tileName = existNodeName + "_pre";
} else {
tileName = existNodeName + "_post";
}
auto transNode = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT);
if (transNode == nullptr) {
MS_LOG(ERROR) << "new TransNode failed";
*errorCode = RET_ERROR;
return graph->nodes.end();
}
auto quantDTypeCastParam = new (std::nothrow) QuantDTypeCastT;
if (quantDTypeCastParam == nullptr) {
MS_LOG(ERROR) << "new quantDTypeCastParam failed";
*errorCode = RET_ERROR;
return graph->nodes.end();
}
transNode->primitive = std::make_unique<schema::PrimitiveT>();
transNode->primitive->value.value = quantDTypeCastParam;
transNode->primitive->value.type = PrimitiveType_QuantDTypeCast;
transNode->quantType = QuantType_AwareTrainning;
if (nodeType == kInt8ToFP32) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8;
quantDTypeCastParam->dstT = TypeId::kNumberTypeFloat32;
transNode->name = "int8toft32_" + tileName + std::to_string(id++);
} else if (nodeType == kFP32ToInt8) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeFloat32;
quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8;
transNode->name = "ft32toint8_" + tileName + std::to_string(id++);
} else if (nodeType == kUInt8ToInt8) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeUInt8;
quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8;
transNode->name = "uint8toint8_" + tileName + std::to_string(id++);
} else if (nodeType == kInt8ToUInt8) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8;
quantDTypeCastParam->dstT = TypeId::kNumberTypeUInt8;
transNode->name = "int8touint8_" + tileName + std::to_string(id++);
}
transNode->primitive->value.value = quantDTypeCastParam;
return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, castOpCopyer);
}

void DTypeTransPass::SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; }
} // namespace lite
} // namespace mindspore

+ 81
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h View File

@@ -0,0 +1,81 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H
#define MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H

#include <memory>
#include <utility>
#include "tools/converter/optimizer.h"
#include "tools/common/graph_util.h"
#include "tools/converter/converter_flags.h"
#include "tools/common/tensor_util.h"

namespace mindspore {
namespace lite {
enum DTypeTransNodeType { kInt8ToFP32, kFP32ToInt8, kUInt8ToInt8, kInt8ToUInt8 };

class DTypeTransPass : public GraphPass {
public:
DTypeTransPass() : id(0) {}

~DTypeTransPass() override = default;

STATUS Run(schema::MetaGraphT *graph) override;

void SetInputDataDType(TypeId dataType);

private:
STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph);

STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph);

STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph);

NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
DTypeTransNodeType nodeType, STATUS *errorCode);

private:
size_t id;
TypeId inputDataDType = TypeId::kNumberTypeFloat;

OpDefCopyer castOpCopyer = [](schema::CNodeT *inCNode) -> std::unique_ptr<schema::CNodeT> {
std::unique_ptr<schema::CNodeT> newCNode(new (std::nothrow) schema::CNodeT);
if (newCNode == nullptr) {
MS_LOG(ERROR) << "new CNodeT failed";
return nullptr;
}
newCNode->name = inCNode->name;
newCNode->quantType = inCNode->quantType;
newCNode->primitive = std::make_unique<schema::PrimitiveT>();
newCNode->primitive->value.type = inCNode->primitive->value.type;

auto oldQuantDTypeCastParam = inCNode->primitive->value.AsQuantDTypeCast();
auto QuantDTypeCastParam = new (std::nothrow) QuantDTypeCastT;
if (QuantDTypeCastParam == nullptr) {
MS_LOG(ERROR) << "new QuantDTypeCast failed";
return nullptr;
}
QuantDTypeCastParam->srcT = oldQuantDTypeCastParam->srcT;
QuantDTypeCastParam->dstT = oldQuantDTypeCastParam->dstT;
newCNode->primitive->value.value = QuantDTypeCastParam;
return std::move(newCNode);
};
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H

+ 38
- 43
mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc View File

@@ -209,6 +209,9 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
return 0; return 0;
} }


// inference needed filterFormat:
// conv deconv depth dedepth
// uint8 KHWC KHWC KHWC KHWC
int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
MS_ASSERT(graphNode != nullptr); MS_ASSERT(graphNode != nullptr);
auto &subGraph = graphNode->subGraph; auto &subGraph = graphNode->subGraph;
@@ -227,7 +230,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
auto &weightTensor = subGraph->allTensors[weightIndex]; auto &weightTensor = subGraph->allTensors[weightIndex];
MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT
STATUS status = RET_OK; STATUS status = RET_OK;
if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK
if (opType == schema::PrimitiveType_Conv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe if (weightTensor->format == schema::Format_KCHW) { // from caffe
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
@@ -236,58 +239,51 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
} else { } else {
MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
<< weightTensor->dataType; << weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK);
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} }
} else if (weightTensor->format == schema::Format_KHWC) { // from onnx
return RET_OK;
// if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
// status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
// } else {
// status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
// }
} else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0;
} else {
} else if (weightTensor->format != schema::Format_KHWC) {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1; return -1;
} }
if (status == 0) { if (status == 0) {
node->primitive->value.AsConv2D()->format = schema::Format_NHWC; node->primitive->value.AsConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_HWCK;
weightTensor->format = schema::Format_KHWC;
} else { } else {
MS_LOG(WARNING) << "TransFilter %sToHWCK failed, node : "
<< (weightTensor->format == schema::Format_KCHW ? "KCHW" : "KHWC"),
node->name.c_str();
MS_LOG(WARNING) << "TransFilter %sToKHWC failed, node : "
<< (weightTensor->format == schema::Format_KHWC ? "KHWC" : "KCHW") << node->name.c_str();
// todo(00445839): consider varible weight condition // todo(00445839): consider varible weight condition
} }
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be HWCK
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_CKHW) { // from caffe if (weightTensor->format == schema::Format_CKHW) { // from caffe
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format,
weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2HWCK);
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<int8_t>(weightTensor.get(), kCKHW2KHWC);
} else if (weightTensor->dataType == kNumberTypeUInt8) {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2KHWC);
} else { } else {
MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format,
weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2HWCK);
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
} }


} else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0;
} else if (weightTensor->format == schema::Format_CHWK) { // from onnx } else if (weightTensor->format == schema::Format_CHWK) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
if (weightTensor->dataType == kNumberTypeInt8) {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC); status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC);
MS_LOG(DEBUG) << node->name << " weight trans format: CHWK->KHWC";
} else if (weightTensor->dataType == kNumberTypeUInt8) {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2KHWC);
} else { } else {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} }
} else if (weightTensor->format == schema::Format_KCHW) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKCHW2HWCK);
} else {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK);
}
} else {
} else if (weightTensor->format != schema::Format_KHWC) {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1; return -1;
} }
@@ -295,14 +291,13 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_KHWC; weightTensor->format = schema::Format_KHWC;
} else { } else {
MS_LOG(WARNING) << "TransFilter %ToHWCK failed, node : "
<< (weightTensor->format == schema::Format_CHWK ? "CHWK" : "CKHW"),
node->name.c_str();
MS_LOG(WARNING) << "TransFilter" << (weightTensor->format == schema::Format_KHWC ? "KHWC" : "CKHW")
<< "To KHWC failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition // todo(00445839): consider varible weight condition
} }
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be HWCK
node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW;
weightTensor->format = schema::Format_CKHW;
} else { // weight should be HWCK
node->primitive->value.AsDeConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_KHWC;
} }
return 0; return 0;
} }
@@ -354,7 +349,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
if (graphNode->subGraph->fmkType == converter::FmkType_MS) { if (graphNode->subGraph->fmkType == converter::FmkType_MS) {
weightTensor->format = schema::Format_CKHW; weightTensor->format = schema::Format_CKHW;
} }
if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms
if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC); status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
} else if (weightTensor->format == schema::Format_KCHW) { } else if (weightTensor->format == schema::Format_KCHW) {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
@@ -374,8 +369,8 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_CHWK) { // from tf
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} else if (weightTensor->format == schema::Format_KHWC) { // from tf
status = RET_OK;
} else { } else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1; return -1;


+ 2
- 1
mindspore/lite/tools/converter/model_parser.h View File

@@ -40,7 +40,8 @@ class ModelParser {
} }
return Fb2Anf(Parse(modelFile, weightFile)); return Fb2Anf(Parse(modelFile, weightFile));
} }
virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) = 0;
virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) = 0;


public: public:
static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) { static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) {


+ 3
- 2
mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc View File

@@ -31,7 +31,8 @@ CaffeModelParser::~CaffeModelParser() {}


const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"}; const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"};


schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const std::string &weightFile) {
schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
std::unique_ptr<schema::MetaGraphT> graph(new schema::MetaGraphT()); std::unique_ptr<schema::MetaGraphT> graph(new schema::MetaGraphT());


if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) { if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) {
@@ -91,7 +92,7 @@ schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const
// ConvertCaffeBatchNorm(graph.get()); // ConvertCaffeBatchNorm(graph.get());


return graph.release(); return graph.release();
// return Fb2Anf(graph.release());
// return Fb2Anf(graph.release());
} }


STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op,


+ 2
- 1
mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h View File

@@ -33,7 +33,8 @@ class CaffeModelParser : public ModelParser {


virtual ~CaffeModelParser(); virtual ~CaffeModelParser();


MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) override;
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;


private: private:
void ConvertCaffeBatchNorm(MetaGraphT *meta_graphT); void ConvertCaffeBatchNorm(MetaGraphT *meta_graphT);


+ 2
- 1
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h View File

@@ -37,7 +37,8 @@ class OnnxModelParser : public ModelParser {
public: public:
OnnxModelParser(); OnnxModelParser();
virtual ~OnnxModelParser(); virtual ~OnnxModelParser();
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) override;
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;


private: private:
TypeId GetDateTypeFromOnnx(onnx::TensorProto_DataType onnx_type); TypeId GetDateTypeFromOnnx(onnx::TensorProto_DataType onnx_type);


+ 81
- 56
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc View File

@@ -20,7 +20,6 @@
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "tools/common/storage.h" #include "tools/common/storage.h"
#include "flatbuffers/flatbuffers.h" #include "flatbuffers/flatbuffers.h"
#include "utils/log_adapter.h"
#include "src/common/file_utils.h" #include "src/common/file_utils.h"


namespace mindspore { namespace mindspore {
@@ -60,42 +59,64 @@ STATUS TfliteModelParser::SetAllTensors(const TensorCache &tensor_cache, schema:
} }
return RET_OK; return RET_OK;
} }
void TfliteModelParser::SetMsTensorFromTflite(const std::unique_ptr<tflite::TensorT> &tflite_tensor,
schema::TensorT *tensor) {
std::unique_ptr<schema::QuantParamT> quant_param(new QuantParamT());
if (!tflite_tensor->quantization->scale.empty()) {
quant_param->scale = tflite_tensor->quantization->scale[0];
}


STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
auto dst_op = tfliteOpMap.at(tflite_op.get());
if (!tflite_tensor->quantization->zero_point.empty()) {
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0];
}


std::vector<uint32_t> quant_params_index;
quant_params_index.insert(quant_params_index.end(), tflite_op->inputs.begin(), tflite_op->inputs.end());
quant_params_index.insert(quant_params_index.end(), tflite_op->outputs.begin(), tflite_op->outputs.end());
for (const auto &index : quant_params_index) {
const auto &tflite_tensor = tflite_subgraph->tensors[index];
if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tensor with id = " << index <<" is null";
return RET_ERROR;
}
// change quant param min to 0 to fit ms-lite ops
if (tensor->dataType == TypeId::kNumberTypeInt8) {
quant_param->zeroPoint = quant_param->zeroPoint - 128;
}

if (!tflite_tensor->quantization->min.empty()) {
quant_param->min = tflite_tensor->quantization->min[0];
}

if (!tflite_tensor->quantization->max.empty()) {
quant_param->max = tflite_tensor->quantization->max[0];
}
quant_param->inited = true;
tensor->quantParams.clear();
tensor->quantParams.emplace_back(std::move(quant_param));
}

STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
schema::CNodeT *op, TensorCache *tensor_cache) {
MS_ASSERT(op->outputIndex.size() == tflite_op->outputs.size());
for (size_t i = 0; i < tflite_op->inputs.size() && i < op->inputIndex.size(); i++) {
const auto &tflite_tensor = tflite_subgraph->tensors[tflite_op->inputs.at(i)];
if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() &&
tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) { tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) {
continue; continue;
} }
std::unique_ptr<schema::QuantParamT> quant_param(new schema::QuantParamT());
if (!tflite_tensor->quantization->scale.empty()) {
quant_param->scale = tflite_tensor->quantization->scale[0];
}

if (!tflite_tensor->quantization->zero_point.empty()) {
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0];
auto &inTensor = tensor_cache->GetCachedTensor().at(op->inputIndex.at(i));
if (inTensor == nullptr) {
MS_LOG(ERROR) << "Parse tflite quant params inTensor is null";
return RET_NULL_PTR;
} }

if (!tflite_tensor->quantization->min.empty()) {
quant_param->min = tflite_tensor->quantization->min[0];
SetMsTensorFromTflite(tflite_tensor, inTensor);
}
for (size_t i = 0; i < tflite_op->outputs.size() && i < op->outputIndex.size(); i++) {
const auto &tflite_tensor = tflite_subgraph->tensors[tflite_op->outputs.at(i)];
if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() &&
tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) {
continue;
} }

if (!tflite_tensor->quantization->max.empty()) {
quant_param->max = tflite_tensor->quantization->max[0];
auto &outTensor = tensor_cache->GetCachedTensor().at(op->outputIndex.at(i));
if (outTensor == nullptr) {
MS_LOG(ERROR) << "Parse tflite quant params outTensor is null";
return RET_NULL_PTR;
} }
SetMsTensorFromTflite(tflite_tensor, outTensor);
} }
dst_op->quantType = schema::QuantType_AwareTrainning;
return RET_OK; return RET_OK;
} }


@@ -105,11 +126,15 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT
for (const auto &index : tflite_op->outputs) { for (const auto &index : tflite_op->outputs) {
const auto &tflite_tensor = tflite_subgraph->tensors[index]; const auto &tflite_tensor = tflite_subgraph->tensors[index];
if (tflite_tensor == nullptr) { if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tensor with id = " << index <<" is null";
MS_LOG(ERROR) << "tensor with id = " << index << " is null";
return RET_ERROR; return RET_ERROR;
} }
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT()); std::unique_ptr<schema::TensorT> tensor(new schema::TensorT());
tensor->dataType = GetTfliteDataType(tflite_tensor->type); tensor->dataType = GetTfliteDataType(tflite_tensor->type);
// change dataType to int8 to fit ms-lite op
if (tensor->dataType == TypeId::kNumberTypeUInt8) {
tensor->dataType = TypeId::kNumberTypeInt8;
}
tensor->dims = tflite_tensor->shape; tensor->dims = tflite_tensor->shape;
tensor->nodeType = schema::NodeType_Parameter; tensor->nodeType = schema::NodeType_Parameter;
auto opOutputIndex = tensorCache->AddTensor(tflite_tensor->name, tensor.release(), OP_OUTPUT); auto opOutputIndex = tensorCache->AddTensor(tflite_tensor->name, tensor.release(), OP_OUTPUT);
@@ -120,7 +145,8 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT


STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &tflite_model, STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op, TensorCache *tensorCache) {
const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op,
TensorCache *tensor_cache) {
auto op_type = GetTfliteNodeType(tflite_op, tflite_model); auto op_type = GetTfliteNodeType(tflite_op, tflite_model);
std::vector<int32_t> op_inputs(tflite_op->inputs); std::vector<int32_t> op_inputs(tflite_op->inputs);
if (op_type == "DeConv2D") { if (op_type == "DeConv2D") {
@@ -130,12 +156,11 @@ STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &t
for (const auto &tflite_index : op_inputs) { for (const auto &tflite_index : op_inputs) {
const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index]; const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index];
if (tflite_tensor == nullptr) { if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tensor with id = " << tflite_index <<" is null";
MS_LOG(ERROR) << "tensor with id = " << tflite_index << " is null";
return RET_ERROR; return RET_ERROR;
} }
auto tensor_name = tflite_tensor->name; auto tensor_name = tflite_tensor->name;
auto op = tfliteOpMap[tflite_op.get()];
unsigned int index = tensorCache->FindTensor(tensor_name);
unsigned int index = tensor_cache->FindTensor(tensor_name);
if (index != -1) { if (index != -1) {
op->inputIndex.push_back(index); op->inputIndex.push_back(index);
} }
@@ -146,19 +171,20 @@ STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &t


STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model, STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
schema::MetaGraphT *subGraph,
mindspore::lite::TensorCache *tensorCache) {
schema::MetaGraphT *subGraph, mindspore::lite::TensorCache *tensorCache,
const QuantType &quantType) {
auto i = 0; auto i = 0;
for (const auto &tflite_op : tflite_subgraph->operators) { for (const auto &tflite_op : tflite_subgraph->operators) {
auto opType = GetTfliteNodeType(tflite_op, tflite_model); auto opType = GetTfliteNodeType(tflite_op, tflite_model);


std::unique_ptr<schema::CNodeT> op(new schema::CNodeT); std::unique_ptr<schema::CNodeT> op(new schema::CNodeT);
op->name = opType + "-" + std::to_string(i++); op->name = opType + "-" + std::to_string(i++);
op->quantType = quantType;
MS_LOG(INFO) << "parse op: " << op->name.c_str(); MS_LOG(INFO) << "parse op: " << op->name.c_str();


auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType); auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType);
if (node_parser == nullptr) { if (node_parser == nullptr) {
MS_LOG(ERROR) << "cannot find node parser, opType: "<< opType.c_str();
MS_LOG(ERROR) << "cannot find node parser, opType: " << opType.c_str();
continue; continue;
// return RET_NULL_PTR; // return RET_NULL_PTR;
} }
@@ -172,7 +198,19 @@ STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_


status = SetOpOutputIdx(tflite_subgraph, tflite_op, op.get(), tensorCache); status = SetOpOutputIdx(tflite_subgraph, tflite_op, op.get(), tensorCache);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Set Op "<< op->name.c_str() << " Output Index Failed!";
MS_LOG(ERROR) << "set op " << opType.c_str() << " output index failed";
return RET_ERROR;
}

status = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, op.get(), tensorCache);
if (status != RET_OK) {
MS_LOG(ERROR) << "set op " << opType.c_str() << " input index failed";
return RET_ERROR;
}

status = ParseTfliteQuantParams(tflite_subgraph, tflite_op, op.get(), tensorCache);
if (status != RET_OK) {
MS_LOG(ERROR) << "parse op " << opType.c_str() << " quant parameters failed";
return RET_ERROR; return RET_ERROR;
} }


@@ -189,8 +227,10 @@ void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::SubGraphT>
const auto &tflite_tensor = tflite_subgraph->tensors[index]; const auto &tflite_tensor = tflite_subgraph->tensors[index];
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT()); std::unique_ptr<schema::TensorT> tensor(new schema::TensorT());
tensor->format = schema::Format_NHWC; tensor->format = schema::Format_NHWC;
tensor->dataType = GetTfliteDataType(tflite_tensor->type);
tensor->nodeType = schema::NodeType_ValueNode;
tensor->dataType = GetTfliteDataType(tflite_tensor->type) != TypeId::kNumberTypeUInt8
? GetTfliteDataType(tflite_tensor->type)
: TypeId::kNumberTypeInt8;
tensor->nodeType = schema::NodeType_Parameter;
tensor->dims = tflite_tensor->shape; tensor->dims = tflite_tensor->shape;
tensor_cache->AddTensor(tflite_tensor->name, tensor.release(), GRAPH_INPUT); tensor_cache->AddTensor(tflite_tensor->name, tensor.release(), GRAPH_INPUT);
} }
@@ -212,7 +252,8 @@ void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache &
} }
} }


MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile) {
MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
if (ValidateFileStr(modelFile, ".tflite") != RET_OK) { if (ValidateFileStr(modelFile, ".tflite") != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.tflite"; MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.tflite";
return nullptr; return nullptr;
@@ -224,7 +265,6 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st
MS_LOG(ERROR) << "read tflite model failed"; MS_LOG(ERROR) << "read tflite model failed";
return nullptr; return nullptr;
} }

if (tflite_model->subgraphs.size() != 1) { if (tflite_model->subgraphs.size() != 1) {
MS_LOG(ERROR) << "read tflite model subgraphs failed"; MS_LOG(ERROR) << "read tflite model subgraphs failed";
return nullptr; return nullptr;
@@ -238,30 +278,15 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st
// set dst subGraph op attr and tensor_cache. // set dst subGraph op attr and tensor_cache.
std::unique_ptr<schema::MetaGraphT> subGraph(new schema::MetaGraphT); std::unique_ptr<schema::MetaGraphT> subGraph(new schema::MetaGraphT);
subGraph->name = "MS_model converted by TF-Lite"; subGraph->name = "MS_model converted by TF-Lite";
auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache);
auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache, quantType);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "ParseOp failed."; MS_LOG(ERROR) << "ParseOp failed.";
return nullptr; return nullptr;
} }


for (const auto &tflite_op : tflite_subgraph->operators) {
auto status_tmp = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, &tensorCache);
if (status_tmp != RET_OK) {
MS_LOG(ERROR) << "Set Op " << tfliteOpMap.at(tflite_op.get())->name.c_str() << " Input Index Failed!";
}
}

for (const auto &tflite_op : tflite_subgraph->operators) {
auto statusTmp = ParseTfliteQuantParams(tflite_subgraph, tflite_op);
if (statusTmp != RET_OK) {
MS_LOG(ERROR) << "ParseTfliteQuantParams " << tfliteOpMap.at(tflite_op.get())->name.c_str() << " Failed!";
}
}

SetGraphTensorIndex(tensorCache, subGraph.get()); SetGraphTensorIndex(tensorCache, subGraph.get());
SetAllTensors(tensorCache, subGraph.get()); SetAllTensors(tensorCache, subGraph.get());
return subGraph.release(); return subGraph.release();
} }
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 11
- 8
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h View File

@@ -40,22 +40,25 @@ class TfliteModelParser : public ModelParser {


virtual ~TfliteModelParser(); virtual ~TfliteModelParser();


MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile);
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;


private: private:
std::unique_ptr<tflite::ModelT> ReadTfliteModelFromFlat(const char *buf); std::unique_ptr<tflite::ModelT> ReadTfliteModelFromFlat(const char *buf);


void SetMsTensorFromTflite(const std::unique_ptr<tflite::TensorT> &tflite_tensor, schema::TensorT *tensor);

void SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache); void SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache);


void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache,
schema::MetaGraphT *subGraphDef);
void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, schema::MetaGraphT *subGraphDef);


STATUS ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model, STATUS ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph, const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph,
TensorCache *tensor_cache);
TensorCache *tensor_cache, const QuantType &quantType);


STATUS ParseTfliteQuantParams(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, STATUS ParseTfliteQuantParams(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op);
const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op,
TensorCache *tensor_cache);


std::string GetTfliteNodeType(const std::unique_ptr<tflite::OperatorT> &tflite_op, std::string GetTfliteNodeType(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model); const std::unique_ptr<tflite::ModelT> &tflite_model);
@@ -63,13 +66,13 @@ class TfliteModelParser : public ModelParser {
STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *sub_graph); STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *sub_graph);


STATUS SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, STATUS SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op,
TensorCache *tensorCache); TensorCache *tensorCache);


STATUS SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &tflite_model, STATUS SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op, TensorCache *tensorCache);
const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op,
TensorCache *tensor_cache);


std::map<std::string, schema::CNodeT *> opMap; std::map<std::string, schema::CNodeT *> opMap;
std::map<const tflite::OperatorT *, schema::CNodeT *> tfliteOpMap; std::map<const tflite::OperatorT *, schema::CNodeT *> tfliteOpMap;


+ 2
- 0
mindspore/lite/tools/converter/quantizer/CMakeLists.txt View File

@@ -4,7 +4,9 @@ include_directories(${3RD_DIR}/flatbuffers/include)
include_directories(${3RD_DIR}/opencv/build/include/opencv4) include_directories(${3RD_DIR}/opencv/build/include/opencv4)


add_library(quantizer_mid OBJECT add_library(quantizer_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc
${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc


+ 594
- 0
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc View File

@@ -0,0 +1,594 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/quantizer/aware_quantizer.h"
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "schema/inner/model_generated.h"
#include "utils/log_adapter.h"
#include "securec/include/securec.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "src/common/utils.h"
#include "tools/converter/quantizer/calc_quant_param.h"
#include "tools/common/tensor_util.h"
#include "tools/common/converter_op_utils.h"
#include "tools/common/node_util.h"

using std::string;
using std::vector;

namespace mindspore::lite::quant {
struct InputArray {
std::unique_ptr<QuantParamT> quantParam;
float mMin = 0.0f;
float mMax = 0.0f;
bool narrowRange = false;
int numBits = 8;
TypeId dataType = TypeId::kTypeUnknown;

InputArray(float mean, float stdDev, TypeId dataType = TypeId::kNumberTypeFloat) {
this->dataType = dataType;
constexpr float qmin = 0;
constexpr float qmax = 255;
mMin = (qmin - mean) / stdDev;
mMax = (qmax - mean) / stdDev;
}

STATUS InitQuantParam() {
this->quantParam = std::make_unique<schema::QuantParamT>();
auto status = CalQuantizationParams(quantParam.get(), mMin, mMax, narrowRange, numBits);
if (status != RET_OK) {
return status;
}
return RET_OK;
}

STATUS SetInputArrayQP(schema::MetaGraphT *graph, size_t inputTensorIdx) {
MS_ASSERT(graph != nullptr);
auto &tensor = graph->allTensors.at(inputTensorIdx);
MS_ASSERT(tensor != nullptr);
if (!tensor->quantParams.empty()) {
auto param = GetTensorQuantParam(tensor);
if (param != nullptr && param->inited) {
MS_LOG(DEBUG) << "tensor " << inputTensorIdx << " already has quantParam";
return RET_OK;
}
tensor->quantParams.clear();
}
std::unique_ptr<schema::QuantParamT> tmpQuantParam(new QuantParamT());
tmpQuantParam->inited = this->quantParam->inited;
tmpQuantParam->scale = this->quantParam->scale;
tmpQuantParam->zeroPoint = this->quantParam->zeroPoint;
tmpQuantParam->min = this->quantParam->min;
tmpQuantParam->max = this->quantParam->max;
tensor->quantParams.push_back(std::move(tmpQuantParam));
return RET_OK;
}
};

const std::array<schema::PrimitiveType, 7> AwareQuantizer::propagatedOps = {
{schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, schema::PrimitiveType_Reshape,
schema::PrimitiveType_Squeeze, schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation,
schema::PrimitiveType_DetectionPostProcess}};

AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInferType, const string &stdValues,
const string &meanValues)
: FbQuantizer(graph) {
MS_ASSERT(graph != nullptr);
string::size_type sz;
const float stdValue = std::stof(stdValues, &sz);
sz = 0;
const float mean = std::stof(meanValues, &sz);
if (inputInferType == "FLOAT") {
mInputArray = new InputArray(mean, stdValue);
} else {
mInputArray = new InputArray(mean, stdValue, TypeId::kNumberTypeUInt8);
}
mInputArray->InitQuantParam();
}

STATUS AwareQuantizer::RemoveFakeQuant() {
// for (auto &subGraph : graphDefT->subgraphs) {
// auto status = GenerateDefaultQuantParam(subGraph.get());
// if (status != RET_OK) {
// MS_LOGE("GenerateDefaultQuantParam failed: %d", status);
// return RET_ERROR;
// }
// for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end(); iter++) {
// auto *node = (*iter).get();
// if (GetCNodeTType(*node) != OpT_FakeQuantWithMinMaxVars && GetCNodeTType(*node) != OpT_FakeQuantWithMinMax) {
// continue;
// }
// auto inputIndexes = node->inputIndex;
// if (inputIndexes.size() != 3) {
// MS_LOGE("invalid fakequant node's input tensors count!");
// return RET_ERROR;
// }
// bool narrorRange;
// int numBits;
// if (GetCNodeTType(*node) == OpT_FakeQuantWithMinMaxVars) {
// narrorRange = node->attr.AsFakeQuantWithMinMaxVars()->narrowRange;
// numBits = node->attr.AsFakeQuantWithMinMaxVars()->numBits;
// }
// if (GetCNodeTType(*node) == OpT_FakeQuantWithMinMax) {
// narrorRange = false;
// numBits = 8;
// }
//
// TensorDefT *tensor0 = subGraph->allTensors.at(inputIndexes[0]).get();
// TensorDefT *tensor1 = subGraph->allTensors.at(inputIndexes[1]).get();
// TensorDefT *tensor2 = subGraph->allTensors.at(inputIndexes[2]).get();
// MS_ASSERT(tensor0 != nullptr);
// MS_ASSERT(tensor1 != nullptr);
// MS_ASSERT(tensor2 != nullptr);
// // calculate quant param
// MS_ASSERT(tensor1->dataType == DataType_DT_FLOAT);
// MS_ASSERT(tensor2->dataType == DataType_DT_FLOAT);
// auto *minData = reinterpret_cast<const float *>(tensor1->data.data());
// auto *maxData = reinterpret_cast<const float *>(tensor2->data.data());
// MS_ASSERT(minData != nullptr);
// MS_ASSERT(maxData != nullptr);
// std::unique_ptr<QuantParamT> quantParam(new (std::nothrow) QuantParamT());
// if (quantParam == nullptr) {
// MS_LOGE("new quantParam failed");
// return RET_ERROR;
// }
// auto realMin = (double)minData[0];
// auto realMax = (double)maxData[0];
// status = CalQuantizationParams(quantParam.get(), realMin, realMax, narrorRange, numBits);
// if (status != RET_OK) {
// MS_LOGE("in aware quantization run CalQuantizationParams failed, node: %s", node->name.c_str());
// return RET_ERROR;
// }
// if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT) {
// CalFakeNode(tensor0, quantParam.get());
// }
// std::unique_ptr<QuantParamArrayT> quantParamArray(new (std::nothrow) QuantParamArrayT());
// if (quantParamArray == nullptr) {
// MS_LOGE("new quantParamArray failed");
// return RET_ERROR;
// }
// quantParamArray->param.push_back(std::move(quantParam));
// auto quantParamArrayCopy = CopyQuantParamArrayT(quantParamArray);
// if (quantParamArrayCopy == nullptr) {
// MS_LOGE("CopyQuantParamArray %s return nullptr", iter->get()->name.c_str());
// return RET_ERROR;
// }
// node->quantParam.emplace_back(std::move(quantParamArrayCopy));
// node->quantParam.emplace_back(nullptr); // secondInTensor and thirdInTensor are weightTensors who have no
// preNode node->quantParam.emplace_back(nullptr); node->quantParam.emplace_back(std::move(quantParamArray));
//
// // BroadCast fakeQuantNode QuantParam
// status = BroadCastQuantParam(subGraph, *iter);
// if (status != RET_OK) {
// MS_LOGE("BroadCastQuantParam %s failed: %d", iter->get()->name.c_str(), status);
// return status;
// }
// // save post node index for SetAttrToConvolution
// auto postNodeIdxes = GetOutputNodeIdx(*subGraph, *node);
// // remove fakequantwithminmax node
// status = IsolateNode(subGraph.get(), node);
// if (status != RET_OK) {
// MS_LOGE("in aware quant IsolateNode failed!");
// return RET_ERROR;
// }
// // set filter param to node
// if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT && !postNodeIdxes.empty()) {
// auto postNode = subGraph->nodes.at(postNodeIdxes.front()).get();
// if (GetCNodeTType(*postNode) == OpT_Conv2D || GetCNodeTType(*postNode) == OpT_DepthwiseConv2D ||
// GetCNodeTType(*postNode) == OpT_DeConv2D || GetCNodeTType(*postNode) == OpT_DeDepthwiseConv2D) {
// auto status = SetAttrToConvolution(subGraph.get(), postNode);
// if (status != RET_OK) {
// MS_LOGE("in aware quant SetAttrToConvolution failed!");
// return RET_ERROR;
// }
// }
// }
// }
//
// // remove IsolatedNode
// for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end();) {
// if ((*iter)->inputIndex.empty() && (*iter)->outputIndex.empty()) {
// iter = subGraph->nodes.erase(iter);
// } else {
// iter++;
// }
// }
// // set graphInputNode inputTensor quantParams
// MS_ASSERT(subGraph->inputIndex.size() == 1);
// for (auto graphInputIndex : subGraph->inputIndex) {
// auto linkedPostIdx = GetLinkedPostIdx(*(subGraph.get()), graphInputIndex);
// for (auto nodeIdx : linkedPostIdx) {
// MS_ASSERT(subGraph->nodes.size() > nodeIdx);
// mInputArray->SetInputArrayQP(subGraph->nodes.at(nodeIdx).get());
// }
// }
// }
return RET_OK;
}

STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph) {
MS_ASSERT(subGraph != nullptr);
for (const auto &tensor : subGraph->allTensors) {
if (!tensor->quantParams.empty()) {
continue;
}
std::unique_ptr<schema::QuantParamT> defaultQuantParam(new QuantParamT());
tensor->quantParams.emplace_back(std::move(defaultQuantParam));
}
return RET_OK;
}

STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
// MS_ASSERT(subGraph != nullptr);
// MS_ASSERT(node != nullptr);
// auto inputIndexes = node->inputIndex;
// MS_ASSERT(GetCNodeTType(*node) == OpT_Conv2D || GetCNodeTType(*node) == OpT_DepthwiseConv2D ||
// GetCNodeTType(*node) == OpT_DeConv2D || GetCNodeTType(*node) == OpT_DeDepthwiseConv2D);
// if (inputIndexes.size() < 2) {
// MS_LOGE("in aware quant %s node's input tensors is invalid(%zu)!", node->name.c_str(), inputIndexes.size());
// return RET_ERROR;
// }
// TensorDefT *filterTensor = subGraph->allTensors.at(inputIndexes[1]).get();
// MS_ASSERT(filterTensor != nullptr);
// auto filterDims = filterTensor->dims;
// MS_ASSERT(filterDims.size() == 4);
// if (GetCNodeTType(*node) == OpT_Conv2D) {
// if (node->fmkType == FmkType_MS) {
// node->attr.AsConv2D()->channelOut = (int32_t)filterDims[0];
// node->attr.AsConv2D()->channelIn = (int32_t)filterDims[1];
// node->attr.AsConv2D()->kernelH = (int32_t)filterDims[2];
// node->attr.AsConv2D()->kernelW = (int32_t)filterDims[3];
// } else if (node->fmkType == FmkType_TF) {
// node->attr.AsConv2D()->kernelH = (int32_t)filterDims[0];
// node->attr.AsConv2D()->kernelW = (int32_t)filterDims[1];
// node->attr.AsConv2D()->channelIn = (int32_t)filterDims[2];
// node->attr.AsConv2D()->channelOut = (int32_t)filterDims[3];
// } else {
// MS_LOGE("Unsupport");
// }
// }
// if (GetCNodeTType(*node) == OpT_DepthwiseConv2D) {
// if (node->fmkType == FmkType_MS) {
// node->attr.AsDepthwiseConv2D()->channelIn = (int32_t)filterDims[0];
// node->attr.AsDepthwiseConv2D()->channelMultiplier = (int32_t)filterDims[1];
// node->attr.AsDepthwiseConv2D()->kernelH = (int32_t)filterDims[2];
// node->attr.AsDepthwiseConv2D()->kernelW = (int32_t)filterDims[3];
// } else if (node->fmkType == FmkType_TF) {
// node->attr.AsDepthwiseConv2D()->kernelH = (int32_t)filterDims[0];
// node->attr.AsDepthwiseConv2D()->kernelW = (int32_t)filterDims[1];
// node->attr.AsDepthwiseConv2D()->channelIn = (int32_t)filterDims[2];
// node->attr.AsDepthwiseConv2D()->channelMultiplier = (int32_t)filterDims[3];
// } else {
// MS_LOGE("Unsupport");
// }
// }
// if (GetCNodeTType(*node) == OpT_DeConv2D) {
// MS_ASSERT(false);
// }
// if (GetCNodeTType(*node) == OpT_DeDepthwiseConv2D) {
// MS_ASSERT(false);
// }
return RET_OK;
}

STATUS AwareQuantizer::GenerateQuantParam() {
// todo why?
MS_ASSERT(graph->inputIndex.size() == 1);
// set graphInputNode input
for (auto graphInputIndex : graph->inputIndex) {
auto status = mInputArray->SetInputArrayQP(graph.get(), graphInputIndex);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetInputArrayQP failed";
return status;
}
}
auto status = GenerateDefaultQuantParam(graph.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "GenerateDefaultQuantParam failed";
return status;
}
auto *quantParamRegister = QuantParamCalcRegister::GetInstance();

for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
MS_ASSERT(node != nullptr);
if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax ||
GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) {
MS_ASSERT(false);
}
auto *quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
if (quantParamCalcer == nullptr) {
MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str()
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
} else {
status = quantParamCalcer->Calc(graph.get(), *node);
if (status != RET_OK) {
MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
node->quantType = schema::QuantType_QUANT_NONE;
} else {
node->quantType = schema::QuantType_AwareTrainning;
}
}
}
return RET_OK;
}

STATUS AwareQuantizer::DoQuantize() {
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
continue;
}
if (node->quantType != schema::QuantType_AwareTrainning) {
continue;
}
STATUS status;
if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D ||
GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D) {
auto inputIndexes = node->inputIndex;
if (inputIndexes.size() < 2) {
MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count";
return RET_ERROR;
}
// quant weight
status = QuantConvWeight(graph.get(), node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantConvWeight failed!";
return RET_ERROR;
}
// quant bias
if (inputIndexes.size() == 3) {
status = QuantConvBias(graph.get(), node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantConvBias failed!";
return RET_ERROR;
}
}
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) {
status = QuantDetectionPostProcessConstTensor(graph.get(), node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!";
return RET_ERROR;
}
} else if (GetCNodeTType(*node) == schema::PrimitiveType_Add) {
status = QuantAddConstTensor(graph.get(), node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantAddConstTensor failed!";
return RET_ERROR;
}
}
const auto nodeType = GetCNodeTType(*node);
auto find = std::find(propagatedOps.begin(), propagatedOps.end(), nodeType);
if (find != propagatedOps.end()) {
auto inputTensor = graph->allTensors.at(node->inputIndex[0]).get();
auto outputTensor = graph->allTensors.at(node->outputIndex[0]).get();
MS_ASSERT(inputTensor != nullptr);
MS_ASSERT(outputTensor != nullptr);
outputTensor->dataType = inputTensor->dataType;
}
}
return RET_OK;
}

STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(node != nullptr);
for (size_t i = 0; i < node->inputIndex.size(); i++) {
auto inTensorIdx = node->inputIndex.at(i);
MS_ASSERT(graph->allTensors.size() > inTensorIdx);
auto &inTensor = graph->allTensors.at(inTensorIdx);
MS_ASSERT(inTensor != nullptr);
if (inTensor->refCount == 999) {
switch (inTensor->dataType) {
case TypeId::kNumberTypeFloat: {
auto quantParam = GetTensorQuantParam(inTensor);
MS_ASSERT(quantParam != nullptr);
MS_ASSERT(quantParam->inited);
auto constTensorShapeSize = GetShapeSize(*(inTensor.get()));
vector<uint8_t> qDatas(constTensorShapeSize);
void *inData = inTensor->data.data();
auto *castedInData = static_cast<float *>(inData);
for (size_t j = 0; j < constTensorShapeSize; j++) {
qDatas[j] = QuantizeData<uint8_t>(castedInData[j], quantParam.get());
}
inTensor->data = std::move(qDatas);
inTensor->dataType = kNumberTypeUInt8;
} break;
case kNumberTypeUInt8:
break;
default:
// MS_LOGE("Unsupported dataType: %d", inTensor->dataType);
return RET_ERROR;
}
}
}
return RET_OK;
}

STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
MS_ASSERT(subGraph != nullptr);
MS_ASSERT(node != nullptr);
auto &constTensor = subGraph->allTensors.at(node->inputIndex[2]);
MS_ASSERT(constTensor != nullptr);
const auto *constData = reinterpret_cast<const float *>(constTensor->data.data());

if (constTensor->refCount == 999 && constTensor->dataType == TypeId::kNumberTypeFloat) {
size_t constTensorShapeSize = GetShapeSize(*constTensor);
std::unique_ptr<QuantParamT> quantParam = GetTensorQuantParam(constTensor);
if (quantParam == nullptr) {
// MS_LOGE("new QuantParamT failed");
return RET_NULL_PTR;
}
vector<uint8_t> qDatas(constTensorShapeSize);
for (size_t j = 0; j < constTensorShapeSize; j++) {
float rawData = constData[j];
qDatas[j] = QuantizeData<uint8_t>(rawData, quantParam.get());
}
constTensor->data = std::move(qDatas);
constTensor->dataType = TypeId::kNumberTypeUInt8;
}
return RET_OK;
}

STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, mindspore::schema::CNodeT *node) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(node != nullptr);
auto inputIndexes = node->inputIndex;
MS_ASSERT(inputIndexes.size() >= 3);
MS_ASSERT(graph->allTensors.size() > inputIndexes.at(0));
MS_ASSERT(graph->allTensors.size() > inputIndexes.at(1));
MS_ASSERT(graph->allTensors.size() > inputIndexes.at(2));
auto &biasTensor = graph->allTensors.at(inputIndexes.at(2));
MS_ASSERT(biasTensor != nullptr);
if (biasTensor->dataType != TypeId::kNumberTypeFloat) {
// MS_LOGD("conv %s's bias data is not float", node->name.c_str());
return RET_OK;
}

if (biasTensor->dataType == TypeId::kNumberTypeInt32) {
return RET_OK;
}
if (biasTensor->dataType != TypeId::kNumberTypeFloat) {
// MS_LOGE("conv %s's bias data is not float", node->name.c_str());
return RET_ERROR;
}
auto &inputTensor = graph->allTensors.at(inputIndexes.at(0));
auto &weightTensor = graph->allTensors.at(inputIndexes.at(1));

MS_ASSERT(inputTensor != nullptr);
MS_ASSERT(weightTensor != nullptr);
auto inputScale = inputTensor->quantParams.front()->scale;
auto weightScale = weightTensor->quantParams.front()->scale;
auto scale = inputScale * weightScale;
// set bias quant param
std::unique_ptr<QuantParamT> biasQuantParam = GetTensorQuantParam(biasTensor);
if (biasQuantParam == nullptr) {
// MS_LOGE("new QuantParamT failed");
return RET_ERROR;
}
biasQuantParam->inited = true;
biasQuantParam->scale = scale;
biasQuantParam->zeroPoint = 0;
biasQuantParam->numBits = 8;
biasQuantParam->narrowRange = false;
biasQuantParam->min = 0.0;
biasQuantParam->max = 0.0;

// quant bias data
auto bShapeSize = GetShapeSize(*(biasTensor.get()));
auto *qDatas = new (std::nothrow) int32_t[bShapeSize];
if (qDatas == nullptr) {
// MS_LOGE("new qDatas failed");
return RET_ERROR;
}
void *biasData = biasTensor->data.data();
auto *rawDatas = static_cast<float *>(biasData);
for (size_t i = 0; i < bShapeSize; ++i) {
qDatas[i] = (int32_t)std::round(rawDatas[i] / scale);
}
biasTensor->dataType = TypeId::kNumberTypeInt32;
biasTensor->data.clear();
biasTensor->data.resize(bShapeSize * sizeof(int32_t));
auto ret = memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), qDatas, bShapeSize * sizeof(int32_t));
if (ret != EOK) {
// MS_LOGE("memcpy_s failed: %d", ret);
return RET_ERROR;
}
delete[] qDatas;
return RET_OK;
}

STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
MS_ASSERT(subGraph != nullptr);
MS_ASSERT(node != nullptr);
MS_ASSERT(node->quantParam.size() == node->inputIndex.size() + node->outputIndex.size());
auto inputIndexes = node->inputIndex;
MS_ASSERT(inputIndexes.size() >= 2);
MS_ASSERT(subGraph->allTensors.size() > inputIndexes.at(1));
auto &weightTensor = subGraph->allTensors.at(inputIndexes.at(1));
if (weightTensor->dataType == TypeId::kNumberTypeInt8) {
return RET_OK;
}
if (weightTensor->dataType != TypeId::kNumberTypeFloat && weightTensor->dataType != TypeId::kNumberTypeUInt8) {
MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8";
return RET_ERROR;
}
size_t wShapeSize = GetShapeSize(*(weightTensor.get()));
void *oriWeightData = weightTensor->data.data();
MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr);
vector<int8_t> qDatas(wShapeSize);
auto weightQauntParam = GetTensorQuantParam(weightTensor);
if (weightTensor->dataType == TypeId::kNumberTypeFloat) { // normal awareing quant
auto *weightData = static_cast<float *>(oriWeightData);
for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get());
}
} else { // tflite awareing quant
auto *weightData = static_cast<uint8_t *>(oriWeightData);
for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = (int32_t)weightData[j] - 128;
}
weightQauntParam->zeroPoint -= 128;
weightTensor->quantParams.clear();
weightTensor->quantParams.emplace_back(weightQauntParam.release());
}

::memcpy(weightTensor->data.data(), qDatas.data(), wShapeSize);
weightTensor->dataType = TypeId::kNumberTypeInt8;
return RET_OK;
}
STATUS AwareQuantizer::DetermineNodeQuantType() {
MS_ASSERT(graph != nullptr);
for (auto &node : graph->nodes) {
MS_ASSERT(node != nullptr);
bool canQuant = true;
for (auto &inTensorIdx : node->inputIndex) {
MS_ASSERT(graph->allTensors.size() > inTensorIdx);
auto &inTensor = graph->allTensors.at(inTensorIdx);
MS_ASSERT(inTensor != nullptr);
if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr ||
!inTensor->quantParams.front()->inited) {
canQuant = false;
break;
}
}

if (canQuant) {
for (auto &outTensorIdx : node->outputIndex) {
MS_ASSERT(graph->allTensors.size() > outTensorIdx);
auto &outTensor = graph->allTensors.at(outTensorIdx);
MS_ASSERT(outTensor != nullptr);
if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr ||
!outTensor->quantParams.front()->inited) {
canQuant = false;
break;
}
}
}
if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
node->quantType = schema::QuantType_AwareTrainning;
} else {
node->quantType = schema::QuantType_QUANT_NONE;
}
}
return RET_OK;
}
} // namespace mindspore::lite::quant

+ 65
- 0
mindspore/lite/tools/converter/quantizer/aware_quantizer.h View File

@@ -0,0 +1,65 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MS_AWARE_QUANTIZER_H
#define MS_AWARE_QUANTIZER_H

#include <array>
#include <string>
#include "tools/converter/quantizer/quantizer.h"
#include "schema/inner/model_generated.h"
#include "include/errorcode.h"

namespace mindspore::lite::quant {
struct InputArray;

class AwareQuantizer : public FbQuantizer {
public:
AwareQuantizer(schema::MetaGraphT *graph, const std::string &inputInferType, const std::string &stdValues,
const std::string &meanValues);

~AwareQuantizer() { delete (mInputArray); }

STATUS RemoveFakeQuant() override;

STATUS GenerateQuantParam() override;

STATUS DetermineNodeQuantType() override;

STATUS DoQuantize() override; // override;

private:
// RemoveFakeQuant
STATUS SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node);

STATUS GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph);

STATUS QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node);

STATUS QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node);

STATUS QuantConvBias(const schema::MetaGraphT *graph, schema::CNodeT *node);

STATUS QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node);

float inputScale = 0.0f;

InputArray *mInputArray;

static const std::array<schema::PrimitiveType, 7> propagatedOps;
};
} // namespace mindspore::lite::quant
#endif

+ 504
- 0
mindspore/lite/tools/converter/quantizer/calc_quant_param.cc View File

@@ -0,0 +1,504 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/quantizer/calc_quant_param.h"
#include <cfloat>
#include <memory>
#include <algorithm>
#include <utility>
#include "tools/common/graph_util.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "schema/inner/ops_generated.h"
#include "src/common/utils.h"

namespace mindspore::lite {
STATUS QuantParamCalcer::ComputeConstQuantParam(const schema::TensorT &tensor, QuantParamT *quantParam) {
MS_ASSERT(quantParam != nullptr);
// int32 weight no need to quant
if (tensor.dataType == TypeId::kNumberTypeInt32 || tensor.dataType == TypeId::kNumberTypeUInt8) {
return RET_OK;
}
if (tensor.dataType != TypeId::kNumberTypeFloat) {
// MS_LOGW("Const Tensor without quantParam should has float dataType, in fact: %d", tensor.dataType);
return RET_ERROR;
}
const auto *constData = reinterpret_cast<const float *>(tensor.data.data());
size_t constTensorShapeSize = GetShapeSize(tensor);
float min = 0.0f;
float max = 0.0f;
// find min and max
for (size_t i = 0; i < constTensorShapeSize; i++) {
min = std::min(min, constData[i]);
max = std::max(max, constData[i]);
}
if (min == 0.0f && max == 0.0f) {
max = 1.0f;
}
bool isQuantExact = true;
for (size_t i = 0; i < constTensorShapeSize; i++) {
isQuantExact &= (constData[i] == min || constData[i] == max);
}
if (!isQuantExact) {
// //MS_LOGD("compute quantParam for const tensor may be a cause of poor inference accuracy");
}
return quant::CalQuantizationParams(quantParam, min, max);
}

// init inTensor quantParam from preNode if possable
// init outTensor quantParam from postNode if possable
int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
MS_ASSERT(node.inputIndex.size() > 0);
MS_ASSERT(node.quantParam.size() == node.inputIndex.size() + node.outputIndex.size());
inputParamDone = 0;
auto inputTensorSize = node.inputIndex.size();
for (size_t i = 0; i < inputTensorSize; i++) {
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i));
auto &tensor = graph->allTensors.at(node.inputIndex.at(i));
MS_ASSERT(tensor != nullptr);
auto quantParam = GetTensorQuantParam(tensor);
if (quantParam->inited) { // inited
inputParamDone++;
continue;
}
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i));

MS_ASSERT(tensor != nullptr);
if (tensor->refCount == schema::NodeType_ValueNode && !IsContain(graph->inputIndex, node.inputIndex.at(i))) {
auto status = ComputeConstQuantParam((*tensor), quantParam.get());
if (status != RET_OK) {
// MS_LOGW("ComputeConstQuantParam failed: %d", status);
return status;
}
tensor->quantParams.front() = std::move(quantParam);
inputParamDone++;
continue;
}
}
outputParamDone = 0;
for (unsigned int i : node.outputIndex) {
MS_ASSERT(graph->allTensors.size() > i);
auto &tensor = graph->allTensors.at(i);
MS_ASSERT(tensor != nullptr);
auto quantParam = GetTensorQuantParam(tensor);
MS_ASSERT(quantParam != nullptr);
if (quantParam->inited) { // inited
outputParamDone++;
continue;
}

if (tensor->refCount == 999) {
MS_ASSERT(false);
}
}
return RET_OK;
}

int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) {
auto status = QuantParamCalcer::Calc(subGraph, node);
if (status != RET_OK) {
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
return status;
}
if (inputParamDone != node.inputIndex.size()) {
MS_LOG(ERROR) << "Can not determine inputTensor quantParam, node " << node.name.c_str();
return RET_ERROR;
}
if (outputParamDone != node.outputIndex.size()) {
MS_LOG(ERROR) << "Can not determine outputTensor quantParam, node " << node.name.c_str();
return RET_ERROR;
}
return RET_OK;
}

int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
auto status = QuantParamCalcer::Calc(graph, node);
if (status != RET_OK) {
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
return status;
}
if (inputParamDone != node.inputIndex.size()) {
MS_ASSERT(graph->allTensors.size() > node.outputIndex.at(0));
auto &outTensor = graph->allTensors.at(node.outputIndex.at(0));
MS_ASSERT(outTensor != nullptr);
auto outputQuantParam = GetTensorQuantParam(outTensor);
MS_ASSERT(outputQuantParam != nullptr);
if (!outputQuantParam->inited) {
// MS_LOGW("Can not determine inputTensor quantParam from outputTensor for node %s", node.name.c_str());
return RET_ERROR;
}
for (unsigned int i : node.inputIndex) {
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i));
auto &inTensor = graph->allTensors.at(i);
MS_ASSERT(inTensor != nullptr);
auto inQuantParam = GetTensorQuantParam(inTensor);
if (inQuantParam->inited) {
continue;
}
inTensor->quantParams.front() = std::move(inQuantParam);
}
}
if (outputParamDone != node.outputIndex.size()) {
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0));
auto &inTensor = graph->allTensors.at(node.inputIndex.at(0));
MS_ASSERT(inTensor != nullptr);
auto inQuantParam = GetTensorQuantParam(inTensor);
if (!inQuantParam->inited) {
// MS_LOGW("Can not determine outputTensor quantParam from inputTensor for node %s", node.name.c_str());
return RET_ERROR;
}
for (size_t i = 0; i < node.outputIndex.size(); i++) {
MS_ASSERT(graph->allTensors.size() > node.outputIndex.at(i));
auto &outTensor = graph->allTensors.at(node.outputIndex.at(i));
MS_ASSERT(outTensor != nullptr);
auto outQuantParam = GetTensorQuantParam(outTensor);
if (outQuantParam->inited) {
continue;
}
// todo copy quant params
outTensor->quantParams.front() = std::move(outQuantParam);
}
}
return RET_OK;
}

class CalcConcat : public QuantParamCalcer {
public:
CalcConcat() = default;

int Calc(MetaGraphT *graph, const CNodeT &node) override {
MS_ASSERT(node.outputIndex.size() == 1);
auto status = QuantParamCalcer::Calc(graph, node);
if (status != RET_OK) {
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
return status;
}

if (inputParamDone != node.inputIndex.size()) {
// MS_LOGW("Can not determine concat inputTensor quantParam, node %s", node.name.c_str());
return RET_ERROR;
}

if (outputParamDone != 1) {
MS_ASSERT(outputParamDone == 0);
float minMin = FLT_MAX;
float maxMax = FLT_MIN;
bool narrowRange = false;
int numBits = -1;
for (size_t i = 0; i < node.inputIndex.size(); i++) {
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i));
auto &inTensor = graph->allTensors.at(i);
MS_ASSERT(inTensor != nullptr);
auto inQuantParam = GetTensorQuantParam(inTensor);
MS_ASSERT(inQuantParam != nullptr);
if (!inQuantParam->inited) {
return RET_ERROR;
}
if (numBits == -1) {
narrowRange = inQuantParam->narrowRange;
numBits = inQuantParam->numBits;
} else {
MS_ASSERT(narrowRange == quantParam->narrowRange);
MS_ASSERT(numBits == quantParam->numBits);
}
if (minMin > inQuantParam->min) {
minMin = inQuantParam->min;
}
if (maxMax < inQuantParam->max) {
maxMax = inQuantParam->max;
}
}

MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
auto &outTensor = graph->allTensors.at(node.outputIndex.front());
MS_ASSERT(outTensor != nullptr);
auto outQuantParam = GetTensorQuantParam(outTensor);

status = quant::CalQuantizationParams(outQuantParam.get(), minMin, maxMax, narrowRange, numBits);
if (status != RET_OK) {
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
return RET_ERROR;
}
outputParamDone++;
}

return RET_OK;
}
};

class CalcAdd : public QuantParamCalcer {
public:
CalcAdd() = default;

int Calc(MetaGraphT *graph, const CNodeT &node) override {
MS_ASSERT(node.inputIndex.size() == 2);
MS_ASSERT(node.outputIndex.size() == 1);
auto status = QuantParamCalcer::Calc(graph, node);
if (status != RET_OK) {
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
return status;
}

if (inputParamDone != 2) {
// MS_LOGW("Can not determine add inputTensor quantParam, node %s", node.name.c_str());
return RET_ERROR;
}
if (outputParamDone != 1) {
MS_ASSERT(outputParamDone == 0);
MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
auto &outTensor = graph->allTensors.at(node.outputIndex.front());
MS_ASSERT(outTensor != nullptr);
auto outQuantParam = GetTensorQuantParam(outTensor);

MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0));
auto &tensor0 = graph->allTensors.at(node.inputIndex.at(0));
MS_ASSERT(tensor0 != nullptr);
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(1));
auto &tensor1 = graph->allTensors.at(node.inputIndex.at(1));
MS_ASSERT(tensor1 != nullptr);
auto biasTensor = &tensor0;
auto paramTensor = &tensor1;
if (tensor0->refCount == 999 && (tensor0->dims.empty() || tensor0->dims.size() == 1)) {
biasTensor = &tensor0;
paramTensor = &tensor1;
} else if (tensor1->refCount == 999 && (tensor1->dims.empty() || tensor1->dims.size() == 1)) {
biasTensor = &tensor1;
paramTensor = &tensor0;
} else {
// MS_LOGW("Can not determine add outputTensor quantParam, node %s", node.name.c_str());
return RET_ERROR;
}
auto quantParam = GetTensorQuantParam(*paramTensor);
MS_ASSERT(quantParam != nullptr);
MS_ASSERT(quantParam->inited);
auto min = quantParam->min;
auto max = quantParam->max;
{
if ((*biasTensor)->dataType == TypeId::kNumberTypeFloat) {
MS_ASSERT((*biasTensor)->data.size() == sizeof(float) / sizeof(uint8_t));
void *oriTensorData = (*biasTensor)->data.data();
auto *bias = static_cast<float *>(oriTensorData);
status = quant::CalQuantizationParams(outQuantParam.get(), min + (*bias), max + (*bias));
if (status != RET_OK) {
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
return RET_ERROR;
}
} else if ((*biasTensor)->dataType == TypeId::kNumberTypeUInt8) {
MS_ASSERT((*biasTensor)->data.size() == 1);
void *oriTensorData = (*biasTensor)->data.data();
auto *bias = static_cast<uint8_t *>(oriTensorData);
status = quant::CalQuantizationParams(outQuantParam.get(), min + (*bias), max + (*bias));
if (status != RET_OK) {
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
return RET_ERROR;
}
} else {
// MS_LOGW("Unsupported tensor dataType: %d", (*biasTensor)->dataType);
return RET_ERROR;
}
}
}
return RET_OK;
}
};

class CalcRealDiv : public QuantParamCalcer {
public:
CalcRealDiv() = default;

int Calc(MetaGraphT *graph, const CNodeT &node) override {
MS_ASSERT(node.inputIndex.size() == 2);
MS_ASSERT(node.outputIndex.size() == 1);
auto status = QuantParamCalcer::Calc(graph, node);
if (status != RET_OK) {
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
return status;
}

if (inputParamDone != 2) {
// MS_LOGW("Can not determine realdiv inputTensor quantParam, node %s", node.name.c_str());
return RET_ERROR;
}
if (outputParamDone != 1) {
MS_ASSERT(outputParamDone == 0);
MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
auto &outTensor = graph->allTensors.at(node.outputIndex.front());
MS_ASSERT(outTensor != nullptr);
auto outQuantParam = GetTensorQuantParam(outTensor);

MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0));
auto &tensor0 = graph->allTensors.at(node.inputIndex.at(0));
MS_ASSERT(tensor0 != nullptr);
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(1));
auto &tensor1 = graph->allTensors.at(node.inputIndex.at(1));
MS_ASSERT(tensor1 != nullptr);
if (tensor1->refCount == 999 && (tensor1->dims.empty() || tensor1->dims.size() == 1)) {
auto quantParam = GetTensorQuantParam(tensor1);
auto min = quantParam->min;
auto max = quantParam->max;
{
if (tensor1->dataType == TypeId::kNumberTypeFloat) {
MS_ASSERT(tensor1->data.size() == sizeof(float) / sizeof(uint8_t));
void *oriTensorData = tensor1->data.data();
auto *div = static_cast<float *>(oriTensorData);
MS_ASSERT(*div != 0);
status = quant::CalQuantizationParams(outQuantParam.get(), min / (*div), max / (*div));
if (status != RET_OK) {
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
return RET_ERROR;
}
} else if (tensor1->dataType == TypeId::kNumberTypeUInt8) {
MS_ASSERT(tensor1->data.size() == 1);
void *oriTensorData = tensor1->data.data();
auto *div = static_cast<uint8_t *>(oriTensorData);
status = quant::CalQuantizationParams(outQuantParam.get(), min / (*div), max + (*div));
if (status != RET_OK) {
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
return RET_ERROR;
}
} else {
// MS_LOGW("Unsupported tensor dataType: %d", tensor1->dataType);
return RET_ERROR;
}
}
} else {
// MS_LOGW("Can not determine realDiv outputTensor quantParam, node %s", node.name.c_str());
return RET_ERROR;
}
}
return RET_OK;
}
};

class CalcToSet : public QuantParamCalcer {
public:
CalcToSet(float min, float max) : min(min), max(max) {}

int Calc(MetaGraphT *graph, const CNodeT &node) override {
MS_ASSERT(node.inputIndex.size() == 1);
MS_ASSERT(node.outputIndex.size() == 1);
auto status = QuantParamCalcer::Calc(graph, node);
if (status != RET_OK) {
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
return status;
}
// input
if (inputParamDone != node.inputIndex.size()) {
// MS_LOGW("Can not determine inputTensor quantParam, node %s", node.name.c_str());
return RET_ERROR;
}
// output
std::unique_ptr<QuantParamT> quantParam(new (std::nothrow) QuantParamT());
if (quantParam == nullptr) {
// MS_LOGW("new QuantParamT failed");
return RET_ERROR;
}
quantParam->scale = (max - min) / 256;
MS_ASSERT(quantParam->scale != 0);
quantParam->zeroPoint = int32_t(std::round(256 - max / quantParam->scale));
quantParam->min = min;
quantParam->max = max;
quantParam->inited = true;
MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
auto &outTensor = graph->allTensors.at(node.outputIndex.front());
MS_ASSERT(outTensor != nullptr);
outTensor->quantParams.front() = std::move(quantParam);
return RET_OK;
}

protected:
float min;
float max;
};

class CalcActivation : public QuantParamCalcer {
public:
CalcActivation() = default;

int Calc(MetaGraphT *subGraph, const CNodeT &node) override {
MS_ASSERT(node.inputIndex.size() == 1);
MS_ASSERT(node.outputIndex.size() == 1);
MS_ASSERT(node.attr.AsActivation() != nullptr);
if (node.primitive->value.AsActivation()->type == schema::ActivationType_SIGMOID) {
auto calcToSet = CalcToSet(0, 1);
return calcToSet.Calc(subGraph, node);
} else {
auto calCommon = CommonCalcer();
return calCommon.Calc(subGraph, node);
}
}
};

QuantParamCalcRegister::QuantParamCalcRegister() {
bool hasError = false;
auto baseCalcer = new (std::nothrow) QuantParamCalcer();
if (baseCalcer == nullptr) {
// MS_LOGW("new QuantParamCalcer failed");
hasError = true;
}
auto commonCalcer = new (std::nothrow) CommonCalcer();
if (commonCalcer == nullptr) {
// MS_LOGW("new commonCalcer failed");
hasError = true;
}
auto linearCalcer = new (std::nothrow) LinearCalcer();
if (linearCalcer == nullptr) {
// MS_LOGW("new linearCalcer failed");
hasError = true;
}
if (!hasError) {
_registerMap[schema::PrimitiveType_Concat] = new CalcConcat();
_registerMap[schema::PrimitiveType_Activation] = new CalcActivation();
_registerMap[schema::PrimitiveType_Add] = new CalcAdd();
_registerMap[schema::PrimitiveType_Mul] = commonCalcer;
_registerMap[schema::PrimitiveType_Conv2D] = commonCalcer;
_registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer;
_registerMap[schema::PrimitiveType_Pooling] = linearCalcer;
_registerMap[schema::PrimitiveType_Resize] = linearCalcer;
_registerMap[schema::PrimitiveType_Reshape] = linearCalcer;
_registerMap[schema::PrimitiveType_Shape] = linearCalcer; // todo if shape influence postNode's output quantParam
_registerMap[schema::PrimitiveType_SoftMax] = new CalcToSet(0, 1);
_registerMap[schema::PrimitiveType_Squeeze] = linearCalcer;
_registerMap[schema::PrimitiveType_RealDiv] = new CalcRealDiv();
_registerMap[schema::PrimitiveType_Reduce] = commonCalcer;
_registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer;
_registerMap[schema::PrimitiveType_Mean] = linearCalcer;
_registerMap[schema::PrimitiveType_Transpose] = linearCalcer;
_registerMap[schema::PrimitiveType_MatMul] = commonCalcer;
_registerMap[schema::PrimitiveType_FullConnection] = commonCalcer;
_registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer;
_registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer;
// todo
// detection_postprocess op's quant param will not infer only fetch from preNode or postNode
// because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float.
// if quantTransNode is inserted after detection_postprocess node, there will be some errors
_registerMap[schema::PrimitiveType_DetectionPostProcess] = baseCalcer;
}
}

QuantParamCalcRegister *QuantParamCalcRegister::GetInstance() {
static QuantParamCalcRegister instance;
return &instance;
}

QuantParamCalcer *QuantParamCalcRegister::GetQuantParamCalcer(schema::PrimitiveType opType) {
auto it = _registerMap.find(opType);
if (it != _registerMap.end()) {
return it->second;
}
return nullptr;
}
} // namespace mindspore::lite

+ 69
- 0
mindspore/lite/tools/converter/quantizer/calc_quant_param.h View File

@@ -0,0 +1,69 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef CALC_QUANT_PARAM_H
#define CALC_QUANT_PARAM_H

#include <unordered_map>
#include <memory>
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace lite {
static constexpr int CONVLUTION_INPUT_NUM = 3;

class QuantParamCalcer {
public:
virtual ~QuantParamCalcer() = default;
virtual int Calc(schema::MetaGraphT *graph, const schema::CNodeT &node);

protected:
STATUS ComputeConstQuantParam(const schema::TensorT &tensor, schema::QuantParamT *quantParam);

protected:
size_t inputParamDone = 0;
size_t outputParamDone = 0;
};

class CommonCalcer : public QuantParamCalcer {
public:
CommonCalcer() = default;
~CommonCalcer() override = default;
int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override;
};

class LinearCalcer : public QuantParamCalcer {
public:
LinearCalcer() = default;
~LinearCalcer() override = default;
int Calc(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
};

class QuantParamCalcRegister {
public:
virtual ~QuantParamCalcRegister() = default;
QuantParamCalcer *GetQuantParamCalcer(schema::PrimitiveType opType);
static QuantParamCalcRegister *GetInstance();

private:
QuantParamCalcRegister();
std::unordered_map<schema::PrimitiveType, QuantParamCalcer *> _registerMap;
};
} // namespace lite
} // namespace mindspore

#endif

+ 230
- 169
mindspore/lite/tools/converter/quantizer/quantize_util.cc View File

@@ -39,126 +39,127 @@ QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThr
: mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {} : mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {}


bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const { bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const {
size_t i = 0;
for (i = 0; i < mConvTypes.size(); i++) {
if (node->fullname_with_scope().find(mConvTypes[i]) == 0) {
break;
}
size_t i = 0;
for (i = 0; i < mConvTypes.size(); i++) {
if (node->fullname_with_scope().find(mConvTypes[i]) == 0) {
break;
} }
}


if ((i == mConvTypes.size()) || (node->size() < 3)) {
return false;
}
if ((i == mConvTypes.size()) || (node->size() < 3)) {
return false;
}


auto inputNode = node->input(2);
if (!inputNode->isa<Parameter>()) {
return false;
}
auto paramNode = inputNode->cast<ParameterPtr>();
auto abstract_base = paramNode->abstract();
if (abstract_base == nullptr) {
return false;
}
auto inputNode = node->input(2);
if (!inputNode->isa<Parameter>()) {
return false;
}
auto paramNode = inputNode->cast<ParameterPtr>();
auto abstract_base = paramNode->abstract();
if (abstract_base == nullptr) {
return false;
}


if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
return false;
}
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
size_t shapeSize = 1;
for (auto dim : weight_shape) {
shapeSize = shapeSize * dim;
}
if (shapeSize < mWeightSize) {
MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
return false;
}
if (weight_shape[0] <= mConvWeightQuantChannelThreshold) {
MS_LOG(INFO) << "channel less mConvWeightQuantChannelThreshold!" << weight_shape[0];
return false;
}
if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
return false;
}
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
size_t shapeSize = 1;
for (auto dim : weight_shape) {
shapeSize = shapeSize * dim;
}
if (shapeSize < mWeightSize) {
MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
return false;
}
if (weight_shape[0] <= mConvWeightQuantChannelThreshold) {
MS_LOG(INFO) << "channel less mConvWeightQuantChannelThreshold!" << weight_shape[0];
return false;
}


return true;
return true;
} }


bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
if (!node->isa<CNode>()) {
return false;
}
auto cnode = std::dynamic_pointer_cast<CNode>(node);
if (!node->isa<CNode>()) {
return false;
}
auto cnode = std::dynamic_pointer_cast<CNode>(node);


auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
if (primitiveT_value == nullptr) {
MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope();
return false;
}
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
if (primitiveT_value == nullptr) {
MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope();
return false;
}


auto type = primitiveT_value->GetPrimitiveT()->value.type;
MS_LOG(INFO) << "Primitive type: " << type;
static const std::vector<schema::PrimitiveType> uint8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape,
schema::PrimitiveType_Activation};
return IsContain(uint8OpList, type);
auto type = primitiveT_value->GetPrimitiveT()->value.type;
MS_LOG(INFO) << "Primitive type: " << type;
static const std::vector<schema::PrimitiveType> uint8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape,
schema::PrimitiveType_Activation};
return IsContain(uint8OpList, type);
} }


bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const { bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const {
size_t i = 0;
for (i = 0; i < mMulTypes.size(); i++) {
if (node->fullname_with_scope().find(mMulTypes[i]) == 0) {
break;
}
}
if (i == mMulTypes.size()) {
return false;
size_t i = 0;
for (i = 0; i < mMulTypes.size(); i++) {
if (node->fullname_with_scope().find(mMulTypes[i]) == 0) {
break;
} }
}
if (i == mMulTypes.size()) {
return false;
}


if (node->size() < 3) {
MS_LOG(INFO) << "input size less!";
return false;
}
if (node->size() < 3) {
MS_LOG(INFO) << "input size less!";
return false;
}


auto inputNode1 = node->input(1);
auto inputNode2 = node->input(2);
if (inputNode1 == nullptr || inputNode2 == nullptr) {
MS_LOG(INFO) << "mul input is nullptr!";
return false;
}
auto inputNode1 = node->input(1);
auto inputNode2 = node->input(2);
if (inputNode1 == nullptr || inputNode2 == nullptr) {
MS_LOG(INFO) << "mul input is nullptr!";
return false;
}


ParameterPtr paramNode = nullptr;
if (inputNode1->isa<Parameter>()) {
paramNode = inputNode1->cast<ParameterPtr>();
} else if (inputNode2->isa<Parameter>()) {
paramNode = inputNode2->cast<ParameterPtr>();
}
ParameterPtr paramNode = nullptr;
if (inputNode1->isa<Parameter>()) {
paramNode = inputNode1->cast<ParameterPtr>();
} else if (inputNode2->isa<Parameter>()) {
paramNode = inputNode2->cast<ParameterPtr>();
}


if (paramNode == nullptr) {
MS_LOG(INFO) << "invalid paramNode!";
return false;
}
if (paramNode == nullptr) {
MS_LOG(INFO) << "invalid paramNode!";
return false;
}


auto abstract_base = paramNode->abstract();
if (abstract_base == nullptr) {
MS_LOG(INFO) << "abstract is nullptr";
return false;
}
auto abstract_base = paramNode->abstract();
if (abstract_base == nullptr) {
MS_LOG(INFO) << "abstract is nullptr";
return false;
}


if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
return false;
}
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
size_t shapeSize = 1;
for (auto dim : weight_shape) {
shapeSize = shapeSize * dim;
}
if (shapeSize < mWeightSize) {
MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
return false;
}
if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
return false;
}
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
size_t shapeSize = 1;
for (auto dim : weight_shape) {
shapeSize = shapeSize * dim;
}
if (shapeSize < mWeightSize) {
MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
return false;
}


return true;
return true;
} }


void CalFakeNode(const AnfNodePtr &inTensor) { void CalFakeNode(const AnfNodePtr &inTensor) {
@@ -190,56 +191,119 @@ void CalFakeNode(const AnfNodePtr &inTensor) {
// } // }
} }


STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin,
double mMax, bool narrowRange, int quant_max, int quant_min, int num_bits) {
MS_ASSERT(quantParam != nullptr);
if (mMin > 0.0f) {
MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
mMin = 0.0f;
}
if (mMax < 0.0f) {
MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision";
mMax = 0.0f;
}
if (mMin > mMax) {
MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax;
return RET_PARAM_INVALID;
}
if (mMin == mMax) {
if (mMin != 0.0f) {
MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
return RET_ERROR;
}
quantParam->inited = true;
quantParam->min = mMin;
quantParam->max = mMax;
quantParam->scale = 0.0f;
quantParam->zeroPoint = 0;
quantParam->narrowRange = narrowRange;
quantParam->numBits = num_bits;
return RET_OK;
STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin, double mMax, bool narrowRange,
int quant_max, int quant_min, int num_bits) {
MS_ASSERT(quantParam != nullptr);
if (mMin > 0.0f) {
MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
mMin = 0.0f;
}
if (mMax < 0.0f) {
MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision";
mMax = 0.0f;
}
if (mMin > mMax) {
MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax;
return RET_PARAM_INVALID;
}
if (mMin == mMax) {
if (mMin != 0.0f) {
MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
return RET_ERROR;
} }

auto quantMinFloat = static_cast<double>(quant_min);
auto quantMaxFloat = static_cast<double>(quant_max);
double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
const double zeroPointFromMin = quantMinFloat - mMin / scale;
// const double zeroPointFromMax = quantMaxFloat - mMax / scale;
int zeroPoint = static_cast<int32_t>(std::round(zeroPointFromMin));

// The zero point should always be in the range of quantized value,
// [qmin, qmax].
MS_ASSERT(zeroPoint >= quantMin);
MS_ASSERT(zeroPoint <= quantMax);
quantParam->inited = true; quantParam->inited = true;
quantParam->min = mMin; quantParam->min = mMin;
quantParam->max = mMax; quantParam->max = mMax;
quantParam->scale = scale;
quantParam->zeroPoint = zeroPoint;
quantParam->scale = 0.0f;
quantParam->zeroPoint = 0;
quantParam->narrowRange = narrowRange; quantParam->narrowRange = narrowRange;
quantParam->numBits = num_bits; quantParam->numBits = num_bits;
return RET_OK;
}

auto quantMinFloat = static_cast<double>(quant_min);
auto quantMaxFloat = static_cast<double>(quant_max);
double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
const double zeroPointFromMin = quantMinFloat - mMin / scale;
// const double zeroPointFromMax = quantMaxFloat - mMax / scale;
int zeroPoint = static_cast<int32_t>(std::round(zeroPointFromMin));

// The zero point should always be in the range of quantized value,
// [qmin, qmax].
MS_ASSERT(zeroPoint >= quantMin);
MS_ASSERT(zeroPoint <= quantMax);
quantParam->inited = true;
quantParam->min = mMin;
quantParam->max = mMax;
quantParam->scale = scale;
quantParam->zeroPoint = zeroPoint;
quantParam->narrowRange = narrowRange;
quantParam->numBits = num_bits;

return RET_OK;
}


STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax,
bool narrowRange, int numBits) {
MS_ASSERT(quantParam != nullptr);
if (mMin > 0.0f) {
MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
mMin = 0.0f;
}
if (mMax < 0.0f) {
MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision";
mMax = 0.0f;
}
if (mMin > mMax) {
MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax;
return RET_PARAM_INVALID;
}
if (mMin == mMax) {
if (mMin != 0.0f) {
MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
return RET_ERROR;
}
quantParam->inited = true;
quantParam->min = mMin;
quantParam->max = mMax;
quantParam->scale = 0.0f;
quantParam->zeroPoint = 0;
quantParam->narrowRange = narrowRange;
quantParam->numBits = numBits;
return RET_OK; return RET_OK;
}

int quantMin = narrowRange ? 1 : 0;
int quantMax = (1 << (unsigned int)numBits) - 1;
auto quantMinFloat = static_cast<double>(quantMin);
auto quantMaxFloat = static_cast<double>(quantMax);
double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
const double zeroPointFromMin = quantMinFloat - mMin / scale;
const double zeroPointFromMax = quantMaxFloat - mMax / scale;
const double zpFromMinError = std::abs(quantMinFloat) + std::abs(mMin / scale);
const double zpFromMaxError = std::abs(quantMaxFloat) + std::abs(mMax / scale);
const double zpDouble = zpFromMinError < zpFromMaxError ? zeroPointFromMin : zeroPointFromMax;
int zeroPoint;
if (zpDouble < quantMinFloat) {
zeroPoint = quantMin;
} else if (zpDouble > quantMaxFloat) {
zeroPoint = quantMax;
} else {
zeroPoint = static_cast<int32_t>(std::round(zpDouble));
}
// The zero point should always be in the range of quantized value,
// [qmin, qmax].
MS_ASSERT(zeroPoint >= quantMin);
MS_ASSERT(zeroPoint <= quantMax);
quantParam->inited = true;
quantParam->min = mMin;
quantParam->max = mMax;
quantParam->scale = scale;
quantParam->zeroPoint = zeroPoint;
quantParam->narrowRange = narrowRange;
quantParam->numBits = numBits;

return RET_OK;
} }


STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum, STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum,
@@ -292,14 +356,14 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_


weightPtr->set_quant_param(quantParam); weightPtr->set_quant_param(quantParam);
} }
auto ret = memcpy_s(const_cast<float*>(rawDatas), weightPtr->tensor_size(),
qDatas.data(), shapeSize * sizeof(int8_t));
auto ret =
memcpy_s(const_cast<float *>(rawDatas), weightPtr->tensor_size(), qDatas.data(), shapeSize * sizeof(int8_t));
if (ret != EOK) { if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret; MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR; return RET_ERROR;
} }
if (quantType == QuantType_WeightQuant) { if (quantType == QuantType_WeightQuant) {
PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum);
PostBitPack(const_cast<float *>(rawDatas), shapeSize, bitNum);
} }


weightPtr->set_tensor_type(kNumberTypeInt8); weightPtr->set_tensor_type(kNumberTypeInt8);
@@ -338,14 +402,13 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
qDatas[i] = quant_max; qDatas[i] = quant_max;
} else if (quant_data < quant_min) { } else if (quant_data < quant_min) {
qDatas[i] = quant_min; qDatas[i] = quant_min;
} else {
} else {
qDatas[i] = static_cast<int8_t>(quant_data); qDatas[i] = static_cast<int8_t>(quant_data);
} }
} }


weightPtr->set_quant_param(quantParam); weightPtr->set_quant_param(quantParam);
auto ret = memcpy_s(rawDatas, weightPtr->tensor_size(),
qDatas.data(), shapeSize * sizeof(int8_t));
auto ret = memcpy_s(rawDatas, weightPtr->tensor_size(), qDatas.data(), shapeSize * sizeof(int8_t));
if (ret != EOK) { if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret; MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR; return RET_ERROR;
@@ -358,34 +421,32 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
weightPtr->set_tensor_size(shapeSize * sizeof(int8_t)); weightPtr->set_tensor_size(shapeSize * sizeof(int8_t));
} }



return RET_OK;
return RET_OK;
} }


STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) { STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) {
auto *rawDatas = reinterpret_cast<uint8_t *>(weight);
vector<uint8_t> qDatas(rawDatas, rawDatas + shapeSize);
vector<uint8_t> qDatas_packed;
if (bitNum < 8 && bitNum > 1) {
BitPack weight_bitpack(bitNum);
weight_bitpack.BitPacking(qDatas, qDatas_packed);
if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas_packed[0], shapeSize)) {
MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed";
return RET_ERROR;
}
} else if (bitNum == 8) {
if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas[0], shapeSize)) {
MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas failed";
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "bitNum must be between 0 and 8 : " << bitNum;
return RET_ERROR;
auto *rawDatas = reinterpret_cast<uint8_t *>(weight);
vector<uint8_t> qDatas(rawDatas, rawDatas + shapeSize);
vector<uint8_t> qDatas_packed;
if (bitNum < 8 && bitNum > 1) {
BitPack weight_bitpack(bitNum);
weight_bitpack.BitPacking(qDatas, qDatas_packed);
if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas_packed[0], shapeSize)) {
MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed";
return RET_ERROR;
}
} else if (bitNum == 8) {
if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas[0], shapeSize)) {
MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas failed";
return RET_ERROR;
} }
} else {
MS_LOG(ERROR) << "bitNum must be between 0 and 8 : " << bitNum;
return RET_ERROR;
}


return RET_OK;
return RET_OK;
} }
} // namespace quant } // namespace quant
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 35
- 0
mindspore/lite/tools/converter/quantizer/quantize_util.h View File

@@ -62,6 +62,41 @@ class QuantStrategy {
STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin, double mMax, STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin, double mMax,
bool narrowRange, int quant_max, int quant_min, int num_bits); bool narrowRange, int quant_max, int quant_min, int num_bits);


STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax,
bool narrowRange = false, int numBits = UINT8_QUANTIZATION);

template <typename T>
T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
MS_ASSERT(quantParam != nullptr);
MS_ASSERT(quantParam->inited);
const auto scale = quantParam->scale;
const auto zeroPoint = quantParam->zeroPoint;
const auto numBit = quantParam->numBits;
const auto narrowRange = quantParam->narrowRange;
const double maxLimit = static_cast<float>((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale;
double minLimit;
if (narrowRange) {
minLimit = static_cast<float>(1 - zeroPoint) * scale;
} else {
minLimit = static_cast<float>(0 - zeroPoint) * scale;
}
return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
double tmp = 0.0f;
if (originData > maxLimit) {
tmp = maxLimit;
} else if (originData < minLimit) {
tmp = minLimit;
} else {
tmp = originData;
}
auto quantData = static_cast<T>(std::round(tmp / scale + zeroPoint));
if (quantData == 0 && narrowRange) {
quantData++;
}
return quantData;
}();
}

template <typename T> template <typename T>
T QuantizeData(float originData, const AnfQuantParam *quantParam, int quant_max, int quant_min) { T QuantizeData(float originData, const AnfQuantParam *quantParam, int quant_max, int quant_min) {
MS_ASSERT(quantParam != nullptr); MS_ASSERT(quantParam != nullptr);


+ 8
- 11
mindspore/lite/tools/converter/quantizer/quantizer.cc View File

@@ -15,22 +15,19 @@
*/ */


#include "mindspore/lite/tools/converter/quantizer/quantizer.h" #include "mindspore/lite/tools/converter/quantizer/quantizer.h"
#include "schema/inner/model_generated.h"


namespace mindspore {
namespace lite {
namespace quant {
Quantizer::Quantizer(FuncGraphPtr graph) : funcGraph(graph) {
if (funcGraph == nullptr) {
return;
}
}
namespace mindspore::lite::quant {


STATUS Quantizer::GenerateQuantParam() { return RET_OK; } STATUS Quantizer::GenerateQuantParam() { return RET_OK; }


STATUS Quantizer::RemoveFakeQuant() { return RET_OK; } STATUS Quantizer::RemoveFakeQuant() { return RET_OK; }


STATUS Quantizer::DetermineNodeQuantType() { return RET_OK; } STATUS Quantizer::DetermineNodeQuantType() { return RET_OK; }
} // namespace quant
} // namespace lite
} // namespace mindspore


STATUS FbQuantizer::GenerateQuantParam() { return RET_OK; }

STATUS FbQuantizer::RemoveFakeQuant() { return RET_OK; }

STATUS FbQuantizer::DetermineNodeQuantType() { return RET_OK; }
} // namespace mindspore::lite::quant

+ 36
- 21
mindspore/lite/tools/converter/quantizer/quantizer.h View File

@@ -18,48 +18,63 @@
#define MS_QUANTIZER_H #define MS_QUANTIZER_H


#include <unordered_map> #include <unordered_map>
#include <utility>
#include <memory>
#include "include/errorcode.h" #include "include/errorcode.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "include/model.h"
#include "base/base.h" #include "base/base.h"
#include "src/param_value_lite.h" #include "src/param_value_lite.h"
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"


namespace mindspore {
namespace lite {
namespace quant {
namespace mindspore::lite::quant {
using STATUS = int; using STATUS = int;
enum QuantType { enum QuantType {
QuantType_QUANT_NONE = 0,
QuantType_AwareTraining = 1,
QuantType_WeightQuant = 2,
QuantType_PostTraining = 3,
QuantType_MIN = QuantType_QUANT_NONE,
QuantType_MAX = QuantType_PostTraining
QuantType_QUANT_NONE = 0,
QuantType_AwareTraining = 1,
QuantType_WeightQuant = 2,
QuantType_PostTraining = 3,
QuantType_MIN = QuantType_QUANT_NONE,
QuantType_MAX = QuantType_PostTraining
}; };


class Quantizer { class Quantizer {
public: public:
explicit Quantizer(FuncGraphPtr graph);
explicit Quantizer(FuncGraphPtr graph) : funcGraph(std::move(graph)) {}


~Quantizer() = default;
~Quantizer() = default;


virtual STATUS RemoveFakeQuant();
virtual STATUS RemoveFakeQuant();


virtual STATUS GenerateQuantParam();
virtual STATUS GenerateQuantParam();


virtual STATUS DetermineNodeQuantType();
virtual STATUS DetermineNodeQuantType();


virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0;
virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0;


mindspore::lite::converter::Flags flags; mindspore::lite::converter::Flags flags;
protected: protected:
FuncGraphPtr funcGraph = nullptr;
FuncGraphPtr funcGraph = nullptr;
}; };
} // namespace quant
} // namespace lite
} // namespace mindspore


#endif
class FbQuantizer {
public:
explicit FbQuantizer(schema::MetaGraphT *graph) : graph(graph) {}

~FbQuantizer() = default;

virtual STATUS RemoveFakeQuant();

virtual STATUS GenerateQuantParam();

virtual STATUS DetermineNodeQuantType();


virtual STATUS DoQuantize() = 0;

protected:
std::shared_ptr<schema::MetaGraphT> graph = nullptr;
};
} // namespace mindspore::lite::quant

#endif

Loading…
Cancel
Save