| @@ -67,7 +67,7 @@ std::string RealPath(const char *path) { | |||||
| MS_LOG(ERROR) << "path is too long"; | MS_LOG(ERROR) << "path is too long"; | ||||
| return ""; | return ""; | ||||
| } | } | ||||
| std::shared_ptr<char> resolvedPath(new (std::nothrow) char[PATH_MAX]{0}); | |||||
| auto resolvedPath = std::make_unique<char[]>(PATH_MAX); | |||||
| if (resolvedPath == nullptr) { | if (resolvedPath == nullptr) { | ||||
| MS_LOG(ERROR) << "new resolvedPath failed"; | MS_LOG(ERROR) << "new resolvedPath failed"; | ||||
| return ""; | return ""; | ||||
| @@ -26,7 +26,15 @@ namespace mindspore::lite { | |||||
| Model *Model::Import(const char *model_buf, size_t size) { | Model *Model::Import(const char *model_buf, size_t size) { | ||||
| auto model = new Model(); | auto model = new Model(); | ||||
| if (model_buf == nullptr) { | |||||
| MS_LOG(ERROR) << "model buf is null"; | |||||
| return nullptr; | |||||
| } | |||||
| model->model_impl_ = ModelImpl::Import(model_buf, size); | model->model_impl_ = ModelImpl::Import(model_buf, size); | ||||
| if (model->model_impl_ == nullptr) { | |||||
| MS_LOG(ERROR) << "model impl is null"; | |||||
| return nullptr; | |||||
| } | |||||
| return model; | return model; | ||||
| } | } | ||||
| @@ -21,7 +21,10 @@ | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) { | ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) { | ||||
| MS_EXCEPTION_IF_NULL(model_buf); | |||||
| if (model_buf == nullptr) { | |||||
| MS_LOG(ERROR) << "The model buf is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| flatbuffers::Verifier verify((const uint8_t *)model_buf, size); | flatbuffers::Verifier verify((const uint8_t *)model_buf, size); | ||||
| if (!schema::VerifyMetaGraphBuffer(verify)) { | if (!schema::VerifyMetaGraphBuffer(verify)) { | ||||
| MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; | MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; | ||||
| @@ -153,6 +153,9 @@ void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags * | |||||
| int RunConverter(int argc, const char **argv) { | int RunConverter(int argc, const char **argv) { | ||||
| auto flags = new converter::Flags; | auto flags = new converter::Flags; | ||||
| auto status = flags->Init(argc, argv); | auto status = flags->Init(argc, argv); | ||||
| if (status == RET_SUCCESS_EXIT) { | |||||
| return 0; | |||||
| } | |||||
| if (status != 0) { | if (status != 0) { | ||||
| MS_LOG(ERROR) << "converter::Flags Init failed: " << status; | MS_LOG(ERROR) << "converter::Flags Init failed: " << status; | ||||
| return 1; | return 1; | ||||
| @@ -14,7 +14,6 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <regex> | |||||
| #include <string> | #include <string> | ||||
| #include "tools/converter/converter_flags.h" | #include "tools/converter/converter_flags.h" | ||||
| @@ -43,31 +42,31 @@ int Flags::Init(int argc, const char **argv) { | |||||
| Option<std::string> err = this->ParseFlags(argc, argv); | Option<std::string> err = this->ParseFlags(argc, argv); | ||||
| if (err.IsSome()) { | if (err.IsSome()) { | ||||
| MS_LOG(ERROR) << err.Get(); | |||||
| std::cerr << err.Get(); | |||||
| std::cerr << this->Usage() << std::endl; | std::cerr << this->Usage() << std::endl; | ||||
| return 1; | return 1; | ||||
| } | } | ||||
| if (this->help) { | if (this->help) { | ||||
| std::cerr << this->Usage() << std::endl; | |||||
| return 0; | |||||
| std::cout << this->Usage() << std::endl; | |||||
| return RET_SUCCESS_EXIT; | |||||
| } | } | ||||
| if (this->modelFile.empty()) { | if (this->modelFile.empty()) { | ||||
| MS_LOG(ERROR) << "INPUT MISSING: model file path is necessary"; | |||||
| std::cerr << "INPUT MISSING: model file path is necessary"; | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| if (this->outputFile.empty()) { | if (this->outputFile.empty()) { | ||||
| MS_LOG(ERROR) << "INPUT MISSING: output file path is necessary"; | |||||
| std::cerr << "INPUT MISSING: output file path is necessary"; | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| if (this->outputFile.rfind('/') == this->outputFile.length() - 1) { | if (this->outputFile.rfind('/') == this->outputFile.length() - 1) { | ||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: outputFile must be a valid file path"; | |||||
| std::cerr << "INPUT ILLEGAL: outputFile must be a valid file path"; | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| if (this->fmkIn.empty()) { | if (this->fmkIn.empty()) { | ||||
| MS_LOG(ERROR) << "INPUT MISSING: fmk is necessary"; | |||||
| std::cerr << "INPUT MISSING: fmk is necessary"; | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| if (this->inputInferenceTypeIn == "FLOAT") { | if (this->inputInferenceTypeIn == "FLOAT") { | ||||
| @@ -75,7 +74,7 @@ int Flags::Init(int argc, const char **argv) { | |||||
| } else if (this->inputInferenceTypeIn == "UINT8") { | } else if (this->inputInferenceTypeIn == "UINT8") { | ||||
| this->inputInferenceType = 1; | this->inputInferenceType = 1; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str(); | |||||
| std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str(); | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| if (this->fmkIn == "CAFFE") { | if (this->fmkIn == "CAFFE") { | ||||
| @@ -85,12 +84,12 @@ int Flags::Init(int argc, const char **argv) { | |||||
| } else if (this->fmkIn == "TFLITE") { | } else if (this->fmkIn == "TFLITE") { | ||||
| this->fmk = FmkType_TFLITE; | this->fmk = FmkType_TFLITE; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS"; | |||||
| std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS"; | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| if (this->fmk != FmkType_CAFFE && !weightFile.empty()) { | if (this->fmk != FmkType_CAFFE && !weightFile.empty()) { | ||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: weightFile is not a valid flag"; | |||||
| std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag"; | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| if (this->quantTypeIn == "AwareTrainning") { | if (this->quantTypeIn == "AwareTrainning") { | ||||
| @@ -102,7 +101,7 @@ int Flags::Init(int argc, const char **argv) { | |||||
| } else if (this->quantTypeIn.empty()) { | } else if (this->quantTypeIn.empty()) { | ||||
| this->quantType = QuantType_QUANT_NONE; | this->quantType = QuantType_QUANT_NONE; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: quantType must be AwareTrainning|WeightQuant|PostTraining"; | |||||
| std::cerr << "INPUT ILLEGAL: quantType must be AwareTrainning|WeightQuant|PostTraining"; | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| @@ -71,7 +71,7 @@ int TimeProfile::ReadInputFile() { | |||||
| } | } | ||||
| auto tensor_data_size = inTensor->Size(); | auto tensor_data_size = inTensor->Size(); | ||||
| if (size != tensor_data_size) { | if (size != tensor_data_size) { | ||||
| MS_LOG(ERROR) << "Input binary file size error, required: " << tensor_data_size << " in fact: %zu" << size; | |||||
| MS_LOG(ERROR) << "Input binary file size error, required: " << tensor_data_size << " in fact: " << size; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto input_data = inTensor->MutableData(); | auto input_data = inTensor->MutableData(); | ||||
| @@ -90,7 +90,7 @@ int TimeProfile::LoadInput() { | |||||
| } else { | } else { | ||||
| auto status = ReadInputFile(); | auto status = ReadInputFile(); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ReadInputFile error, " << status; | |||||
| MS_LOG(ERROR) << "ReadInputFile error " << status; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| @@ -149,10 +149,10 @@ int TimeProfile::InitCallbackParameter() { | |||||
| uint64_t opEnd = GetTimeUs(); | uint64_t opEnd = GetTimeUs(); | ||||
| if (after_inputs.empty()) { | if (after_inputs.empty()) { | ||||
| MS_LOG(INFO) << "The num of beforeInputs is empty"; | |||||
| MS_LOG(INFO) << "The num of after inputs is empty"; | |||||
| } | } | ||||
| if (after_outputs.empty()) { | if (after_outputs.empty()) { | ||||
| MS_LOG(INFO) << "The num of beforeOutputs is empty"; | |||||
| MS_LOG(INFO) << "The num of after outputs is empty"; | |||||
| } | } | ||||
| float cost = static_cast<float>(opEnd - op_begin_) / 1000.0f; | float cost = static_cast<float>(opEnd - op_begin_) / 1000.0f; | ||||
| @@ -297,22 +297,34 @@ int TimeProfile::RunTimeProfile() { | |||||
| size_t size = 0; | size_t size = 0; | ||||
| char *graphBuf = ReadFile(_flags->model_path_.c_str(), &size); | char *graphBuf = ReadFile(_flags->model_path_.c_str(), &size); | ||||
| if (graphBuf == nullptr) { | if (graphBuf == nullptr) { | ||||
| MS_LOG(ERROR) << "Load graph failed while running %s", modelName.c_str(); | |||||
| MS_LOG(ERROR) << "Load graph failed while running " << modelName.c_str(); | |||||
| delete graphBuf; | |||||
| delete session_; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto model = lite::Model::Import(graphBuf, size); | auto model = lite::Model::Import(graphBuf, size); | ||||
| delete graphBuf; | |||||
| if (model == nullptr) { | |||||
| MS_LOG(ERROR) << "Import model file failed while running " << modelName.c_str(); | |||||
| delete session_; | |||||
| delete model; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto ret = session_->CompileGraph(model); | auto ret = session_->CompileGraph(model); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Compile graph failed."; | MS_LOG(ERROR) << "Compile graph failed."; | ||||
| delete session_; | |||||
| delete model; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // load input | // load input | ||||
| MS_LOG(INFO) << "start generate input data"; | MS_LOG(INFO) << "start generate input data"; | ||||
| auto status = LoadInput(); | auto status = LoadInput(); | ||||
| if (status != 0) { | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Generate input data error"; | MS_LOG(ERROR) << "Generate input data error"; | ||||
| delete session_; | |||||
| delete model; | |||||
| return status; | return status; | ||||
| } | } | ||||
| @@ -324,6 +336,8 @@ int TimeProfile::RunTimeProfile() { | |||||
| ret = session_->RunGraph(before_call_back_, after_call_back_); | ret = session_->RunGraph(before_call_back_, after_call_back_); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Run graph failed."; | MS_LOG(ERROR) << "Run graph failed."; | ||||
| delete session_; | |||||
| delete model; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto outputs = session_->GetOutputs(); | auto outputs = session_->GetOutputs(); | ||||
| @@ -345,14 +359,8 @@ int TimeProfile::RunTimeProfile() { | |||||
| printf("\n total time: %5.5f ms, kernel cost: %5.5f ms \n\n", runCost, op_cost_total_ / _flags->loop_count_); | printf("\n total time: %5.5f ms, kernel cost: %5.5f ms \n\n", runCost, op_cost_total_ / _flags->loop_count_); | ||||
| printf("-------------------------------------------------------------------------\n"); | printf("-------------------------------------------------------------------------\n"); | ||||
| for (auto &msInput : ms_inputs_) { | |||||
| delete msInput; | |||||
| } | |||||
| ms_inputs_.clear(); | |||||
| delete graphBuf; | |||||
| delete session_; | |||||
| delete model; | delete model; | ||||
| delete session_; | |||||
| return ret; | return ret; | ||||
| } | } | ||||