From ae436cb454f6a8041d5c506aeb8916baacd123ea Mon Sep 17 00:00:00 2001 From: ms_yan Date: Wed, 28 Apr 2021 21:48:14 +0800 Subject: [PATCH] descrease num_worker when mechine has less worker then default --- .../api/python/bindings/dataset/core/bindings.cc | 3 ++- .../ccsrc/minddata/dataset/core/config_manager.cc | 13 +++++++++++-- .../ccsrc/minddata/dataset/core/config_manager.h | 5 +++-- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc index f9a4462269..4e88a8b158 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc @@ -50,7 +50,8 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) { .def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval) .def("stop_dataset_profiler", &ConfigManager::stop_dataset_profiler) .def("get_profiler_file_status", &ConfigManager::get_profiler_file_status) - .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers) + .def("set_num_parallel_workers", + [](ConfigManager &c, int32_t num) { THROW_IF_ERROR(c.set_num_parallel_workers(num)); }) .def("set_op_connector_size", &ConfigManager::set_op_connector_size) .def("set_seed", &ConfigManager::set_seed) .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size) diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.cc b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc index f4f8469daa..4cbdd5b6f1 100644 --- a/mindspore/ccsrc/minddata/dataset/core/config_manager.cc +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -49,6 +50,8 @@ ConfigManager::ConfigManager() num_cpu_threads_(std::thread::hardware_concurrency()), auto_num_workers_num_shards_(1), auto_worker_config_(0) { + num_cpu_threads_ = num_cpu_threads_ > 0 ? num_cpu_threads_ : std::numeric_limits::max(); + num_parallel_workers_ = num_parallel_workers_ < num_cpu_threads_ ? num_parallel_workers_ : num_cpu_threads_; auto env_cache_host = std::getenv("MS_CACHE_HOST"); auto env_cache_port = std::getenv("MS_CACHE_PORT"); if (env_cache_host != nullptr) { @@ -76,7 +79,7 @@ void ConfigManager::Print(std::ostream &out) const { // Private helper function that takes a nlohmann json format and populates the settings Status ConfigManager::FromJson(const nlohmann::json &j) { - set_num_parallel_workers(j.value("numParallelWorkers", num_parallel_workers_)); + RETURN_IF_NOT_OK(set_num_parallel_workers(j.value("numParallelWorkers", num_parallel_workers_))); set_worker_connector_size(j.value("workerConnectorSize", worker_connector_size_)); set_op_connector_size(j.value("opConnectorSize", op_connector_size_)); set_seed(j.value("seed", seed_)); @@ -113,8 +116,14 @@ Status ConfigManager::LoadFile(const std::string &settingsFile) { } // Setter function -void ConfigManager::set_num_parallel_workers(int32_t num_parallel_workers) { +Status ConfigManager::set_num_parallel_workers(int32_t num_parallel_workers) { + if (num_parallel_workers > num_cpu_threads_ || num_parallel_workers < 1) { + std::string err_msg = "Invalid Parameter, num_parallel_workers exceeds the boundary between 1 and " + + std::to_string(num_cpu_threads_) + ", as got " + std::to_string(num_parallel_workers) + "."; + RETURN_STATUS_UNEXPECTED(err_msg); + } num_parallel_workers_ = num_parallel_workers; + return Status::OK(); } // Setter function diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.h b/mindspore/ccsrc/minddata/dataset/core/config_manager.h index e7c476d3ab..71912ef8e9 100644 --- a/mindspore/ccsrc/minddata/dataset/core/config_manager.h +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.h @@ -110,7 +110,8 @@ class ConfigManager { // setter function // @param num_parallel_workers - The setting to apply to the config - void set_num_parallel_workers(int32_t num_parallel_workers); + // @return Status error code + Status set_num_parallel_workers(int32_t num_parallel_workers); // setter function // @param connector_size - The setting to apply to the config @@ -240,7 +241,7 @@ class ConfigManager { bool numa_enable_; int32_t prefetch_size_; bool auto_num_workers_; - const int32_t num_cpu_threads_; + int32_t num_cpu_threads_; int32_t auto_num_workers_num_shards_; uint8_t auto_worker_config_; // Private helper function that takes a nlohmann json format and populates the settings