浏览代码

!7923 review fix

Merge pull request !7923 from cjh9368/rewrite_aware_quant
tags/v1.1.0
mindspore-ci-bot Gitee 5 年前
父节点
当前提交
4ec14e3368
共有 15 个文件被更改,包括 140 次插入300 次删除
  1. +34
    -5
      mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c
  2. +6
    -4
      mindspore/lite/nnacl/int8/quant_dtype_cast_int8.h
  3. +0
    -34
      mindspore/lite/src/common/utils.h
  4. +30
    -6
      mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc
  5. +1
    -1
      mindspore/lite/src/sub_graph_kernel.cc
  6. +34
    -0
      mindspore/lite/src/sub_graph_kernel.h
  7. +4
    -0
      mindspore/lite/tools/converter/anf_transform.cc
  8. +2
    -2
      mindspore/lite/tools/converter/converter_flags.cc
  9. +2
    -11
      mindspore/lite/tools/converter/graphdef_transform.cc
  10. +24
    -30
      mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc
  11. +1
    -1
      mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h
  12. +1
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc
  13. +0
    -161
      mindspore/lite/tools/converter/quantizer/aware_quantizer.cc
  14. +0
    -44
      mindspore/lite/tools/converter/quantizer/aware_quantizer.h
  15. +1
    -1
      mindspore/lite/tools/converter/quantizer/calc_quant_param.cc

+ 34
- 5
mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c 查看文件

@@ -18,7 +18,7 @@
#include "nnacl/int8/quant_dtype_cast_int8.h"
#include "nnacl/errorcode.h"

int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) {
int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) {
if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID;
}
@@ -29,13 +29,13 @@ int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale
return NNACL_OK;
}

int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) {
int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) {
if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID;
}

for (int i = 0; i < size; ++i) {
float temp = round(real_values[i] * 1.0 / scale + zp);
float temp = (float)round(real_values[i] * 1.0 / scale + zp);
if (temp > 127) {
quant_values[i] = 127;
} else if (temp < -128) {
@@ -47,7 +47,36 @@ int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float sca
return NNACL_OK;
}

int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size) {
int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size) {
if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID;
}

for (int i = 0; i < size; ++i) {
real_values[i] = (float)((int)quant_values[i] - zp) * scale;
}
return NNACL_OK;
}

int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size) {
if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID;
}

for (int i = 0; i < size; ++i) {
float temp = (float)round(real_values[i] * 1.0 / scale + zp);
if (temp > 255) {
quant_values[i] = 255;
} else if (temp < 0) {
quant_values[i] = 0;
} else {
quant_values[i] = (uint8_t)temp;
}
}
return NNACL_OK;
}

int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size) {
if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID;
}
@@ -65,7 +94,7 @@ int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size
return NNACL_OK;
}

int DoQuantizeToInt8FromUint8(uint8_t *real_values, int8_t *quant_values, int size) {
int UInt8ToInt8(const uint8_t *real_values, int8_t *quant_values, int size) {
if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID;
}


+ 6
- 4
mindspore/lite/nnacl/int8/quant_dtype_cast_int8.h 查看文件

@@ -28,10 +28,12 @@ typedef struct QuantDTypeCastParameter {
#ifdef __cplusplus
extern "C" {
#endif
int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size);
int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size);
int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size);
int DoQuantizeToInt8FromUint8(uint8_t *real_values, int8_t *quant_values, int size);
int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size);
int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size);
int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size);
int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size);
int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size);
int UInt8ToInt8(const uint8_t *real_values, int8_t *quant_values, int size);
#ifdef __cplusplus
}
#endif


+ 0
- 34
mindspore/lite/src/common/utils.h 查看文件

@@ -27,9 +27,6 @@
#include "src/common/log_adapter.h"
#include "tools/common/option.h"
#include "include/errorcode.h"
#ifdef ENABLE_ARM64
#include "nnacl/optimized_kernel.h"
#endif

