Browse Source

rewrite aware train converter

tags/v1.1.0
lyvette cjh9368 5 years ago
parent
commit
ef86629298
55 changed files with 751 additions and 697 deletions
  1. +3
    -2
      mindspore/lite/schema/model.fbs
  2. +35
    -0
      mindspore/lite/src/common/utils.h
  3. +3
    -2
      mindspore/lite/src/lite_session.cc
  4. +1
    -6
      mindspore/lite/src/ops/add.cc
  5. +1
    -7
      mindspore/lite/src/ops/conv2d.cc
  6. +1
    -8
      mindspore/lite/src/ops/deconv2d.cc
  7. +1
    -8
      mindspore/lite/src/ops/depthwise_conv2d.cc
  8. +2
    -7
      mindspore/lite/src/ops/matmul.cc
  9. +63
    -20
      mindspore/lite/src/ops/primitive_c.cc
  10. +12
    -6
      mindspore/lite/src/ops/primitive_c.h
  11. +6
    -4
      mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc
  12. +3
    -2
      mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc
  13. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc
  14. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc
  15. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc
  16. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc
  17. +6
    -4
      mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc
  18. +6
    -4
      mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc
  19. +6
    -4
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc
  20. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc
  21. +6
    -4
      mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc
  22. +6
    -4
      mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc
  23. +2
    -3
      mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc
  24. +1
    -31
      mindspore/lite/src/sub_graph_kernel.cc
  25. +1
    -0
      mindspore/lite/src/tensor.h
  26. +3
    -3
      mindspore/lite/test/models_tflite_awaretraining.cfg
  27. +2
    -2
      mindspore/lite/test/run_benchmark_nets.sh
  28. +1
    -0
      mindspore/lite/tools/anf_exporter/anf_exporter.cc
  29. +23
    -23
      mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc
  30. +20
    -25
      mindspore/lite/tools/converter/anf_transform.cc
  31. +0
    -1
      mindspore/lite/tools/converter/converter.cc
  32. +35
    -14
      mindspore/lite/tools/converter/converter_flags.cc
  33. +4
    -1
      mindspore/lite/tools/converter/converter_flags.h
  34. +20
    -44
      mindspore/lite/tools/converter/graphdef_transform.cc
  35. +0
    -4
      mindspore/lite/tools/converter/graphdef_transform.h
  36. +2
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt
  37. +8
    -73
      mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc
  38. +0
    -2
      mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h
  39. +87
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc
  40. +39
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h
  41. +87
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc
  42. +36
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.h
  43. +1
    -1
      mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc
  44. +1
    -1
      mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc
  45. +50
    -254
      mindspore/lite/tools/converter/quantizer/aware_quantizer.cc
  46. +0
    -18
      mindspore/lite/tools/converter/quantizer/aware_quantizer.h
  47. +40
    -7
      mindspore/lite/tools/converter/quantizer/calc_quant_param.cc
  48. +14
    -0
      mindspore/lite/tools/converter/quantizer/calc_quant_param.h
  49. +13
    -19
      mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc
  50. +2
    -1
      mindspore/lite/tools/converter/quantizer/post_training_quantizer.h
  51. +1
    -1
      mindspore/lite/tools/converter/quantizer/quantize_util.cc
  52. +5
    -4
      mindspore/lite/tools/converter/quantizer/quantize_util.h
  53. +2
    -6
      mindspore/lite/tools/converter/quantizer/weight_quantizer.cc
  54. +79
    -57
      mindspore/lite/tools/optimizer/common/gllo_utils.cc
  55. +2
    -2
      mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc

+ 3
- 2
mindspore/lite/schema/model.fbs View File

@@ -32,8 +32,9 @@ table QuantParam {
narrowRange: bool = true;
numBits: int = 8;
inited: bool = false;
var_corr: double = 1;
mean_corr: double = 0;
varCorr: double = 1;
meanCorr: double = 0;
dstDtype: int = 32;
clusters: [float];
}



+ 35
- 0
mindspore/lite/src/common/utils.h View File

@@ -27,6 +27,9 @@
#include "src/common/log_adapter.h"
#include "tools/common/option.h"
#include "include/errorcode.h"
#ifdef ENABLE_ARM64
#include "nnacl/optimized_kernel.h"
#endif

namespace mindspore {
namespace lite {
@@ -186,6 +189,38 @@ inline Option<bool> GenericParseValue(const std::string &value) {

return Option<bool>(None());
}

using Float16CastFunc = void (*)(const void *, void *, int);

class Float16CastUtil {
public:
static Float16CastUtil *GetInstance() {
static Float16CastUtil float16_cast_util;
return &float16_cast_util;
}

private:
Float16CastUtil() {
#ifdef ENABLE_ARM64
void *fp16_op_handler = Float16Module::GetInstance()->float16_op_handler_;
if (fp16_op_handler != nullptr) {
dlerror();
*(reinterpret_cast<void **>(&float16_to_float32_func_)) = dlsym(fp16_op_handler, "Float16ToFloat32_fp16_handler");
*(reinterpret_cast<void **>(&float32_to_float16_func_)) = dlsym(fp16_op_handler, "Float32ToFloat16_fp16_handler");
auto dlopen_error = dlerror();
if (dlopen_error != nullptr) {
MS_LOG(ERROR) << "load float16 cast func failed! " << dlopen_error << ".";
}
}
#endif
}
~Float16CastUtil() = default;

public:
Float16CastFunc float16_to_float32_func_ = nullptr;
Float16CastFunc float32_to_float16_func_ = nullptr;
};

} // namespace lite
} // namespace mindspore



+ 3
- 2
mindspore/lite/src/lite_session.cc View File

@@ -108,8 +108,9 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
QuantArg quant_arg{};
quant_arg.scale = quant_params->Get(j)->scale();
quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint();
quant_arg.var_corr = quant_params->Get(j)->var_corr();
quant_arg.mean_corr = quant_params->Get(j)->mean_corr();
quant_arg.var_corr = quant_params->Get(j)->varCorr();
quant_arg.mean_corr = quant_params->Get(j)->meanCorr();
quant_arg.inited = quant_params->Get(j)->inited();
auto quant_clusters = quant_params->Get(j)->clusters();
if (quant_clusters != nullptr) {
for (size_t k = 0; k < quant_clusters->size(); k++) {


+ 1
- 6
mindspore/lite/src/ops/add.cc View File

@@ -49,12 +49,7 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
return RET_ERROR;
}
}
if (GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
SetOutputQuantParam(vecOutputQuantParam);
}
PopulaterQuantParam(prim, inputs);
return RET_OK;
}



+ 1
- 7
mindspore/lite/src/ops/conv2d.cc View File

@@ -277,13 +277,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
PopulaterConv2DSingleGroup(prim, this->primitive_, group);
}

if (GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
SetInputQuantParam(vecInputQuantParam);
SetOutputQuantParam(vecOutputQuantParam);
}
PopulaterQuantParam(prim, inputs);
return RET_OK;
}



+ 1
- 8
mindspore/lite/src/ops/deconv2d.cc View File

@@ -254,14 +254,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
} else if (group > 1) {
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
}

if (GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
SetInputQuantParam(vecInputQuantParam);
SetOutputQuantParam(vecOutputQuantParam);
}
PopulaterQuantParam(prim, inputs);
return RET_OK;
}
#else


+ 1
- 8
mindspore/lite/src/ops/depthwise_conv2d.cc View File

@@ -146,14 +146,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode

this->primitive_->value.type = schema::PrimitiveType_DepthwiseConv2D;
this->primitive_->value.value = attr.release();

if (GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
SetInputQuantParam(vecInputQuantParam);
SetOutputQuantParam(vecOutputQuantParam);
}
PopulaterQuantParam(prim, inputs);
return RET_OK;
}



+ 2
- 7
mindspore/lite/src/ops/matmul.cc View File

@@ -61,13 +61,8 @@ int MatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
return RET_ERROR;
}
}
if (GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs);
SetInputQuantParam(vecInputQuantParam);
SetOutputQuantParam(vecOutputQuantParam);
}

PopulaterQuantParam(prim, inputs);
return RET_OK;
}



+ 63
- 20
mindspore/lite/src/ops/primitive_c.cc View File

@@ -164,32 +164,29 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
void PrimitiveC::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) {
void PrimitiveC::CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax) {
const float qmin = 0;
const float qmax = 255;
*mMin = static_cast<float>((qmin - mean) / stdDev);
*mMax = static_cast<float>((qmax - mean) / stdDev);
}

void PrimitiveC::PopulaterQuantParam(const Primitive &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam,
const std::vector<AnfNodePtr> &inputs) {
void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
auto narrow_range = prim.GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
bool narrowRangeQuantParam = narrow_range != nullptr ? GetValue<bool>(narrow_range) : false;
auto num_bits = prim.GetAttr("num_bits");
int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits);
int32_t numbitsRangeQuantParam = num_bits != nullptr ? GetValue<int32_t>(num_bits) : 8;

