From 0ac7c7884d31a7ee30440164e6ddb4ef7215b0a8 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Sat, 19 Sep 2020 14:46:01 +0800 Subject: [PATCH] fuse benchmark and timeprofiler --- mindspore/lite/src/model.cc | 6 +- mindspore/lite/tools/benchmark/benchmark.cc | 173 +++++++++++++++--- mindspore/lite/tools/benchmark/benchmark.h | 19 +- mindspore/lite/tools/common/flag_parser.cc | 4 +- .../converter/quantizer/aware_quantizer.cc | 11 +- 5 files changed, 169 insertions(+), 44 deletions(-) diff --git a/mindspore/lite/src/model.cc b/mindspore/lite/src/model.cc index 8f5a896820..b8c2623fb1 100644 --- a/mindspore/lite/src/model.cc +++ b/mindspore/lite/src/model.cc @@ -78,7 +78,7 @@ Model *Model::Import(const char *model_buf, size_t size) { MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; return nullptr; } - Model *model = new (std::nothrow) Model(); + auto *model = new (std::nothrow) Model(); if (model == nullptr) { MS_LOG(ERROR) << "new model fail!"; return nullptr; @@ -86,14 +86,14 @@ Model *Model::Import(const char *model_buf, size_t size) { model->buf = reinterpret_cast(malloc(size)); if (model->buf == nullptr) { MS_LOG(ERROR) << "new inner model buf fail!"; - delete(model); + delete (model); return nullptr; } memcpy(model->buf, model_buf, size); auto meta_graph = schema::GetMetaGraph(model->buf); if (meta_graph == nullptr) { MS_LOG(ERROR) << "meta_graph is nullptr!"; - delete(model); + delete (model); return nullptr; } diff --git a/mindspore/lite/tools/benchmark/benchmark.cc b/mindspore/lite/tools/benchmark/benchmark.cc index b5cf34da5b..1046e150a3 100644 --- a/mindspore/lite/tools/benchmark/benchmark.cc +++ b/mindspore/lite/tools/benchmark/benchmark.cc @@ -265,7 +265,8 @@ int Benchmark::MarkPerformance() { for (int i = 0; i < _flags->loopCount; i++) { session->BindThread(true); auto start = GetTimeUs(); - auto status = session->RunGraph(); + auto status = + _flags->runTimeProfiler ? session->RunGraph(before_call_back_, after_call_back_) : session->RunGraph(); if (status != 0) { MS_LOG(ERROR) << "Inference error " << status; std::cerr << "Inference error " << status; @@ -280,6 +281,14 @@ int Benchmark::MarkPerformance() { session->BindThread(false); } + + if (_flags->runTimeProfiler) { + const std::vector per_op_name = {"opName", "avg(ms)", "percent", "calledTimes", "opTotalTime"}; + const std::vector per_op_type = {"opType", "avg(ms)", "percent", "calledTimes", "opTotalTime"}; + PrintResult(per_op_name, op_times_by_name_); + PrintResult(per_op_type, op_times_by_type_); + } + if (_flags->loopCount > 0) { timeAvg /= _flags->loopCount; MS_LOG(INFO) << "Model = " << _flags->modelPath.substr(_flags->modelPath.find_last_of(DELIM_SLASH) + 1).c_str() @@ -295,25 +304,25 @@ int Benchmark::MarkPerformance() { int Benchmark::MarkAccuracy() { MS_LOG(INFO) << "MarkAccuracy"; std::cout << "MarkAccuracy" << std::endl; - for (size_t i = 0; i < msInputs.size(); i++) { - switch (msInputs.at(i)->data_type()) { + for (auto &msInput : msInputs) { + switch (msInput->data_type()) { case TypeId::kNumberTypeFloat: - PrintInputData(msInputs.at(i)); + PrintInputData(msInput); break; case TypeId::kNumberTypeFloat32: - PrintInputData(msInputs.at(i)); + PrintInputData(msInput); break; case TypeId::kNumberTypeInt8: - PrintInputData(msInputs.at(i)); + PrintInputData(msInput); break; case TypeId::kNumberTypeUInt8: - PrintInputData(msInputs.at(i)); + PrintInputData(msInput); break; case TypeId::kNumberTypeInt32: - PrintInputData(msInputs.at(i)); + PrintInputData(msInput); break; default: - MS_LOG(ERROR) << "Datatype " << msInputs.at(i)->data_type() << " is not supported."; + MS_LOG(ERROR) << "Datatype " << msInput->data_type() << " is not supported."; return RET_ERROR; } } @@ -340,7 +349,7 @@ int Benchmark::MarkAccuracy() { return RET_OK; } -int Benchmark::RunBenchmark(const std::string &deviceType) { +int Benchmark::RunBenchmark() { auto startPrepareTime = GetTimeUs(); // Load graph std::string modelName = _flags->modelPath.substr(_flags->modelPath.find_last_of(DELIM_SLASH) + 1); @@ -355,13 +364,12 @@ int Benchmark::RunBenchmark(const std::string &deviceType) { return RET_ERROR; } auto model = lite::Model::Import(graphBuf, size); + delete[](graphBuf); if (model == nullptr) { MS_LOG(ERROR) << "Import model file failed while running " << modelName.c_str(); std::cerr << "Import model file failed while running " << modelName.c_str() << std::endl; - delete[](graphBuf); return RET_ERROR; } - delete[](graphBuf); auto context = new (std::nothrow) lite::Context; if (context == nullptr) { MS_LOG(ERROR) << "New context failed while running " << modelName.c_str(); @@ -372,8 +380,6 @@ int Benchmark::RunBenchmark(const std::string &deviceType) { context->device_type_ = lite::DT_CPU; } else if (_flags->device == "GPU") { context->device_type_ = lite::DT_GPU; - } else { - context->device_type_ = lite::DT_NPU; } if (_flags->cpuBindMode == -1) { @@ -403,13 +409,8 @@ int Benchmark::RunBenchmark(const std::string &deviceType) { model->Free(); msInputs = session->GetInputs(); auto endPrepareTime = GetTimeUs(); -#if defined(__arm__) MS_LOG(INFO) << "PrepareTime = " << (endPrepareTime - startPrepareTime) / 1000 << " ms"; - printf("PrepareTime = %lld ms, ", (endPrepareTime - startPrepareTime) / 1000); -#else - MS_LOG(INFO) << "PrepareTime = " << (endPrepareTime - startPrepareTime) / 1000 << " ms "; - printf("PrepareTime = %ld ms, ", (endPrepareTime - startPrepareTime) / 1000); -#endif + std::cout << "PrepareTime = " << (endPrepareTime - startPrepareTime) / 1000 << " ms" << std::endl; // Load input MS_LOG(INFO) << "start generate input data"; @@ -481,6 +482,54 @@ void BenchmarkFlags::InitResizeDimsList() { } } +int Benchmark::InitCallbackParameter() { + // before callback + before_call_back_ = [&](const std::vector &before_inputs, + const std::vector &before_outputs, + const session::CallBackParam &callParam) { + if (before_inputs.empty()) { + MS_LOG(INFO) << "The num of beforeInputs is empty"; + } + if (before_outputs.empty()) { + MS_LOG(INFO) << "The num of beforeOutputs is empty"; + } + if (op_times_by_type_.find(callParam.type_callback_param) == op_times_by_type_.end()) { + op_times_by_type_.insert(std::make_pair(callParam.type_callback_param, std::make_pair(0, 0.0f))); + } + if (op_times_by_name_.find(callParam.name_callback_param) == op_times_by_name_.end()) { + op_times_by_name_.insert(std::make_pair(callParam.name_callback_param, std::make_pair(0, 0.0f))); + } + + op_call_times_total_++; + op_begin_ = GetTimeUs(); + return true; + }; + + // after callback + after_call_back_ = [&](const std::vector &after_inputs, + const std::vector &after_outputs, + const session::CallBackParam &call_param) { + uint64_t opEnd = GetTimeUs(); + + if (after_inputs.empty()) { + MS_LOG(INFO) << "The num of after inputs is empty"; + } + if (after_outputs.empty()) { + MS_LOG(INFO) << "The num of after outputs is empty"; + } + + float cost = static_cast(opEnd - op_begin_) / 1000.0f; + op_cost_total_ += cost; + op_times_by_type_[call_param.type_callback_param].first++; + op_times_by_type_[call_param.type_callback_param].second += cost; + op_times_by_name_[call_param.name_callback_param].first++; + op_times_by_name_[call_param.name_callback_param].second += cost; + return true; + }; + + return RET_OK; +} + int Benchmark::Init() { if (this->_flags == nullptr) { return 1; @@ -550,6 +599,79 @@ int Benchmark::Init() { return RET_ERROR; } + if (_flags->runTimeProfiler) { + auto status = InitCallbackParameter(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Init callback Parameter failed."; + std::cerr << "Init callback Parameter failed." << std::endl; + return RET_ERROR; + } + } + + return RET_OK; +} + +int Benchmark::PrintResult(const std::vector &title, + const std::map> &result) { + std::vector columnLenMax(5); + std::vector> rows; + + for (auto &iter : result) { + char stringBuf[5][100] = {}; + std::vector columns; + size_t len; + + len = iter.first.size(); + if (len > columnLenMax.at(0)) { + columnLenMax.at(0) = len + 4; + } + columns.push_back(iter.first); + + len = snprintf(stringBuf[1], sizeof(stringBuf[1]), "%f", iter.second.second / _flags->loopCount); + if (len > columnLenMax.at(1)) { + columnLenMax.at(1) = len + 4; + } + columns.emplace_back(stringBuf[1]); + + len = snprintf(stringBuf[2], sizeof(stringBuf[2]), "%f", iter.second.second / op_cost_total_); + if (len > columnLenMax.at(2)) { + columnLenMax.at(2) = len + 4; + } + columns.emplace_back(stringBuf[2]); + + len = snprintf(stringBuf[3], sizeof(stringBuf[3]), "%d", iter.second.first); + if (len > columnLenMax.at(3)) { + columnLenMax.at(3) = len + 4; + } + columns.emplace_back(stringBuf[3]); + + len = snprintf(stringBuf[4], sizeof(stringBuf[4]), "%f", iter.second.second); + if (len > columnLenMax.at(4)) { + columnLenMax.at(4) = len + 4; + } + columns.emplace_back(stringBuf[4]); + + rows.push_back(columns); + } + + printf("-------------------------------------------------------------------------\n"); + for (int i = 0; i < 5; i++) { + auto printBuf = title[i]; + if (printBuf.size() > columnLenMax.at(i)) { + columnLenMax.at(i) = printBuf.size(); + } + printBuf.resize(columnLenMax.at(i), ' '); + printf("%s\t", printBuf.c_str()); + } + printf("\n"); + for (size_t i = 0; i < rows.size(); i++) { + for (int j = 0; j < 5; j++) { + auto printBuf = rows[i][j]; + printBuf.resize(columnLenMax.at(j), ' '); + printf("%s\t", printBuf.c_str()); + } + printf("\n"); + } return RET_OK; } @@ -583,16 +705,7 @@ int RunBenchmark(int argc, const char **argv) { return RET_ERROR; } - if (flags.device == "GPU") { - status = mBenchmark.RunBenchmark("GPU"); - } else if (flags.device == "CPU") { - status = mBenchmark.RunBenchmark("CPU"); - } else { - MS_LOG(ERROR) << "Device type" << flags.device << " not support."; - std::cerr << "Device type" << flags.device << " not support." << std::endl; - return RET_ERROR; - } - + status = mBenchmark.RunBenchmark(); if (status != 0) { MS_LOG(ERROR) << "Run Benchmark " << flags.modelPath.substr(flags.modelPath.find_last_of(DELIM_SLASH) + 1).c_str() << " Failed : " << status; diff --git a/mindspore/lite/tools/benchmark/benchmark.h b/mindspore/lite/tools/benchmark/benchmark.h index e38790d9d2..8878846c93 100644 --- a/mindspore/lite/tools/benchmark/benchmark.h +++ b/mindspore/lite/tools/benchmark/benchmark.h @@ -28,6 +28,7 @@ #include #include #include +#include #include "include/model.h" #include "tools/common/flag_parser.h" #include "src/common/file_utils.h" @@ -64,6 +65,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser { AddFlag(&BenchmarkFlags::numThreads, "numThreads", "Run threads number", 2); AddFlag(&BenchmarkFlags::fp16Priority, "fp16Priority", "Priority float16", false); AddFlag(&BenchmarkFlags::warmUpLoopCount, "warmUpLoopCount", "Run warm up loop", 3); + AddFlag(&BenchmarkFlags::runTimeProfiler, "runTimeProfiler", "Run time profiler", false); // MarkAccuracy AddFlag(&BenchmarkFlags::calibDataPath, "calibDataPath", "Calibration data file path", ""); AddFlag(&BenchmarkFlags::calibDataType, "calibDataType", "Calibration data type. FLOAT | INT32 | INT8 | UINT8", @@ -90,6 +92,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser { int numThreads; bool fp16Priority; int warmUpLoopCount; + bool runTimeProfiler; // MarkAccuracy std::string calibDataPath; std::string calibDataType; @@ -108,7 +111,7 @@ class MS_API Benchmark { virtual ~Benchmark(); int Init(); - int RunBenchmark(const std::string &deviceType = "NPU"); + int RunBenchmark(); private: // call GenerateInputData or ReadInputFile to init inputTensors @@ -125,6 +128,10 @@ class MS_API Benchmark { int CompareOutput(); + int InitCallbackParameter(); + + int PrintResult(const std::vector &title, const std::map> &result); + template void PrintInputData(tensor::MSTensor *input) { MS_ASSERT(input != nullptr); @@ -228,6 +235,16 @@ class MS_API Benchmark { {"INT32", TypeId::kNumberTypeInt32}, {"UINT8", TypeId::kNumberTypeUInt8}}; TypeId msCalibDataType = TypeId::kNumberTypeFloat; + + // callback parameters + uint64_t op_begin_ = 0; + int op_call_times_total_ = 0; + float op_cost_total_ = 0.0f; + std::map> op_times_by_type_; + std::map> op_times_by_name_; + + session::KernelCallBack before_call_back_; + session::KernelCallBack after_call_back_; }; int MS_API RunBenchmark(int argc, const char **argv); diff --git a/mindspore/lite/tools/common/flag_parser.cc b/mindspore/lite/tools/common/flag_parser.cc index ebc87a61a8..37e89dcc54 100644 --- a/mindspore/lite/tools/common/flag_parser.cc +++ b/mindspore/lite/tools/common/flag_parser.cc @@ -38,13 +38,13 @@ Option FlagParser::ParseFlags(int argc, const char *const *argv, bo } if (flagItem.find("--") == std::string::npos) { - continue; + return Option("Failed: flag " + flagItem + " is not valid."); } std::string key; Option value = Option(None()); - size_t pos = flagItem.find_first_of("="); + size_t pos = flagItem.find_first_of('='); if (pos == std::string::npos && flagItem.find("--no-") != std::string::npos) { key = flagItem.substr(FLAG_PREFIX_LEN); } else if (pos == std::string::npos) { diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc index 0e40fba131..745e9acf75 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc @@ -79,13 +79,8 @@ AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferTyp const float stdValue = std::stof(stdValues, &sz); sz = 0; const float mean = std::stof(meanValues, &sz); - std::unique_ptr inArr = nullptr; - if (inferType == kNumberTypeFloat) { - inArr.reset(new (std::nothrow) InputArray(mean, stdValue)); - } else { - inArr.reset(new (std::nothrow) InputArray(mean, stdValue, TypeId::kNumberTypeInt8)); - } - mInputArray = inArr.get(); + mInputArray = new (std::nothrow) InputArray(mean, stdValue); + mInputArray->dataType = inferType; mInputArray->InitQuantParam(); } @@ -132,7 +127,7 @@ STATUS AwareQuantizer::GenerateQuantParam() { } else { auto status = quantParamCalcer->Calc(graph, *node); if (status != RET_OK) { - MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); + MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); node->quantType = schema::QuantType_QUANT_NONE; } else { node->quantType = schema::QuantType_AwareTraining;