namespace mindspore {
namespace lite {
@@ -190,37 +187,6 @@ inline Option<bool> GenericParseValue(const std::string &value) {
return Option<bool>(None());
}

using Float16CastFunc = void (*)(const void *, void *, int);

class Float16CastUtil {
public:
static Float16CastUtil *GetInstance() {
static Float16CastUtil float16_cast_util;
return &float16_cast_util;
}

private:
Float16CastUtil() {
#ifdef ENABLE_ARM64
void *fp16_op_handler = Float16Module::GetInstance()->float16_op_handler_;
if (fp16_op_handler != nullptr) {
dlerror();
*(reinterpret_cast<void **>(&float16_to_float32_func_)) = dlsym(fp16_op_handler, "Float16ToFloat32_fp16_handler");
*(reinterpret_cast<void **>(&float32_to_float16_func_)) = dlsym(fp16_op_handler, "Float32ToFloat16_fp16_handler");
auto dlopen_error = dlerror();
if (dlopen_error != nullptr) {
MS_LOG(ERROR) << "load float16 cast func failed! " << dlopen_error << ".";
}
}
#endif
}
~Float16CastUtil() = default;

public:
Float16CastFunc float16_to_float32_func_ = nullptr;
Float16CastFunc float32_to_float16_func_ = nullptr;
};

} // namespace lite
} // namespace mindspore



+ 30
- 6
mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc 查看文件

@@ -66,6 +66,16 @@ int QuantDTypeCastCPUKernel::Init() {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
} else if (param->srcT == kNumberTypeUInt8 && param->dstT == kNumberTypeFloat32) {
if (in_tensor->data_type() != kNumberTypeUInt8 || out_tensor->data_type() != kNumberTypeFloat32) {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
} else if (param->srcT == kNumberTypeFloat32 && param->dstT == kNumberTypeUInt8) {
if (in_tensor->data_type() != kNumberTypeFloat32 || out_tensor->data_type() != kNumberTypeUInt8) {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "param data type not supported:"
<< " src: " << param->srcT << " dst: " << param->dstT;
@@ -106,20 +116,26 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeInt8) {
ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) {
ret = DoDequantizeInt8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread);
ret = Int8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeFloat32) {
ret = DoDequantizeUInt8ToFp32(uint8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeUInt8) {
ret = DoQuantizeFp32ToUInt8(float32_ptr_ + thread_offset, uint8_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeInt8) {
ret = DoQuantizeToInt8FromUint8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread);
ret = UInt8ToInt8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeInt8) {
auto input_quant_arg = in_tensors_.front()->GetQuantParams().front();
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, num_unit_thread,
input_quant_arg.scale, input_quant_arg.zeroPoint);
if (ret) {
auto output_quant_arg = out_tensors_.front()->GetQuantParams().front();
ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset,
output_quant_arg.scale, output_quant_arg.zeroPoint, num_unit_thread);
ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset, output_quant_arg.scale,
output_quant_arg.zeroPoint, num_unit_thread);
}
}

@@ -162,6 +178,14 @@ int QuantDTypeCastCPUKernel::Run() {
int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c());
int8_out_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c());
float32_ptr_ = new float[in_tensors_[0]->ElementsNum()];
} else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeUInt8 &&
out_tensors_[0]->data_type() == TypeId::kNumberTypeFloat32) {
uint8_ptr_ = reinterpret_cast<uint8_t *>(in_tensors_[0]->data_c());
float32_ptr_ = reinterpret_cast<float *>(out_tensors_[0]->data_c());
} else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeFloat32 &&
out_tensors_[0]->data_type() == TypeId::kNumberTypeUInt8) {
float32_ptr_ = reinterpret_cast<float *>(in_tensors_[0]->data_c());
uint8_ptr_ = reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c());
}

auto ret = ParallelLaunch(this->context_->thread_pool_, QuantDTypeCastRun, this, thread_n_num_);