std::vector<schema::QuantParamT> quants;
schema::QuantParamT quantParam;
auto mean = prim.GetAttr("mean");
auto std_dev = prim.GetAttr("std_dev");
if (mean != nullptr && std_dev != nullptr) {
auto meanQuantOaram = GetValue<double>(mean);
double stddevQuantOaram = GetValue<double>(std_dev);
auto meanValue = GetValue<double>(mean);
auto stddevValue = GetValue<double>(std_dev);
float mMin = 0.0;
float mMax = 0.0;
CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax);
CalFloatScopeByMeanAndStddev(meanValue, stddevValue, &mMin, &mMax);
quantParam.min = mMin;
quantParam.max = mMax;
} else {
@@ -198,8 +195,8 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim,
if (inputMin != nullptr && inputMax != nullptr) {
auto inputMinPtr = inputMin->cast<TensorPtr>();
auto inputMaxPtr = inputMax->cast<TensorPtr>();
float *minBuf = static_cast<float *>(inputMinPtr->data_c());
float *maxBuf = static_cast<float *>(inputMaxPtr->data_c());
auto *minBuf = static_cast<float *>(inputMinPtr->data_c());
auto *maxBuf = static_cast<float *>(inputMaxPtr->data_c());
quantParam.min = *minBuf;
quantParam.max = *maxBuf;
}
@@ -207,7 +204,7 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecInputQuantParam->emplace_back(quants);
input_quant_param_.emplace_back(quants);

quants.clear();
auto filterMin = prim.GetAttr("filter_minq");
@@ -227,17 +224,25 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim,
}
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, true, numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecInputQuantParam->emplace_back(quants);
input_quant_param_.emplace_back(quants);
}

if (vecInputQuantParam->size() == kDoubleNum) {
if (input_quant_param_.size() == kDoubleNum) {
quants.clear();
quantParam.min = 0.0;
quantParam.max = 0.0;
quantParam.zeroPoint = 0;
quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(0).scale;
quantParam.scale = input_quant_param_.at(0).at(0).scale * input_quant_param_.at(1).at(0).scale;
quants.emplace_back(quantParam);
vecInputQuantParam->emplace_back(quants);
input_quant_param_.emplace_back(quants);
}

// fill input_quant_param_ by not inited quant_parm
if (input_quant_param_.size() < inputs.size()) {
quants.clear();
schema::QuantParamT tmpQuantParam;
quants.emplace_back(tmpQuantParam);
input_quant_param_.insert(input_quant_param_.end(), inputs.size() - 1 - input_quant_param_.size(), quants);
}

quants.clear();
@@ -253,7 +258,11 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecOutputQuantParam->emplace_back(quants);
output_quant_param_.emplace_back(quants);
} else {
schema::QuantParamT tmpQuantParam;
quants.emplace_back(tmpQuantParam);
output_quant_param_.emplace_back(quants);
}
}

@@ -279,14 +288,48 @@ schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_;

void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; }

void PrimitiveC::SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) {
void PrimitiveC::SetInputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) {
this->input_quant_param_ = input_quant_param;
}

void PrimitiveC::SetOutputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) {
void PrimitiveC::SetInputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) {
MS_ASSERT(index < this->input_quant_param_.size());
this->input_quant_param_[index] = input_quant_param;
}

void PrimitiveC::SetOutputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) {
this->output_quant_param_ = output_quant_param;
}

void PrimitiveC::SetOutputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param) {
MS_ASSERT(index < this->output_quant_param_.size());
this->output_quant_param_[index] = output_quant_param;
}

bool PrimitiveC::IsInputQuantParamsInited() {
if (this->input_quant_param_.empty()) {
return false;
}
for (auto &quant_param : this->input_quant_param_) {
if (!quant_param.front().inited) {
return false;
}
}
return true;
}

bool PrimitiveC::IsOutputQuantParamsInited() {
if (this->output_quant_param_.empty()) {
return false;
}
for (auto &quant_param : this->output_quant_param_) {
if (!quant_param.front().inited) {
return false;
}
}
return true;
}

void PrimitiveC::ClearInputOutputQuantParam() {
input_quant_param_.clear();
output_quant_param_.clear();


+ 12
- 6
mindspore/lite/src/ops/primitive_c.h View File

@@ -88,9 +88,17 @@ class PrimitiveC : public mindspore::Primitive {
}
}

void SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param);
void SetInputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param);

void SetOutputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param);
void SetInputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param);

void SetOutputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param);

void SetOutputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param);

bool IsInputQuantParamsInited();

bool IsOutputQuantParamsInited();

void ClearInputOutputQuantParam();

@@ -120,10 +128,8 @@ class PrimitiveC : public mindspore::Primitive {

static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
const schema::QuantType &quantType);
void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam,
const std::vector<AnfNodePtr> &inputs);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
void PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax);

protected:
virtual int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_ERROR; }


+ 6
- 4
mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc View File

@@ -42,7 +42,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
auto *weight_tensor = inputs.at(kWeightIndex);
// data of second tensor of fc may be nullptr
auto *restore_data = weight_tensor->data_c();
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
@@ -53,7 +55,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (!kernel) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (!weight_tensor->GetQuantParams().empty()) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
@@ -65,13 +67,13 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (!weight_tensor->GetQuantParams().empty()) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}


+ 3
- 2
mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc View File

