|
|
|
@@ -39,10 +39,9 @@ using mindspore::dataset::vision::CenterCrop; |
|
|
|
using mindspore::dataset::vision::Normalize; |
|
|
|
using mindspore::dataset::vision::HWC2CHW; |
|
|
|
using mindspore::dataset::TensorTransform; |
|
|
|
using mindspore::GlobalContext; |
|
|
|
using mindspore::Serialization; |
|
|
|
using mindspore::Model; |
|
|
|
using mindspore::ModelContext; |
|
|
|
using mindspore::Context; |
|
|
|
using mindspore::Status; |
|
|
|
using mindspore::ModelType; |
|
|
|
using mindspore::GraphCell; |
|
|
|
@@ -62,14 +61,14 @@ int main(int argc, char **argv) { |
|
|
|
return 1; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310); |
|
|
|
GlobalContext::SetGlobalDeviceID(FLAGS_device_id); |
|
|
|
auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR); |
|
|
|
auto model_context = std::make_shared<mindspore::ModelContext>(); |
|
|
|
|
|
|
|
Model model(GraphCell(graph), model_context); |
|
|
|
Status ret = model.Build(); |
|
|
|
auto context = std::make_shared<Context>(); |
|
|
|
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>(); |
|
|
|
ascend310->SetDeviceID(FLAGS_device_id); |
|
|
|
context->MutableDeviceInfo().push_back(ascend310); |
|
|
|
mindspore::Graph graph; |
|
|
|
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); |
|
|
|
Model model; |
|
|
|
Status ret = model.Build(GraphCell(graph), context); |
|
|
|
if (ret != kSuccess) { |
|
|
|
std::cout << "ERROR: Build failed." << std::endl; |
|
|
|
return 1; |
|
|
|
|