From ed6e49fc82044ce88afe167918e6d2137bed0f94 Mon Sep 17 00:00:00 2001 From: lihongkang <[lihongkang1@huawei.com]> Date: Wed, 24 Mar 2021 16:49:03 +0800 Subject: [PATCH] update interface for yolov4 and unet 310 infer --- .../cv/unet/ascend310_infer/src/main.cc | 18 +++++++-------- .../official/cv/unet/scripts/run_infer_310.sh | 4 ++-- .../cv/yolov4/ascend310_infer/src/main.cc | 23 +++++++++---------- .../cv/yolov4/scripts/run_infer_310.sh | 4 ++-- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/model_zoo/official/cv/unet/ascend310_infer/src/main.cc b/model_zoo/official/cv/unet/ascend310_infer/src/main.cc index a12b558745..66324e2495 100644 --- a/model_zoo/official/cv/unet/ascend310_infer/src/main.cc +++ b/model_zoo/official/cv/unet/ascend310_infer/src/main.cc @@ -34,8 +34,6 @@ using mindspore::Context; -using mindspore::GlobalContext; -using mindspore::ModelContext; using mindspore::Serialization; using mindspore::Model; using mindspore::Status; @@ -57,19 +55,21 @@ int main(int argc, char **argv) { return 1; } - GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310); - GlobalContext::SetGlobalDeviceID(FLAGS_device_id); - auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR); - auto model_context = std::make_shared(); - Model model(GraphCell(graph), model_context); + auto context = std::make_shared(); + auto ascend310 = std::make_shared(); + ascend310->SetDeviceID(FLAGS_device_id); + context->MutableDeviceInfo().push_back(ascend310); + mindspore::Graph graph; + Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); - Status ret = model.Build(); + Model model; + Status ret = model.Build(GraphCell(graph), context); if (ret != kSuccess) { std::cout << "EEEEEEEERROR Build failed." << std::endl; return 1; } - std::vector model_inputs = model.GetInputs(); + std::vector model_inputs = model.GetInputs(); auto all_files = GetAllFiles(FLAGS_dataset_path); if (all_files.empty()) { std::cout << "ERROR: no input data." << std::endl; diff --git a/model_zoo/official/cv/unet/scripts/run_infer_310.sh b/model_zoo/official/cv/unet/scripts/run_infer_310.sh index 59c9745f1a..2e60fd5c07 100644 --- a/model_zoo/official/cv/unet/scripts/run_infer_310.sh +++ b/model_zoo/official/cv/unet/scripts/run_infer_310.sh @@ -15,7 +15,7 @@ # ============================================================================ if [[ $# -lt 2 || $# -gt 3 ]]; then - echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] + echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" exit 1 fi @@ -71,7 +71,7 @@ function compile_app() if [ -f "Makefile" ]; then make clean fi - sh build.sh &> build.log + bash build.sh &> build.log } function infer() diff --git a/model_zoo/official/cv/yolov4/ascend310_infer/src/main.cc b/model_zoo/official/cv/yolov4/ascend310_infer/src/main.cc index 46050f61f5..f908d976f4 100644 --- a/model_zoo/official/cv/yolov4/ascend310_infer/src/main.cc +++ b/model_zoo/official/cv/yolov4/ascend310_infer/src/main.cc @@ -33,8 +33,6 @@ #include "include/api/types.h" using mindspore::Context; -using mindspore::GlobalContext; -using mindspore::ModelContext; using mindspore::Serialization; using mindspore::Model; using mindspore::Status; @@ -63,24 +61,25 @@ int main(int argc, char **argv) { return 1; } - GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310); - GlobalContext::SetGlobalDeviceID(FLAGS_device_id); - - auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR); - auto model_context = std::make_shared(); + auto context = std::make_shared(); + auto ascend310 = std::make_shared(); + ascend310->SetDeviceID(FLAGS_device_id); + context->MutableDeviceInfo().push_back(ascend310); + mindspore::Graph graph; + Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); if (!FLAGS_precision_mode.empty()) { - ModelContext::SetPrecisionMode(model_context, FLAGS_precision_mode); + ascend310->SetPrecisionMode(FLAGS_precision_mode); } if (!FLAGS_op_select_impl_mode.empty()) { - ModelContext::SetOpSelectImplMode(model_context, FLAGS_op_select_impl_mode); + ascend310->SetOpSelectImplMode(FLAGS_op_select_impl_mode); } if (!FLAGS_aipp_path.empty()) { - ModelContext::SetInsertOpConfigPath(model_context, FLAGS_aipp_path); + ascend310->SetInsertOpConfigPath(FLAGS_aipp_path); } - Model model(GraphCell(graph), model_context); - Status ret = model.Build(); + Model model; + Status ret = model.Build(GraphCell(graph), context); if (ret != kSuccess) { std::cout << "EEEEEEEERROR Build failed." << std::endl; return 1; diff --git a/model_zoo/official/cv/yolov4/scripts/run_infer_310.sh b/model_zoo/official/cv/yolov4/scripts/run_infer_310.sh index 1b786ba0ed..9e16530bf1 100644 --- a/model_zoo/official/cv/yolov4/scripts/run_infer_310.sh +++ b/model_zoo/official/cv/yolov4/scripts/run_infer_310.sh @@ -15,7 +15,7 @@ # ============================================================================ if [[ $# -lt 3 || $# -gt 4 ]]; then - echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] [ANN_FILE] + echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] [ANN_FILE] DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" exit 1 fi @@ -64,7 +64,7 @@ function compile_app() if [ -f "Makefile" ]; then make clean fi - sh build.sh &> build.log + bash build.sh &> build.log } function infer()