@@ -98,8 +98,9 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found.";
return RET_ERROR;
}
auto quant_arg = !out_tensors_.front()->GetQuantParams().empty() ? out_tensors_.front()->GetQuantParams().front()
: in_tensors_.front()->GetQuantParams().front();
auto quant_arg = out_tensors_.front()->GetQuantParams().front().inited
? out_tensors_.front()->GetQuantParams().front()
: in_tensors_.front()->GetQuantParams().front();
int ret = RET_OK;
if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeFloat32) {
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale,


+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc View File

@@ -140,8 +140,8 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto dequant_flag =
(weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false;
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {


+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc View File

@@ -182,8 +182,8 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto dequant_flag =
(weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false;
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {


+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc View File

@@ -204,8 +204,8 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto dequant_flag =
(weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false;
auto dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {


+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc View File

@@ -216,8 +216,8 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto dequant_flag =
(weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false;
auto dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {


+ 6
- 4
mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc View File

@@ -237,7 +237,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
auto *weight_tensor = inputs.at(kWeightIndex);
// data of second tensor of fc may be nullptr
auto *restore_data = weight_tensor->data_c();
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
@@ -250,7 +252,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (!weight_tensor->GetQuantParams().empty()) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
@@ -262,13 +264,13 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
if (!weight_tensor->GetQuantParams().empty()) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (!weight_tensor->GetQuantParams().empty()) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}


+ 6
- 4
mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc View File

@@ -251,7 +251,9 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
const mindspore::lite::PrimitiveC *primitive) {
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
@@ -263,7 +265,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
@@ -275,13 +277,13 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}


+ 6
- 4
mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc View File

@@ -284,7 +284,9 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
@@ -303,7 +305,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &

if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
@@ -315,14 +317,14 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return nullptr;
}

if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc View File

@@ -124,6 +124,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();

if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {


+ 6
- 4
mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc View File

@@ -234,7 +234,9 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
@@ -255,7 +257,7 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>

if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
@@ -267,14 +269,14 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return nullptr;
}

if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}


+ 6
- 4
mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc View File

@@ -196,7 +196,9 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
@@ -209,7 +211,7 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
@@ -221,13 +223,13 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}


+ 2
- 3
mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc View File

@@ -208,9 +208,8 @@ kernel::LiteKernel *CpuMatmulInt8KernelCreator(const std::vector<lite::Tensor *>

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto is_const_quant_weight =
(restore_data != nullptr) &&
((weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16));
bool is_const_quant_weight = !weight_tensor->GetQuantParams().empty() &&
weight_tensor->GetQuantParams().front().inited && restore_data != nullptr;
if (is_const_quant_weight) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {


+ 1
- 31
mindspore/lite/src/sub_graph_kernel.cc View File

@@ -25,36 +25,6 @@ using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INFER_ERR;
using mindspore::lite::RET_INFER_INVALID;
using mindspore::lite::RET_OK;
using Float16CastFunc = void (*)(const void *, void *, int);

class Float16CastUtil {
public:
static Float16CastUtil *GetInstance() {
static Float16CastUtil float16_cast_util;
return &float16_cast_util;
}

private:
Float16CastUtil() {
#ifdef ENABLE_ARM64
void *fp16_op_handler = Float16Module::GetInstance()->float16_op_handler_;
if (fp16_op_handler != nullptr) {
dlerror();
*(reinterpret_cast<void **>(&float16_to_float32_func_)) = dlsym(fp16_op_handler, "Float16ToFloat32_fp16_handler");
*(reinterpret_cast<void **>(&float32_to_float16_func_)) = dlsym(fp16_op_handler, "Float32ToFloat16_fp16_handler");
auto dlopen_error = dlerror();
if (dlopen_error != nullptr) {
MS_LOG(ERROR) << "load float16 cast func failed! " << dlopen_error << ".";
}
}
#endif
}
~Float16CastUtil() = default;

public:
Float16CastFunc float16_to_float32_func_ = nullptr;
Float16CastFunc float32_to_float16_func_ = nullptr;
};

int SubGraphKernel::Prepare() {
for (auto node : this->nodes_) {
@@ -208,7 +178,7 @@ int CpuFp16SubGraph::PreProcess() {
}

int CpuFp16SubGraph::PostProcess() {
auto fp16_to_fp32_cast_func = Float16CastUtil::GetInstance()->float16_to_float32_func_;
auto fp16_to_fp32_cast_func = lite::Float16CastUtil::GetInstance()->float16_to_float32_func_;
if (fp16_to_fp32_cast_func == nullptr) {
MS_LOG(ERROR) << "Can not find cast fp16 to fp32 func";
return RET_ERROR;


+ 1
- 0
mindspore/lite/src/tensor.h View File

@@ -35,6 +35,7 @@ struct QuantArg {
int32_t zeroPoint;
double var_corr{1};
double mean_corr{0};
bool inited;
std::vector<float> clusters{};
};



+ 3
- 3
mindspore/lite/test/models_tflite_awaretraining.cfg View File

@@ -33,6 +33,6 @@ lite-model_on_device_vision_classifier_landmarks_classifier_oceania_antarctica_V
lite-model_on_device_vision_classifier_landmarks_classifier_europe_V1_1.tflite
lite-model_on_device_vision_classifier_landmarks_classifier_south_america_V1_1.tflite
vision_classifier_fungi_mobile_V1_1_default_1.tflite
detect.tflite
ssd_mobilenet_v1_1_default_1.tflite
object_detection_mobile_object_localizer_v1_1_default_1.tflite
#detect.tflite
#ssd_mobilenet_v1_1_default_1.tflite
#object_detection_mobile_object_localizer_v1_1_default_1.tflite

+ 2
- 2
mindspore/lite/test/run_benchmark_nets.sh View File

@@ -121,8 +121,8 @@ function Run_Converter() {
continue
fi
echo ${model_name} >> "${run_converter_log_file}"
echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}' --quantType=AwareTraining' >> "${run_converter_log_file}"
./converter_lite --fmk=TFLITE --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name} --quantType=AwareTraining
echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}' --inputDataType=FLOAT --outputDataType=FLOAT' >> "${run_converter_log_file}"
./converter_lite --fmk=TFLITE --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name} --inputDataType=FLOAT --outputDataType=FLOAT
if [ $? = 0 ]; then
converter_result='converter aware_training '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file}
else


+ 1
- 0
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -544,6 +544,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
} else {
auto ms_tensor = new schema::TensorT();
ms_tensor->nodeType = schema::NodeType_CNode;
ms_tensor->dataType = TypeId::kNumberTypeFloat32;
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
node_id_map_[cnode_name] = meta_graphT->allTensors.size();
meta_graphT->allTensors.emplace_back(ms_tensor);


+ 23
- 23
mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc View File

@@ -73,30 +73,30 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s
auto primitiveCValue = PrimitiveC::Create(cNode->primitive.release());
cNode->primitive = nullptr;
// add quant parameter
if (cNode->quantType != schema::QuantType_PostTraining && cNode->quantType != schema::QuantType_WeightQuant) {
primitiveCValue->SetQuantType(cNode->quantType);
for (int index : cNode->inputIndex) {
if (!meta_graph_->allTensors[index]->quantParams.empty()) {
std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size());
std::transform(
meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(),
quant_params.begin(),
[](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; });
primitiveCValue->AddInputQuantParam(quant_params);
} else {
std::vector<schema::QuantParamT> empty_quant_params;
primitiveCValue->AddInputQuantParam(empty_quant_params);
}
for (auto index : cNode->inputIndex) {
if (!meta_graph_->allTensors[index]->quantParams.empty()) {
std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size());
std::transform(
meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(),
quant_params.begin(),
[](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; });
primitiveCValue->AddInputQuantParam(quant_params);
} else {
std::vector<schema::QuantParamT> notinited_quant_params(1);
primitiveCValue->AddInputQuantParam(notinited_quant_params);
}
for (int index : cNode->outputIndex) {
if (!meta_graph_->allTensors[index]->quantParams.empty()) {
std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size());
std::transform(
meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(),
quant_params.begin(),
[](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; });
primitiveCValue->AddOutputQuantParam(quant_params);
}
}
for (auto index : cNode->outputIndex) {
if (!meta_graph_->allTensors[index]->quantParams.empty()) {
std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size());
std::transform(
meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(),
quant_params.begin(),
[](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; });
primitiveCValue->AddOutputQuantParam(quant_params);
} else {
std::vector<schema::QuantParamT> notinited_quant_params(1);
primitiveCValue->AddOutputQuantParam(notinited_quant_params);
}
}
auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitiveCValue));


+ 20
- 25
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -52,16 +52,13 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);

// for now - trainning is not supporting fuse operations
if (config != nullptr && config->trainModel == false) {
if (config != nullptr && !config->trainModel) {
// remove quantdtype when awaretraining
if (config->fmk == lite::converter::FmkType_ONNX) {
auto remove_identity_pass = std::make_shared<opt::RemoveIdentityOpPass>();
remove_identity_pass->SetFmkType(config->fmk);
pm->AddPass(remove_identity_pass);
}
if (config->quantType == QuantType_AwareTraining) {
pm->AddPass(std::make_shared<opt::QuantDtypeCastFusion>());
}
pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>());
pm->AddPass(std::make_shared<opt::ConvScaleFusion>());
@@ -101,27 +98,25 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
return nullptr;
}
// quant
if (config != nullptr) {
if (config->quantType == schema::QuantType_PostTraining) {
this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8);
if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return nullptr;
}
} else if (config->quantType == schema::QuantType_WeightQuant) {
if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) {
MS_LOG(ERROR) << "weight quant input param error";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantWeightSize,
config->quantWeightChannel, config->bitNum);
if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return nullptr;
}
if (config->quantType == schema::QuantType_PostTraining) {
this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8);
if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return nullptr;
}
} else if (config->quantType == schema::QuantType_WeightQuant) {
if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) {
MS_LOG(ERROR) << "weight quant input param error";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantWeightSize,
config->quantWeightChannel, config->bitNum);
if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return nullptr;
}
}
if (mQuantizer != nullptr) {


+ 0
- 1
mindspore/lite/tools/converter/converter.cc View File

@@ -93,7 +93,6 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
}
// transform
transform->SetGraphDef(meta_graph);
transform->CreateQuantizer(flag);
auto status = transform->Transform(*flag);
if (status != RET_OK) {
MS_LOG(ERROR) << "Transform meta graph failed " << status;


+ 35
- 14
mindspore/lite/tools/converter/converter_flags.cc View File

@@ -29,9 +29,14 @@ Flags::Flags() {
AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel",
"");
AddFlag(&Flags::inferenceTypeIn, "inferenceType", "Data type of input and output tensors. FLOAT | INT8 | UINT8",
"FLOAT");
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", "");
AddFlag(&Flags::inputDataTypeIn, "inputDataType",
"Data type of input tensors, default is same with the type defined in model. FLOAT | INT8 | UINT8 | DEFAULT",
"DEFAULT");
AddFlag(&Flags::outputDataTypeIn, "outputDataType",
"Data type of output and output tensors, default is same with the type defined in model. FLOAT | INT8 | "
"UINT8 | DEFAULT",
"DEFAULT");
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", "");
AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8");
AddFlag(&Flags::quantWeightSize, "quantWeightSize", "Weight quantization size threshold", "0");
AddFlag(&Flags::quantWeightChannel, "quantWeightChannel", "Channel threshold for weight quantization", "16");
@@ -78,15 +83,32 @@ int Flags::Init(int argc, const char **argv) {
return RET_INPUT_PARAM_INVALID;
}

if (this->inferenceTypeIn == "FLOAT") {
this->inferenceType = TypeId::kNumberTypeFloat;
} else if (this->inferenceTypeIn == "INT8") {
this->inferenceType = TypeId::kNumberTypeInt8;
} else if (this->inferenceTypeIn == "UINT8") {
this->inferenceType = TypeId::kNumberTypeUInt8;
if (this->inputDataTypeIn == "FLOAT") {
this->inputDataType = TypeId::kNumberTypeFloat;
} else if (this->inputDataTypeIn == "INT8") {
this->inputDataType = TypeId::kNumberTypeInt8;
} else if (this->inputDataTypeIn == "UINT8") {
this->inputDataType = TypeId::kNumberTypeUInt8;
} else if (this->inputDataTypeIn == "DEFAULT") {
this->inputDataType = TypeId::kTypeUnknown;
} else {
std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | UINT8",
this->inferenceTypeIn.c_str();
std::cerr << "INPUT INVALID: inputDataType is invalid: %s, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT",
this->inputDataTypeIn.c_str();
return RET_INPUT_PARAM_INVALID;
}

if (this->outputDataTypeIn == "FLOAT") {
this->outputDataType = TypeId::kNumberTypeFloat;
} else if (this->outputDataTypeIn == "INT8") {
this->outputDataType = TypeId::kNumberTypeInt8;
} else if (this->outputDataTypeIn == "UINT8") {
this->outputDataType = TypeId::kNumberTypeUInt8;
} else if (this->outputDataTypeIn == "DEFAULT") {
this->outputDataType = TypeId::kTypeUnknown;
} else {
std::cerr
<< "INPUT INVALID: outputDataType is invalid: %s, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT",
this->outputDataTypeIn.c_str();
return RET_INPUT_PARAM_INVALID;
}

@@ -107,9 +129,8 @@ int Flags::Init(int argc, const char **argv) {
std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag";
return RET_INPUT_PARAM_INVALID;
}
if (this->quantTypeIn == "AwareTraining") {
this->quantType = QuantType_AwareTraining;
} else if (this->quantTypeIn == "WeightQuant") {

if (this->quantTypeIn == "WeightQuant") {
this->quantType = QuantType_WeightQuant;
} else if (this->quantTypeIn == "PostTraining") {
this->quantType = QuantType_PostTraining;


+ 4
- 1
mindspore/lite/tools/converter/converter_flags.h View File

@@ -53,8 +53,11 @@ class Flags : public virtual mindspore::lite::FlagParser {
std::string quantTypeIn;
QuantType quantType;
std::string inferenceTypeIn;
std::string inputDataTypeIn;
std::string outputDataTypeIn;
// used for parse aware trainning
TypeId inferenceType = TypeId::kNumberTypeFloat;
TypeId inputDataType;
TypeId outputDataType;
// used for post-trainning-weight
std::string quantWeightSize;
std::string bitNum;


+ 20
- 44
mindspore/lite/tools/converter/graphdef_transform.cc View File

@@ -34,6 +34,8 @@
#include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
#include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h"
#include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h"
#include "tools/converter/quantizer/aware_quantizer.h"

using std::string;
@@ -44,20 +46,6 @@ GraphDefTransform::~GraphDefTransform() = default;

void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; }

void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) {
auto type = flags->quantType;
switch (type) {
case QuantType::QuantType_AwareTraining: {
MS_LOG(INFO) << "create AwareTrainingQuantizer!";
fbQuantizer = std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inferenceType);
break;
}
default:
MS_LOG(INFO) << "will support quantizer type " << flags->quantTypeIn << " in the future";
break;
}
}

int GraphDefTransform::Transform(const converter::Flags &ctx) {
STATUS status;
{
@@ -84,26 +72,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {

// generate and infer quant parameters
{
if (fbQuantizer != nullptr) {
Optimizer topologicalOptimizer;
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
status = topologicalOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
if (ctx.quantType == QuantType_AwareTraining) {
status = fbQuantizer->GenerateQuantParam();
if (status != RET_OK) {
MS_LOG(ERROR) << "GenerateQuantParam failed";
return status;
}
status = fbQuantizer->DetermineNodeQuantType();
if (status != RET_OK) {
MS_LOG(ERROR) << "DetermineNodeQuant failed";
return status;
}
}
Optimizer inferQuantParamPass;
inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass());
inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass());
status = inferQuantParamPass.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
}

@@ -146,12 +121,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
{
Optimizer fusionOptimizer;
fusionOptimizer.AddPass(new (std::nothrow) FormatTransPermuteFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
status = fusionOptimizer.Run(graphDefT);
Optimizer inferQuantParamOtimizer;
inferQuantParamOtimizer.AddPass(new (std::nothrow) InferQuantParamPass());
status = inferQuantParamOtimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run tensorQuantOptimizer graphPasses Failed";
return status;
}
}
@@ -168,8 +142,10 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}

// do quantization
if (fbQuantizer != nullptr) {
status = fbQuantizer->DoQuantize();
{
Optimizer fusionOptimizer;
fusionOptimizer.AddPass(new (std::nothrow) TensorQuantPass());
status = fusionOptimizer.Run(graphDefT);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantize failed!";
return status;
@@ -177,11 +153,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}

// insert quantNode and deQuantNode
if (ctx.quantType == QuantType_AwareTraining) {
{
Optimizer quantNodeOptimizer;
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
dTypeTransPass->SetInputDataDType(ctx.inferenceType);
dTypeTransPass->SetOutputDataDType(ctx.inferenceType);
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
quantNodeOptimizer.AddPass(dTypeTransPass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());


+ 0
- 4
mindspore/lite/tools/converter/graphdef_transform.h View File

@@ -37,14 +37,10 @@ class GraphDefTransform {
virtual int Transform(const converter::Flags &ctx);
void SetGraphDef(schema::MetaGraphT *dstDef);
inline schema::MetaGraphT *GetOutput() { return graphDefT; }
void CreateQuantizer(const converter::Flags *flags);

protected:
schema::MetaGraphT *graphDefT = nullptr;
Optimizer *optimizer = nullptr;

std::unique_ptr<quant::Quantizer> mQuantizer;
std::unique_ptr<quant::FbQuantizer> fbQuantizer;
};
} // namespace lite
} // namespace mindspore


+ 2
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt View File

@@ -10,6 +10,8 @@ file(GLOB GRAPH_PASS
${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/trans_format_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/tensor_quant_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/infer_quant_param_pass.cc
)
set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
add_library(graph_pass_mid OBJECT ${GRAPH_PASS})

+ 8
- 73
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc View File

@@ -27,9 +27,6 @@ namespace lite {
#define kMinInputNum 1
#define kOutputNum 1

static const std::set<schema::PrimitiveType> NoNeedDtypeTransList = {
PrimitiveType_QuantDTypeCast, PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw};

STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);

@@ -44,12 +41,6 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status;
return status;
}

status = DoNodeInoutDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status;
return status;
}
return RET_OK;
}

@@ -57,7 +48,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
auto &graphInIdxes = graph->inputIndex;

if (this->inputDataDType == TypeId::kNumberTypeInt8) {
if (this->inputDataDType == TypeId::kNumberTypeInt8 || this->inputDataDType == TypeId::kTypeUnknown) {
return RET_OK;
}
if (this->inputDataDType != TypeId::kNumberTypeFloat && this->inputDataDType != TypeId::kNumberTypeUInt8) {
@@ -68,7 +59,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
for (auto graphInIdx : graphInIdxes) {
MS_ASSERT(graphInIdx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graphInIdx);
if (tensor->dims.size() != kNHWCDimNumber || tensor->dataType != kNumberTypeInt8) {
if (tensor->dataType != kNumberTypeInt8 || tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
continue;
}

@@ -98,7 +89,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {

STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
if (outputDataDType == TypeId::kNumberTypeInt8) {
if (outputDataDType == TypeId::kNumberTypeInt8 || outputDataDType == TypeId::kTypeUnknown) {
return RET_OK;
}
if (this->outputDataDType != TypeId::kNumberTypeFloat && this->outputDataDType != TypeId::kNumberTypeUInt8) {
@@ -107,6 +98,11 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
}
auto &graphOutIdxes = graph->outputIndex;
for (auto graphOutIdx : graphOutIdxes) {
MS_ASSERT(graphOutIdx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graphOutIdx);
if (tensor->dataType != kNumberTypeInt8 || tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
continue;
}
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto nodeName = (*iter)->name;
MS_ASSERT(node != nullptr);
@@ -131,67 +127,6 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
return RET_OK;
}

STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// insert transNode before and after existNode
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) {
continue;
}
auto iterType = GetCNodeTType(**iter);
if (NoNeedDtypeTransList.find(iterType) != NoNeedDtypeTransList.end()) {
continue;
}
bool needInsertPost = true;
if (GetCNodeTType(**iter) == PrimitiveType_Shape) {
needInsertPost = false;
}
auto nodeName = (*iter)->name;
if ((*iter)->inputIndex.size() < kMinInputNum) {
MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least";
return RET_ERROR;
}
STATUS status;
// insert pre
for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) {
MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i));
auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i));
if (preTensor->dataType == TypeId::kNumberTypeInt || preTensor->dataType == TypeId::kNumberTypeInt32) {
continue;
}
auto &graphInIdxes = graph->inputIndex;
if (!preTensor->data.empty() && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) {
continue;
}
if ((preTensor->dataType != TypeId::kNumberTypeInt8) && (IsContain(graphInIdxes, (*iter)->inputIndex.at(i)))) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed";
return RET_ERROR;
}
}

if (needInsertPost) {
for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) {
auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i));
if (postTensor->dataType == TypeId::kNumberTypeInt || postTensor->dataType == TypeId::kNumberTypeInt32) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed";
return RET_ERROR;
}
}
}
(*iter)->quantType = QuantType_QUANT_NONE;
}

