diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 3b2a89c909..07a76fb6fe 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -95,7 +95,7 @@ PYBIND11_MODULE(_c_expression, m) { (void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature."); (void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"), py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"), - py::arg("phase") = py::str("dataset"), "Init and exec dataset."); + py::arg("phase") = py::str("dataset"), py::arg("need_run") = py::bool_(true), "Init and exec dataset."); (void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode."); (void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend."); diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index 6799a6bd77..3606fb8cd6 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -694,7 +694,7 @@ void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &ph bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, const std::string &phase) { + const std::vector &input_indexes, const std::string &phase, bool need_run) { std::string name = MsContext::GetInstance()->backend_policy(); #ifndef NO_DLIB auto ms_context = MsContext::GetInstance(); @@ -704,7 +704,7 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba } #endif if (name == kMsConvert || name == kMsVm) { - return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes); + return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run); } #if ENABLE_GE return InitExecDatasetGe(queue_name, iter_num, batch_size, types, shapes, input_indexes, phase); @@ -719,7 +719,7 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes) { + const std::vector &input_indexes, bool need_run) { MS_LOG(INFO) << "Start InitDataSet Entry"; std::vector int_input_indexes; (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes), @@ -772,7 +772,9 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc // launch init dataset runner without inputs and outputs VectorRef args; auto fn = runner.run; - (void)(*fn)(args); + if (need_run) { + (void)(*fn)(args); + } MS_LOG(DEBUG) << "InitDataSetVm End."; return true; } diff --git a/mindspore/ccsrc/pipeline/pipeline.h b/mindspore/ccsrc/pipeline/pipeline.h index 6a99d4dbcd..f2354c7474 100644 --- a/mindspore/ccsrc/pipeline/pipeline.h +++ b/mindspore/ccsrc/pipeline/pipeline.h @@ -127,12 +127,12 @@ void ExportGraph(const std::string &file_name, const std::string &, const std::s // init and exec dataset sub graph bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, const std::string &phase); + const std::vector &input_indexes, const std::string &phase, bool need_run); // Build and run dataset subgraph for ms backend bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes); + const std::vector &input_indexes, bool need_run); } // namespace pipeline } // namespace mindspore