| @@ -18,7 +18,7 @@ | |||
| #include "nnacl/int8/quant_dtype_cast_int8.h" | |||
| #include "nnacl/errorcode.h" | |||
| int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { | |||
| int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { | |||
| if (quant_values == NULL || real_values == NULL) { | |||
| return NNACL_PARAM_INVALID; | |||
| } | |||
| @@ -29,13 +29,13 @@ int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale | |||
| return NNACL_OK; | |||
| } | |||
| int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { | |||
| int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { | |||
| if (quant_values == NULL || real_values == NULL) { | |||
| return NNACL_PARAM_INVALID; | |||
| } | |||
| for (int i = 0; i < size; ++i) { | |||
| float temp = round(real_values[i] * 1.0 / scale + zp); | |||
| float temp = (float)round(real_values[i] * 1.0 / scale + zp); | |||
| if (temp > 127) { | |||
| quant_values[i] = 127; | |||
| } else if (temp < -128) { | |||
| @@ -47,7 +47,36 @@ int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float sca | |||
| return NNACL_OK; | |||
| } | |||
| int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size) { | |||
| int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { | |||
| if (quant_values == NULL || real_values == NULL) { | |||
| return NNACL_PARAM_INVALID; | |||
| } | |||
| for (int i = 0; i < size; ++i) { | |||
| real_values[i] = (float)((int)quant_values[i] - zp) * scale; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size) { | |||
| if (quant_values == NULL || real_values == NULL) { | |||
| return NNACL_PARAM_INVALID; | |||
| } | |||
| for (int i = 0; i < size; ++i) { | |||
| float temp = (float)round(real_values[i] * 1.0 / scale + zp); | |||
| if (temp > 255) { | |||
| quant_values[i] = 255; | |||
| } else if (temp < 0) { | |||
| quant_values[i] = 0; | |||
| } else { | |||
| quant_values[i] = (uint8_t)temp; | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size) { | |||
| if (quant_values == NULL || real_values == NULL) { | |||
| return NNACL_PARAM_INVALID; | |||
| } | |||
| @@ -65,7 +94,7 @@ int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size | |||
| return NNACL_OK; | |||
| } | |||
| int DoQuantizeToInt8FromUint8(uint8_t *real_values, int8_t *quant_values, int size) { | |||
| int UInt8ToInt8(const uint8_t *real_values, int8_t *quant_values, int size) { | |||
| if (quant_values == NULL || real_values == NULL) { | |||
| return NNACL_PARAM_INVALID; | |||
| } | |||
| @@ -28,10 +28,12 @@ typedef struct QuantDTypeCastParameter { | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); | |||
| int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size); | |||
| int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size); | |||
| int DoQuantizeToInt8FromUint8(uint8_t *real_values, int8_t *quant_values, int size); | |||
| int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); | |||
| int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size); | |||
| int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size); | |||
| int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size); | |||
| int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size); | |||
| int UInt8ToInt8(const uint8_t *real_values, int8_t *quant_values, int size); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -27,9 +27,6 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/common/option.h" | |||
| #include "include/errorcode.h" | |||
| #ifdef ENABLE_ARM64 | |||
| #include "nnacl/optimized_kernel.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -190,37 +187,6 @@ inline Option<bool> GenericParseValue(const std::string &value) { | |||
| return Option<bool>(None()); | |||
| } | |||
| using Float16CastFunc = void (*)(const void *, void *, int); | |||
| class Float16CastUtil { | |||
| public: | |||
| static Float16CastUtil *GetInstance() { | |||
| static Float16CastUtil float16_cast_util; | |||
| return &float16_cast_util; | |||
| } | |||
| private: | |||
| Float16CastUtil() { | |||
| #ifdef ENABLE_ARM64 | |||
| void *fp16_op_handler = Float16Module::GetInstance()->float16_op_handler_; | |||
| if (fp16_op_handler != nullptr) { | |||
| dlerror(); | |||
| *(reinterpret_cast<void **>(&float16_to_float32_func_)) = dlsym(fp16_op_handler, "Float16ToFloat32_fp16_handler"); | |||
| *(reinterpret_cast<void **>(&float32_to_float16_func_)) = dlsym(fp16_op_handler, "Float32ToFloat16_fp16_handler"); | |||
| auto dlopen_error = dlerror(); | |||
| if (dlopen_error != nullptr) { | |||
| MS_LOG(ERROR) << "load float16 cast func failed! " << dlopen_error << "."; | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| ~Float16CastUtil() = default; | |||
| public: | |||
| Float16CastFunc float16_to_float32_func_ = nullptr; | |||
| Float16CastFunc float32_to_float16_func_ = nullptr; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -66,6 +66,16 @@ int QuantDTypeCastCPUKernel::Init() { | |||
| MS_LOG(ERROR) << "param data type and tensor data type do not match."; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (param->srcT == kNumberTypeUInt8 && param->dstT == kNumberTypeFloat32) { | |||
| if (in_tensor->data_type() != kNumberTypeUInt8 || out_tensor->data_type() != kNumberTypeFloat32) { | |||
| MS_LOG(ERROR) << "param data type and tensor data type do not match."; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (param->srcT == kNumberTypeFloat32 && param->dstT == kNumberTypeUInt8) { | |||
| if (in_tensor->data_type() != kNumberTypeFloat32 || out_tensor->data_type() != kNumberTypeUInt8) { | |||
| MS_LOG(ERROR) << "param data type and tensor data type do not match."; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "param data type not supported:" | |||
| << " src: " << param->srcT << " dst: " << param->dstT; | |||
| @@ -106,20 +116,26 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { | |||
| ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, | |||
| quant_arg.zeroPoint, num_unit_thread); | |||
| } else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeInt8) { | |||
| ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, | |||
| quant_arg.zeroPoint, num_unit_thread); | |||
| ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, | |||
| quant_arg.zeroPoint, num_unit_thread); | |||
| } else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) { | |||
| ret = DoDequantizeInt8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); | |||
| ret = Int8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); | |||
| } else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeFloat32) { | |||
| ret = DoDequantizeUInt8ToFp32(uint8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, | |||
| quant_arg.zeroPoint, num_unit_thread); | |||
| } else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeUInt8) { | |||
| ret = DoQuantizeFp32ToUInt8(float32_ptr_ + thread_offset, uint8_ptr_ + thread_offset, quant_arg.scale, | |||
| quant_arg.zeroPoint, num_unit_thread); | |||
| } else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeInt8) { | |||
| ret = DoQuantizeToInt8FromUint8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread); | |||
| ret = UInt8ToInt8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread); | |||
| } else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeInt8) { | |||
| auto input_quant_arg = in_tensors_.front()->GetQuantParams().front(); | |||
| ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, num_unit_thread, | |||
| input_quant_arg.scale, input_quant_arg.zeroPoint); | |||
| if (ret) { | |||
| auto output_quant_arg = out_tensors_.front()->GetQuantParams().front(); | |||
| ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset, | |||
| output_quant_arg.scale, output_quant_arg.zeroPoint, num_unit_thread); | |||
| ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset, output_quant_arg.scale, | |||
| output_quant_arg.zeroPoint, num_unit_thread); | |||
| } | |||
| } | |||
| @@ -162,6 +178,14 @@ int QuantDTypeCastCPUKernel::Run() { | |||
| int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c()); | |||
| int8_out_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c()); | |||
| float32_ptr_ = new float[in_tensors_[0]->ElementsNum()]; | |||
| } else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeUInt8 && | |||
| out_tensors_[0]->data_type() == TypeId::kNumberTypeFloat32) { | |||
| uint8_ptr_ = reinterpret_cast<uint8_t *>(in_tensors_[0]->data_c()); | |||
| float32_ptr_ = reinterpret_cast<float *>(out_tensors_[0]->data_c()); | |||
| } else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeFloat32 && | |||
| out_tensors_[0]->data_type() == TypeId::kNumberTypeUInt8) { | |||
| float32_ptr_ = reinterpret_cast<float *>(in_tensors_[0]->data_c()); | |||
| uint8_ptr_ = reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()); | |||
| } | |||
| auto ret = ParallelLaunch(this->context_->thread_pool_, QuantDTypeCastRun, this, thread_n_num_); | |||
| @@ -178,7 +178,7 @@ int CpuFp16SubGraph::PreProcess() { | |||
| } | |||
| int CpuFp16SubGraph::PostProcess() { | |||
| auto fp16_to_fp32_cast_func = lite::Float16CastUtil::GetInstance()->float16_to_float32_func_; | |||
| auto fp16_to_fp32_cast_func = kernel::Float16CastUtil::GetInstance()->float16_to_float32_func_; | |||
| if (fp16_to_fp32_cast_func == nullptr) { | |||
| MS_LOG(ERROR) << "Can not find cast fp16 to fp32 func"; | |||
| return RET_ERROR; | |||
| @@ -22,8 +22,42 @@ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/executor.h" | |||
| #ifdef ENABLE_ARM64 | |||
| #include "nnacl/optimized_kernel.h" | |||
| #endif | |||
| namespace mindspore::kernel { | |||
| using Float16CastFunc = void (*)(const void *, void *, int); | |||
| class Float16CastUtil { | |||
| public: | |||
| static Float16CastUtil *GetInstance() { | |||
| static Float16CastUtil float16_cast_util; | |||
| return &float16_cast_util; | |||
| } | |||
| private: | |||
| Float16CastUtil() { | |||
| #ifdef ENABLE_ARM64 | |||
| void *fp16_op_handler = Float16Module::GetInstance()->float16_op_handler_; | |||
| if (fp16_op_handler != nullptr) { | |||
| dlerror(); | |||
| *(reinterpret_cast<void **>(&float16_to_float32_func_)) = dlsym(fp16_op_handler, "Float16ToFloat32_fp16_handler"); | |||
| *(reinterpret_cast<void **>(&float32_to_float16_func_)) = dlsym(fp16_op_handler, "Float32ToFloat16_fp16_handler"); | |||
| auto dlopen_error = dlerror(); | |||
| if (dlopen_error != nullptr) { | |||
| MS_LOG(ERROR) << "load float16 cast func failed! " << dlopen_error << "."; | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| ~Float16CastUtil() = default; | |||
| public: | |||
| Float16CastFunc float16_to_float32_func_ = nullptr; | |||
| Float16CastFunc float32_to_float16_func_ = nullptr; | |||
| }; | |||
| class SubGraphKernel : public LiteKernel { | |||
| public: | |||
| explicit SubGraphKernel(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| @@ -46,6 +46,10 @@ AnfTransform::~AnfTransform() = default; | |||
| FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const converter::Flags *config) { | |||
| MS_ASSERT(nullptr != old_graph); | |||
| if (config == nullptr) { | |||
| MS_LOG(ERROR) << "config shoud be specified"; | |||
| return nullptr; | |||
| } | |||
| // fusion const_fold | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false); | |||
| @@ -84,7 +84,7 @@ int Flags::Init(int argc, const char **argv) { | |||
| } | |||
| if (this->inputDataTypeIn == "FLOAT") { | |||
| this->inputDataType = TypeId::kNumberTypeFloat; | |||
| this->inputDataType = TypeId::kNumberTypeFloat32; | |||
| } else if (this->inputDataTypeIn == "INT8") { | |||
| this->inputDataType = TypeId::kNumberTypeInt8; | |||
| } else if (this->inputDataTypeIn == "UINT8") { | |||
| @@ -98,7 +98,7 @@ int Flags::Init(int argc, const char **argv) { | |||
| } | |||
| if (this->outputDataTypeIn == "FLOAT") { | |||
| this->outputDataType = TypeId::kNumberTypeFloat; | |||
| this->outputDataType = TypeId::kNumberTypeFloat32; | |||
| } else if (this->outputDataTypeIn == "INT8") { | |||
| this->outputDataType = TypeId::kNumberTypeInt8; | |||
| } else if (this->outputDataTypeIn == "UINT8") { | |||
| @@ -22,7 +22,6 @@ | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h" | |||
| @@ -36,7 +35,6 @@ | |||
| #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" | |||
| #include "tools/converter/quantizer/aware_quantizer.h" | |||
| using std::string; | |||
| namespace mindspore::lite { | |||
| @@ -120,15 +118,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| return status; | |||
| } | |||
| } | |||
| { | |||
| Optimizer inferQuantParamOtimizer; | |||
| inferQuantParamOtimizer.AddPass(new (std::nothrow) InferQuantParamPass()); | |||
| status = inferQuantParamOtimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run tensorQuantOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| { | |||
| Optimizer fusionOptimizer; | |||
| @@ -158,6 +147,8 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | |||
| dTypeTransPass->SetInputDataDType(ctx.inputDataType); | |||
| dTypeTransPass->SetOutputDataDType(ctx.outputDataType); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| quantNodeOptimizer.AddPass(dTypeTransPass); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| @@ -53,18 +53,18 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto &graphInIdxes = graph->inputIndex; | |||
| if (this->inputDataDType == TypeId::kNumberTypeInt8 || this->inputDataDType == TypeId::kTypeUnknown) { | |||
| if (this->inputDataDType == TypeId::kTypeUnknown) { | |||
| return RET_OK; | |||
| } | |||
| if (this->inputDataDType != TypeId::kNumberTypeFloat && this->inputDataDType != TypeId::kNumberTypeUInt8) { | |||
| if (this->inputDataDType != TypeId::kNumberTypeFloat32 && this->inputDataDType != TypeId::kNumberTypeUInt8 && | |||
| this->inputDataDType != TypeId::kNumberTypeInt8) { | |||
| MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType; | |||
| return RET_ERROR; | |||
| } | |||
| // insert fp2int8 node | |||
| for (auto graphInIdx : graphInIdxes) { | |||
| MS_ASSERT(graphInIdx < graph->allTensors.size()); | |||
| auto &tensor = graph->allTensors.at(graphInIdx); | |||
| if (tensor->dataType != kNumberTypeInt8 || tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||
| continue; | |||
| } | |||
| @@ -75,10 +75,9 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| STATUS status = RET_OK; | |||
| // insert dtype cast node between input tensor and input node | |||
| if (inputDataDType == TypeId::kNumberTypeFloat) { | |||
| iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kFP32ToInt8, &status); | |||
| } else { | |||
| iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kUInt8ToInt8, &status); | |||
| if (this->inputDataDType != tensor->dataType) { | |||
| iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, this->inputDataDType, tensor->dataType, | |||
| &status); | |||
| } | |||
| if (status != RET_OK) { | |||
| @@ -94,10 +93,11 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (outputDataDType == TypeId::kNumberTypeInt8 || outputDataDType == TypeId::kTypeUnknown) { | |||
| if (outputDataDType == TypeId::kTypeUnknown) { | |||
| return RET_OK; | |||
| } | |||
| if (this->outputDataDType != TypeId::kNumberTypeFloat && this->outputDataDType != TypeId::kNumberTypeUInt8) { | |||
| if (this->outputDataDType != TypeId::kNumberTypeFloat32 && this->outputDataDType != TypeId::kNumberTypeUInt8 && | |||
| this->outputDataDType != TypeId::kNumberTypeInt8) { | |||
| MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -105,7 +105,7 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| for (auto graphOutIdx : graphOutIdxes) { | |||
| MS_ASSERT(graphOutIdx < graph->allTensors.size()); | |||
| auto &tensor = graph->allTensors.at(graphOutIdx); | |||
| if (tensor->dataType != kNumberTypeInt8 || tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||
| continue; | |||
| } | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| @@ -115,10 +115,9 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { | |||
| // insert transNode | |||
| STATUS status = RET_OK; | |||
| if (inputDataDType == TypeId::kNumberTypeFloat) { | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status); | |||
| } else { | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToUInt8, &status); | |||
| if (this->outputDataDType != tensor->dataType) { | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, this->outputDataDType, | |||
| &status); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; | |||
| @@ -152,7 +151,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||
| if (preTensor->dataType != TypeId::kNumberTypeInt8) { | |||
| continue; | |||
| } | |||
| iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); | |||
| iter = InsertDTypeTransNode(graph, iter, kBefore, i, kNumberTypeInt8, kNumberTypeFloat32, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed"; | |||
| return RET_ERROR; | |||
| @@ -165,7 +164,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||
| if (postTensor->dataType != TypeId::kNumberTypeInt8) { | |||
| continue; | |||
| } | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status); | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, i, kNumberTypeFloat32, kNumberTypeInt8, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; | |||
| return RET_ERROR; | |||
| @@ -178,7 +177,8 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||
| } | |||
| NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, | |||
| size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) { | |||
| size_t inoutIdx, int32_t inputDataType, int32_t outputDataType, | |||
| STATUS *errorCode) { | |||
| MS_ASSERT((*existNodeIter) != nullptr); | |||
| auto existNodeName = (*existNodeIter)->name; | |||
| std::string tileName; | |||
| @@ -203,21 +203,15 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte | |||
| transNode->primitive->value.value = quantDTypeCastParam; | |||
| transNode->primitive->value.type = PrimitiveType_QuantDTypeCast; | |||
| transNode->quantType = QuantType_AwareTraining; | |||
| if (nodeType == kInt8ToFP32) { | |||
| quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8; | |||
| quantDTypeCastParam->dstT = TypeId::kNumberTypeFloat32; | |||
| quantDTypeCastParam->srcT = inputDataType; | |||
| quantDTypeCastParam->dstT = outputDataType; | |||
| if (inputDataType == TypeId::kNumberTypeInt8 && outputDataType == TypeId::kNumberTypeFloat32) { | |||
| transNode->name = "int8toft32_" + tileName + std::to_string(id++); | |||
| } else if (nodeType == kFP32ToInt8) { | |||
| quantDTypeCastParam->srcT = TypeId::kNumberTypeFloat32; | |||
| quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8; | |||
| } else if (inputDataType == TypeId::kNumberTypeFloat32 && outputDataType == TypeId::kNumberTypeInt8) { | |||
| transNode->name = "ft32toint8_" + tileName + std::to_string(id++); | |||
| } else if (nodeType == kUInt8ToInt8) { | |||
| quantDTypeCastParam->srcT = TypeId::kNumberTypeUInt8; | |||
| quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8; | |||
| } else if (inputDataType == TypeId::kNumberTypeUInt8 && outputDataType == TypeId::kNumberTypeInt8) { | |||
| transNode->name = "uint8toint8_" + tileName + std::to_string(id++); | |||
| } else if (nodeType == kInt8ToUInt8) { | |||
| quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8; | |||
| quantDTypeCastParam->dstT = TypeId::kNumberTypeUInt8; | |||
| } else if (inputDataType == TypeId::kNumberTypeInt8 && outputDataType == TypeId::kNumberTypeUInt8) { | |||
| transNode->name = "int8touint8_" + tileName + std::to_string(id++); | |||
| } | |||
| transNode->primitive->value.value = quantDTypeCastParam; | |||
| @@ -47,7 +47,7 @@ class DTypeTransPass : public GraphPass { | |||
| STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph); | |||
| NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, | |||
| DTypeTransNodeType nodeType, STATUS *errorCode); | |||
| int32_t inputDataType, int32_t outputDataType, STATUS *errorCode); | |||
| private: | |||
| size_t id; | |||
| @@ -87,6 +87,7 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | |||
| } | |||
| } else { // perchannel | |||
| MS_LOG(ERROR) << "perchannel doquant is not supported yet"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| @@ -1,161 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/quantizer/aware_quantizer.h" | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "securec/include/securec.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/quantizer/calc_quant_param.h" | |||
| #include "src/common/log_adapter.h" | |||
| using std::string; | |||
| using std::vector; | |||
| namespace mindspore::lite::quant { | |||
| AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType) : FbQuantizer(graph) {} | |||
| STATUS AwareQuantizer::RemoveFakeQuant() { return RET_OK; } | |||
| STATUS AwareQuantizer::GenerateQuantParam() { | |||
| auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| MS_ASSERT(node != nullptr); | |||
| if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax || | |||
| GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { | |||
| MS_ASSERT(false); | |||
| } | |||
| auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); | |||
| if (quantParamCalcer == nullptr) { | |||
| MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str() | |||
| << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; | |||
| node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE); | |||
| } else { | |||
| auto status = quantParamCalcer->Calc(graph, *node); | |||
| if (status != RET_OK) { | |||
| MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); | |||
| node->quantType = schema::QuantType_QUANT_NONE; | |||
| } else { | |||
| node->quantType = schema::QuantType_AwareTraining; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::DoQuantize() { | |||
| for (auto &tensor : graph->allTensors) { | |||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) { | |||
| continue; | |||
| } | |||
| if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && | |||
| tensor->dataType != TypeId::kNumberTypeUInt8) { | |||
| continue; | |||
| } | |||
| // perlayer | |||
| if (tensor->quantParams.size() == 1) { | |||
| auto &quantParam = tensor->quantParams.front(); | |||
| size_t wShapeSize = GetShapeSize(*(tensor.get())); | |||
| void *oriWeightData = tensor->data.data(); | |||
| if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { | |||
| vector<int8_t> qDatas(wShapeSize); | |||
| auto weightQauntParam = GetTensorQuantParam(tensor); | |||
| if (tensor->dataType == TypeId::kNumberTypeFloat || | |||
| tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant | |||
| auto *weightData = static_cast<float *>(oriWeightData); | |||
| for (size_t j = 0; j < wShapeSize; j++) { | |||
| qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); | |||
| } | |||
| } else { // tflite awareing quant | |||
| auto *weightData = static_cast<uint8_t *>(oriWeightData); | |||
| for (size_t j = 0; j < wShapeSize; j++) { | |||
| qDatas[j] = (int32_t)weightData[j] - 128; | |||
| } | |||
| weightQauntParam->zeroPoint -= 128; | |||
| tensor->quantParams.clear(); | |||
| tensor->quantParams.emplace_back(weightQauntParam.release()); | |||
| } | |||
| tensor->data.clear(); | |||
| tensor->data.resize(wShapeSize * sizeof(int8_t)); | |||
| auto ret = | |||
| memcpy_s(tensor->data.data(), wShapeSize * sizeof(int8_t), qDatas.data(), wShapeSize * sizeof(int8_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { | |||
| // quant bias data | |||
| auto bShapeSize = GetShapeSize(*(tensor.get())); | |||
| std::unique_ptr<int32_t[]> qDatas(new (std::nothrow) int32_t[bShapeSize]); | |||
| if (qDatas == nullptr) { | |||
| MS_LOG(ERROR) << "new qDatas failed"; | |||
| return RET_ERROR; | |||
| } | |||
| void *biasData = tensor->data.data(); | |||
| auto *rawDatas = static_cast<float *>(biasData); | |||
| for (size_t i = 0; i < bShapeSize; ++i) { | |||
| qDatas[i] = (int32_t)std::round(rawDatas[i] / quantParam->scale); | |||
| } | |||
| tensor->dataType = TypeId::kNumberTypeInt32; | |||
| tensor->data.clear(); | |||
| tensor->data.resize(bShapeSize * sizeof(int32_t)); | |||
| auto ret = | |||
| memcpy_s(tensor->data.data(), bShapeSize * sizeof(int32_t), qDatas.get(), bShapeSize * sizeof(int32_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } else { // pertensor | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::DetermineNodeQuantType() { | |||
| MS_ASSERT(graph != nullptr); | |||
| for (auto &node : graph->nodes) { | |||
| MS_ASSERT(node != nullptr); | |||
| bool canQuant = true; | |||
| for (auto &outTensorIdx : node->outputIndex) { | |||
| MS_ASSERT(graph->allTensors.size() > outTensorIdx); | |||
| auto &outTensor = graph->allTensors.at(outTensorIdx); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || | |||
| !outTensor->quantParams.front()->inited) { | |||
| canQuant = false; | |||
| break; | |||
| } | |||
| } | |||
| if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*node))) { | |||
| node->quantType = schema::QuantType_AwareTraining; | |||
| } else { | |||
| node->quantType = schema::QuantType_QUANT_NONE; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::lite::quant | |||
| @@ -1,44 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_AWARE_QUANTIZER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_AWARE_QUANTIZER_H | |||
| #include <array> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "tools/converter/quantizer/quantizer.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "include/errorcode.h" | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| namespace mindspore::lite::quant { | |||
| class AwareQuantizer : public FbQuantizer { | |||
| public: | |||
| AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType); | |||
| ~AwareQuantizer() override = default; | |||
| STATUS RemoveFakeQuant() override; | |||
| STATUS GenerateQuantParam() override; | |||
| STATUS DetermineNodeQuantType() override; | |||
| STATUS DoQuantize() override; // override; | |||
| }; | |||
| } // namespace mindspore::lite::quant | |||
| #endif | |||
| @@ -86,7 +86,7 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | |||
| if (!tensor->data.empty() && !IsContain(graph->inputIndex, node.inputIndex.at(i))) { | |||
| auto status = ComputeConstQuantParam((*tensor), quantParam.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(INFO) << "ComputeConstQuantParam failed: " << status; | |||
| MS_LOG(DEBUG) << "ComputeConstQuantParam failed: " << status; | |||
| return status; | |||
| } | |||
| tensor->quantParams.front() = std::move(quantParam); | |||