|
|
|
@@ -33,8 +33,6 @@ |
|
|
|
#include "include/api/types.h" |
|
|
|
|
|
|
|
using mindspore::Context; |
|
|
|
using mindspore::GlobalContext; |
|
|
|
using mindspore::ModelContext; |
|
|
|
using mindspore::Serialization; |
|
|
|
using mindspore::Model; |
|
|
|
using mindspore::Status; |
|
|
|
@@ -63,24 +61,25 @@ int main(int argc, char **argv) { |
|
|
|
return 1; |
|
|
|
} |
|
|
|
|
|
|
|
GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310); |
|
|
|
GlobalContext::SetGlobalDeviceID(FLAGS_device_id); |
|
|
|
|
|
|
|
auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR); |
|
|
|
auto model_context = std::make_shared<Context>(); |
|
|
|
auto context = std::make_shared<Context>(); |
|
|
|
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>(); |
|
|
|
ascend310->SetDeviceID(FLAGS_device_id); |
|
|
|
context->MutableDeviceInfo().push_back(ascend310); |
|
|
|
mindspore::Graph graph; |
|
|
|
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); |
|
|
|
|
|
|
|
if (!FLAGS_precision_mode.empty()) { |
|
|
|
ModelContext::SetPrecisionMode(model_context, FLAGS_precision_mode); |
|
|
|
ascend310->SetPrecisionMode(FLAGS_precision_mode); |
|
|
|
} |
|
|
|
if (!FLAGS_op_select_impl_mode.empty()) { |
|
|
|
ModelContext::SetOpSelectImplMode(model_context, FLAGS_op_select_impl_mode); |
|
|
|
ascend310->SetOpSelectImplMode(FLAGS_op_select_impl_mode); |
|
|
|
} |
|
|
|
if (!FLAGS_aipp_path.empty()) { |
|
|
|
ModelContext::SetInsertOpConfigPath(model_context, FLAGS_aipp_path); |
|
|
|
ascend310->SetInsertOpConfigPath(FLAGS_aipp_path); |
|
|
|
} |
|
|
|
|
|
|
|
Model model(GraphCell(graph), model_context); |
|
|
|
Status ret = model.Build(); |
|
|
|
Model model; |
|
|
|
Status ret = model.Build(GraphCell(graph), context); |
|
|
|
if (ret != kSuccess) { |
|
|
|
std::cout << "EEEEEEEERROR Build failed." << std::endl; |
|
|
|
return 1; |
|
|
|
|