return RET_OK;
}

NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place,
size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) {
MS_ASSERT((*existNodeIter) != nullptr);


+ 0
- 2
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h View File

@@ -45,8 +45,6 @@ class DTypeTransPass : public GraphPass {

STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph);

STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph);

NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
DTypeTransNodeType nodeType, STATUS *errorCode);



+ 87
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc View File

@@ -0,0 +1,87 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include "src/common/utils.h"
#include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h"
#include "tools/converter/quantizer/calc_quant_param.h"
#include "tools/common/node_util.h"
#include "tools/common/converter_op_utils.h"

namespace mindspore::lite {
STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) {
auto *quantParamRegister = QuantParamCalcRegister::GetInstance();

for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
MS_ASSERT(node != nullptr);
if (node->quantType == schema::QuantType_WeightQuant) {
continue;
}
if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax ||
GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) {
MS_ASSERT(false);
}
auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
if (quantParamCalcer == nullptr) {
MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str()
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
node->quantType = static_cast<schema::QuantType>(schema::QuantType_QUANT_NONE);
} else {
auto status = quantParamCalcer->Calc(graph, *node);
if (status != RET_OK) {
MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
node->quantType = schema::QuantType_QUANT_NONE;
} else {
DetermineNodeQuantType(*graph, node.get());
}
}
}
return RET_OK;
}

void InferQuantParamPass::DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode) {
MS_ASSERT(cnode != nullptr);
bool canQuant = true;
for (auto &inputTensorIdx : cnode->inputIndex) {
MS_ASSERT(graph.allTensors.size() > inputTensorIdx);
auto &inTensor = graph.allTensors.at(inputTensorIdx);
MS_ASSERT(inTensor != nullptr);
if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr ||
!inTensor->quantParams.front()->inited) {
canQuant = false;
break;
}
}

