Browse Source

!1260 for second order subgraph switch

Merge pull request !1260 from zongha/master
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
bb73bfdf3a
3 changed files with 9 additions and 7 deletions
  1. +1
    -1
      mindspore/ccsrc/pipeline/init.cc
  2. +6
    -4
      mindspore/ccsrc/pipeline/pipeline.cc
  3. +2
    -2
      mindspore/ccsrc/pipeline/pipeline.h

+ 1
- 1
mindspore/ccsrc/pipeline/init.cc View File

@@ -95,7 +95,7 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature.");
(void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"),
py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"),
py::arg("phase") = py::str("dataset"), "Init and exec dataset.");
py::arg("phase") = py::str("dataset"), py::arg("need_run") = py::bool_(true), "Init and exec dataset.");
(void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode.");
(void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend.");



+ 6
- 4
mindspore/ccsrc/pipeline/pipeline.cc View File

@@ -694,7 +694,7 @@ void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &ph

bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, const std::string &phase) {
const std::vector<int64_t> &input_indexes, const std::string &phase, bool need_run) {
std::string name = MsContext::GetInstance()->backend_policy();
#ifndef NO_DLIB
auto ms_context = MsContext::GetInstance();
@@ -704,7 +704,7 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
}
#endif
if (name == kMsConvert || name == kMsVm) {
return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes);
return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run);
}
#if ENABLE_GE
return InitExecDatasetGe(queue_name, iter_num, batch_size, types, shapes, input_indexes, phase);
@@ -719,7 +719,7 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba

bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes) {
const std::vector<int64_t> &input_indexes, bool need_run) {
MS_LOG(INFO) << "Start InitDataSet Entry";
std::vector<int> int_input_indexes;
(void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes),
@@ -772,7 +772,9 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
// launch init dataset runner without inputs and outputs
VectorRef args;
auto fn = runner.run;
(void)(*fn)(args);
if (need_run) {
(void)(*fn)(args);
}
MS_LOG(DEBUG) << "InitDataSetVm End.";
return true;
}


+ 2
- 2
mindspore/ccsrc/pipeline/pipeline.h View File

@@ -127,12 +127,12 @@ void ExportGraph(const std::string &file_name, const std::string &, const std::s
// init and exec dataset sub graph
bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, const std::string &phase);
const std::vector<int64_t> &input_indexes, const std::string &phase, bool need_run);

// Build and run dataset subgraph for ms backend
bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes);
const std::vector<int64_t> &input_indexes, bool need_run);

} // namespace pipeline
} // namespace mindspore


Loading…
Cancel
Save