+ 1
- 1
mindspore/lite/src/sub_graph_kernel.cc 查看文件

@@ -178,7 +178,7 @@ int CpuFp16SubGraph::PreProcess() {
}

int CpuFp16SubGraph::PostProcess() {
auto fp16_to_fp32_cast_func = lite::Float16CastUtil::GetInstance()->float16_to_float32_func_;
auto fp16_to_fp32_cast_func = kernel::Float16CastUtil::GetInstance()->float16_to_float32_func_;
if (fp16_to_fp32_cast_func == nullptr) {
MS_LOG(ERROR) << "Can not find cast fp16 to fp32 func";
return RET_ERROR;


+ 34
- 0
mindspore/lite/src/sub_graph_kernel.h 查看文件

@@ -22,8 +22,42 @@
#include <vector>
#include "src/lite_kernel.h"
#include "src/executor.h"
#ifdef ENABLE_ARM64
#include "nnacl/optimized_kernel.h"
#endif

namespace mindspore::kernel {
using Float16CastFunc = void (*)(const void *, void *, int);

class Float16CastUtil {
public:
static Float16CastUtil *GetInstance() {
static Float16CastUtil float16_cast_util;
return &float16_cast_util;
}

private:
Float16CastUtil() {
#ifdef ENABLE_ARM64
void *fp16_op_handler = Float16Module::GetInstance()->float16_op_handler_;
if (fp16_op_handler != nullptr) {
dlerror();
*(reinterpret_cast<void **>(&float16_to_float32_func_)) = dlsym(fp16_op_handler, "Float16ToFloat32_fp16_handler");
*(reinterpret_cast<void **>(&float32_to_float16_func_)) = dlsym(fp16_op_handler, "Float32ToFloat16_fp16_handler");
auto dlopen_error = dlerror();
if (dlopen_error != nullptr) {
MS_LOG(ERROR) << "load float16 cast func failed! " << dlopen_error << ".";
}
}
#endif
}
~Float16CastUtil() = default;

public:
Float16CastFunc float16_to_float32_func_ = nullptr;
Float16CastFunc float32_to_float16_func_ = nullptr;
};

class SubGraphKernel : public LiteKernel {
public:
explicit SubGraphKernel(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,


+ 4
- 0
mindspore/lite/tools/converter/anf_transform.cc 查看文件

@@ -46,6 +46,10 @@ AnfTransform::~AnfTransform() = default;

FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const converter::Flags *config) {
MS_ASSERT(nullptr != old_graph);
if (config == nullptr) {
MS_LOG(ERROR) << "config shoud be specified";
return nullptr;
}
// fusion const_fold
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);


+ 2
- 2
mindspore/lite/tools/converter/converter_flags.cc 查看文件

@@ -84,7 +84,7 @@ int Flags::Init(int argc, const char **argv) {
}

if (this->inputDataTypeIn == "FLOAT") {
this->inputDataType = TypeId::kNumberTypeFloat;
this->inputDataType = TypeId::kNumberTypeFloat32;
} else if (this->inputDataTypeIn == "INT8") {
this->inputDataType = TypeId::kNumberTypeInt8;
} else if (this->inputDataTypeIn == "UINT8") {
@@ -98,7 +98,7 @@ int Flags::Init(int argc, const char **argv) {
}

if (this->outputDataTypeIn == "FLOAT") {
this->outputDataType = TypeId::kNumberTypeFloat;
this->outputDataType = TypeId::kNumberTypeFloat32;
} else if (this->outputDataTypeIn == "INT8") {
this->outputDataType = TypeId::kNumberTypeInt8;
} else if (this->outputDataTypeIn == "UINT8") {


+ 2
- 11
mindspore/lite/tools/converter/graphdef_transform.cc 查看文件

@@ -22,7 +22,6 @@
#include "tools/converter/converter_flags.h"
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h"
#include "tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h"
@@ -36,7 +35,6 @@
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
#include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h"
#include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h"
#include "tools/converter/quantizer/aware_quantizer.h"

using std::string;
namespace mindspore::lite {
@@ -120,15 +118,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
return status;
}
}
{
Optimizer inferQuantParamOtimizer;
inferQuantParamOtimizer.AddPass(new (std::nothrow) InferQuantParamPass());
status = inferQuantParamOtimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run tensorQuantOptimizer graphPasses Failed";
return status;
}
}

{
Optimizer fusionOptimizer;
@@ -158,6 +147,8 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass());
quantNodeOptimizer.AddPass(dTypeTransPass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());


+ 24
- 30
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc 查看文件

@@ -53,18 +53,18 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
auto &graphInIdxes = graph->inputIndex;

if (this->inputDataDType == TypeId::kNumberTypeInt8 || this->inputDataDType == TypeId::kTypeUnknown) {
if (this->inputDataDType == TypeId::kTypeUnknown) {
return RET_OK;
}
if (this->inputDataDType != TypeId::kNumberTypeFloat && this->inputDataDType != TypeId::kNumberTypeUInt8) {
if (this->inputDataDType != TypeId::kNumberTypeFloat32 && this->inputDataDType != TypeId::kNumberTypeUInt8 &&
this->inputDataDType != TypeId::kNumberTypeInt8) {
MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType;
return RET_ERROR;
}
// insert fp2int8 node
for (auto graphInIdx : graphInIdxes) {
MS_ASSERT(graphInIdx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graphInIdx);
if (tensor->dataType != kNumberTypeInt8 || tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
continue;
}

@@ -75,10 +75,9 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
STATUS status = RET_OK;

// insert dtype cast node between input tensor and input node
if (inputDataDType == TypeId::kNumberTypeFloat) {
iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kFP32ToInt8, &status);
} else {
iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kUInt8ToInt8, &status);
if (this->inputDataDType != tensor->dataType) {
iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, this->inputDataDType, tensor->dataType,
&status);
}

if (status != RET_OK) {
@@ -94,10 +93,11 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {

STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
if (outputDataDType == TypeId::kNumberTypeInt8 || outputDataDType == TypeId::kTypeUnknown) {
if (outputDataDType == TypeId::kTypeUnknown) {
return RET_OK;
}
if (this->outputDataDType != TypeId::kNumberTypeFloat && this->outputDataDType != TypeId::kNumberTypeUInt8) {
if (this->outputDataDType != TypeId::kNumberTypeFloat32 && this->outputDataDType != TypeId::kNumberTypeUInt8 &&
this->outputDataDType != TypeId::kNumberTypeInt8) {
MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType;
return RET_ERROR;
}
@@ -105,7 +105,7 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
for (auto graphOutIdx : graphOutIdxes) {
MS_ASSERT(graphOutIdx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graphOutIdx);
if (tensor->dataType != kNumberTypeInt8 || tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
continue;
}
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
@@ -115,10 +115,9 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) {
// insert transNode
STATUS status = RET_OK;
if (inputDataDType == TypeId::kNumberTypeFloat) {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status);
} else {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToUInt8, &status);
if (this->outputDataDType != tensor->dataType) {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, this->outputDataDType,
&status);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed";
@@ -152,7 +151,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
if (preTensor->dataType != TypeId::kNumberTypeInt8) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status);
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kNumberTypeInt8, kNumberTypeFloat32, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed";
return RET_ERROR;
@@ -165,7 +164,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
if (postTensor->dataType != TypeId::kNumberTypeInt8) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status);
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kNumberTypeFloat32, kNumberTypeInt8, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed";
return RET_ERROR;
@@ -178,7 +177,8 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
}

NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place,
size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) {
size_t inoutIdx, int32_t inputDataType, int32_t outputDataType,
STATUS *errorCode) {
MS_ASSERT((*existNodeIter) != nullptr);
auto existNodeName = (*existNodeIter)->name;
std::string tileName;
@@ -203,21 +203,15 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte
transNode->primitive->value.value = quantDTypeCastParam;
transNode->primitive->value.type = PrimitiveType_QuantDTypeCast;
transNode->quantType = QuantType_AwareTraining;
if (nodeType == kInt8ToFP32) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8;
quantDTypeCastParam->dstT = TypeId::kNumberTypeFloat32;
quantDTypeCastParam->srcT = inputDataType;
quantDTypeCastParam->dstT = outputDataType;
if (inputDataType == TypeId::kNumberTypeInt8 && outputDataType == TypeId::kNumberTypeFloat32) {
transNode->name = "int8toft32_" + tileName + std::to_string(id++);
} else if (nodeType == kFP32ToInt8) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeFloat32;
quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8;
} else if (inputDataType == TypeId::kNumberTypeFloat32 && outputDataType == TypeId::kNumberTypeInt8) {
transNode->name = "ft32toint8_" + tileName + std::to_string(id++);
} else if (nodeType == kUInt8ToInt8) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeUInt8;
quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8;
} else if (inputDataType == TypeId::kNumberTypeUInt8 && outputDataType == TypeId::kNumberTypeInt8) {
transNode->name = "uint8toint8_" + tileName + std::to_string(id++);
} else if (nodeType == kInt8ToUInt8) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8;
quantDTypeCastParam->dstT = TypeId::kNumberTypeUInt8;
} else if (inputDataType == TypeId::kNumberTypeInt8 && outputDataType == TypeId::kNumberTypeUInt8) {
transNode->name = "int8touint8_" + tileName + std::to_string(id++);
}
transNode->primitive->value.value = quantDTypeCastParam;


+ 1
- 1
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h 查看文件

@@ -47,7 +47,7 @@ class DTypeTransPass : public GraphPass {

STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph);
NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
DTypeTransNodeType nodeType, STATUS *errorCode);
int32_t inputDataType, int32_t outputDataType, STATUS *errorCode);

private:
size_t id;


+ 1
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc 查看文件

@@ -87,6 +87,7 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
}
} else { // perchannel
MS_LOG(ERROR) << "perchannel doquant is not supported yet";
return RET_ERROR;
}
}
return RET_OK;


