|
|
@@ -28,7 +28,7 @@ |
|
|
#include "schema/inner/model_generated.h" |
|
|
#include "schema/inner/model_generated.h" |
|
|
#include "src/ir/tensor.h" |
|
|
#include "src/ir/tensor.h" |
|
|
#include "src/common/anf_exporter/anf_exporter.h" |
|
|
#include "src/common/anf_exporter/anf_exporter.h" |
|
|
#include "tools/converter/quantizer/post_training.h" |
|
|
|
|
|
|
|
|
#include "tools/converter/quantizer/post_training_quantizer.h" |
|
|
#include "tools/converter/quantizer/quantize_util.h" |
|
|
#include "tools/converter/quantizer/quantize_util.h" |
|
|
#include "src/common/common.h" |
|
|
#include "src/common/common.h" |
|
|
#include "utils/log_adapter.h" |
|
|
#include "utils/log_adapter.h" |
|
|
@@ -54,7 +54,10 @@ struct DivergInfo { |
|
|
size_t bit_num; |
|
|
size_t bit_num; |
|
|
int quant_max = 255; |
|
|
int quant_max = 255; |
|
|
int quant_min = 0; |
|
|
int quant_min = 0; |
|
|
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min) { |
|
|
|
|
|
|
|
|
std::string method_x = kMethodKL; |
|
|
|
|
|
|
|
|
|
|
|
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min, const std::string &method_x) { |
|
|
|
|
|
this->method_x = method_x; |
|
|
this->cnode = cnode; |
|
|
this->cnode = cnode; |
|
|
this->bin_num = bins; |
|
|
this->bin_num = bins; |
|
|
this->bit_num = bits; |
|
|
this->bit_num = bits; |
|
|
@@ -99,6 +102,12 @@ struct DivergInfo { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
STATUS ComputeThreshold() { |
|
|
STATUS ComputeThreshold() { |
|
|
|
|
|
if (method_x == kMethodMaxMin) { |
|
|
|
|
|
this->best_T = std::max(fabs(this->max), fabs(this->min)); |
|
|
|
|
|
MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T; |
|
|
|
|
|
return RET_OK; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
constexpr int quant_bint_nums = 128; |
|
|
constexpr int quant_bint_nums = 128; |
|
|
int threshold = quant_bint_nums; |
|
|
int threshold = quant_bint_nums; |
|
|
float min_kl = FLT_MAX; |
|
|
float min_kl = FLT_MAX; |
|
|
@@ -200,46 +209,32 @@ struct DivergInfo { |
|
|
threshold = i; |
|
|
threshold = i; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
MS_LOG(DEBUG) << "Best threshold bin index: " << threshold; |
|
|
|
|
|
this->best_T = (static_cast<float>(threshold) + 0.5f) * this->interval; |
|
|
this->best_T = (static_cast<float>(threshold) + 0.5f) * this->interval; |
|
|
|
|
|
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold |
|
|
|
|
|
<< " T: " << best_T |
|
|
|
|
|
<< " max: " << std::max(fabs(this->max), fabs(this->min)); |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::pair<CNodePtr, float> GetScale() { |
|
|
std::pair<CNodePtr, float> GetScale() { |
|
|
float max_value = this->best_T; |
|
|
float max_value = this->best_T; |
|
|
float min_value = -max_value; |
|
|
float min_value = -max_value; |
|
|
|
|
|
|
|
|
MS_ASSERT(quant_max - quant_min != 0); |
|
|
MS_ASSERT(quant_max - quant_min != 0); |
|
|
double scale = (max_value - min_value) / (quant_max - quant_min); |
|
|
|
|
|
|
|
|
float scale = (max_value - min_value) / (quant_max - quant_min); |
|
|
MS_ASSERT(scale != 0); |
|
|
MS_ASSERT(scale != 0); |
|
|
return std::make_pair(this->cnode, scale); |
|
|
return std::make_pair(this->cnode, scale); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::pair<CNodePtr, int32_t> GetZeropoint() { |
|
|
std::pair<CNodePtr, int32_t> GetZeropoint() { |
|
|
float max_value = this->best_T; |
|
|
|
|
|
float min_value = -max_value; |
|
|
|
|
|
MS_ASSERT(quant_max - quant_min != 0); |
|
|
|
|
|
float scale = (max_value - min_value) / (quant_max - quant_min); |
|
|
|
|
|
|
|
|
|
|
|
auto quant_min_float = static_cast<float>(quant_min); |
|
|
|
|
|
auto quant_max_float = static_cast<float>(quant_max); |
|
|
|
|
|
MS_ASSERT(scale != 0); |
|
|
|
|
|
const float zero_point_from_min = quant_min_float - min_value / scale; |
|
|
|
|
|
// const float zero_point_from_max = quant_max_float - max_value / scale; |
|
|
|
|
|
int zero_point; |
|
|
|
|
|
if (zero_point_from_min < quant_min_float) { |
|
|
|
|
|
zero_point = quant_min; |
|
|
|
|
|
} else if (zero_point_from_min > quant_max_float) { |
|
|
|
|
|
zero_point = quant_max; |
|
|
|
|
|
} else { |
|
|
|
|
|
zero_point = static_cast<int>(std::round(zero_point_from_min)); |
|
|
|
|
|
} |
|
|
|
|
|
MS_LOG(DEBUG) << "zero point:" << zero_point; |
|
|
|
|
|
|
|
|
int zero_point = 0; |
|
|
if (quant_min == 0 && quant_max == 255) { |
|
|
if (quant_min == 0 && quant_max == 255) { |
|
|
zero_point = 128; |
|
|
zero_point = 128; |
|
|
} else if (quant_min == -128 && quant_max == 127) { |
|
|
} else if (quant_min == -128 && quant_max == 127) { |
|
|
zero_point = 0; |
|
|
zero_point = 0; |
|
|
|
|
|
} else { |
|
|
|
|
|
MS_LOG(ERROR) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
return std::make_pair(this->cnode, zero_point); |
|
|
return std::make_pair(this->cnode, zero_point); |
|
|
} |
|
|
} |
|
|
}; |
|
|
}; |
|
|
@@ -356,9 +351,9 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) { |
|
|
} |
|
|
} |
|
|
string node_name = node->fullname_with_scope(); |
|
|
string node_name = node->fullname_with_scope(); |
|
|
std::unique_ptr<DivergInfo> input_diverg = |
|
|
std::unique_ptr<DivergInfo> input_diverg = |
|
|
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_)); |
|
|
|
|
|
|
|
|
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_, config_param_.method_x)); |
|
|
std::unique_ptr<DivergInfo> output_diverg = |
|
|
std::unique_ptr<DivergInfo> output_diverg = |
|
|
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_)); |
|
|
|
|
|
|
|
|
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_, config_param_.method_x)); |
|
|
|
|
|
|
|
|
input_diverg_info_.insert(std::make_pair(string(node_name), std::move(input_diverg))); |
|
|
input_diverg_info_.insert(std::make_pair(string(node_name), std::move(input_diverg))); |
|
|
output_diverg_info_.insert(std::make_pair(string(node_name), std::move(output_diverg))); |
|
|
output_diverg_info_.insert(std::make_pair(string(node_name), std::move(output_diverg))); |
|
|
@@ -383,13 +378,13 @@ STATUS Calibrator::GenerateInputData(const int index, mindspore::tensor::MSTenso |
|
|
MS_LOG(INFO) << "read image: " << path; |
|
|
MS_LOG(INFO) << "read image: " << path; |
|
|
size_t size; |
|
|
size_t size; |
|
|
char *binBuf = ReadFile(path.c_str(), &size); |
|
|
char *binBuf = ReadFile(path.c_str(), &size); |
|
|
|
|
|
|
|
|
// auto *rawinputDatas = reinterpret_cast<const float *>(binBuf); |
|
|
|
|
|
// auto mobilenet_input = const_cast<float *>(rawinputDatas); |
|
|
|
|
|
auto data = tensor->MutableData(); |
|
|
auto data = tensor->MutableData(); |
|
|
|
|
|
if (size != tensor->Size()) { |
|
|
|
|
|
MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size |
|
|
|
|
|
<< " input tensor size: " << tensor->Size(); |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
memcpy(data, binBuf, size); |
|
|
memcpy(data, binBuf, size); |
|
|
|
|
|
|
|
|
// tensor->SetData(mobilenet_input); |
|
|
|
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -457,13 +452,20 @@ STATUS Calibrator::ReadConfig() { |
|
|
config_param_.batch_count = std::stoul(value); |
|
|
config_param_.batch_count = std::stoul(value); |
|
|
} else if (key == "thread_num") { |
|
|
} else if (key == "thread_num") { |
|
|
config_param_.thread_num = std::stoul(value); |
|
|
config_param_.thread_num = std::stoul(value); |
|
|
|
|
|
} else if (key == "method_x") { |
|
|
|
|
|
if (value != kMethodKL && value != kMethodMaxMin) { |
|
|
|
|
|
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value."; |
|
|
|
|
|
} else { |
|
|
|
|
|
config_param_.method_x = value; |
|
|
|
|
|
} |
|
|
} else { |
|
|
} else { |
|
|
MS_LOG(WARNING) << "unsupported parameter"; |
|
|
MS_LOG(WARNING) << "unsupported parameter"; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
MS_LOG(INFO) << "image_path: " << config_param_.image_path << " " |
|
|
|
|
|
<< "batch_count: " << config_param_.batch_count << " " |
|
|
|
|
|
<< "thread_num: " << config_param_.thread_num; |
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "image_path: " << config_param_.image_path << " " |
|
|
|
|
|
<< "batch_count: " << config_param_.batch_count << " " |
|
|
|
|
|
<< "mothod_x: " << config_param_.method_x << " " |
|
|
|
|
|
<< "thread_num: " << config_param_.thread_num; |
|
|
|
|
|
|
|
|
delete[] resolved_path; |
|
|
delete[] resolved_path; |
|
|
fs.close(); |
|
|
fs.close(); |
|
|
@@ -615,7 +617,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr<PrimitiveTValue> input |
|
|
quant_datas[i] = quant_data; |
|
|
quant_datas[i] = quant_data; |
|
|
} |
|
|
} |
|
|
auto ret = |
|
|
auto ret = |
|
|
memcpy_s(bias_param->tensor_addr(), shape_size * sizeof(int32_t), quant_datas, shape_size * sizeof(int32_t)); |
|
|
|
|
|
|
|
|
memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t)); |
|
|
if (ret != EOK) { |
|
|
if (ret != EOK) { |
|
|
MS_LOG(ERROR) << "memcpy_s failed."; |
|
|
MS_LOG(ERROR) << "memcpy_s failed."; |
|
|
delete[] quant_datas; |
|
|
delete[] quant_datas; |
|
|
@@ -805,14 +807,6 @@ STATUS PostTrainingQuantizer::DoInference() { |
|
|
MS_LOG(ERROR) << "generate input data from images failed!"; |
|
|
MS_LOG(ERROR) << "generate input data from images failed!"; |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
/** |
|
|
|
|
|
* struct CallBackParam { |
|
|
|
|
|
std::string nodeType; |
|
|
|
|
|
NODE_ID nodeName; |
|
|
|
|
|
std::unordered_set<NODE_ID> depends; |
|
|
|
|
|
int opExecResult; |
|
|
|
|
|
}; |
|
|
|
|
|
*/ |
|
|
|
|
|
mindspore::session::KernelCallBack beforeCallBack = |
|
|
mindspore::session::KernelCallBack beforeCallBack = |
|
|
[&](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs, |
|
|
[&](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs, |
|
|
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs, |
|
|
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs, |
|
|
@@ -916,9 +910,26 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) { |
|
|
MS_LOG(ERROR) << "do pre process failed!"; |
|
|
MS_LOG(ERROR) << "do pre process failed!"; |
|
|
return status; |
|
|
return status; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// anf -- fb |
|
|
|
|
|
auto meta_graph = Export(funcGraph); |
|
|
|
|
|
if (meta_graph == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "Export to meta_graph return nullptr"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// transform |
|
|
|
|
|
GraphDefTransform transform; |
|
|
|
|
|
transform.SetGraphDef(meta_graph); |
|
|
|
|
|
flags.quantType = schema::QuantType_QUANT_NONE; |
|
|
|
|
|
status = transform.Transform(flags); |
|
|
|
|
|
if (status != RET_OK) { |
|
|
|
|
|
MS_LOG(ERROR) << "FBTransform model failed " << status; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
MS_LOG(INFO) << "start create session"; |
|
|
MS_LOG(INFO) << "start create session"; |
|
|
flatbuffers::FlatBufferBuilder builder(1024); |
|
|
flatbuffers::FlatBufferBuilder builder(1024); |
|
|
auto offset = schema::MetaGraph::Pack(builder, Export(funcGraph)); |
|
|
|
|
|
|
|
|
auto offset = schema::MetaGraph::Pack(builder, meta_graph); |
|
|
builder.Finish(offset); |
|
|
builder.Finish(offset); |
|
|
size_t size = builder.GetSize(); |
|
|
size_t size = builder.GetSize(); |
|
|
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer()); |
|
|
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer()); |