Browse Source

fix time_profile and converter bug

tags/v0.7.0-beta
yeyunpeng 5 years ago
parent
commit
00ecfc3137
6 changed files with 49 additions and 28 deletions
  1. +1
    -1
      mindspore/lite/src/common/file_utils.cc
  2. +8
    -0
      mindspore/lite/src/model.cc
  3. +4
    -1
      mindspore/lite/src/model_impl.cc
  4. +3
    -0
      mindspore/lite/tools/converter/converter.cc
  5. +11
    -12
      mindspore/lite/tools/converter/converter_flags.cc
  6. +22
    -14
      mindspore/lite/tools/time_profile/time_profile.cc

+ 1
- 1
mindspore/lite/src/common/file_utils.cc View File

@@ -67,7 +67,7 @@ std::string RealPath(const char *path) {
MS_LOG(ERROR) << "path is too long"; MS_LOG(ERROR) << "path is too long";
return ""; return "";
} }
std::shared_ptr<char> resolvedPath(new (std::nothrow) char[PATH_MAX]{0});
auto resolvedPath = std::make_unique<char[]>(PATH_MAX);
if (resolvedPath == nullptr) { if (resolvedPath == nullptr) {
MS_LOG(ERROR) << "new resolvedPath failed"; MS_LOG(ERROR) << "new resolvedPath failed";
return ""; return "";


+ 8
- 0
mindspore/lite/src/model.cc View File

@@ -26,7 +26,15 @@ namespace mindspore::lite {


Model *Model::Import(const char *model_buf, size_t size) { Model *Model::Import(const char *model_buf, size_t size) {
auto model = new Model(); auto model = new Model();
if (model_buf == nullptr) {
MS_LOG(ERROR) << "model buf is null";
return nullptr;
}
model->model_impl_ = ModelImpl::Import(model_buf, size); model->model_impl_ = ModelImpl::Import(model_buf, size);
if (model->model_impl_ == nullptr) {
MS_LOG(ERROR) << "model impl is null";
return nullptr;
}
return model; return model;
} }




+ 4
- 1
mindspore/lite/src/model_impl.cc View File

@@ -21,7 +21,10 @@


namespace mindspore::lite { namespace mindspore::lite {
ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) { ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) {
MS_EXCEPTION_IF_NULL(model_buf);
if (model_buf == nullptr) {
MS_LOG(ERROR) << "The model buf is nullptr";
return nullptr;
}
flatbuffers::Verifier verify((const uint8_t *)model_buf, size); flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
if (!schema::VerifyMetaGraphBuffer(verify)) { if (!schema::VerifyMetaGraphBuffer(verify)) {
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";


+ 3
- 0
mindspore/lite/tools/converter/converter.cc View File

@@ -153,6 +153,9 @@ void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *
int RunConverter(int argc, const char **argv) { int RunConverter(int argc, const char **argv) {
auto flags = new converter::Flags; auto flags = new converter::Flags;
auto status = flags->Init(argc, argv); auto status = flags->Init(argc, argv);
if (status == RET_SUCCESS_EXIT) {
return 0;
}
if (status != 0) { if (status != 0) {
MS_LOG(ERROR) << "converter::Flags Init failed: " << status; MS_LOG(ERROR) << "converter::Flags Init failed: " << status;
return 1; return 1;


+ 11
- 12
mindspore/lite/tools/converter/converter_flags.cc View File

@@ -14,7 +14,6 @@
* limitations under the License. * limitations under the License.
*/ */


#include <regex>
#include <string> #include <string>
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"


@@ -43,31 +42,31 @@ int Flags::Init(int argc, const char **argv) {
Option<std::string> err = this->ParseFlags(argc, argv); Option<std::string> err = this->ParseFlags(argc, argv);


if (err.IsSome()) { if (err.IsSome()) {
MS_LOG(ERROR) << err.Get();
std::cerr << err.Get();
std::cerr << this->Usage() << std::endl; std::cerr << this->Usage() << std::endl;
return 1; return 1;
} }


if (this->help) { if (this->help) {
std::cerr << this->Usage() << std::endl;
return 0;
std::cout << this->Usage() << std::endl;
return RET_SUCCESS_EXIT;
} }
if (this->modelFile.empty()) { if (this->modelFile.empty()) {
MS_LOG(ERROR) << "INPUT MISSING: model file path is necessary";
std::cerr << "INPUT MISSING: model file path is necessary";
return 1; return 1;
} }
if (this->outputFile.empty()) { if (this->outputFile.empty()) {
MS_LOG(ERROR) << "INPUT MISSING: output file path is necessary";
std::cerr << "INPUT MISSING: output file path is necessary";
return 1; return 1;
} }


if (this->outputFile.rfind('/') == this->outputFile.length() - 1) { if (this->outputFile.rfind('/') == this->outputFile.length() - 1) {
MS_LOG(ERROR) << "INPUT ILLEGAL: outputFile must be a valid file path";
std::cerr << "INPUT ILLEGAL: outputFile must be a valid file path";
return 1; return 1;
} }


if (this->fmkIn.empty()) { if (this->fmkIn.empty()) {
MS_LOG(ERROR) << "INPUT MISSING: fmk is necessary";
std::cerr << "INPUT MISSING: fmk is necessary";
return 1; return 1;
} }
if (this->inputInferenceTypeIn == "FLOAT") { if (this->inputInferenceTypeIn == "FLOAT") {
@@ -75,7 +74,7 @@ int Flags::Init(int argc, const char **argv) {
} else if (this->inputInferenceTypeIn == "UINT8") { } else if (this->inputInferenceTypeIn == "UINT8") {
this->inputInferenceType = 1; this->inputInferenceType = 1;
} else { } else {
MS_LOG(ERROR) << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str();
std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str();
return 1; return 1;
} }
if (this->fmkIn == "CAFFE") { if (this->fmkIn == "CAFFE") {
@@ -85,12 +84,12 @@ int Flags::Init(int argc, const char **argv) {
} else if (this->fmkIn == "TFLITE") { } else if (this->fmkIn == "TFLITE") {
this->fmk = FmkType_TFLITE; this->fmk = FmkType_TFLITE;
} else { } else {
MS_LOG(ERROR) << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS";
std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS";
return 1; return 1;
} }


if (this->fmk != FmkType_CAFFE && !weightFile.empty()) { if (this->fmk != FmkType_CAFFE && !weightFile.empty()) {
MS_LOG(ERROR) << "INPUT ILLEGAL: weightFile is not a valid flag";
std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag";
return 1; return 1;
} }
if (this->quantTypeIn == "AwareTrainning") { if (this->quantTypeIn == "AwareTrainning") {
@@ -102,7 +101,7 @@ int Flags::Init(int argc, const char **argv) {
} else if (this->quantTypeIn.empty()) { } else if (this->quantTypeIn.empty()) {
this->quantType = QuantType_QUANT_NONE; this->quantType = QuantType_QUANT_NONE;
} else { } else {
MS_LOG(ERROR) << "INPUT ILLEGAL: quantType must be AwareTrainning|WeightQuant|PostTraining";
std::cerr << "INPUT ILLEGAL: quantType must be AwareTrainning|WeightQuant|PostTraining";
return 1; return 1;
} }




+ 22
- 14
mindspore/lite/tools/time_profile/time_profile.cc View File

@@ -71,7 +71,7 @@ int TimeProfile::ReadInputFile() {
} }
auto tensor_data_size = inTensor->Size(); auto tensor_data_size = inTensor->Size();
if (size != tensor_data_size) { if (size != tensor_data_size) {
MS_LOG(ERROR) << "Input binary file size error, required: " << tensor_data_size << " in fact: %zu" << size;
MS_LOG(ERROR) << "Input binary file size error, required: " << tensor_data_size << " in fact: " << size;
return RET_ERROR; return RET_ERROR;
} }
auto input_data = inTensor->MutableData(); auto input_data = inTensor->MutableData();
@@ -90,7 +90,7 @@ int TimeProfile::LoadInput() {
} else { } else {
auto status = ReadInputFile(); auto status = ReadInputFile();
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "ReadInputFile error, " << status;
MS_LOG(ERROR) << "ReadInputFile error " << status;
return RET_ERROR; return RET_ERROR;
} }
} }
@@ -149,10 +149,10 @@ int TimeProfile::InitCallbackParameter() {
uint64_t opEnd = GetTimeUs(); uint64_t opEnd = GetTimeUs();


if (after_inputs.empty()) { if (after_inputs.empty()) {
MS_LOG(INFO) << "The num of beforeInputs is empty";
MS_LOG(INFO) << "The num of after inputs is empty";
} }
if (after_outputs.empty()) { if (after_outputs.empty()) {
MS_LOG(INFO) << "The num of beforeOutputs is empty";
MS_LOG(INFO) << "The num of after outputs is empty";
} }


float cost = static_cast<float>(opEnd - op_begin_) / 1000.0f; float cost = static_cast<float>(opEnd - op_begin_) / 1000.0f;
@@ -297,22 +297,34 @@ int TimeProfile::RunTimeProfile() {
size_t size = 0; size_t size = 0;
char *graphBuf = ReadFile(_flags->model_path_.c_str(), &size); char *graphBuf = ReadFile(_flags->model_path_.c_str(), &size);
if (graphBuf == nullptr) { if (graphBuf == nullptr) {
MS_LOG(ERROR) << "Load graph failed while running %s", modelName.c_str();
MS_LOG(ERROR) << "Load graph failed while running " << modelName.c_str();
delete graphBuf;
delete session_;
return RET_ERROR; return RET_ERROR;
} }
auto model = lite::Model::Import(graphBuf, size); auto model = lite::Model::Import(graphBuf, size);

delete graphBuf;
if (model == nullptr) {
MS_LOG(ERROR) << "Import model file failed while running " << modelName.c_str();
delete session_;
delete model;
return RET_ERROR;
}
auto ret = session_->CompileGraph(model); auto ret = session_->CompileGraph(model);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Compile graph failed."; MS_LOG(ERROR) << "Compile graph failed.";
delete session_;
delete model;
return RET_ERROR; return RET_ERROR;
} }


// load input // load input
MS_LOG(INFO) << "start generate input data"; MS_LOG(INFO) << "start generate input data";
auto status = LoadInput(); auto status = LoadInput();
if (status != 0) {
if (status != RET_OK) {
MS_LOG(ERROR) << "Generate input data error"; MS_LOG(ERROR) << "Generate input data error";
delete session_;
delete model;
return status; return status;
} }


@@ -324,6 +336,8 @@ int TimeProfile::RunTimeProfile() {
ret = session_->RunGraph(before_call_back_, after_call_back_); ret = session_->RunGraph(before_call_back_, after_call_back_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Run graph failed."; MS_LOG(ERROR) << "Run graph failed.";
delete session_;
delete model;
return RET_ERROR; return RET_ERROR;
} }
auto outputs = session_->GetOutputs(); auto outputs = session_->GetOutputs();
@@ -345,14 +359,8 @@ int TimeProfile::RunTimeProfile() {


printf("\n total time: %5.5f ms, kernel cost: %5.5f ms \n\n", runCost, op_cost_total_ / _flags->loop_count_); printf("\n total time: %5.5f ms, kernel cost: %5.5f ms \n\n", runCost, op_cost_total_ / _flags->loop_count_);
printf("-------------------------------------------------------------------------\n"); printf("-------------------------------------------------------------------------\n");

for (auto &msInput : ms_inputs_) {
delete msInput;
}
ms_inputs_.clear();
delete graphBuf;
delete session_;
delete model; delete model;
delete session_;
return ret; return ret;
} }




Loading…
Cancel
Save