| @@ -35,20 +35,23 @@ PYBIND_REGISTER(GlobalContext, 0, ([](const py::module *m) { | |||
| PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) { | |||
| (void)py::class_<ConfigManager, std::shared_ptr<ConfigManager>>(*m, "ConfigManager") | |||
| .def("__str__", &ConfigManager::ToString) | |||
| .def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer) | |||
| .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers) | |||
| .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size) | |||
| .def("set_op_connector_size", &ConfigManager::set_op_connector_size) | |||
| .def("set_seed", &ConfigManager::set_seed) | |||
| .def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval) | |||
| .def("get_rows_per_buffer", &ConfigManager::rows_per_buffer) | |||
| .def("get_auto_num_workers", &ConfigManager::auto_num_workers) | |||
| .def("get_callback_timeout", &ConfigManager::callback_timeout) | |||
| .def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval) | |||
| .def("get_num_parallel_workers", &ConfigManager::num_parallel_workers) | |||
| .def("get_worker_connector_size", &ConfigManager::worker_connector_size) | |||
| .def("get_op_connector_size", &ConfigManager::op_connector_size) | |||
| .def("get_rows_per_buffer", &ConfigManager::rows_per_buffer) | |||
| .def("get_seed", &ConfigManager::seed) | |||
| .def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval) | |||
| .def("get_callback_timeout", &ConfigManager::callback_timeout) | |||
| .def("get_worker_connector_size", &ConfigManager::worker_connector_size) | |||
| .def("set_auto_num_workers", &ConfigManager::set_auto_num_workers) | |||
| .def("set_auto_worker_config", &ConfigManager::set_auto_worker_config_) | |||
| .def("set_callback_timeout", &ConfigManager::set_callback_timeout) | |||
| .def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval) | |||
| .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers) | |||
| .def("set_op_connector_size", &ConfigManager::set_op_connector_size) | |||
| .def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer) | |||
| .def("set_seed", &ConfigManager::set_seed) | |||
| .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size) | |||
| .def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); }); | |||
| })); | |||
| @@ -123,7 +123,13 @@ DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_i | |||
| num_samples_(num_samples), | |||
| seed_(seed), | |||
| offset_(offset), | |||
| even_dist_(even_dist) {} | |||
| even_dist_(even_dist) { | |||
| // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion | |||
| // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't | |||
| // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once | |||
| // PreBuildSampler is phased out, this can be cleaned up. | |||
| GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); | |||
| } | |||
| bool DistributedSamplerObj::ValidateParams() { | |||
| if (num_shards_ <= 0) { | |||
| @@ -18,6 +18,7 @@ | |||
| #include <fstream> | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <thread> | |||
| #include <utility> | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -40,7 +41,11 @@ ConfigManager::ConfigManager() | |||
| cache_host_(kCfgDefaultCacheHost), | |||
| cache_port_(kCfgDefaultCachePort), | |||
| num_connections_(kDftNumConnections), | |||
| prefetch_size_(kDftPrefetchSize) { | |||
| prefetch_size_(kDftPrefetchSize), | |||
| auto_num_workers_(kDftAutoNumWorkers), | |||
| num_cpu_threads_(std::thread::hardware_concurrency()), | |||
| auto_num_workers_num_shards_(1), | |||
| auto_worker_config_(0) { | |||
| auto env_cache_host = std::getenv("MS_CACHE_HOST"); | |||
| auto env_cache_port = std::getenv("MS_CACHE_PORT"); | |||
| if (env_cache_host != nullptr) { | |||
| @@ -68,7 +73,7 @@ void ConfigManager::Print(std::ostream &out) const { | |||
| << "\nSize of each Connector : " << op_connector_size_ << std::endl; | |||
| } | |||
| // Private helper function that taks a nlohmann json format and populates the settings | |||
| // Private helper function that takes a nlohmann json format and populates the settings | |||
| Status ConfigManager::FromJson(const nlohmann::json &j) { | |||
| set_rows_per_buffer(j.value("rowsPerBuffer", rows_per_buffer_)); | |||
| set_num_parallel_workers(j.value("numParallelWorkers", num_parallel_workers_)); | |||
| @@ -136,5 +141,6 @@ void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_por | |||
| void ConfigManager::set_num_connections(int32_t num_connections) { num_connections_ = num_connections; } | |||
| void ConfigManager::set_prefetch_size(int32_t prefetch_size) { prefetch_size_ = prefetch_size; } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -89,6 +89,8 @@ class ConfigManager { | |||
| // @return The internal worker-to-master connector queue size | |||
| int32_t worker_connector_size() const { return worker_connector_size_; } | |||
| int32_t num_cpu_threads() const { return num_cpu_threads_; } | |||
| // getter function | |||
| // @return The hostname of cache server | |||
| std::string cache_host() const { return cache_host_; } | |||
| @@ -105,6 +107,10 @@ class ConfigManager { | |||
| /// \return Prefetch size | |||
| int32_t prefetch_size() const { return prefetch_size_; } | |||
| /// getter function | |||
| /// \return auto_num_workers_ | |||
| bool auto_num_workers() const { return auto_num_workers_; } | |||
| // setter function | |||
| // @param rows_per_buffer - The setting to apply to the config | |||
| void set_rows_per_buffer(int32_t rows_per_buffer); | |||
| @@ -151,6 +157,20 @@ class ConfigManager { | |||
| // @return The interval of monitor sampling | |||
| int32_t monitor_sampling_interval() const { return monitor_sampling_interval_; } | |||
| // setter function | |||
| // @param auto_num_workers - whether assign threads to each op automatically | |||
| void set_auto_num_workers(bool auto_num_workers) { auto_num_workers_ = auto_num_workers; } | |||
| // setter function | |||
| // this function will be called when a distributed sampler (RT and Obj) is created and will be used by AutoWorkerPass | |||
| // This is to get around the limitation of PreBuildSampler (which doesn't have a getter for sharding params) | |||
| // @param num_shards | |||
| void set_num_shards_for_auto_num_workers(int32_t num_shards) { auto_num_workers_num_shards_ = num_shards; } | |||
| // getter function, will be called by AutoNumWorker, user discretion above AutoNumWorker is advised | |||
| // @param num_shards_ | |||
| int32_t get_num_shards_for_auto_num_workers() const { return auto_num_workers_num_shards_; } | |||
| // setter function | |||
| // @param timeout - The setting to apply to the config | |||
| void set_callback_timeout(uint32_t timeout); | |||
| @@ -159,6 +179,18 @@ class ConfigManager { | |||
| // @return The timeout DSWaitedCallback would wait for before raising an error | |||
| int32_t callback_timeout() const { return callback_timout_; } | |||
| // getter function | |||
| // E.g. 0 would corresponds to a 1:1:1 ratio of num_worker among leaf batch and map. | |||
| // please refer to AutoWorkerPass for detail on what each option is. | |||
| // @return The experimental config used by AutoNumWorker, each 1 refers to a different setup configuration | |||
| uint8_t get_auto_worker_config_() { return auto_worker_config_; } | |||
| // setter function | |||
| // E.g. set the value of 0 would corresponds to a 1:1:1 ratio of num_worker among leaf batch and map. | |||
| // please refer to AutoWorkerPass for detail on what each option is. | |||
| // @return The experimental config used by AutoNumWorker, each 1 refers to a different setup configuration | |||
| void set_auto_worker_config_(uint8_t cfg) { auto_worker_config_ = cfg; } | |||
| private: | |||
| int32_t rows_per_buffer_; | |||
| int32_t num_parallel_workers_; | |||
| @@ -171,7 +203,10 @@ class ConfigManager { | |||
| int32_t cache_port_; | |||
| int32_t num_connections_; | |||
| int32_t prefetch_size_; | |||
| bool auto_num_workers_; | |||
| const int32_t num_cpu_threads_; | |||
| int32_t auto_num_workers_num_shards_; | |||
| uint8_t auto_worker_config_; | |||
| // Private helper function that takes a nlohmann json format and populates the settings | |||
| // @param j - The json nlohmann json info | |||
| Status FromJson(const nlohmann::json &j); | |||
| @@ -91,6 +91,7 @@ constexpr int32_t kCfgDefaultCachePort = 50052; | |||
| constexpr char kCfgDefaultCacheHost[] = "127.0.0.1"; | |||
| constexpr int32_t kDftPrefetchSize = 20; | |||
| constexpr int32_t kDftNumConnections = 12; | |||
| constexpr int32_t kDftAutoNumWorkers = false; | |||
| // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) | |||
| constexpr uint8_t kCVInvalidType = 255; | |||
| @@ -263,7 +263,8 @@ Status BatchOp::LaunchThreadsAndInitOp() { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set."); | |||
| } | |||
| RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&BatchOp::WorkerEntry, this, std::placeholders::_1))); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&BatchOp::WorkerEntry, this, std::placeholders::_1), Name())); | |||
| return Status::OK(); | |||
| } | |||
| @@ -74,7 +74,7 @@ Status CacheBase::FetchSamplesToWorkers() { | |||
| // Kick off several threads which will prefetch prefetch_size_ rows in advance. The rows_per_buffers_ | |||
| // is too small (1 by default) and won't help performance. | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1))); | |||
| tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1), Name())); | |||
| auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id, | |||
| std::vector<row_id_type> &keys) -> Status { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| @@ -59,10 +59,11 @@ Status CacheMergeOp::operator()() { | |||
| static const int32_t queue_sz = 512; | |||
| io_que_ = std::make_unique<Queue<row_id_type>>(queue_sz); | |||
| RETURN_IF_NOT_OK(io_que_->Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::WorkerEntry, this, std::placeholders::_1))); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::CacheMissWorkerEntry, this, std::placeholders::_1))); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers( | |||
| num_workers_, std::bind(&CacheMergeOp::WorkerEntry, this, std::placeholders::_1), Name() + "::WorkerEntry")); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, | |||
| std::bind(&CacheMergeOp::CacheMissWorkerEntry, this, std::placeholders::_1), | |||
| Name() + "::CacheMissWorkerEntry")); | |||
| // One dedicated thread to move TensorRow from the pool to the cache server | |||
| for (auto i = 0; i < num_cleaners_; ++i) { | |||
| RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Cleaner", std::bind(&CacheMergeOp::Cleaner, this))); | |||
| @@ -83,7 +83,8 @@ Status CacheOp::operator()() { | |||
| } | |||
| RETURN_IF_NOT_OK(RegisterResources()); | |||
| // Kick off the workers | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheOp::WorkerEntry, this, std::placeholders::_1))); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&CacheOp::WorkerEntry, this, std::placeholders::_1), Name())); | |||
| // required task group sync after launching workers | |||
| TaskManager::FindMe()->Post(); | |||
| // Wait for the workers to finish caching the rows. | |||
| @@ -70,7 +70,8 @@ Status FilterOp::operator()() { | |||
| } | |||
| filter_queues_.Init(num_workers_, oc_queue_size_); | |||
| RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks())); | |||
| Status rc = tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1)); | |||
| Status rc = | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1), Name()); | |||
| // Synchronize with TaskManager. | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(rc); | |||
| @@ -385,10 +385,11 @@ Status ImageFolderOp::LaunchThreadsAndInitOp() { | |||
| // 2) Workers that pull foldername from folder_name_queue_, walk it and return the sorted images to image_name_queue | |||
| // 3) Launch main workers that load DataBuffers by reading all images | |||
| RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("walk dir", std::bind(&ImageFolderOp::StartAsyncWalk, this))); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&ImageFolderOp::PrescanWorkerEntry, this, std::placeholders::_1))); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&ImageFolderOp::WorkerEntry, this, std::placeholders::_1))); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, | |||
| std::bind(&ImageFolderOp::PrescanWorkerEntry, this, std::placeholders::_1), | |||
| Name() + "::PrescanWorkerEntry")); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers( | |||
| num_workers_, std::bind(&ImageFolderOp::WorkerEntry, this, std::placeholders::_1), Name() + "::WorkerEntry")); | |||
| TaskManager::FindMe()->Post(); | |||
| // The order of the following 2 functions must not be changed! | |||
| RETURN_IF_NOT_OK(this->PrescanMasterEntry(folder_path_)); // Master thread of pre-scan workers, blocking | |||
| @@ -34,7 +34,13 @@ DistributedSamplerRT::DistributedSamplerRT(int64_t num_samples, int64_t num_dev, | |||
| shuffle_(shuffle), | |||
| even_dist_(even_dist), | |||
| offset_(offset), | |||
| non_empty_(true) {} | |||
| non_empty_(true) { | |||
| // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion | |||
| // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't | |||
| // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once | |||
| // PreBuildSampler is phased out, this can be cleaned up. | |||
| GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_devices_); | |||
| } | |||
| Status DistributedSamplerRT::InitSampler() { | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| @@ -204,7 +204,16 @@ ExecutionTree::Iterator::Iterator(const std::shared_ptr<DatasetOp> &root) : ind_ | |||
| // Given the number of workers, launches the worker entry function for each. Essentially a | |||
| // wrapper for the TaskGroup handling that is stored inside the execution tree. | |||
| Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name) { | |||
| int32_t num_cpu_threads = GlobalContext::Instance()->config_manager()->num_cpu_threads(); | |||
| // this performs check that num_workers is positive and not unreasonably large which could happen | |||
| // for example, un-initialized variable. uint16 max is 65536 which is large enough to cover everything | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_workers > 0 && num_workers < std::numeric_limits<uint16_t>::max(), | |||
| name + "'s num_worker=" + std::to_string(num_workers) + ", is negative or too large."); | |||
| // Launch the workers | |||
| if (num_workers > num_cpu_threads) { | |||
| MS_LOG(WARNING) << name + " is launched with " << std::to_string(num_workers) << " worker threads which exceeds " | |||
| << std::to_string(num_cpu_threads) << ", the maximum number of threads on this CPU."; | |||
| } | |||
| for (int32_t i = 0; i < num_workers; ++i) { | |||
| RETURN_IF_NOT_OK(tg_->CreateAsyncTask(name, std::bind(func, i))); | |||
| } | |||
| @@ -24,6 +24,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -216,19 +216,6 @@ Status DatasetNode::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops | |||
| DatasetNode::DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode() { cache_ = dataset_cache; } | |||
| std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| #ifndef ENABLE_ANDROID | |||
| int32_t cpu_count = sysconf(_SC_NPROCESSORS_CONF); | |||
| if (cpu_count < 0 || cpu_count > INT32_MAX) { | |||
| MS_LOG(ERROR) << "Error determining current CPU: " << cpu_count; | |||
| return nullptr; | |||
| } | |||
| if (num_workers < 1 || num_workers > cpu_count) { | |||
| MS_LOG(ERROR) << "num_workers exceeds the boundary between 1 and " << cpu_count; | |||
| return nullptr; | |||
| } | |||
| #endif | |||
| #endif | |||
| num_workers_ = num_workers; | |||
| return shared_from_this(); | |||
| } | |||
| @@ -431,5 +418,14 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz | |||
| RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); | |||
| } | |||
| } | |||
| Status MappableSourceNode::Accept(IRNodePass *p, bool *modified) { | |||
| return p->Visit(shared_from_base<MappableSourceNode>(), modified); | |||
| } | |||
| Status NonMappableSourceNode::Accept(IRNodePass *p, bool *modified) { | |||
| return p->Visit(shared_from_base<MappableSourceNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -41,7 +41,6 @@ constexpr char kBucketBatchByLengthNode[] = "BucketBatchByLength"; | |||
| constexpr char kBuildSentencePieceVocabNode[] = "BuildSentencePieceVocab"; | |||
| constexpr char kBuildVocabNode[] = "BuildVocab"; | |||
| constexpr char kConcatNode[] = "Concat"; | |||
| constexpr char kDatasetNode[] = "Dataset"; | |||
| constexpr char kEpochCtrlNode[] = "EpochCtrl"; | |||
| constexpr char kFilterNode[] = "Filter"; | |||
| constexpr char kMapNode[] = "Map"; | |||
| @@ -290,6 +289,8 @@ class MappableSourceNode : public DatasetNode { | |||
| descendant_of_cache_ = false; | |||
| } | |||
| Status Accept(IRNodePass *p, bool *modified) override; | |||
| /// \brief Destructor | |||
| ~MappableSourceNode() = default; | |||
| @@ -312,6 +313,8 @@ class NonMappableSourceNode : public DatasetNode { | |||
| descendant_of_cache_ = false; | |||
| } | |||
| Status Accept(IRNodePass *p, bool *modified) override; | |||
| /// \brief Destructor | |||
| ~NonMappableSourceNode() = default; | |||
| @@ -41,7 +41,13 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim, | |||
| num_samples_(num_samples), | |||
| shuffle_(shuffle), | |||
| num_shards_(num_shards), | |||
| shard_id_(shard_id) {} | |||
| shard_id_(shard_id) { | |||
| // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion | |||
| // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't | |||
| // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once | |||
| // PreBuildSampler is phased out, this can be cleaned up. | |||
| GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); | |||
| } | |||
| std::shared_ptr<DatasetNode> CSVNode::Copy() { | |||
| auto node = std::make_shared<CSVNode>(dataset_files_, field_delim_, column_defaults_, column_names_, num_samples_, | |||
| @@ -36,7 +36,13 @@ TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_s | |||
| num_samples_(num_samples), | |||
| shuffle_(shuffle), | |||
| num_shards_(num_shards), | |||
| shard_id_(shard_id) {} | |||
| shard_id_(shard_id) { | |||
| // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion | |||
| // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't | |||
| // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once | |||
| // PreBuildSampler is phased out, this can be cleaned up. | |||
| GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); | |||
| } | |||
| std::shared_ptr<DatasetNode> TextFileNode::Copy() { | |||
| auto node = std::make_shared<TextFileNode>(dataset_files_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); | |||
| @@ -44,7 +44,13 @@ class TFRecordNode : public NonMappableSourceNode { | |||
| shuffle_(shuffle), | |||
| num_shards_(num_shards), | |||
| shard_id_(shard_id), | |||
| shard_equal_rows_(shard_equal_rows) {} | |||
| shard_equal_rows_(shard_equal_rows) { | |||
| // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User | |||
| // discretion is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the | |||
| // num_shards_ isn't 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return | |||
| // num_shards. Once PreBuildSampler is phased out, this can be cleaned up. | |||
| GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); | |||
| } | |||
| /// \brief Constructor | |||
| /// \note Parameter 'schema' is shared pointer to Schema object | |||
| @@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE | |||
| add_library(engine-opt OBJECT | |||
| optional/tensor_op_fusion_pass.cc | |||
| pass.cc | |||
| post/auto_worker_pass.cc | |||
| post/repeat_pass.cc | |||
| pre/cache_error_pass.cc | |||
| pre/cache_transform_pass.cc | |||
| @@ -258,6 +258,15 @@ Status IRNodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool | |||
| } | |||
| #endif | |||
| // leaf-IR Node | |||
| Status IRNodePass::Visit(std::shared_ptr<MappableSourceNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status IRNodePass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| ////////////////////////////////// | |||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | |||
| // Driver method for TreePass | |||
| @@ -229,6 +229,10 @@ class IRNodePass : public IRPass { | |||
| virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified); | |||
| #endif | |||
| // leaf-IR Node | |||
| virtual Status Visit(std::shared_ptr<MappableSourceNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *modified); | |||
| private: | |||
| // Helper function to perform DFS visit | |||
| Status DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified); | |||
| @@ -0,0 +1,122 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cmath> | |||
| #include <algorithm> | |||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||
| #include "minddata/dataset/engine/opt/post/auto_worker_pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // this will become the RootNode:DatasetNode when it is turned on | |||
| Status AutoWorkerPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | |||
| uint8_t config = GlobalContext::config_manager()->get_auto_worker_config_(); | |||
| OpWeightPass pass(kOpWeightConfigs[config < kOpWeightConfigs.size() ? config : 0]); | |||
| std::string weight_str; | |||
| for (const auto &p : pass.weight_profile_) weight_str += ("(" + p.first + "=" + std::to_string(p.second) + ")"); | |||
| int32_t num_shards = GlobalContext::config_manager()->get_num_shards_for_auto_num_workers(); | |||
| num_shards = std::min(std::max(1, num_shards), thread_cnt_); | |||
| MS_LOG(INFO) << "AutoWorkerPass is enabled; this could override existing num_workers set in each parallel op." | |||
| << "total number of threads on this CPU: " << thread_cnt_ << ", " | |||
| << "min num_workers to override:" << min_num_workers_ << ", " | |||
| << "max num_workers to override:" << max_num_workers_ << ", " | |||
| << "adjusted num_shards (between 1 and total thread cnt): " << num_shards | |||
| << ", weight profile:" << weight_str << "."; | |||
| // get the maximum weight of all the ops, this value is used to ensure the ratio of num_workers between ops | |||
| float max_weight = 0; | |||
| for (const auto &p : pass.weight_profile_) max_weight = std::max(max_weight, p.second); | |||
| RETURN_IF_NOT_OK(pass.Run(root_ir, modified)); | |||
| if (pass.parallel_ops_.size() > 3) { | |||
| MS_LOG(WARNING) << "AutoWorkerPass at current stage is only optimized for simple network that has LeafNode, " | |||
| << "BatchNode and MapNode. User discretion is advised for usage on other complex networks."; | |||
| } | |||
| for (auto &p : pass.parallel_ops_) { | |||
| // get the num worker via the weight ratio | |||
| int32_t num_workers = std::ceil((thread_cnt_ * p.second) / (pass.weight_sum_ * num_shards)); | |||
| // this is to ensure when thread_cnt_ is very large let's say 192, the num_worker ratio is still kept | |||
| // e.g. the optional 2:1 ratio between minddataset and batch | |||
| int32_t cur_node_max = std::ceil(p.second * max_num_workers_ / max_weight); | |||
| // this will ensure that num_workers will fall with the range of [1,cur_node_max] | |||
| int32_t cur_node_num_worker = std::max(std::min(num_workers, cur_node_max), min_num_workers_); | |||
| // log the change via warning msg so user can see what the num_worker is being set for which op | |||
| MS_LOG(WARNING) << "num_workers in " << p.first->Name() << " is auto-adjusted from " | |||
| << std::to_string(p.first->num_workers()) + " to " + std::to_string(cur_node_num_worker); | |||
| p.first->SetNumWorkers(cur_node_num_worker); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<MapNode> node, bool *modified) { | |||
| auto itr = weight_profile_.find(node->Name()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(), node->Name() + "'s weight doesn't exist."); | |||
| int32_t weight = itr->second; | |||
| weight_sum_ += weight; | |||
| parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight)); | |||
| return Status::OK(); | |||
| } | |||
| Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<BatchNode> node, bool *modified) { | |||
| auto itr = weight_profile_.find(node->Name()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(), node->Name() + "'s weight doesn't exist."); | |||
| int32_t weight = itr->second; | |||
| weight_sum_ += weight; | |||
| parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight)); | |||
| return Status::OK(); | |||
| } | |||
| Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<MappableSourceNode> node, bool *modified) { | |||
| RETURN_OK_IF_TRUE(node->Name() == kGeneratorNode); // generator is pipeline op, skip this | |||
| auto itr = weight_profile_.find("MappableSource"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(), | |||
| "LeafSourceNode::" + node->Name() + "'s weight doesn't exist."); | |||
| int32_t weight = itr->second; | |||
| weight_sum_ += weight; | |||
| parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight)); | |||
| return Status::OK(); | |||
| } | |||
| Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *modified) { | |||
| auto itr = weight_profile_.find("NonMappableSourceNode"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(), | |||
| "NonLeafSource::" + node->Name() + "'s weight doesn't exist."); | |||
| int32_t weight = itr->second; | |||
| weight_sum_ += weight; | |||
| parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight)); | |||
| return Status::OK(); | |||
| } | |||
| Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<DatasetNode> node, bool *modified) { | |||
| weight_sum_ += GetNodeWeightFromProfile(node); | |||
| return Status::OK(); | |||
| } | |||
| float AutoWorkerPass::OpWeightPass::GetNodeWeightFromProfile(std::shared_ptr<DatasetNode> node) { | |||
| auto itr = weight_profile_.find(node->Name()); | |||
| // returns 0 if name doesn't exist in the weight profile | |||
| return itr == weight_profile_.end() ? 0 : itr->second; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,82 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_ | |||
| #define DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <thread> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class AutoWorkerPass : public IRTreePass { | |||
| public: | |||
| // this map will contain weight for the basic pipeline ops. Pipeline op takes up 1 thread but doesn't have workers | |||
| const std::vector<std::map<std::string, float>> kOpWeightConfigs = { | |||
| {{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 8}, {kMapNode, 8}}, // config1 leaf:batch:map=1:1:1 | |||
| {{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 4}, {kMapNode, 4}}, // config2 leaf:batch:map=2:1:1 | |||
| {{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 8}, {kMapNode, 4}}, // config3 leaf:batch:map=1:2:1 | |||
| {{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 4}, {kMapNode, 8}}, // config4 leaf:batch:map=1:1:2 | |||
| {{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 8}, {kMapNode, 4}}, // config5 leaf:batch:map=2:2:1 | |||
| {{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 4}, {kMapNode, 8}}, // config6 leaf:batch:map=2:1:2 | |||
| {{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 8}, {kMapNode, 8}}, // config7 leaf:batch:map=1:2:2 | |||
| }; | |||
| AutoWorkerPass() | |||
| : min_num_workers_(1), | |||
| max_num_workers_(8), | |||
| thread_cnt_(GlobalContext::Instance()->config_manager()->num_cpu_threads()) {} | |||
| Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *) override; | |||
| private: | |||
| class OpWeightPass : public IRNodePass { | |||
| public: | |||
| explicit OpWeightPass(const std::map<std::string, float> &weight_profile) | |||
| : IRNodePass(), weight_sum_(0), weight_profile_(weight_profile) {} | |||
| // this is the base class function which contains the logic to handle most of the pipeline ops | |||
| // pipeline ops although can't config num_workers it still runs 1 thread they need to be factored into weight | |||
| Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override; | |||
| // these functions calculate the weights of more complex Nodes which may depend on its input arg. these functions | |||
| // will also push these nodes to a vector whose num_workers will be set int the Tree Pass | |||
| Status Visit(std::shared_ptr<BatchNode> node, bool *modified) override; | |||
| Status Visit(std::shared_ptr<MapNode> node, bool *modified) override; | |||
| Status Visit(std::shared_ptr<MappableSourceNode> node, bool *modified) override; | |||
| Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *modified) override; | |||
| // helper function to look up weight according to the name of this Op. | |||
| float GetNodeWeightFromProfile(std::shared_ptr<DatasetNode> node); | |||
| int32_t weight_sum_; // sum of all weights in the pipeline | |||
| const std::map<std::string, float> weight_profile_; // key: name of ir node, val: weight of this node | |||
| std::vector<std::pair<std::shared_ptr<DatasetNode>, float>> parallel_ops_; // first: node second: weight | |||
| }; | |||
| const int32_t min_num_workers_; // minimum number of threads allowed for each op | |||
| const int32_t max_num_workers_; // maximum number of threads allowed for each op | |||
| const int32_t thread_cnt_; // thread cnt of current CPU, obtained through config manager | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_ | |||
| @@ -19,6 +19,7 @@ | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/root_node.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/opt/post/auto_worker_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/cache_validation_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/deep_copy_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h" | |||
| @@ -30,7 +31,7 @@ namespace dataset { | |||
| TreeAdapter::TreeAdapter() { | |||
| tree_state_ = kCompileStateInit; | |||
| optimize_ = common::GetEnv("OPTIMIZE") == "true" ? true : false; | |||
| optimize_ = common::GetEnv("OPTIMIZE") == "true"; | |||
| } | |||
| Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) { | |||
| @@ -79,6 +80,11 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) { | |||
| std::vector<std::unique_ptr<IRPass>> actions; | |||
| MS_LOG(INFO) << "Running post pass loops."; | |||
| // AutoWorkerPass should ideally precede CacheTransForm Pass to avoid complications of the setting | |||
| if (GlobalContext::config_manager()->auto_num_workers()) { | |||
| actions.emplace_back(std::make_unique<AutoWorkerPass>()); | |||
| } | |||
| // We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here. | |||
| // Vector of flags for each action | |||
| @@ -235,7 +241,7 @@ Status TreeAdapter::GetNext(TensorRow *row) { | |||
| // Record profiling info | |||
| if (tracing_ != nullptr) { | |||
| cur_batch_num_++; | |||
| tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_); | |||
| RETURN_IF_NOT_OK(tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -22,7 +22,7 @@ import mindspore._c_dataengine as cde | |||
| __all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers', | |||
| 'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load', | |||
| 'get_callback_timeout'] | |||
| 'get_callback_timeout', 'set_auto_num_workers', 'get_auto_num_workers'] | |||
| INT32_MAX = 2147483647 | |||
| UINT32_MAX = 4294967295 | |||
| @@ -165,6 +165,65 @@ def get_monitor_sampling_interval(): | |||
| return _config.get_monitor_sampling_interval() | |||
| def set_auto_num_workers(enable): | |||
| """ | |||
| Set the default automatic number of workers. (This feature is turned off by default) | |||
| This will adjust the number of workers in each op automatically, overriding the preset user value. | |||
| For now, this function is only optimized for Yolo3 dataset with per_batch_map (running map in batch). | |||
| It aims to provide a baseline for optimized num_workers assignment. The adjusted value will be logged. | |||
| Args: | |||
| enable (bool): Whether to enable auto num_workers. | |||
| Raises: | |||
| ValueError: If enable is not of boolean type. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> | |||
| >>> # Enable the auto_num_worker, will override user's preset num_worker values | |||
| >>> ds.config.set_auto_num_workers(True) | |||
| """ | |||
| if not isinstance(enable, bool): | |||
| raise ValueError("enable isn't of type bool.") | |||
| _config.set_auto_num_workers(enable) | |||
| def _set_auto_workers_config(option): | |||
| """ | |||
| INTERNAL USE ONLY! | |||
| Select the weight profile of auto_num_workers. currently these 7 options are supported. | |||
| Option #0 leaf_num_workers:batch_num_workers:map_num_workers=1:1:1 | |||
| Option #1 leaf_num_workers:batch_num_workers:map_num_workers=2:1:1 | |||
| Option #2 leaf_num_workers:batch_num_workers:map_num_workers=1:2:1 | |||
| Option #3 leaf_num_workers:batch_num_workers:map_num_workers=1:1:2 | |||
| Option #4 leaf_num_workers:batch_num_workers:map_num_workers=2:2:1 | |||
| Option #5 leaf_num_workers:batch_num_workers:map_num_workers=2:1:2 | |||
| Option #6 leaf_num_workers:batch_num_workers:map_num_workers=1:2:2 | |||
| Args: | |||
| option (int): The id of the profile to use. | |||
| Raises: | |||
| ValueError: If option is not int or not within the range of [0, 6] | |||
| """ | |||
| if not isinstance(option, int): | |||
| raise ValueError("option isn't of type int.") | |||
| if option < 0 or option > 6: | |||
| raise ValueError("option isn't within the required range of [0, 6].") | |||
| _config.set_auto_worker_config(option) | |||
| def get_auto_num_workers(): | |||
| """ | |||
| Get the setting (turned on or off) automatic number of workers. | |||
| Returns: | |||
| Bool, whether auto num worker feature is turned on | |||
| Examples: | |||
| >>> ds.config.get_auto_num_workers() | |||
| """ | |||
| return _config.get_auto_num_workers() | |||
| def set_callback_timeout(timeout): | |||
| """ | |||
| Set the default timeout (in seconds) for DSWaitedCallback. | |||
| @@ -1771,25 +1771,3 @@ TEST_F(MindDataTestPipeline, TestZipSuccess2) { | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| } | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| #ifndef ENABLE_ANDROID | |||
| TEST_F(MindDataTestPipeline, TestNumWorkersValidate) { | |||
| // Testing the static zip() function | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNumWorkersValidate."; | |||
| // Create an ImageFolder Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 9)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // test if set num_workers=-1 | |||
| std::shared_ptr<Dataset> ds1 = ds->SetNumWorkers(-1); | |||
| EXPECT_EQ(ds1, nullptr); | |||
| // test if set num_workers>cpu_count | |||
| std::shared_ptr<Dataset> ds2 = ds->SetNumWorkers(UINT32_MAX); | |||
| EXPECT_EQ(ds2, nullptr); | |||
| } | |||
| #endif | |||
| #endif | |||
| @@ -22,6 +22,7 @@ | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| #include "minddata/dataset/engine/opt/post/auto_worker_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/getter_pass.h" | |||
| using namespace mindspore::dataset; | |||
| @@ -136,3 +137,35 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) { | |||
| EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos); | |||
| EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos); | |||
| } | |||
| TEST_F(MindDataTestOptimizationPass, MindDataTestAutoWorkerPass) { | |||
| MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestAutoWorkerPass."; | |||
| std::shared_ptr<SchemaObj> schema = std::make_shared<SchemaObj>(); | |||
| ASSERT_TRUE(schema->add_column("label", "uint32", {})); | |||
| std::shared_ptr<Dataset> map_leaf = ImageFolder("dir")->SetNumWorkers(0); | |||
| std::shared_ptr<Dataset> nonmap_leaf = RandomData(44, schema)->SetNumWorkers(0); | |||
| std::shared_ptr<Dataset> batch = Zip({map_leaf, nonmap_leaf})->Batch(1)->SetNumWorkers(0); | |||
| std::shared_ptr<Dataset> map = batch->Map({})->SetNumWorkers(0); | |||
| // {ImageFolder, RandomData} -> zip -> batch | |||
| EXPECT_EQ(map_leaf->IRNode()->num_workers(), 0); | |||
| EXPECT_EQ(nonmap_leaf->IRNode()->num_workers(), 0); | |||
| EXPECT_EQ(batch->IRNode()->num_workers(), 0); | |||
| EXPECT_EQ(map->IRNode()->num_workers(), 0); | |||
| std::unique_ptr<IRPass> pass = std::make_unique<AutoWorkerPass>(); | |||
| bool m = false; | |||
| ASSERT_OK(pass->Run(map->IRNode(), &m)); | |||
| // checking that after this pass, num_workers are set correctly (aka a positive number) | |||
| // It is hard to test a exact value because num_threads are different for different machine | |||
| // however, this will for sure succeed bc regardless of the total threads on cpu, this would always be >= 1 | |||
| EXPECT_NE(map_leaf->IRNode()->num_workers(), 0); | |||
| EXPECT_NE(nonmap_leaf->IRNode()->num_workers(), 0); | |||
| EXPECT_NE(batch->IRNode()->num_workers(), 0); | |||
| EXPECT_NE(map->IRNode()->num_workers(), 0); | |||
| MS_LOG(DEBUG) << map_leaf->IRNode()->Name() << ": num_worker=" << map_leaf->IRNode()->num_workers(); | |||
| MS_LOG(DEBUG) << nonmap_leaf->IRNode()->Name() << ": num_worker=" << nonmap_leaf->IRNode()->num_workers(); | |||
| MS_LOG(DEBUG) << batch->IRNode()->Name() << ": num_worker=" << batch->IRNode()->num_workers(); | |||
| MS_LOG(DEBUG) << map->IRNode()->Name() << ": num_worker=" << map->IRNode()->num_workers(); | |||
| } | |||
| @@ -357,6 +357,35 @@ def test_deterministic_python_seed_multi_thread(): | |||
| ds.config.set_seed(seed_original) | |||
| def test_auto_num_workers_error(): | |||
| """ | |||
| Test auto_num_workers error | |||
| """ | |||
| err_msg = "" | |||
| try: | |||
| ds.config.set_auto_num_workers([1, 2]) | |||
| except ValueError as e: | |||
| err_msg = str(e) | |||
| assert "isn't of type bool" in err_msg | |||
| def test_auto_num_workers(): | |||
| """ | |||
| Test auto_num_workers can be set. | |||
| """ | |||
| saved_config = ds.config.get_auto_num_workers() | |||
| assert isinstance(saved_config, bool) | |||
| # change to a different config | |||
| flipped_config = not saved_config | |||
| ds.config.set_auto_num_workers(flipped_config) | |||
| assert flipped_config == ds.config.get_auto_num_workers() | |||
| # now flip this back | |||
| ds.config.set_auto_num_workers(saved_config) | |||
| assert saved_config == ds.config.get_auto_num_workers() | |||
| if __name__ == '__main__': | |||
| test_basic() | |||
| test_get_seed() | |||
| @@ -367,3 +396,5 @@ if __name__ == '__main__': | |||
| test_deterministic_run_distribution() | |||
| test_deterministic_python_seed() | |||
| test_deterministic_python_seed_multi_thread() | |||
| test_auto_num_workers_error() | |||
| test_auto_num_workers() | |||