+ 0
- 161
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc 查看文件

@@ -1,161 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/quantizer/aware_quantizer.h"

#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "schema/inner/model_generated.h"
#include "securec/include/securec.h"
#include "src/common/utils.h"
#include "tools/common/node_util.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/quantizer/calc_quant_param.h"
#include "src/common/log_adapter.h"

using std::string;
using std::vector;

namespace mindspore::lite::quant {
AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType) : FbQuantizer(graph) {}

STATUS AwareQuantizer::RemoveFakeQuant() { return RET_OK; }

STATUS AwareQuantizer::GenerateQuantParam() {
auto *quantParamRegister = QuantParamCalcRegister::GetInstance();

for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
MS_ASSERT(node != nullptr);
if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax ||
GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) {
MS_ASSERT(false);
}
auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
if (quantParamCalcer == nullptr) {
MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str()
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
} else {
auto status = quantParamCalcer->Calc(graph, *node);
if (status != RET_OK) {
MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
node->quantType = schema::QuantType_QUANT_NONE;
} else {
node->quantType = schema::QuantType_AwareTraining;
}
}
}
return RET_OK;
}

STATUS AwareQuantizer::DoQuantize() {
for (auto &tensor : graph->allTensors) {
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) {
continue;
}
if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat &&
tensor->dataType != TypeId::kNumberTypeUInt8) {
continue;
}
// perlayer
if (tensor->quantParams.size() == 1) {
auto &quantParam = tensor->quantParams.front();
size_t wShapeSize = GetShapeSize(*(tensor.get()));
void *oriWeightData = tensor->data.data();
if (quantParam->dstDtype == TypeId::kNumberTypeInt8) {
vector<int8_t> qDatas(wShapeSize);
auto weightQauntParam = GetTensorQuantParam(tensor);
if (tensor->dataType == TypeId::kNumberTypeFloat ||
tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant
auto *weightData = static_cast<float *>(oriWeightData);
for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get());
}
} else { // tflite awareing quant
auto *weightData = static_cast<uint8_t *>(oriWeightData);
for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = (int32_t)weightData[j] - 128;
}
weightQauntParam->zeroPoint -= 128;
tensor->quantParams.clear();
tensor->quantParams.emplace_back(weightQauntParam.release());
}
tensor->data.clear();
tensor->data.resize(wShapeSize * sizeof(int8_t));
auto ret =
memcpy_s(tensor->data.data(), wShapeSize * sizeof(int8_t), qDatas.data(), wShapeSize * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed: " << ret;
return RET_ERROR;
}
} else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) {
// quant bias data
auto bShapeSize = GetShapeSize(*(tensor.get()));
std::unique_ptr<int32_t[]> qDatas(new (std::nothrow) int32_t[bShapeSize]);
if (qDatas == nullptr) {
MS_LOG(ERROR) << "new qDatas failed";
return RET_ERROR;
}
void *biasData = tensor->data.data();
auto *rawDatas = static_cast<float *>(biasData);
for (size_t i = 0; i < bShapeSize; ++i) {
qDatas[i] = (int32_t)std::round(rawDatas[i] / quantParam->scale);
}
tensor->dataType = TypeId::kNumberTypeInt32;
tensor->data.clear();
tensor->data.resize(bShapeSize * sizeof(int32_t));
auto ret =
memcpy_s(tensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed: " << ret;
return RET_ERROR;
}
}
} else { // pertensor
}
}
return RET_OK;
}
STATUS AwareQuantizer::DetermineNodeQuantType() {
MS_ASSERT(graph != nullptr);
for (auto &node : graph->nodes) {
MS_ASSERT(node != nullptr);
bool canQuant = true;
for (auto &outTensorIdx : node->outputIndex) {
MS_ASSERT(graph->allTensors.size() > outTensorIdx);
auto &outTensor = graph->allTensors.at(outTensorIdx);
MS_ASSERT(outTensor != nullptr);
if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr ||
!outTensor->quantParams.front()->inited) {
canQuant = false;
break;
}
}

if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*node))) {
node->quantType = schema::QuantType_AwareTraining;
} else {
node->quantType = schema::QuantType_QUANT_NONE;
}
}
return RET_OK;
}
} // namespace mindspore::lite::quant

