Browse Source

fix res18 310 infer

tags/v1.2.0
jiangzhenguang jiangzg001 5 years ago
parent
commit
7fd5508f59
1 changed files with 9 additions and 10 deletions
  1. +9
    -10
      model_zoo/official/cv/resnet/ascend310_infer/src/main.cc

+ 9
- 10
model_zoo/official/cv/resnet/ascend310_infer/src/main.cc View File

@@ -39,10 +39,9 @@ using mindspore::dataset::vision::CenterCrop;
using mindspore::dataset::vision::Normalize;
using mindspore::dataset::vision::HWC2CHW;
using mindspore::dataset::TensorTransform;
using mindspore::GlobalContext;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::ModelContext;
using mindspore::Context;
using mindspore::Status;
using mindspore::ModelType;
using mindspore::GraphCell;
@@ -62,14 +61,14 @@ int main(int argc, char **argv) {
return 1;
}

GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310);
GlobalContext::SetGlobalDeviceID(FLAGS_device_id);
auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR);
auto model_context = std::make_shared<mindspore::ModelContext>();
Model model(GraphCell(graph), model_context);
Status ret = model.Build();
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
Model model;
Status ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;


Loading…
Cancel
Save