for (auto &outTensorIdx : cnode->outputIndex) {
MS_ASSERT(graph.allTensors.size() > outTensorIdx);
auto &outTensor = graph.allTensors.at(outTensorIdx);
MS_ASSERT(outTensor != nullptr);
if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr ||
!outTensor->quantParams.front()->inited) {
canQuant = false;
break;
}
}

if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*cnode))) {
cnode->quantType = schema::QuantType_AwareTraining;
} else {
cnode->quantType = schema::QuantType_QUANT_NONE;
}
}
} // namespace mindspore::lite

+ 39
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h View File

@@ -0,0 +1,39 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_INFER_QUANT_PARAM_PASS_H
#define LITE_INFER_QUANT_PARAM_PASS_H

#include <memory>
#include "tools/converter/optimizer.h"
#include "tools/common/graph_util.h"

namespace mindspore {
namespace lite {
class InferQuantParamPass : public GraphPass {
public:
InferQuantParamPass() {}

~InferQuantParamPass() override = default;

STATUS Run(schema::MetaGraphT *graph) override;

private:
void DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode);
};
} // namespace lite
} // namespace mindspore

#endif // LITE_INFER_QUANT_PARAM_PASS_H

+ 87
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc View File

@@ -0,0 +1,87 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <vector>
#include <cmath>
#include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/common/tensor_util.h"

namespace mindspore::lite {
STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
for (auto &tensor : graph->allTensors) {
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) {
continue;
}
if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat &&
tensor->dataType != TypeId::kNumberTypeUInt8) {
continue;
}
// perlayer
if (tensor->quantParams.size() == 1) {
auto &quantParam = tensor->quantParams.front();
size_t wShapeSize = GetShapeSize(*(tensor.get()));
void *oriWeightData = tensor->data.data();
if (quantParam->dstDtype == TypeId::kNumberTypeInt8) {
std::vector<int8_t> qDatas(wShapeSize);
auto weightQauntParam = GetTensorQuantParam(tensor);
if (tensor->dataType == TypeId::kNumberTypeFloat ||
tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant
auto *weightData = static_cast<float *>(oriWeightData);
for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = quant::QuantizeData<int8_t>(weightData[j], weightQauntParam.get());
}
} else { // tflite awareing quant
auto *weightData = static_cast<uint8_t *>(oriWeightData);
for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = (int32_t)weightData[j] - 128;
}
weightQauntParam->zeroPoint -= 128;
tensor->quantParams.clear();
tensor->quantParams.emplace_back(weightQauntParam.release());
}
tensor->dataType = TypeId::kNumberTypeInt8;
::memcpy(tensor->data.data(), qDatas.data(), wShapeSize);
} else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) {
// quant bias data
auto bShapeSize = GetShapeSize(*(tensor.get()));
std::unique_ptr<int32_t[]> qDatas(new (std::nothrow) int32_t[bShapeSize]);
if (qDatas == nullptr) {
MS_LOG(ERROR) << "new qDatas failed";
return RET_ERROR;
}
void *biasData = tensor->data.data();
auto *rawDatas = static_cast<float *>(biasData);
for (size_t i = 0; i < bShapeSize; ++i) {
qDatas[i] = (int32_t)std::round(rawDatas[i] / quantParam->scale);
}
tensor->dataType = TypeId::kNumberTypeInt32;
tensor->data.clear();
tensor->data.resize(bShapeSize * sizeof(int32_t));
auto ret =
memcpy_s(tensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed: " << ret;
return RET_ERROR;
}
}
} else { // pertensor
}
}
return RET_OK;
}

} // namespace mindspore::lite

+ 36
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TENSOR_QUANT_PASS_H
#define LITE_TENSOR_QUANT_PASS_H

#include <memory>
#include "tools/converter/optimizer.h"
#include "tools/common/graph_util.h"

