From: @yuzhenhua666 Reviewed-by: @c_34,@wuxuejian Signed-off-by: @c_34pull/13861/MERGE
| @@ -33,12 +33,12 @@ | |||||
| #include "inc/utils.h" | #include "inc/utils.h" | ||||
| using mindspore::GlobalContext; | |||||
| using mindspore::Context; | |||||
| using mindspore::Serialization; | using mindspore::Serialization; | ||||
| using mindspore::Model; | using mindspore::Model; | ||||
| using mindspore::ModelContext; | |||||
| using mindspore::Status; | using mindspore::Status; | ||||
| using mindspore::ModelType; | using mindspore::ModelType; | ||||
| using mindspore::Graph; | |||||
| using mindspore::GraphCell; | using mindspore::GraphCell; | ||||
| using mindspore::kSuccess; | using mindspore::kSuccess; | ||||
| using mindspore::MSTensor; | using mindspore::MSTensor; | ||||
| @@ -64,21 +64,28 @@ int main(int argc, char **argv) { | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310); | |||||
| GlobalContext::SetGlobalDeviceID(FLAGS_device_id); | |||||
| auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR); | |||||
| auto model_context = std::make_shared<mindspore::ModelContext>(); | |||||
| if (!FLAGS_aipp_path.empty()) { | |||||
| ModelContext::SetInsertOpConfigPath(model_context, FLAGS_aipp_path); | |||||
| auto context = std::make_shared<Context>(); | |||||
| auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>(); | |||||
| ascend310_info->SetDeviceID(FLAGS_device_id); | |||||
| ascend310_info->SetInsertOpConfigPath({FLAGS_aipp_path}); | |||||
| context->MutableDeviceInfo().push_back(ascend310_info); | |||||
| Graph graph; | |||||
| Status ret = Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); | |||||
| if (ret != kSuccess) { | |||||
| std::cout << "Load model failed." << std::endl; | |||||
| return 1; | |||||
| } | } | ||||
| Model model(GraphCell(graph), model_context); | |||||
| Status ret = model.Build(); | |||||
| Model model; | |||||
| ret = model.Build(GraphCell(graph), context); | |||||
| if (ret != kSuccess) { | if (ret != kSuccess) { | ||||
| std::cout << "ERROR: Build failed." << std::endl; | std::cout << "ERROR: Build failed." << std::endl; | ||||
| return 1; | return 1; | ||||
| } | } | ||||
| std::vector<MSTensor> modelInputs = model.GetInputs(); | |||||
| auto allFiles = GetAllFiles(FLAGS_dataset_path); | auto allFiles = GetAllFiles(FLAGS_dataset_path); | ||||
| if (allFiles.empty()) { | if (allFiles.empty()) { | ||||
| std::cout << "ERROR: no input data." << std::endl; | std::cout << "ERROR: no input data." << std::endl; | ||||
| @@ -108,11 +115,12 @@ int main(int argc, char **argv) { | |||||
| std::cout << "wrong file format: " << allFiles[i] << std::endl; | std::cout << "wrong file format: " << allFiles[i] << std::endl; | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto img = std::make_shared<MSTensor>(); | |||||
| compose(ReadFileToTensor(allFiles[i]), img.get()); | |||||
| inputs.emplace_back(img->Name(), img->DataType(), img->Shape(), | |||||
| img->Data().get(), img->DataSize()); | |||||
| mindspore::MSTensor img; | |||||
| compose(ReadFileToTensor(allFiles[i]), &img); | |||||
| inputs.emplace_back(modelInputs[0].Name(), modelInputs[0].DataType(), modelInputs[0].Shape(), | |||||
| img.Data().get(), img.DataSize()); | |||||
| gettimeofday(&start, NULL); | gettimeofday(&start, NULL); | ||||
| ret = model.Predict(inputs, &outputs); | ret = model.Predict(inputs, &outputs); | ||||
| @@ -34,12 +34,12 @@ | |||||
| #include "include/api/serialization.h" | #include "include/api/serialization.h" | ||||
| #include "include/api/context.h" | #include "include/api/context.h" | ||||
| using mindspore::GlobalContext; | |||||
| using mindspore::Serialization; | using mindspore::Serialization; | ||||
| using mindspore::Model; | using mindspore::Model; | ||||
| using mindspore::ModelContext; | |||||
| using mindspore::Context; | |||||
| using mindspore::Status; | using mindspore::Status; | ||||
| using mindspore::ModelType; | using mindspore::ModelType; | ||||
| using mindspore::Graph; | |||||
| using mindspore::GraphCell; | using mindspore::GraphCell; | ||||
| using mindspore::kSuccess; | using mindspore::kSuccess; | ||||
| using mindspore::MSTensor; | using mindspore::MSTensor; | ||||
| @@ -71,18 +71,27 @@ int main(int argc, char **argv) { | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| GlobalContext::SetGlobalDeviceTarget(FLAGS_device_target); | |||||
| GlobalContext::SetGlobalDeviceID(FLAGS_device_id); | |||||
| auto context = std::make_shared<Context>(); | |||||
| auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>(); | |||||
| ascend310_info->SetDeviceID(FLAGS_device_id); | |||||
| context->MutableDeviceInfo().push_back(ascend310_info); | |||||
| auto graph = Serialization::LoadModel(FLAGS_model_path, ModelType::kMindIR); | |||||
| Graph graph; | |||||
| Status ret = Serialization::Load(FLAGS_model_path, ModelType::kMindIR, &graph); | |||||
| if (ret != kSuccess) { | |||||
| std::cout << "Load model failed." << std::endl; | |||||
| return 1; | |||||
| } | |||||
| Model model((GraphCell(graph))); | |||||
| Status ret = model.Build(); | |||||
| Model model; | |||||
| ret = model.Build(GraphCell(graph), context); | |||||
| if (ret != kSuccess) { | if (ret != kSuccess) { | ||||
| std::cout << "ERROR Build failed." << std::endl; | |||||
| std::cout << "ERROR: Build failed." << std::endl; | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| std::vector<MSTensor> modelInputs = model.GetInputs(); | |||||
| auto all_files = GetAllFiles(FLAGS_dataset_path); | auto all_files = GetAllFiles(FLAGS_dataset_path); | ||||
| if (all_files.empty()) { | if (all_files.empty()) { | ||||
| std::cout << "ERROR: no input data." << std::endl; | std::cout << "ERROR: no input data." << std::endl; | ||||
| @@ -118,7 +127,8 @@ int main(int argc, char **argv) { | |||||
| transform(image, &image); | transform(image, &image); | ||||
| transformCast(image, &image); | transformCast(image, &image); | ||||
| inputs.emplace_back(image); | |||||
| inputs.emplace_back(modelInputs[0].Name(), modelInputs[0].DataType(), modelInputs[0].Shape(), | |||||
| image.Data().get(), image.DataSize()); | |||||
| gettimeofday(&start, NULL); | gettimeofday(&start, NULL); | ||||
| model.Predict(inputs, &outputs); | model.Predict(inputs, &outputs); | ||||
| @@ -165,6 +165,10 @@ int AclProcess::WriteResult(const std::string& imageFile) { | |||||
| std::string outFileName = homePath + "/" + fileName; | std::string outFileName = homePath + "/" + fileName; | ||||
| try { | try { | ||||
| FILE * outputFile = fopen(outFileName.c_str(), "wb"); | FILE * outputFile = fopen(outFileName.c_str(), "wb"); | ||||
| if (outputFile == nullptr) { | |||||
| std::cout << "open result file " << outFileName << " failed" << std::endl; | |||||
| return INVALID_POINTER; | |||||
| } | |||||
| fwrite(resHostBuf, output_size, sizeof(char), outputFile); | fwrite(resHostBuf, output_size, sizeof(char), outputFile); | ||||
| fclose(outputFile); | fclose(outputFile); | ||||
| outputFile = nullptr; | outputFile = nullptr; | ||||
| @@ -79,7 +79,11 @@ int main(int argc, char* argv[]) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (is_file(FLAGS_data_path)) { | if (is_file(FLAGS_data_path)) { | ||||
| aclProcess.Process(FLAGS_data_path, &costTime_map); | |||||
| ret = aclProcess.Process(FLAGS_data_path, &costTime_map); | |||||
| if (ret != OK) { | |||||
| std::cout << "model process failed, errno = " << ret << std::endl; | |||||
| return ret; | |||||
| } | |||||
| } else if (is_dir(FLAGS_data_path)) { | } else if (is_dir(FLAGS_data_path)) { | ||||
| struct dirent * filename; | struct dirent * filename; | ||||
| DIR * dir; | DIR * dir; | ||||
| @@ -93,7 +97,11 @@ int main(int argc, char* argv[]) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| std::string wholePath = FLAGS_data_path + "/" + filename->d_name; | std::string wholePath = FLAGS_data_path + "/" + filename->d_name; | ||||
| aclProcess.Process(wholePath, &costTime_map); | |||||
| ret = aclProcess.Process(wholePath, &costTime_map); | |||||
| if (ret != OK) { | |||||
| std::cout << "model process failed, errno = " << ret << std::endl; | |||||
| return ret; | |||||
| } | |||||
| } | } | ||||
| } else { | } else { | ||||
| std::cout << " input image path error" << std::endl; | std::cout << " input image path error" << std::endl; | ||||
| @@ -165,6 +165,10 @@ int AclProcess::WriteResult(const std::string& imageFile) { | |||||
| std::string outFileName = homePath + "/" + fileName; | std::string outFileName = homePath + "/" + fileName; | ||||
| try { | try { | ||||
| FILE * outputFile = fopen(outFileName.c_str(), "wb"); | FILE * outputFile = fopen(outFileName.c_str(), "wb"); | ||||
| if (outputFile == nullptr) { | |||||
| std::cout << "open result file " << outFileName << " failed" << std::endl; | |||||
| return INVALID_POINTER; | |||||
| } | |||||
| fwrite(resHostBuf, output_size, sizeof(char), outputFile); | fwrite(resHostBuf, output_size, sizeof(char), outputFile); | ||||
| fclose(outputFile); | fclose(outputFile); | ||||
| outputFile = nullptr; | outputFile = nullptr; | ||||
| @@ -79,7 +79,11 @@ int main(int argc, char* argv[]) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (is_file(FLAGS_data_path)) { | if (is_file(FLAGS_data_path)) { | ||||
| aclProcess.Process(FLAGS_data_path, &costTime_map); | |||||
| ret = aclProcess.Process(FLAGS_data_path, &costTime_map); | |||||
| if (ret != OK) { | |||||
| std::cout << "model process failed, errno = " << ret << std::endl; | |||||
| return ret; | |||||
| } | |||||
| } else if (is_dir(FLAGS_data_path)) { | } else if (is_dir(FLAGS_data_path)) { | ||||
| struct dirent * filename; | struct dirent * filename; | ||||
| DIR * dir; | DIR * dir; | ||||
| @@ -93,7 +97,11 @@ int main(int argc, char* argv[]) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| std::string wholePath = FLAGS_data_path + "/" + filename->d_name; | std::string wholePath = FLAGS_data_path + "/" + filename->d_name; | ||||
| aclProcess.Process(wholePath, &costTime_map); | |||||
| ret = aclProcess.Process(wholePath, &costTime_map); | |||||
| if (ret != OK) { | |||||
| std::cout << "model process failed, errno = " << ret << std::endl; | |||||
| return ret; | |||||
| } | |||||
| } | } | ||||
| } else { | } else { | ||||
| std::cout << " input image path error" << std::endl; | std::cout << " input image path error" << std::endl; | ||||