Merge pull request !4434 from hangq/mastertags/v0.7.0-beta
| @@ -92,6 +92,16 @@ void LiteSession::InitGraphInputTensors(const lite::Model *model) { | |||
| } | |||
| } | |||
| void LiteSession::InitGraphInputMSTensors(const lite::Model *model) { | |||
| auto meta_graph = model->GetMetaGraph(); | |||
| MS_ASSERT(this->input_vec_.empty()); | |||
| MS_ASSERT(meta_graph != nullptr); | |||
| for (auto &input_tensor : this->inputs_) { | |||
| MS_ASSERT(input_tensor != nullptr); | |||
| this->input_vec_.emplace_back(new lite::tensor::LiteTensor(input_tensor)); | |||
| } | |||
| } | |||
| void LiteSession::InitGraphOutputTensors(const lite::Model *model) { | |||
| auto meta_graph = model->GetMetaGraph(); | |||
| MS_ASSERT(this->outputs_.empty()); | |||
| @@ -169,6 +179,7 @@ void LiteSession::InitGraphOutputMap(const lite::Model *model) { | |||
| void LiteSession::InitGraphInOutTensors(const lite::Model *model) { | |||
| InitGraphInputTensors(model); | |||
| InitGraphInputMSTensors(model); | |||
| InitGraphOutputTensors(model); | |||
| InitGraphInputMap(model); | |||
| InitGraphOutputMap(model); | |||
| @@ -201,16 +212,7 @@ int LiteSession::CompileGraph(Model *model) { | |||
| } | |||
| std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputs() const { | |||
| std::vector<mindspore::tensor::MSTensor *> ret; | |||
| for (auto &iter : this->input_map_) { | |||
| auto &node_input_tensors = iter.second; | |||
| for (auto tensor : node_input_tensors) { | |||
| if (!IsContain(ret, tensor)) { | |||
| ret.emplace_back(tensor); | |||
| } | |||
| } | |||
| } | |||
| return ret; | |||
| return this->input_vec_; | |||
| } | |||
| int LiteSession::RunGraph(const session::KernelCallBack &before, const session::KernelCallBack &after) { | |||
| @@ -57,13 +57,15 @@ class LiteSession : public session::LiteSession { | |||
| int ConvertTensors(const lite::Model *model); | |||
| void InitGraphInOutTensors(const lite::Model *model); | |||
| // init this->inputs_ | |||
| void InitGraphInputTensors(const lite::Model *model); | |||
| // init this->input_vec_ | |||
| void InitGraphInputMSTensors(const lite::Model *model); | |||
| // init this->outputs_ | |||
| void InitGraphOutputTensors(const lite::Model *model); | |||
| // init this->input_map_ | |||
| void InitGraphInputMap(const lite::Model *model); | |||
| // init this->output_map_ | |||
| void InitGraphOutputMap(const lite::Model *model); | |||
| protected: | |||
| @@ -74,6 +76,8 @@ class LiteSession : public session::LiteSession { | |||
| std::vector<tensor::Tensor *> inputs_; | |||
| // graph output tensors | |||
| std::vector<tensor::Tensor *> outputs_; | |||
| // graph input MSTensors | |||
| std::vector<mindspore::tensor::MSTensor *> input_vec_; | |||
| // graph input node name -- input tensors | |||
| std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> input_map_; | |||
| // graph output node name -- output tensors | |||
| @@ -49,7 +49,8 @@ int Benchmark::GenerateInputData() { | |||
| auto tensorByteSize = tensor->Size(); | |||
| auto status = GenerateRandomData(tensorByteSize, inputData); | |||
| if (status != 0) { | |||
| MS_LOG(ERROR) << "GenerateRandomData for inTensor failed %d" << status; | |||
| std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl; | |||
| MS_LOG(ERROR) << "GenerateRandomData for inTensor failed:" << status; | |||
| return status; | |||
| } | |||
| } | |||
| @@ -60,12 +61,14 @@ int Benchmark::LoadInput() { | |||
| if (_flags->inDataPath.empty()) { | |||
| auto status = GenerateInputData(); | |||
| if (status != 0) { | |||
| std::cerr << "Generate input data error " << status << std::endl; | |||
| MS_LOG(ERROR) << "Generate input data error " << status; | |||
| return status; | |||
| } | |||
| } else { | |||
| auto status = ReadInputFile(); | |||
| if (status != 0) { | |||
| std::cerr << "ReadInputFile error, " << status << std::endl; | |||
| MS_LOG(ERROR) << "ReadInputFile error, " << status; | |||
| return status; | |||
| } | |||
| @@ -97,6 +100,7 @@ int Benchmark::ReadInputFile() { | |||
| char *binBuf = ReadFile(_flags->input_data_list[i].c_str(), &size); | |||
| auto tensorDataSize = cur_tensor->Size(); | |||
| if (size != tensorDataSize) { | |||
| std::cerr << "Input binary file size error, required: %zu, in fact: %zu" << tensorDataSize << size << std::endl; | |||
| MS_LOG(ERROR) << "Input binary file size error, required: %zu, in fact: %zu" << tensorDataSize << size; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -113,11 +117,13 @@ int Benchmark::ReadCalibData() { | |||
| // read calib data | |||
| std::ifstream inFile(calibDataPath); | |||
| if (!inFile.good()) { | |||
| std::cerr << "file: " << calibDataPath << " is not exist" << std::endl; | |||
| MS_LOG(ERROR) << "file: " << calibDataPath << " is not exist"; | |||
| return RET_ERROR; | |||
| } | |||
| if (!inFile.is_open()) { | |||
| std::cerr << "file: " << calibDataPath << " open failed" << std::endl; | |||
| MS_LOG(ERROR) << "file: " << calibDataPath << " open failed"; | |||
| inFile.close(); | |||
| return RET_ERROR; | |||
| @@ -181,6 +187,7 @@ float Benchmark::CompareData(const std::string &nodeName, std::vector<int> msSha | |||
| oss << dim << ","; | |||
| } | |||
| oss << ") are different"; | |||
| std::cerr << oss.str() << std::endl; | |||
| MS_LOG(ERROR) << "%s", oss.str().c_str(); | |||
| return RET_ERROR; | |||
| } | |||
| @@ -193,6 +200,7 @@ float Benchmark::CompareData(const std::string &nodeName, std::vector<int> msSha | |||
| } | |||
| if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { | |||
| std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; | |||
| MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -524,6 +532,13 @@ int Benchmark::Init() { | |||
| return RET_OK; | |||
| } | |||
| Benchmark::~Benchmark() { | |||
| for (auto iter : this->calibData) { | |||
| delete (iter.second); | |||
| } | |||
| this->calibData.clear(); | |||
| } | |||
| int RunBenchmark(int argc, const char **argv) { | |||
| BenchmarkFlags flags; | |||
| Option<std::string> err = flags.ParseFlags(argc, argv); | |||
| @@ -104,7 +104,7 @@ class MS_API Benchmark { | |||
| public: | |||
| explicit Benchmark(BenchmarkFlags *flags) : _flags(flags) {} | |||
| virtual ~Benchmark() = default; | |||
| virtual ~Benchmark(); | |||
| int Init(); | |||
| int RunBenchmark(const std::string &deviceType = "NPU"); | |||