namespace mindspore {
namespace lite {
class TensorQuantPass : public GraphPass {
public:
TensorQuantPass() {}

~TensorQuantPass() override = default;

STATUS Run(schema::MetaGraphT *graph) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_TENSOR_QUANT_PASS_H

+ 1
- 1
mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc View File

@@ -52,7 +52,7 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info,
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
attr->srcT = kNumberTypeInt8;
attr->dstT = GetTfliteDataType(out_tensor->type);
op->primitive->value.value = attr.release();
op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast;


+ 1
- 1
mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc View File

@@ -52,7 +52,7 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u
return RET_NULL_PTR;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
attr->dstT = GetTfliteDataType(out_tensor->type);
attr->dstT = kNumberTypeInt8;
op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
op->primitive->value.value = attr.release();
} else {


+ 50
- 254
mindspore/lite/tools/converter/quantizer/aware_quantizer.cc View File

@@ -35,29 +35,10 @@ using std::string;
using std::vector;

namespace mindspore::lite::quant {
const std::array<schema::PrimitiveType, 7> AwareQuantizer::propagatedOps = {
{schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, schema::PrimitiveType_Reshape,
schema::PrimitiveType_Squeeze, schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation,
schema::PrimitiveType_DetectionPostProcess}};

AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType) : FbQuantizer(graph) {}

STATUS AwareQuantizer::RemoveFakeQuant() { return RET_OK; }

STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph) {
MS_ASSERT(subGraph != nullptr);
for (const auto &tensor : subGraph->allTensors) {
if (!tensor->quantParams.empty()) {
continue;
}
std::unique_ptr<schema::QuantParamT> defaultQuantParam(new QuantParamT());
tensor->quantParams.emplace_back(std::move(defaultQuantParam));
}
return RET_OK;
}

STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { return RET_OK; }

STATUS AwareQuantizer::GenerateQuantParam() {
auto *quantParamRegister = QuantParamCalcRegister::GetInstance();

@@ -70,13 +51,13 @@ STATUS AwareQuantizer::GenerateQuantParam() {
}
auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
if (quantParamCalcer == nullptr) {
MS_LOG(INFO) << "Can not find QuantParamCalcer for " << node->name.c_str()
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str()
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
} else {
auto status = quantParamCalcer->Calc(graph, *node);
if (status != RET_OK) {
MS_LOG(INFO) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
node->quantType = schema::QuantType_QUANT_NONE;
} else {
node->quantType = schema::QuantType_AwareTraining;
@@ -87,250 +68,65 @@ STATUS AwareQuantizer::GenerateQuantParam() {
}

STATUS AwareQuantizer::DoQuantize() {
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
if (!IsContain(GetInt8OpList(), GetCNodeTType(*node))) {
for (auto &tensor : graph->allTensors) {
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) {
continue;
}
if (node->quantType != schema::QuantType_AwareTraining) {
if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat &&
tensor->dataType != TypeId::kNumberTypeUInt8) {
continue;
}
STATUS status;
if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D ||
GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D ||
GetCNodeTType(*node) == schema::PrimitiveType_DeConv2D ||
GetCNodeTType(*node) == schema::PrimitiveType_FullConnection ||
GetCNodeTType(*node) == schema::PrimitiveType_MatMul) {
auto inputIndexes = node->inputIndex;
if (inputIndexes.size() < 2) {
MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count";
return RET_ERROR;
}
// quant weight
auto &weightTensor = graph->allTensors.at(node->inputIndex.at(1));
if (!weightTensor->quantParams.empty() && weightTensor->quantParams.at(0)->inited) {
status = QuantConvWeight(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantConvWeight failed!";
return RET_ERROR;
}
}
// quant bias
if (inputIndexes.size() == 3) {
auto &biasTensor = graph->allTensors.at(node->inputIndex.at(2));
if (!biasTensor->quantParams.empty() && biasTensor->quantParams.at(0)->inited) {
status = QuantConvBias(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantConvBias failed!";
return RET_ERROR;
// perlayer
if (tensor->quantParams.size() == 1) {
auto &quantParam = tensor->quantParams.front();
size_t wShapeSize = GetShapeSize(*(tensor.get()));
void *oriWeightData = tensor->data.data();
if (quantParam->dstDtype == TypeId::kNumberTypeInt8) {
vector<int8_t> qDatas(wShapeSize);
auto weightQauntParam = GetTensorQuantParam(tensor);
if (tensor->dataType == TypeId::kNumberTypeFloat ||
tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant
auto *weightData = static_cast<float *>(oriWeightData);
for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get());
}
} else { // tflite awareing quant
auto *weightData = static_cast<uint8_t *>(oriWeightData);
for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = (int32_t)weightData[j] - 128;
}
weightQauntParam->zeroPoint -= 128;
tensor->quantParams.clear();
tensor->quantParams.emplace_back(weightQauntParam.release());
}
}
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) {
status = QuantDetectionPostProcessConstTensor(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!";
return RET_ERROR;
}
} else if (GetCNodeTType(*node) == schema::PrimitiveType_Add ||
GetCNodeTType(*node) == schema::PrimitiveType_Scale ||
GetCNodeTType(*node) == schema::PrimitiveType_Mul) {
status = QuantArithmeticConstTensor(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantArithmeticConstTensor failed!";
return RET_ERROR;
}
}
const auto nodeType = GetCNodeTType(*node);
auto find = std::find(propagatedOps.begin(), propagatedOps.end(), nodeType);
if (find != propagatedOps.end()) {
auto inputTensor = graph->allTensors.at(node->inputIndex[0]).get();
auto outputTensor = graph->allTensors.at(node->outputIndex[0]).get();
MS_ASSERT(inputTensor != nullptr);
MS_ASSERT(outputTensor != nullptr);
outputTensor->dataType = inputTensor->dataType;
}
}
return RET_OK;
}

STATUS AwareQuantizer::QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(node != nullptr);
for (size_t i = 0; i < node->inputIndex.size(); i++) {
auto inTensorIdx = node->inputIndex.at(i);
MS_ASSERT(graph->allTensors.size() > inTensorIdx);
auto &inTensor = graph->allTensors.at(inTensorIdx);
MS_ASSERT(inTensor != nullptr);
if (!inTensor->data.empty()) {
if (inTensor->dataType == TypeId::kNumberTypeInt8) {
continue;
}
if (inTensor->dataType != TypeId::kNumberTypeFloat32 && inTensor->dataType != TypeId::kNumberTypeFloat &&
inTensor->dataType != TypeId::kNumberTypeUInt8) {
MS_LOG(ERROR) << node->name.c_str() << "'s weight data is not float or uint8";
return RET_ERROR;
}

auto quantParam = GetTensorQuantParam(inTensor);
MS_ASSERT(quantParam != nullptr);
MS_ASSERT(quantParam->inited);
auto constTensorShapeSize = GetShapeSize(*(inTensor.get()));
vector<int8_t> qDatas(constTensorShapeSize);
void *inData = inTensor->data.data();
if (inTensor->dataType == TypeId::kNumberTypeFloat ||
inTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant
auto *weightData = static_cast<float *>(inData);
for (size_t j = 0; j < constTensorShapeSize; j++) {
qDatas[j] = QuantizeData<int8_t>(weightData[j], quantParam.get());
::memcpy(tensor->data.data(), qDatas.data(), wShapeSize);
} else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) {
// quant bias data
auto bShapeSize = GetShapeSize(*(tensor.get()));
std::unique_ptr<int32_t[]> qDatas(new (std::nothrow) int32_t[bShapeSize]);
if (qDatas == nullptr) {
MS_LOG(ERROR) << "new qDatas failed";
return RET_ERROR;
}
} else { // tflite awareing quant
auto *weightData = static_cast<uint8_t *>(inData);
for (size_t j = 0; j < constTensorShapeSize; j++) {
qDatas[j] = (int32_t)weightData[j] - 128;
void *biasData = tensor->data.data();
auto *rawDatas = static_cast<float *>(biasData);
for (size_t i = 0; i < bShapeSize; ++i) {
qDatas[i] = (int32_t)std::round(rawDatas[i] / quantParam->scale);
}
tensor->dataType = TypeId::kNumberTypeInt32;
tensor->data.clear();
tensor->data.resize(bShapeSize * sizeof(int32_t));
auto ret =
memcpy_s(tensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed: " << ret;
return RET_ERROR;
}
quantParam->zeroPoint -= 128;
inTensor->quantParams.clear();
inTensor->quantParams.emplace_back(quantParam.release());
}

::memcpy(inTensor->data.data(), qDatas.data(), constTensorShapeSize);
inTensor->dataType = TypeId::kNumberTypeInt8;
}
}
return RET_OK;
}

STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
MS_ASSERT(subGraph != nullptr);
MS_ASSERT(node != nullptr);
auto &constTensor = subGraph->allTensors.at(node->inputIndex[2]);
MS_ASSERT(constTensor != nullptr);
const auto *constData = reinterpret_cast<const float *>(constTensor->data.data());

if (!constTensor->data.empty() &&
(constTensor->dataType == TypeId::kNumberTypeFloat || constTensor->dataType == TypeId::kNumberTypeFloat32)) {
size_t constTensorShapeSize = GetShapeSize(*constTensor);
std::unique_ptr<QuantParamT> quantParam = GetTensorQuantParam(constTensor);
if (quantParam == nullptr) {
MS_LOG(ERROR) << "new QuantParamT failed";
return RET_NULL_PTR;
}
vector<int8_t> qDatas(constTensorShapeSize);
for (size_t j = 0; j < constTensorShapeSize; j++) {
float rawData = constData[j];
qDatas[j] = QuantizeData<int8_t>(rawData, quantParam.get());
}
::memcpy(constTensor->data.data(), qDatas.data(), constTensorShapeSize);
constTensor->dataType = TypeId::kNumberTypeInt8;
}
return RET_OK;
}

STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, mindspore::schema::CNodeT *node) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(node != nullptr);
auto inputIndexes = node->inputIndex;
MS_ASSERT(inputIndexes.size() >= 3);
MS_ASSERT(graph->allTensors.size() > inputIndexes.at(0));
MS_ASSERT(graph->allTensors.size() > inputIndexes.at(1));
MS_ASSERT(graph->allTensors.size() > inputIndexes.at(2));
auto &biasTensor = graph->allTensors.at(inputIndexes.at(2));
MS_ASSERT(biasTensor != nullptr);
if (biasTensor->dataType == TypeId::kNumberTypeInt32) {
return RET_OK;
}
if (biasTensor->dataType != TypeId::kNumberTypeFloat && biasTensor->dataType != TypeId::kNumberTypeFloat32) {
MS_LOG(ERROR) << "conv " << node->name << "'s bias data is not float";
return RET_ERROR;
}
auto &inputTensor = graph->allTensors.at(inputIndexes.at(0));
auto &weightTensor = graph->allTensors.at(inputIndexes.at(1));

MS_ASSERT(inputTensor != nullptr);
MS_ASSERT(weightTensor != nullptr);
auto inputScale = inputTensor->quantParams.front()->scale;
auto weightScale = weightTensor->quantParams.front()->scale;
auto scale = inputScale * weightScale;
// set bias quant param
std::unique_ptr<QuantParamT> biasQuantParam = GetTensorQuantParam(biasTensor);
if (biasQuantParam == nullptr) {
MS_LOG(ERROR) << "new QuantParamT failed";
return RET_ERROR;
}
biasQuantParam->inited = true;
biasQuantParam->scale = scale;
biasQuantParam->zeroPoint = 0;
biasQuantParam->numBits = 8;
biasQuantParam->narrowRange = false;
biasQuantParam->min = 0.0;
biasQuantParam->max = 0.0;

// quant bias data
auto bShapeSize = GetShapeSize(*(biasTensor.get()));
std::unique_ptr<int32_t[]> qDatas(new (std::nothrow) int32_t[bShapeSize]);
if (qDatas == nullptr) {
MS_LOG(ERROR) << "new qDatas failed";
return RET_ERROR;
}
void *biasData = biasTensor->data.data();
auto *rawDatas = static_cast<float *>(biasData);
for (size_t i = 0; i < bShapeSize; ++i) {
qDatas[i] = (int32_t)std::round(rawDatas[i] / scale);
}
biasTensor->dataType = TypeId::kNumberTypeInt32;
biasTensor->data.clear();
biasTensor->data.resize(bShapeSize * sizeof(int32_t));
auto ret =
memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed: " << ret;
return RET_ERROR;
}
return RET_OK;
}

STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node) {
MS_ASSERT(subGraph != nullptr);
MS_ASSERT(node != nullptr);
MS_ASSERT(node->quantParam.size() == node->inputIndex.size() + node->outputIndex.size());
auto inputIndexes = node->inputIndex;
MS_ASSERT(inputIndexes.size() >= 2);
MS_ASSERT(subGraph->allTensors.size() > inputIndexes.at(1));
auto &weightTensor = subGraph->allTensors.at(inputIndexes.at(1));
if (weightTensor->dataType == TypeId::kNumberTypeInt8) {
return RET_OK;
}
if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && weightTensor->dataType != TypeId::kNumberTypeFloat &&
weightTensor->dataType != TypeId::kNumberTypeUInt8) {
MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8";
return RET_ERROR;
}
size_t wShapeSize = GetShapeSize(*(weightTensor.get()));
void *oriWeightData = weightTensor->data.data();
MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr);
vector<int8_t> qDatas(wShapeSize);
// todo support perchannel
auto weightQauntParam = GetTensorQuantParam(weightTensor);
if (weightTensor->dataType == TypeId::kNumberTypeFloat ||
weightTensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant
auto *weightData = static_cast<float *>(oriWeightData);
for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get());
}
} else { // tflite awareing quant
auto *weightData = static_cast<uint8_t *>(oriWeightData);
for (size_t j = 0; j < wShapeSize; j++) {
qDatas[j] = (int32_t)weightData[j] - 128;
} else { // pertensor
}
weightQauntParam->zeroPoint -= 128;
weightTensor->quantParams.clear();
weightTensor->quantParams.emplace_back(weightQauntParam.release());
}

weightTensor->data.resize(wShapeSize * sizeof(uint8_t));
::memcpy(weightTensor->data.data(), qDatas.data(), wShapeSize);
weightTensor->dataType = TypeId::kNumberTypeInt8;
return RET_OK;
}
STATUS AwareQuantizer::DetermineNodeQuantType() {


+ 0
- 18
mindspore/lite/tools/converter/quantizer/aware_quantizer.h View File

@@ -39,24 +39,6 @@ class AwareQuantizer : public FbQuantizer {
STATUS DetermineNodeQuantType() override;

STATUS DoQuantize() override; // override;

private:
// RemoveFakeQuant
STATUS SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node);

STATUS GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph);