+ 0
- 44
mindspore/lite/tools/converter/quantizer/aware_quantizer.h 查看文件

@@ -1,44 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_AWARE_QUANTIZER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_AWARE_QUANTIZER_H

#include <array>
#include <string>
#include <memory>
#include "tools/converter/quantizer/quantizer.h"
#include "schema/inner/model_generated.h"
#include "include/errorcode.h"
#include "tools/converter/quantizer/quantize_util.h"

namespace mindspore::lite::quant {
class AwareQuantizer : public FbQuantizer {
public:
AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType);

~AwareQuantizer() override = default;

STATUS RemoveFakeQuant() override;

STATUS GenerateQuantParam() override;

STATUS DetermineNodeQuantType() override;

STATUS DoQuantize() override; // override;
};
} // namespace mindspore::lite::quant
#endif

+ 1
- 1
mindspore/lite/tools/converter/quantizer/calc_quant_param.cc 查看文件

@@ -86,7 +86,7 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
if (!tensor->data.empty() && !IsContain(graph->inputIndex, node.inputIndex.at(i))) {
auto status = ComputeConstQuantParam((*tensor), quantParam.get());
if (status != RET_OK) {
MS_LOG(INFO) << "ComputeConstQuantParam failed: " << status;
MS_LOG(DEBUG) << "ComputeConstQuantParam failed: " << status;
return status;
}
tensor->quantParams.front() = std::move(quantParam);


正在加载...
取消
保存