From a0334145d224f3a7afe85a0a2f34e4f6d6f99ff2 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Sat, 24 Oct 2020 14:38:58 +0800 Subject: [PATCH] bug fix --- .../lite/nnacl/int8/quant_dtype_cast_int8.c | 39 ++++- .../lite/nnacl/int8/quant_dtype_cast_int8.h | 10 +- mindspore/lite/src/common/utils.h | 34 ---- .../kernel/arm/base/quant_dtype_cast.cc | 36 +++- mindspore/lite/src/sub_graph_kernel.cc | 2 +- mindspore/lite/src/sub_graph_kernel.h | 34 ++++ .../lite/tools/converter/anf_transform.cc | 4 + .../lite/tools/converter/converter_flags.cc | 4 +- .../tools/converter/graphdef_transform.cc | 13 +- .../graph/dtype_trans_pass.cc | 54 +++--- .../legacy_optimizer/graph/dtype_trans_pass.h | 2 +- .../graph/tensor_quant_pass.cc | 1 + .../converter/quantizer/aware_quantizer.cc | 161 ------------------ .../converter/quantizer/aware_quantizer.h | 44 ----- .../converter/quantizer/calc_quant_param.cc | 2 +- 15 files changed, 140 insertions(+), 300 deletions(-) delete mode 100644 mindspore/lite/tools/converter/quantizer/aware_quantizer.cc delete mode 100644 mindspore/lite/tools/converter/quantizer/aware_quantizer.h diff --git a/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c b/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c index ad532d57fa..14fb8db68b 100644 --- a/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c +++ b/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c @@ -18,7 +18,7 @@ #include "nnacl/int8/quant_dtype_cast_int8.h" #include "nnacl/errorcode.h" -int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { +int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { if (quant_values == NULL || real_values == NULL) { return NNACL_PARAM_INVALID; } @@ -29,13 +29,13 @@ int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale return NNACL_OK; } -int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { +int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { if (quant_values == NULL || real_values == NULL) { return NNACL_PARAM_INVALID; } for (int i = 0; i < size; ++i) { - float temp = round(real_values[i] * 1.0 / scale + zp); + float temp = (float)round(real_values[i] * 1.0 / scale + zp); if (temp > 127) { quant_values[i] = 127; } else if (temp < -128) { @@ -47,7 +47,36 @@ int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float sca return NNACL_OK; } -int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size) { +int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + real_values[i] = (float)((int)quant_values[i] - zp) * scale; + } + return NNACL_OK; +} + +int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + float temp = (float)round(real_values[i] * 1.0 / scale + zp); + if (temp > 255) { + quant_values[i] = 255; + } else if (temp < 0) { + quant_values[i] = 0; + } else { + quant_values[i] = (uint8_t)temp; + } + } + return NNACL_OK; +} + +int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size) { if (quant_values == NULL || real_values == NULL) { return NNACL_PARAM_INVALID; } @@ -65,7 +94,7 @@ int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size return NNACL_OK; } -int DoQuantizeToInt8FromUint8(uint8_t *real_values, int8_t *quant_values, int size) { +int UInt8ToInt8(const uint8_t *real_values, int8_t *quant_values, int size) { if (quant_values == NULL || real_values == NULL) { return NNACL_PARAM_INVALID; } diff --git a/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.h b/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.h index 941d1952e6..e5e843f9ec 100644 --- a/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.h +++ b/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.h @@ -28,10 +28,12 @@ typedef struct QuantDTypeCastParameter { #ifdef __cplusplus extern "C" { #endif -int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); -int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size); -int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size); -int DoQuantizeToInt8FromUint8(uint8_t *real_values, int8_t *quant_values, int size); +int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size); +int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size); +int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size); +int UInt8ToInt8(const uint8_t *real_values, int8_t *quant_values, int size); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/common/utils.h b/mindspore/lite/src/common/utils.h index fbfb9ab04b..245b32ece3 100644 --- a/mindspore/lite/src/common/utils.h +++ b/mindspore/lite/src/common/utils.h @@ -27,9 +27,6 @@ #include "src/common/log_adapter.h" #include "tools/common/option.h" #include "include/errorcode.h" -#ifdef ENABLE_ARM64 -#include "nnacl/optimized_kernel.h" -#endif namespace mindspore { namespace lite { @@ -190,37 +187,6 @@ inline Option GenericParseValue(const std::string &value) { return Option(None()); } -using Float16CastFunc = void (*)(const void *, void *, int); - -class Float16CastUtil { - public: - static Float16CastUtil *GetInstance() { - static Float16CastUtil float16_cast_util; - return &float16_cast_util; - } - - private: - Float16CastUtil() { -#ifdef ENABLE_ARM64 - void *fp16_op_handler = Float16Module::GetInstance()->float16_op_handler_; - if (fp16_op_handler != nullptr) { - dlerror(); - *(reinterpret_cast(&float16_to_float32_func_)) = dlsym(fp16_op_handler, "Float16ToFloat32_fp16_handler"); - *(reinterpret_cast(&float32_to_float16_func_)) = dlsym(fp16_op_handler, "Float32ToFloat16_fp16_handler"); - auto dlopen_error = dlerror(); - if (dlopen_error != nullptr) { - MS_LOG(ERROR) << "load float16 cast func failed! " << dlopen_error << "."; - } - } -#endif - } - ~Float16CastUtil() = default; - - public: - Float16CastFunc float16_to_float32_func_ = nullptr; - Float16CastFunc float32_to_float16_func_ = nullptr; -}; - } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc index d6fb0f919c..bbfd9f9409 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc @@ -66,6 +66,16 @@ int QuantDTypeCastCPUKernel::Init() { MS_LOG(ERROR) << "param data type and tensor data type do not match."; return RET_ERROR; } + } else if (param->srcT == kNumberTypeUInt8 && param->dstT == kNumberTypeFloat32) { + if (in_tensor->data_type() != kNumberTypeUInt8 || out_tensor->data_type() != kNumberTypeFloat32) { + MS_LOG(ERROR) << "param data type and tensor data type do not match."; + return RET_ERROR; + } + } else if (param->srcT == kNumberTypeFloat32 && param->dstT == kNumberTypeUInt8) { + if (in_tensor->data_type() != kNumberTypeFloat32 || out_tensor->data_type() != kNumberTypeUInt8) { + MS_LOG(ERROR) << "param data type and tensor data type do not match."; + return RET_ERROR; + } } else { MS_LOG(ERROR) << "param data type not supported:" << " src: " << param->srcT << " dst: " << param->dstT; @@ -106,20 +116,26 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, quant_arg.zeroPoint, num_unit_thread); } else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeInt8) { - ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, - quant_arg.zeroPoint, num_unit_thread); + ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, + quant_arg.zeroPoint, num_unit_thread); } else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) { - ret = DoDequantizeInt8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); + ret = Int8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); + } else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeFloat32) { + ret = DoDequantizeUInt8ToFp32(uint8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, + quant_arg.zeroPoint, num_unit_thread); + } else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeUInt8) { + ret = DoQuantizeFp32ToUInt8(float32_ptr_ + thread_offset, uint8_ptr_ + thread_offset, quant_arg.scale, + quant_arg.zeroPoint, num_unit_thread); } else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeInt8) { - ret = DoQuantizeToInt8FromUint8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread); + ret = UInt8ToInt8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread); } else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeInt8) { auto input_quant_arg = in_tensors_.front()->GetQuantParams().front(); ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, num_unit_thread, input_quant_arg.scale, input_quant_arg.zeroPoint); if (ret) { auto output_quant_arg = out_tensors_.front()->GetQuantParams().front(); - ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset, - output_quant_arg.scale, output_quant_arg.zeroPoint, num_unit_thread); + ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset, output_quant_arg.scale, + output_quant_arg.zeroPoint, num_unit_thread); } } @@ -162,6 +178,14 @@ int QuantDTypeCastCPUKernel::Run() { int8_ptr_ = reinterpret_cast(in_tensors_[0]->data_c()); int8_out_ptr_ = reinterpret_cast(out_tensors_[0]->data_c()); float32_ptr_ = new float[in_tensors_[0]->ElementsNum()]; + } else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeUInt8 && + out_tensors_[0]->data_type() == TypeId::kNumberTypeFloat32) { + uint8_ptr_ = reinterpret_cast(in_tensors_[0]->data_c()); + float32_ptr_ = reinterpret_cast(out_tensors_[0]->data_c()); + } else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeFloat32 && + out_tensors_[0]->data_type() == TypeId::kNumberTypeUInt8) { + float32_ptr_ = reinterpret_cast(in_tensors_[0]->data_c()); + uint8_ptr_ = reinterpret_cast(out_tensors_[0]->data_c()); } auto ret = ParallelLaunch(this->context_->thread_pool_, QuantDTypeCastRun, this, thread_n_num_); diff --git a/mindspore/lite/src/sub_graph_kernel.cc b/mindspore/lite/src/sub_graph_kernel.cc index 2018e6d65a..b1070e33d1 100644 --- a/mindspore/lite/src/sub_graph_kernel.cc +++ b/mindspore/lite/src/sub_graph_kernel.cc @@ -178,7 +178,7 @@ int CpuFp16SubGraph::PreProcess() { } int CpuFp16SubGraph::PostProcess() { - auto fp16_to_fp32_cast_func = lite::Float16CastUtil::GetInstance()->float16_to_float32_func_; + auto fp16_to_fp32_cast_func = kernel::Float16CastUtil::GetInstance()->float16_to_float32_func_; if (fp16_to_fp32_cast_func == nullptr) { MS_LOG(ERROR) << "Can not find cast fp16 to fp32 func"; return RET_ERROR; diff --git a/mindspore/lite/src/sub_graph_kernel.h b/mindspore/lite/src/sub_graph_kernel.h index 92ec2c7d49..7e03586cbc 100644 --- a/mindspore/lite/src/sub_graph_kernel.h +++ b/mindspore/lite/src/sub_graph_kernel.h @@ -22,8 +22,42 @@ #include #include "src/lite_kernel.h" #include "src/executor.h" +#ifdef ENABLE_ARM64 +#include "nnacl/optimized_kernel.h" +#endif namespace mindspore::kernel { +using Float16CastFunc = void (*)(const void *, void *, int); + +class Float16CastUtil { + public: + static Float16CastUtil *GetInstance() { + static Float16CastUtil float16_cast_util; + return &float16_cast_util; + } + + private: + Float16CastUtil() { +#ifdef ENABLE_ARM64 + void *fp16_op_handler = Float16Module::GetInstance()->float16_op_handler_; + if (fp16_op_handler != nullptr) { + dlerror(); + *(reinterpret_cast(&float16_to_float32_func_)) = dlsym(fp16_op_handler, "Float16ToFloat32_fp16_handler"); + *(reinterpret_cast(&float32_to_float16_func_)) = dlsym(fp16_op_handler, "Float32ToFloat16_fp16_handler"); + auto dlopen_error = dlerror(); + if (dlopen_error != nullptr) { + MS_LOG(ERROR) << "load float16 cast func failed! " << dlopen_error << "."; + } + } +#endif + } + ~Float16CastUtil() = default; + + public: + Float16CastFunc float16_to_float32_func_ = nullptr; + Float16CastFunc float32_to_float16_func_ = nullptr; +}; + class SubGraphKernel : public LiteKernel { public: explicit SubGraphKernel(const std::vector &inputs, const std::vector &outputs, diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 848020897a..5c0cf91b0c 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -46,6 +46,10 @@ AnfTransform::~AnfTransform() = default; FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const converter::Flags *config) { MS_ASSERT(nullptr != old_graph); + if (config == nullptr) { + MS_LOG(ERROR) << "config shoud be specified"; + return nullptr; + } // fusion const_fold auto optimizer = std::make_shared(); auto pm = std::make_shared("anf fusion pass manager", false); diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index e2fc3f1eb7..743264d056 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -84,7 +84,7 @@ int Flags::Init(int argc, const char **argv) { } if (this->inputDataTypeIn == "FLOAT") { - this->inputDataType = TypeId::kNumberTypeFloat; + this->inputDataType = TypeId::kNumberTypeFloat32; } else if (this->inputDataTypeIn == "INT8") { this->inputDataType = TypeId::kNumberTypeInt8; } else if (this->inputDataTypeIn == "UINT8") { @@ -98,7 +98,7 @@ int Flags::Init(int argc, const char **argv) { } if (this->outputDataTypeIn == "FLOAT") { - this->outputDataType = TypeId::kNumberTypeFloat; + this->outputDataType = TypeId::kNumberTypeFloat32; } else if (this->outputDataTypeIn == "INT8") { this->outputDataType = TypeId::kNumberTypeInt8; } else if (this->outputDataTypeIn == "UINT8") { diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index f53b7fdd7c..74371e900d 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -22,7 +22,6 @@ #include "tools/converter/converter_flags.h" #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" -#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h" #include "tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h" @@ -36,7 +35,6 @@ #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" -#include "tools/converter/quantizer/aware_quantizer.h" using std::string; namespace mindspore::lite { @@ -120,15 +118,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { return status; } } - { - Optimizer inferQuantParamOtimizer; - inferQuantParamOtimizer.AddPass(new (std::nothrow) InferQuantParamPass()); - status = inferQuantParamOtimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run tensorQuantOptimizer graphPasses Failed"; - return status; - } - } { Optimizer fusionOptimizer; @@ -158,6 +147,8 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); dTypeTransPass->SetInputDataDType(ctx.inputDataType); dTypeTransPass->SetOutputDataDType(ctx.outputDataType); + quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); quantNodeOptimizer.AddPass(dTypeTransPass); quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index e5adca31c8..5c84694b24 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -53,18 +53,18 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); auto &graphInIdxes = graph->inputIndex; - if (this->inputDataDType == TypeId::kNumberTypeInt8 || this->inputDataDType == TypeId::kTypeUnknown) { + if (this->inputDataDType == TypeId::kTypeUnknown) { return RET_OK; } - if (this->inputDataDType != TypeId::kNumberTypeFloat && this->inputDataDType != TypeId::kNumberTypeUInt8) { + if (this->inputDataDType != TypeId::kNumberTypeFloat32 && this->inputDataDType != TypeId::kNumberTypeUInt8 && + this->inputDataDType != TypeId::kNumberTypeInt8) { MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType; return RET_ERROR; } - // insert fp2int8 node for (auto graphInIdx : graphInIdxes) { MS_ASSERT(graphInIdx < graph->allTensors.size()); auto &tensor = graph->allTensors.at(graphInIdx); - if (tensor->dataType != kNumberTypeInt8 || tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { + if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { continue; } @@ -75,10 +75,9 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { STATUS status = RET_OK; // insert dtype cast node between input tensor and input node - if (inputDataDType == TypeId::kNumberTypeFloat) { - iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kFP32ToInt8, &status); - } else { - iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kUInt8ToInt8, &status); + if (this->inputDataDType != tensor->dataType) { + iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, this->inputDataDType, tensor->dataType, + &status); } if (status != RET_OK) { @@ -94,10 +93,11 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); - if (outputDataDType == TypeId::kNumberTypeInt8 || outputDataDType == TypeId::kTypeUnknown) { + if (outputDataDType == TypeId::kTypeUnknown) { return RET_OK; } - if (this->outputDataDType != TypeId::kNumberTypeFloat && this->outputDataDType != TypeId::kNumberTypeUInt8) { + if (this->outputDataDType != TypeId::kNumberTypeFloat32 && this->outputDataDType != TypeId::kNumberTypeUInt8 && + this->outputDataDType != TypeId::kNumberTypeInt8) { MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType; return RET_ERROR; } @@ -105,7 +105,7 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { for (auto graphOutIdx : graphOutIdxes) { MS_ASSERT(graphOutIdx < graph->allTensors.size()); auto &tensor = graph->allTensors.at(graphOutIdx); - if (tensor->dataType != kNumberTypeInt8 || tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { + if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { continue; } for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { @@ -115,10 +115,9 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { // insert transNode STATUS status = RET_OK; - if (inputDataDType == TypeId::kNumberTypeFloat) { - iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status); - } else { - iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToUInt8, &status); + if (this->outputDataDType != tensor->dataType) { + iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, this->outputDataDType, + &status); } if (status != RET_OK) { MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; @@ -152,7 +151,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { if (preTensor->dataType != TypeId::kNumberTypeInt8) { continue; } - iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); + iter = InsertDTypeTransNode(graph, iter, kBefore, i, kNumberTypeInt8, kNumberTypeFloat32, &status); if (status != RET_OK) { MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed"; return RET_ERROR; @@ -165,7 +164,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { if (postTensor->dataType != TypeId::kNumberTypeInt8) { continue; } - iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status); + iter = InsertDTypeTransNode(graph, iter, kAfter, i, kNumberTypeFloat32, kNumberTypeInt8, &status); if (status != RET_OK) { MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; return RET_ERROR; @@ -178,7 +177,8 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { } NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, - size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) { + size_t inoutIdx, int32_t inputDataType, int32_t outputDataType, + STATUS *errorCode) { MS_ASSERT((*existNodeIter) != nullptr); auto existNodeName = (*existNodeIter)->name; std::string tileName; @@ -203,21 +203,15 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte transNode->primitive->value.value = quantDTypeCastParam; transNode->primitive->value.type = PrimitiveType_QuantDTypeCast; transNode->quantType = QuantType_AwareTraining; - if (nodeType == kInt8ToFP32) { - quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8; - quantDTypeCastParam->dstT = TypeId::kNumberTypeFloat32; + quantDTypeCastParam->srcT = inputDataType; + quantDTypeCastParam->dstT = outputDataType; + if (inputDataType == TypeId::kNumberTypeInt8 && outputDataType == TypeId::kNumberTypeFloat32) { transNode->name = "int8toft32_" + tileName + std::to_string(id++); - } else if (nodeType == kFP32ToInt8) { - quantDTypeCastParam->srcT = TypeId::kNumberTypeFloat32; - quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8; + } else if (inputDataType == TypeId::kNumberTypeFloat32 && outputDataType == TypeId::kNumberTypeInt8) { transNode->name = "ft32toint8_" + tileName + std::to_string(id++); - } else if (nodeType == kUInt8ToInt8) { - quantDTypeCastParam->srcT = TypeId::kNumberTypeUInt8; - quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8; + } else if (inputDataType == TypeId::kNumberTypeUInt8 && outputDataType == TypeId::kNumberTypeInt8) { transNode->name = "uint8toint8_" + tileName + std::to_string(id++); - } else if (nodeType == kInt8ToUInt8) { - quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8; - quantDTypeCastParam->dstT = TypeId::kNumberTypeUInt8; + } else if (inputDataType == TypeId::kNumberTypeInt8 && outputDataType == TypeId::kNumberTypeUInt8) { transNode->name = "int8touint8_" + tileName + std::to_string(id++); } transNode->primitive->value.value = quantDTypeCastParam; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h index f38ee93fed..e04bef1cb0 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h @@ -47,7 +47,7 @@ class DTypeTransPass : public GraphPass { STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph); NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, - DTypeTransNodeType nodeType, STATUS *errorCode); + int32_t inputDataType, int32_t outputDataType, STATUS *errorCode); private: size_t id; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc index ba31e00e05..8170c9da40 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc @@ -87,6 +87,7 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { } } else { // perchannel MS_LOG(ERROR) << "perchannel doquant is not supported yet"; + return RET_ERROR; } } return RET_OK; diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc deleted file mode 100644 index 90377bdac4..0000000000 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/quantizer/aware_quantizer.h" - -#include -#include -#include -#include -#include - -#include "schema/inner/model_generated.h" -#include "securec/include/securec.h" -#include "src/common/utils.h" -#include "tools/common/node_util.h" -#include "tools/common/tensor_util.h" -#include "tools/converter/quantizer/calc_quant_param.h" -#include "src/common/log_adapter.h" - -using std::string; -using std::vector; - -namespace mindspore::lite::quant { -AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType) : FbQuantizer(graph) {} - -STATUS AwareQuantizer::RemoveFakeQuant() { return RET_OK; } - -STATUS AwareQuantizer::GenerateQuantParam() { - auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); - - for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - auto &node = *iter; - MS_ASSERT(node != nullptr); - if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax || - GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { - MS_ASSERT(false); - } - auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); - if (quantParamCalcer == nullptr) { - MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str() - << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; - node->quantType = static_cast(QuantType_QUANT_NONE); - } else { - auto status = quantParamCalcer->Calc(graph, *node); - if (status != RET_OK) { - MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); - node->quantType = schema::QuantType_QUANT_NONE; - } else { - node->quantType = schema::QuantType_AwareTraining; - } - } - } - return RET_OK; -} - -STATUS AwareQuantizer::DoQuantize() { - for (auto &tensor : graph->allTensors) { - if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) { - continue; - } - if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && - tensor->dataType != TypeId::kNumberTypeUInt8) { - continue; - } - // perlayer - if (tensor->quantParams.size() == 1) { - auto &quantParam = tensor->quantParams.front(); - size_t wShapeSize = GetShapeSize(*(tensor.get())); - void *oriWeightData = tensor->data.data(); - if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { - vector qDatas(wShapeSize); - auto weightQauntParam = GetTensorQuantParam(tensor); - if (tensor->dataType == TypeId::kNumberTypeFloat || - tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant - auto *weightData = static_cast(oriWeightData); - for (size_t j = 0; j < wShapeSize; j++) { - qDatas[j] = QuantizeData(weightData[j], weightQauntParam.get()); - } - } else { // tflite awareing quant - auto *weightData = static_cast(oriWeightData); - for (size_t j = 0; j < wShapeSize; j++) { - qDatas[j] = (int32_t)weightData[j] - 128; - } - weightQauntParam->zeroPoint -= 128; - tensor->quantParams.clear(); - tensor->quantParams.emplace_back(weightQauntParam.release()); - } - tensor->data.clear(); - tensor->data.resize(wShapeSize * sizeof(int8_t)); - auto ret = - memcpy_s(tensor->data.data(), wShapeSize * sizeof(int8_t), qDatas.data(), wShapeSize * sizeof(int8_t)); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s failed: " << ret; - return RET_ERROR; - } - } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { - // quant bias data - auto bShapeSize = GetShapeSize(*(tensor.get())); - std::unique_ptr qDatas(new (std::nothrow) int32_t[bShapeSize]); - if (qDatas == nullptr) { - MS_LOG(ERROR) << "new qDatas failed"; - return RET_ERROR; - } - void *biasData = tensor->data.data(); - auto *rawDatas = static_cast(biasData); - for (size_t i = 0; i < bShapeSize; ++i) { - qDatas[i] = (int32_t)std::round(rawDatas[i] / quantParam->scale); - } - tensor->dataType = TypeId::kNumberTypeInt32; - tensor->data.clear(); - tensor->data.resize(bShapeSize * sizeof(int32_t)); - auto ret = - memcpy_s(tensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t)); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s failed: " << ret; - return RET_ERROR; - } - } - } else { // pertensor - } - } - return RET_OK; -} -STATUS AwareQuantizer::DetermineNodeQuantType() { - MS_ASSERT(graph != nullptr); - for (auto &node : graph->nodes) { - MS_ASSERT(node != nullptr); - bool canQuant = true; - for (auto &outTensorIdx : node->outputIndex) { - MS_ASSERT(graph->allTensors.size() > outTensorIdx); - auto &outTensor = graph->allTensors.at(outTensorIdx); - MS_ASSERT(outTensor != nullptr); - if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || - !outTensor->quantParams.front()->inited) { - canQuant = false; - break; - } - } - - if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*node))) { - node->quantType = schema::QuantType_AwareTraining; - } else { - node->quantType = schema::QuantType_QUANT_NONE; - } - } - return RET_OK; -} -} // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.h b/mindspore/lite/tools/converter/quantizer/aware_quantizer.h deleted file mode 100644 index e28a220e4c..0000000000 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_AWARE_QUANTIZER_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_AWARE_QUANTIZER_H - -#include -#include -#include -#include "tools/converter/quantizer/quantizer.h" -#include "schema/inner/model_generated.h" -#include "include/errorcode.h" -#include "tools/converter/quantizer/quantize_util.h" - -namespace mindspore::lite::quant { -class AwareQuantizer : public FbQuantizer { - public: - AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType); - - ~AwareQuantizer() override = default; - - STATUS RemoveFakeQuant() override; - - STATUS GenerateQuantParam() override; - - STATUS DetermineNodeQuantType() override; - - STATUS DoQuantize() override; // override; -}; -} // namespace mindspore::lite::quant -#endif diff --git a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc index 734e794d40..3733cd87c1 100644 --- a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc +++ b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc @@ -86,7 +86,7 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { if (!tensor->data.empty() && !IsContain(graph->inputIndex, node.inputIndex.at(i))) { auto status = ComputeConstQuantParam((*tensor), quantParam.get()); if (status != RET_OK) { - MS_LOG(INFO) << "ComputeConstQuantParam failed: " << status; + MS_LOG(DEBUG) << "ComputeConstQuantParam failed: " << status; return status; } tensor->quantParams.front() = std::move(quantParam);