| @@ -39,10 +39,9 @@ using mindspore::dataset::vision::CenterCrop; | |||||
| using mindspore::dataset::vision::Normalize; | using mindspore::dataset::vision::Normalize; | ||||
| using mindspore::dataset::vision::HWC2CHW; | using mindspore::dataset::vision::HWC2CHW; | ||||
| using mindspore::dataset::TensorTransform; | using mindspore::dataset::TensorTransform; | ||||
| using mindspore::GlobalContext; | |||||
| using mindspore::Serialization; | using mindspore::Serialization; | ||||
| using mindspore::Model; | using mindspore::Model; | ||||
| using mindspore::ModelContext; | |||||
| using mindspore::Context; | |||||
| using mindspore::Status; | using mindspore::Status; | ||||
| using mindspore::ModelType; | using mindspore::ModelType; | ||||
| using mindspore::GraphCell; | using mindspore::GraphCell; | ||||
| @@ -62,14 +61,14 @@ int main(int argc, char **argv) { | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310); | |||||
| GlobalContext::SetGlobalDeviceID(FLAGS_device_id); | |||||
| auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR); | |||||
| auto model_context = std::make_shared<mindspore::ModelContext>(); | |||||
| Model model(GraphCell(graph), model_context); | |||||
| Status ret = model.Build(); | |||||
| auto context = std::make_shared<Context>(); | |||||
| auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>(); | |||||
| ascend310->SetDeviceID(FLAGS_device_id); | |||||
| context->MutableDeviceInfo().push_back(ascend310); | |||||
| mindspore::Graph graph; | |||||
| Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); | |||||
| Model model; | |||||
| Status ret = model.Build(GraphCell(graph), context); | |||||
| if (ret != kSuccess) { | if (ret != kSuccess) { | ||||
| std::cout << "ERROR: Build failed." << std::endl; | std::cout << "ERROR: Build failed." << std::endl; | ||||
| return 1; | return 1; | ||||