STATUS QuantArithmeticConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node);

STATUS QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node);

STATUS QuantConvBias(const schema::MetaGraphT *graph, schema::CNodeT *node);

STATUS QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node);

float inputScale = 0.0f;

static const std::array<schema::PrimitiveType, 7> propagatedOps;
};
} // namespace mindspore::lite::quant
#endif

+ 40
- 7
mindspore/lite/tools/converter/quantizer/calc_quant_param.cc View File

@@ -26,6 +26,9 @@
#include "tools/converter/quantizer/quantize_util.h"

namespace mindspore::lite {
static constexpr size_t BIAS_SIZE = 3;
static constexpr size_t BIAS_ADD_SIZE = 2;

STATUS QuantParamCalcer::ComputeConstQuantParam(const schema::TensorT &tensor, QuantParamT *quantParam) {
MS_ASSERT(quantParam != nullptr);
// int32 weight no need to quant
@@ -126,6 +129,36 @@ int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) {
return RET_OK;
}

int ConvCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) {
auto status = CommonCalcer::Calc(subGraph, node);
if (status != RET_OK) {
MS_LOG(WARNING) << "Call CommonCalcer::Calc failed: " << status;
return status;
}
if (node.inputIndex.size() == BIAS_SIZE) {
auto &biasTensor = subGraph->allTensors.at(node.inputIndex.at(BIAS_SIZE - 1));
for (auto &quantParam : biasTensor->quantParams) {
quantParam->dstDtype = TypeId::kNumberTypeInt32;
}
}
return RET_OK;
}

int BiasAddCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) {
auto status = CommonCalcer::Calc(subGraph, node);
if (status != RET_OK) {
MS_LOG(WARNING) << "Call CommonCalcer::Calc failed: " << status;
return status;
}
if (node.inputIndex.size() == BIAS_ADD_SIZE) {
auto &biasTensor = subGraph->allTensors.at(node.inputIndex.at(BIAS_ADD_SIZE - 1));
for (auto &quantParam : biasTensor->quantParams) {
quantParam->dstDtype = TypeId::kNumberTypeInt32;
}
}
return RET_OK;
}

int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
auto status = QuantParamCalcer::Calc(graph, node);
if (status != RET_OK) {
@@ -474,10 +507,10 @@ QuantParamCalcRegister::QuantParamCalcRegister() {
_registerMap[schema::PrimitiveType_Activation] = std::make_shared<CalcActivation>();
_registerMap[schema::PrimitiveType_Add] = std::make_shared<CalcAdd>();
_registerMap[schema::PrimitiveType_Mul] = commonCalcer;
_registerMap[schema::PrimitiveType_Scale] = commonCalcer;
_registerMap[schema::PrimitiveType_Conv2D] = commonCalcer;
_registerMap[schema::PrimitiveType_DeConv2D] = commonCalcer;
_registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer;
_registerMap[schema::PrimitiveType_Scale] = std::make_shared<ConvCalcer>();
_registerMap[schema::PrimitiveType_Conv2D] = std::make_shared<ConvCalcer>();
_registerMap[schema::PrimitiveType_DeConv2D] = std::make_shared<ConvCalcer>();
_registerMap[schema::PrimitiveType_DepthwiseConv2D] = std::make_shared<ConvCalcer>();
_registerMap[schema::PrimitiveType_Pooling] = linearCalcer;
_registerMap[schema::PrimitiveType_Resize] = linearCalcer;
_registerMap[schema::PrimitiveType_Reshape] = linearCalcer;
@@ -487,11 +520,11 @@ QuantParamCalcRegister::QuantParamCalcRegister() {
_registerMap[schema::PrimitiveType_Squeeze] = linearCalcer;
_registerMap[schema::PrimitiveType_RealDiv] = std::make_shared<CalcRealDiv>();
_registerMap[schema::PrimitiveType_Reduce] = commonCalcer;
_registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer;
_registerMap[schema::PrimitiveType_BiasAdd] = std::make_shared<BiasAddCalcer>();
_registerMap[schema::PrimitiveType_Mean] = linearCalcer;
_registerMap[schema::PrimitiveType_Transpose] = linearCalcer;
_registerMap[schema::PrimitiveType_MatMul] = commonCalcer;
_registerMap[schema::PrimitiveType_FullConnection] = commonCalcer;
_registerMap[schema::PrimitiveType_MatMul] = std::make_shared<ConvCalcer>();
_registerMap[schema::PrimitiveType_FullConnection] = std::make_shared<ConvCalcer>();
_registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer;
_registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer;
// detection_postprocess op's quant param will not infer only fetch from preNode or postNode


+ 14
- 0
mindspore/lite/tools/converter/quantizer/calc_quant_param.h View File

@@ -46,6 +46,20 @@ class CommonCalcer : public QuantParamCalcer {
int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override;
};

class ConvCalcer : public CommonCalcer {
public:
ConvCalcer() = default;
~ConvCalcer() override = default;
int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override;
};

class BiasAddCalcer : public CommonCalcer {
public:
BiasAddCalcer() = default;
~BiasAddCalcer() override = default;
int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override;
};

class LinearCalcer : public QuantParamCalcer {
public:
LinearCalcer() = default;


+ 13
- 19
mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc View File

@@ -564,11 +564,8 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in
}
}

STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct MaxMin *max_min,
std::shared_ptr<PrimitiveC> lite_primitive) {
if (!lite_primitive->GetInputQuantParams().empty()) {
MS_LOG(DEBUG) << "input quant params not empty"; // multi-input op: like concat
}
STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min,
std::shared_ptr<PrimitiveC> lite_primitive, const size_t &index) {
schema::QuantParamT quant_param;
quant_param.scale = scale;
quant_param.zeroPoint = zeropoint;
@@ -577,15 +574,12 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct M
quant_param.numBits = bit_num;
quant_param.narrowRange = false;
std::vector<schema::QuantParamT> quant_params = {quant_param};
lite_primitive->AddInputQuantParam(quant_params);
lite_primitive->SetInputQuantParam(index, quant_params);
return RET_OK;
}

STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min,
std::shared_ptr<PrimitiveC> lite_primitive) {
if (!lite_primitive->GetOutputQuantParams().empty()) {
MS_LOG(DEBUG) << "output quant params not empty"; // multi-output op: like split
}
schema::QuantParamT quant_param;
quant_param.scale = scale;
quant_param.zeroPoint = zeropoint;
@@ -593,8 +587,9 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
quant_param.min = max_min->min;
quant_param.numBits = bit_num;
quant_param.narrowRange = false;
quant_param.inited = true;
std::vector<schema::QuantParamT> quant_params = {quant_param};
lite_primitive->AddOutputQuantParam(quant_params);
lite_primitive->SetOutputQuantParam(0, quant_params);
return RET_OK;
}

@@ -647,7 +642,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
auto bias_param = std::dynamic_pointer_cast<ParamValueLite>(bias_default_param);

auto active_weight_quant_params = primitive_c->GetInputQuantParams();
if (active_weight_quant_params.size() != 2) {
if (active_weight_quant_params.size() != 3) {
MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size();
return RET_ERROR;
}
@@ -714,7 +709,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
double filter_scale = std::abs(raw_datas[i]) / (activate_scale * quanted_bias_abs_limit);
active_weight_quant_params[1][i].scale = filter_scale;
active_weight_quant_params[1][i].zeroPoint = 0;
primitive_c->SetInputQuantParam(active_weight_quant_params);
primitive_c->SetInputQuantParams(active_weight_quant_params);
bias_scale_tmp = std::abs(raw_datas[i]) / quanted_bias_abs_limit;
quant_params[i].scale = bias_scale_tmp;
MS_LOG(DEBUG) << "new filter scale: " << filter_scale;
@@ -726,7 +721,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp);
quant_datas[i] = quant_data;
}
primitive_c->AddInputQuantParam(quant_params);
primitive_c->SetInputQuantParam(2, quant_params);
auto ret = memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed.";
@@ -834,22 +829,21 @@ STATUS PostTrainingQuantizer::QuantNode() {
<< " PrimitiveC is null";
continue;
}
if (!input_cnode_primitive_c->GetOutputQuantParams().empty()) {
for (auto &quant_param : input_cnode_primitive_c->GetOutputQuantParams()) {
primitive_c->AddInputQuantParam(quant_param);
}
if (input_cnode_primitive_c->IsOutputQuantParamsInited()) {
auto quant_param = input_cnode_primitive_c->GetOutputQuantParams().front();
primitive_c->SetInputQuantParam(i - 1, quant_param);
} else {
// do input quant
double scale = input_scale[cnode];
int32_t zp = input_zero_point[cnode];
DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c);
DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c, i - 1);
}
}
} else {
// do input quant
double scale = input_scale[cnode];
int32_t convInputzeropoint = input_zero_point[cnode];
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c);
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c, 0);
// do weight quant
auto weight = cnode->input(2);
bool perchannel = per_channel_;


+ 2
- 1
mindspore/lite/tools/converter/quantizer/post_training_quantizer.h View File

@@ -106,7 +106,8 @@ class PostTrainingQuantizer : public Quantizer {

STATUS QuantNode();

STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min,
std::shared_ptr<PrimitiveC> lite_primitive, const size_t &index);
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);

STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel);


+ 1
- 1
mindspore/lite/tools/converter/quantizer/quantize_util.cc View File

@@ -246,7 +246,7 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
return RET_ERROR;
}
quantParam->inited = true;
quantParam->inited = false;
quantParam->min = mMin;
quantParam->max = mMax;
quantParam->scale = 0.0f;


+ 5
- 4
mindspore/lite/tools/converter/quantizer/quantize_util.h View File

@@ -39,6 +39,7 @@ namespace mindspore {
namespace lite {
namespace quant {
static constexpr size_t UINT8_QUANTIZATION = 8;
static constexpr size_t WEIGHT_INDEX = 1;

/**
* 1. when op's weight size > mWeightSize just skip
@@ -225,16 +226,16 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
}
variance_dequant = std::sqrt(variance_dequant / one_filter_size);
variance_raw = std::sqrt(variance_raw / one_filter_size);
quant_param.var_corr = 1;
quant_param.varCorr = 1;
if (variance_raw != 0 && variance_dequant != 0) {
auto temp_var_corr = variance_raw / variance_dequant;
if (temp_var_corr > 0 && temp_var_corr < 10) {
quant_param.var_corr = temp_var_corr;
quant_param.varCorr = temp_var_corr;
} else {
MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr;
}
}
quant_param.mean_corr = average_raw - average_dequant * quant_param.var_corr;
quant_param.meanCorr = average_raw - average_dequant * quant_param.varCorr;
}
quant_params.emplace_back(quant_param);
}
@@ -282,7 +283,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
MS_LOG(ERROR) << "quant_params empty";
return RET_ERROR;
}
primitive_c->AddInputQuantParam(quant_params);
primitive_c->SetInputQuantParam(WEIGHT_INDEX, quant_params);
return RET_OK;
}



+ 2
- 6
mindspore/lite/tools/converter/quantizer/weight_quantizer.cc View File

@@ -101,8 +101,6 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
return RET_ERROR;
}

std::vector<schema::QuantParamT> quant_params;
primitive_c->AddInputQuantParam(quant_params);
auto status = RET_ERROR;
if (type_id == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
@@ -143,9 +141,9 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
ParameterPtr param_node = nullptr;
for (size_t i = 1; i < node->size(); i++) {
auto inputNode = node->input(i);
if (inputNode->isa<Parameter>() == true) {
if (inputNode->isa<Parameter>()) {
param_node = inputNode->cast<ParameterPtr>();
if ((param_node != nullptr) && (param_node->has_default() == true)) {
if ((param_node != nullptr) && param_node->has_default()) {
param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
if ((param_value == nullptr) || (param_value->tensor_size() == 0) ||
(param_value->tensor_addr() == nullptr) ||
@@ -169,8 +167,6 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
return RET_ERROR;
}

std::vector<schema::QuantParamT> quant_params;
primitive_c->AddInputQuantParam(quant_params);
auto status = RET_ERROR;
if (type_id == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);


+ 79
- 57
mindspore/lite/tools/optimizer/common/gllo_utils.cc View File

@@ -619,7 +619,7 @@ STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int3
}
return RET_OK;
}
template<typename T>
template <typename T>
static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
int32_t filterH, int32_t filterW) {
MS_ASSERT(tensor != nullptr);
@@ -628,7 +628,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType
MS_LOG(ERROR) << "Dim size invalid";
return RET_ERROR;
}
std::unique_ptr<T[]> buf(new(std::nothrow) T[count]);
std::unique_ptr<T[]> buf(new (std::nothrow) T[count]);
if (buf == nullptr) {
MS_LOG(ERROR) << "new buf failed";
return RET_ERROR;
@@ -653,18 +653,17 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType
p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
if (type == kCHWK2HWCK) {
p2Buff =
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kCHWK2KHWC) {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
}
*p2Buff = *p1Buff;
}
}
}
}
}
break;
} break;
case kKHWC2HWCK: {
for (int k = 0; k < filterK; ++k) {
for (int h = 0; h < filterH; ++h) {
@@ -677,8 +676,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType
}
}
}
}
break;
} break;
case kKCHW2HWCK:
case kKCHW2CKHW:
case kKCHW2KHWC:
@@ -690,24 +688,23 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType
p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
if (type == kKCHW2HWCK) {
p2Buff =
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kKCHW2KHWC) {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
} else if (type == kKCHW2CKHW) {
p2Buff =
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff =
buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
}
*p2Buff = *p1Buff;
}
}
}
}
}
break;
} break;
case kCKHW2HWCK:
case kCKHW2KHWC:
case kCKHW2HWKC: {
@@ -718,21 +715,20 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType
p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
if (type == kCKHW2HWCK) {
p2Buff =
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kCKHW2KHWC) {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
} else {
p2Buff =
buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
}
*p2Buff = *p1Buff;
}
}
}
}
}
break;
} break;
case kHWCK2KCHW:
case kHWCK2CKHW: {
for (int h = 0; h < filterH; ++h) {
@@ -742,18 +738,17 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType
p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
if (type == kHWCK2KCHW) {
p2Buff =
buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff =
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
}
*p2Buff = *p1Buff;
}
}
}
}
}
break;
} break;
case kHWKC2KCHW:
case kHWKC2CKHW: {
for (int h = 0; h < filterH; ++h) {
@@ -763,18 +758,17 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType
p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
if (type == kHWKC2KCHW) {
p2Buff =
buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff =
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
}
*p2Buff = *p1Buff;
}
}
}
}
}
break;
} break;
case kNHWC2HWCK:
case kNHWC2KCHW:
case kNHWC2CKHW: {
@@ -785,21 +779,20 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType
p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
if (type == kNHWC2HWCK) {
p2Buff =
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kNHWC2CKHW) {
p2Buff =
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff =
buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
}
*p2Buff = *p1Buff;
}
}
}
}
}
break;
} break;
case kKHWC2CHWK: {
for (int k = 0; k < filterK; ++k) {
for (int h = 0; h < filterH; ++h) {
@@ -812,8 +805,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType
}
}
}
}
break;
} break;
default: {
MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
return RET_ERROR;
@@ -828,7 +820,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType
return RET_OK;
}

template<typename T>
template <typename T>
static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type) {
MS_ASSERT(tensor != nullptr);
auto oriDims = tensor->tensor_shape();
@@ -882,6 +874,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kKCHW2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
@@ -894,6 +888,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kCKHW2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
@@ -906,18 +902,20 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kCHWK2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_KHWC:return RET_OK;
default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
<< EnumNameFormat(dst_format);
case schema::Format::Format_KHWC:
return RET_OK;
default:
MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format);
return RET_ERROR;
}
}
break;
} break;
case schema::Format::Format_HWCK: {
switch (src_format) {
case schema::Format::Format_KCHW:
@@ -927,6 +925,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kKCHW2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
@@ -939,6 +939,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kKHWC2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
@@ -951,6 +953,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kCKHW2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
@@ -963,21 +967,24 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kCHWK2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return lite::RET_ERROR;
}
break;
case schema::Format::Format_HWCK:return RET_OK;
default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
<< EnumNameFormat(dst_format);
case schema::Format::Format_HWCK:
return RET_OK;
default:
MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format);
return RET_ERROR;
}
}
break;
} break;
case schema::Format::Format_KCHW: {
switch (src_format) {
case schema::Format::Format_KCHW:return RET_OK;
case schema::Format::Format_KCHW:
return RET_OK;
case schema::Format::Format_HWCK:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kHWCK2KCHW);
@@ -985,6 +992,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kHWCK2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
@@ -997,6 +1006,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kHWCK2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
@@ -1009,6 +1020,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kKHWC2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
@@ -1021,6 +1034,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kCKHW2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
@@ -1033,17 +1048,18 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kCKHW2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
<< EnumNameFormat(dst_format);
default:
MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format);
return RET_ERROR;
}
}
break;
} break;
case schema::Format::Format_CKHW: {
switch (src_format) {
case schema::Format::Format_HWCK:
@@ -1053,6 +1069,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kHWCK2CKHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
@@ -1065,6 +1083,8 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kHWKC2CKHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
@@ -1077,20 +1097,22 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for
status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW);
} else if (data_type == kNumberTypeFloat16) {
status = TransFilterFormat<float16>(tensor, kKCHW2CKHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_CKHW:return RET_OK;
default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
<< EnumNameFormat(dst_format);
case schema::Format::Format_CKHW:
return RET_OK;
default:
MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format);
return RET_ERROR;
}
}
break;
default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
<< EnumNameFormat(dst_format);
} break;
default:
MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format);
return RET_ERROR;
}
if (status != RET_OK) {


+ 2
- 2
mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc View File

@@ -155,8 +155,8 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
rmatmul_quant_params.pop_back();
// no bias quantParams
rmatmul_quant_params.emplace_back(jointed_quant_params);
matmul_cvalue->SetInputQuantParam(rmatmul_quant_params);
matmul_cvalue->SetOutputQuantParam(fc_prim->GetOutputQuantParams());
matmul_cvalue->SetInputQuantParams(rmatmul_quant_params);
matmul_cvalue->SetOutputQuantParams(fc_prim->GetOutputQuantParams());
auto matmul_value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(matmul_cvalue));
std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input};



Loading…
Cancel
Save