Browse Source

Implemented AutoNumWorker Pass which sets num_workers of selected parallel ops automatically if enabled

tags/v1.1.0
Zirui Wu 5 years ago
parent
commit
d6df1b0832
29 changed files with 483 additions and 69 deletions
  1. +13
    -10
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc
  2. +7
    -1
      mindspore/ccsrc/minddata/dataset/api/samplers.cc
  3. +8
    -2
      mindspore/ccsrc/minddata/dataset/core/config_manager.cc
  4. +36
    -1
      mindspore/ccsrc/minddata/dataset/core/config_manager.h
  5. +1
    -0
      mindspore/ccsrc/minddata/dataset/core/constants.h
  6. +2
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc
  7. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc
  8. +5
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
  9. +2
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc
  10. +2
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc
  11. +5
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc
  12. +7
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
  13. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc
  14. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h
  15. +9
    -13
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc
  16. +4
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h
  17. +7
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc
  18. +7
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc
  19. +7
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h
  20. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt
  21. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc
  22. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/opt/pass.h
  23. +122
    -0
      mindspore/ccsrc/minddata/dataset/engine/opt/post/auto_worker_pass.cc
  24. +82
    -0
      mindspore/ccsrc/minddata/dataset/engine/opt/post/auto_worker_pass.h
  25. +8
    -2
      mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc
  26. +60
    -1
      mindspore/dataset/core/config.py
  27. +0
    -22
      tests/ut/cpp/dataset/c_api_dataset_ops_test.cc
  28. +33
    -0
      tests/ut/cpp/dataset/optimization_pass_test.cc
  29. +31
    -0
      tests/ut/python/dataset/test_config.py

+ 13
- 10
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc View File

@@ -35,20 +35,23 @@ PYBIND_REGISTER(GlobalContext, 0, ([](const py::module *m) {
PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
(void)py::class_<ConfigManager, std::shared_ptr<ConfigManager>>(*m, "ConfigManager")
.def("__str__", &ConfigManager::ToString)
.def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer)
.def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers)
.def("set_worker_connector_size", &ConfigManager::set_worker_connector_size)
.def("set_op_connector_size", &ConfigManager::set_op_connector_size)
.def("set_seed", &ConfigManager::set_seed)
.def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval)
.def("get_rows_per_buffer", &ConfigManager::rows_per_buffer)
.def("get_auto_num_workers", &ConfigManager::auto_num_workers)
.def("get_callback_timeout", &ConfigManager::callback_timeout)
.def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval)
.def("get_num_parallel_workers", &ConfigManager::num_parallel_workers)
.def("get_worker_connector_size", &ConfigManager::worker_connector_size)
.def("get_op_connector_size", &ConfigManager::op_connector_size)
.def("get_rows_per_buffer", &ConfigManager::rows_per_buffer)
.def("get_seed", &ConfigManager::seed)
.def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval)
.def("get_callback_timeout", &ConfigManager::callback_timeout)
.def("get_worker_connector_size", &ConfigManager::worker_connector_size)
.def("set_auto_num_workers", &ConfigManager::set_auto_num_workers)
.def("set_auto_worker_config", &ConfigManager::set_auto_worker_config_)
.def("set_callback_timeout", &ConfigManager::set_callback_timeout)
.def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval)
.def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers)
.def("set_op_connector_size", &ConfigManager::set_op_connector_size)
.def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer)
.def("set_seed", &ConfigManager::set_seed)
.def("set_worker_connector_size", &ConfigManager::set_worker_connector_size)
.def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); });
}));



+ 7
- 1
mindspore/ccsrc/minddata/dataset/api/samplers.cc View File

@@ -123,7 +123,13 @@ DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_i
num_samples_(num_samples),
seed_(seed),
offset_(offset),
even_dist_(even_dist) {}
even_dist_(even_dist) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}

bool DistributedSamplerObj::ValidateParams() {
if (num_shards_ <= 0) {


+ 8
- 2
mindspore/ccsrc/minddata/dataset/core/config_manager.cc View File

@@ -18,6 +18,7 @@
#include <fstream>
#include <iostream>
#include <string>
#include <thread>
#include <utility>

#ifndef ENABLE_ANDROID
@@ -40,7 +41,11 @@ ConfigManager::ConfigManager()
cache_host_(kCfgDefaultCacheHost),
cache_port_(kCfgDefaultCachePort),
num_connections_(kDftNumConnections),
prefetch_size_(kDftPrefetchSize) {
prefetch_size_(kDftPrefetchSize),
auto_num_workers_(kDftAutoNumWorkers),
num_cpu_threads_(std::thread::hardware_concurrency()),
auto_num_workers_num_shards_(1),
auto_worker_config_(0) {
auto env_cache_host = std::getenv("MS_CACHE_HOST");
auto env_cache_port = std::getenv("MS_CACHE_PORT");
if (env_cache_host != nullptr) {
@@ -68,7 +73,7 @@ void ConfigManager::Print(std::ostream &out) const {
<< "\nSize of each Connector : " << op_connector_size_ << std::endl;
}

// Private helper function that taks a nlohmann json format and populates the settings
// Private helper function that takes a nlohmann json format and populates the settings
Status ConfigManager::FromJson(const nlohmann::json &j) {
set_rows_per_buffer(j.value("rowsPerBuffer", rows_per_buffer_));
set_num_parallel_workers(j.value("numParallelWorkers", num_parallel_workers_));
@@ -136,5 +141,6 @@ void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_por
void ConfigManager::set_num_connections(int32_t num_connections) { num_connections_ = num_connections; }

void ConfigManager::set_prefetch_size(int32_t prefetch_size) { prefetch_size_ = prefetch_size; }

} // namespace dataset
} // namespace mindspore

+ 36
- 1
mindspore/ccsrc/minddata/dataset/core/config_manager.h View File

@@ -89,6 +89,8 @@ class ConfigManager {
// @return The internal worker-to-master connector queue size
int32_t worker_connector_size() const { return worker_connector_size_; }

int32_t num_cpu_threads() const { return num_cpu_threads_; }

// getter function
// @return The hostname of cache server
std::string cache_host() const { return cache_host_; }
@@ -105,6 +107,10 @@ class ConfigManager {
/// \return Prefetch size
int32_t prefetch_size() const { return prefetch_size_; }

/// getter function
/// \return auto_num_workers_
bool auto_num_workers() const { return auto_num_workers_; }

// setter function
// @param rows_per_buffer - The setting to apply to the config
void set_rows_per_buffer(int32_t rows_per_buffer);
@@ -151,6 +157,20 @@ class ConfigManager {
// @return The interval of monitor sampling
int32_t monitor_sampling_interval() const { return monitor_sampling_interval_; }

// setter function
// @param auto_num_workers - whether assign threads to each op automatically
void set_auto_num_workers(bool auto_num_workers) { auto_num_workers_ = auto_num_workers; }

// setter function
// this function will be called when a distributed sampler (RT and Obj) is created and will be used by AutoWorkerPass
// This is to get around the limitation of PreBuildSampler (which doesn't have a getter for sharding params)
// @param num_shards
void set_num_shards_for_auto_num_workers(int32_t num_shards) { auto_num_workers_num_shards_ = num_shards; }

// getter function, will be called by AutoNumWorker, user discretion above AutoNumWorker is advised
// @param num_shards_
int32_t get_num_shards_for_auto_num_workers() const { return auto_num_workers_num_shards_; }

// setter function
// @param timeout - The setting to apply to the config
void set_callback_timeout(uint32_t timeout);
@@ -159,6 +179,18 @@ class ConfigManager {
// @return The timeout DSWaitedCallback would wait for before raising an error
int32_t callback_timeout() const { return callback_timout_; }

// getter function
// E.g. 0 would corresponds to a 1:1:1 ratio of num_worker among leaf batch and map.
// please refer to AutoWorkerPass for detail on what each option is.
// @return The experimental config used by AutoNumWorker, each 1 refers to a different setup configuration
uint8_t get_auto_worker_config_() { return auto_worker_config_; }

// setter function
// E.g. set the value of 0 would corresponds to a 1:1:1 ratio of num_worker among leaf batch and map.
// please refer to AutoWorkerPass for detail on what each option is.
// @return The experimental config used by AutoNumWorker, each 1 refers to a different setup configuration
void set_auto_worker_config_(uint8_t cfg) { auto_worker_config_ = cfg; }

private:
int32_t rows_per_buffer_;
int32_t num_parallel_workers_;
@@ -171,7 +203,10 @@ class ConfigManager {
int32_t cache_port_;
int32_t num_connections_;
int32_t prefetch_size_;

bool auto_num_workers_;
const int32_t num_cpu_threads_;
int32_t auto_num_workers_num_shards_;
uint8_t auto_worker_config_;
// Private helper function that takes a nlohmann json format and populates the settings
// @param j - The json nlohmann json info
Status FromJson(const nlohmann::json &j);


+ 1
- 0
mindspore/ccsrc/minddata/dataset/core/constants.h View File

@@ -91,6 +91,7 @@ constexpr int32_t kCfgDefaultCachePort = 50052;
constexpr char kCfgDefaultCacheHost[] = "127.0.0.1";
constexpr int32_t kDftPrefetchSize = 20;
constexpr int32_t kDftNumConnections = 12;
constexpr int32_t kDftAutoNumWorkers = false;

// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
constexpr uint8_t kCVInvalidType = 255;


+ 2
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc View File

@@ -263,7 +263,8 @@ Status BatchOp::LaunchThreadsAndInitOp() {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
}
RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&BatchOp::WorkerEntry, this, std::placeholders::_1)));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&BatchOp::WorkerEntry, this, std::placeholders::_1), Name()));
return Status::OK();
}



+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc View File

@@ -74,7 +74,7 @@ Status CacheBase::FetchSamplesToWorkers() {
// Kick off several threads which will prefetch prefetch_size_ rows in advance. The rows_per_buffers_
// is too small (1 by default) and won't help performance.
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1)));
tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1), Name()));
auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id,
std::vector<row_id_type> &keys) -> Status {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));


+ 5
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc View File

@@ -59,10 +59,11 @@ Status CacheMergeOp::operator()() {
static const int32_t queue_sz = 512;
io_que_ = std::make_unique<Queue<row_id_type>>(queue_sz);
RETURN_IF_NOT_OK(io_que_->Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::WorkerEntry, this, std::placeholders::_1)));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::CacheMissWorkerEntry, this, std::placeholders::_1)));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(
num_workers_, std::bind(&CacheMergeOp::WorkerEntry, this, std::placeholders::_1), Name() + "::WorkerEntry"));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_,
std::bind(&CacheMergeOp::CacheMissWorkerEntry, this, std::placeholders::_1),
Name() + "::CacheMissWorkerEntry"));
// One dedicated thread to move TensorRow from the pool to the cache server
for (auto i = 0; i < num_cleaners_; ++i) {
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Cleaner", std::bind(&CacheMergeOp::Cleaner, this)));


+ 2
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc View File

@@ -83,7 +83,8 @@ Status CacheOp::operator()() {
}
RETURN_IF_NOT_OK(RegisterResources());
// Kick off the workers
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheOp::WorkerEntry, this, std::placeholders::_1)));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&CacheOp::WorkerEntry, this, std::placeholders::_1), Name()));
// required task group sync after launching workers
TaskManager::FindMe()->Post();
// Wait for the workers to finish caching the rows.


+ 2
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc View File

@@ -70,7 +70,8 @@ Status FilterOp::operator()() {
}
filter_queues_.Init(num_workers_, oc_queue_size_);
RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks()));
Status rc = tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1));
Status rc =
tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1), Name());
// Synchronize with TaskManager.
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(rc);


+ 5
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc View File

@@ -385,10 +385,11 @@ Status ImageFolderOp::LaunchThreadsAndInitOp() {
// 2) Workers that pull foldername from folder_name_queue_, walk it and return the sorted images to image_name_queue
// 3) Launch main workers that load DataBuffers by reading all images
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("walk dir", std::bind(&ImageFolderOp::StartAsyncWalk, this)));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&ImageFolderOp::PrescanWorkerEntry, this, std::placeholders::_1)));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&ImageFolderOp::WorkerEntry, this, std::placeholders::_1)));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_,
std::bind(&ImageFolderOp::PrescanWorkerEntry, this, std::placeholders::_1),
Name() + "::PrescanWorkerEntry"));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(
num_workers_, std::bind(&ImageFolderOp::WorkerEntry, this, std::placeholders::_1), Name() + "::WorkerEntry"));
TaskManager::FindMe()->Post();
// The order of the following 2 functions must not be changed!
RETURN_IF_NOT_OK(this->PrescanMasterEntry(folder_path_)); // Master thread of pre-scan workers, blocking


+ 7
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc View File

@@ -34,7 +34,13 @@ DistributedSamplerRT::DistributedSamplerRT(int64_t num_samples, int64_t num_dev,
shuffle_(shuffle),
even_dist_(even_dist),
offset_(offset),
non_empty_(true) {}
non_empty_(true) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_devices_);
}

Status DistributedSamplerRT::InitSampler() {
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.


+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc View File

@@ -204,7 +204,16 @@ ExecutionTree::Iterator::Iterator(const std::shared_ptr<DatasetOp> &root) : ind_
// Given the number of workers, launches the worker entry function for each. Essentially a
// wrapper for the TaskGroup handling that is stored inside the execution tree.
Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name) {
int32_t num_cpu_threads = GlobalContext::Instance()->config_manager()->num_cpu_threads();
// this performs check that num_workers is positive and not unreasonably large which could happen
// for example, un-initialized variable. uint16 max is 65536 which is large enough to cover everything
CHECK_FAIL_RETURN_UNEXPECTED(num_workers > 0 && num_workers < std::numeric_limits<uint16_t>::max(),
name + "'s num_worker=" + std::to_string(num_workers) + ", is negative or too large.");
// Launch the workers
if (num_workers > num_cpu_threads) {
MS_LOG(WARNING) << name + " is launched with " << std::to_string(num_workers) << " worker threads which exceeds "
<< std::to_string(num_cpu_threads) << ", the maximum number of threads on this CPU.";
}
for (int32_t i = 0; i < num_workers; ++i) {
RETURN_IF_NOT_OK(tg_->CreateAsyncTask(name, std::bind(func, i)));
}


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h View File

@@ -24,6 +24,7 @@
#include <vector>

#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/opt/pass.h"

namespace mindspore {
namespace dataset {


+ 9
- 13
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc View File

@@ -216,19 +216,6 @@ Status DatasetNode::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops
DatasetNode::DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode() { cache_ = dataset_cache; }

std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
#if !defined(_WIN32) && !defined(_WIN64)
#ifndef ENABLE_ANDROID
int32_t cpu_count = sysconf(_SC_NPROCESSORS_CONF);
if (cpu_count < 0 || cpu_count > INT32_MAX) {
MS_LOG(ERROR) << "Error determining current CPU: " << cpu_count;
return nullptr;
}
if (num_workers < 1 || num_workers > cpu_count) {
MS_LOG(ERROR) << "num_workers exceeds the boundary between 1 and " << cpu_count;
return nullptr;
}
#endif
#endif
num_workers_ = num_workers;
return shared_from_this();
}
@@ -431,5 +418,14 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
}
}

Status MappableSourceNode::Accept(IRNodePass *p, bool *modified) {
return p->Visit(shared_from_base<MappableSourceNode>(), modified);
}

Status NonMappableSourceNode::Accept(IRNodePass *p, bool *modified) {
return p->Visit(shared_from_base<MappableSourceNode>(), modified);
}

} // namespace dataset
} // namespace mindspore

+ 4
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h View File

@@ -41,7 +41,6 @@ constexpr char kBucketBatchByLengthNode[] = "BucketBatchByLength";
constexpr char kBuildSentencePieceVocabNode[] = "BuildSentencePieceVocab";
constexpr char kBuildVocabNode[] = "BuildVocab";
constexpr char kConcatNode[] = "Concat";
constexpr char kDatasetNode[] = "Dataset";
constexpr char kEpochCtrlNode[] = "EpochCtrl";
constexpr char kFilterNode[] = "Filter";
constexpr char kMapNode[] = "Map";
@@ -290,6 +289,8 @@ class MappableSourceNode : public DatasetNode {
descendant_of_cache_ = false;
}

Status Accept(IRNodePass *p, bool *modified) override;

/// \brief Destructor
~MappableSourceNode() = default;

@@ -312,6 +313,8 @@ class NonMappableSourceNode : public DatasetNode {
descendant_of_cache_ = false;
}

Status Accept(IRNodePass *p, bool *modified) override;

/// \brief Destructor
~NonMappableSourceNode() = default;



+ 7
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc View File

@@ -41,7 +41,13 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim,
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id) {}
shard_id_(shard_id) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}

std::shared_ptr<DatasetNode> CSVNode::Copy() {
auto node = std::make_shared<CSVNode>(dataset_files_, field_delim_, column_defaults_, column_names_, num_samples_,


+ 7
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc View File

@@ -36,7 +36,13 @@ TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_s
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id) {}
shard_id_(shard_id) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}

std::shared_ptr<DatasetNode> TextFileNode::Copy() {
auto node = std::make_shared<TextFileNode>(dataset_files_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);


+ 7
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h View File

@@ -44,7 +44,13 @@ class TFRecordNode : public NonMappableSourceNode {
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
shard_equal_rows_(shard_equal_rows) {}
shard_equal_rows_(shard_equal_rows) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User
// discretion is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the
// num_shards_ isn't 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return
// num_shards. Once PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}

/// \brief Constructor
/// \note Parameter 'schema' is shared pointer to Schema object


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt View File

@@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
add_library(engine-opt OBJECT
optional/tensor_op_fusion_pass.cc
pass.cc
post/auto_worker_pass.cc
post/repeat_pass.cc
pre/cache_error_pass.cc
pre/cache_transform_pass.cc


+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc View File

@@ -258,6 +258,15 @@ Status IRNodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool
}
#endif

// leaf-IR Node
Status IRNodePass::Visit(std::shared_ptr<MappableSourceNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status IRNodePass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
// Driver method for TreePass


+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/opt/pass.h View File

@@ -229,6 +229,10 @@ class IRNodePass : public IRPass {
virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
#endif

// leaf-IR Node
virtual Status Visit(std::shared_ptr<MappableSourceNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *modified);

private:
// Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified);


+ 122
- 0
mindspore/ccsrc/minddata/dataset/engine/opt/post/auto_worker_pass.cc View File

@@ -0,0 +1,122 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cmath>
#include <algorithm>

#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/opt/post/auto_worker_pass.h"

namespace mindspore {
namespace dataset {

// this will become the RootNode:DatasetNode when it is turned on
Status AutoWorkerPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
uint8_t config = GlobalContext::config_manager()->get_auto_worker_config_();

OpWeightPass pass(kOpWeightConfigs[config < kOpWeightConfigs.size() ? config : 0]);

std::string weight_str;
for (const auto &p : pass.weight_profile_) weight_str += ("(" + p.first + "=" + std::to_string(p.second) + ")");
int32_t num_shards = GlobalContext::config_manager()->get_num_shards_for_auto_num_workers();
num_shards = std::min(std::max(1, num_shards), thread_cnt_);

MS_LOG(INFO) << "AutoWorkerPass is enabled; this could override existing num_workers set in each parallel op."
<< "total number of threads on this CPU: " << thread_cnt_ << ", "
<< "min num_workers to override:" << min_num_workers_ << ", "
<< "max num_workers to override:" << max_num_workers_ << ", "
<< "adjusted num_shards (between 1 and total thread cnt): " << num_shards
<< ", weight profile:" << weight_str << ".";

// get the maximum weight of all the ops, this value is used to ensure the ratio of num_workers between ops
float max_weight = 0;
for (const auto &p : pass.weight_profile_) max_weight = std::max(max_weight, p.second);
RETURN_IF_NOT_OK(pass.Run(root_ir, modified));
if (pass.parallel_ops_.size() > 3) {
MS_LOG(WARNING) << "AutoWorkerPass at current stage is only optimized for simple network that has LeafNode, "
<< "BatchNode and MapNode. User discretion is advised for usage on other complex networks.";
}

for (auto &p : pass.parallel_ops_) {
// get the num worker via the weight ratio
int32_t num_workers = std::ceil((thread_cnt_ * p.second) / (pass.weight_sum_ * num_shards));
// this is to ensure when thread_cnt_ is very large let's say 192, the num_worker ratio is still kept
// e.g. the optional 2:1 ratio between minddataset and batch
int32_t cur_node_max = std::ceil(p.second * max_num_workers_ / max_weight);
// this will ensure that num_workers will fall with the range of [1,cur_node_max]
int32_t cur_node_num_worker = std::max(std::min(num_workers, cur_node_max), min_num_workers_);
// log the change via warning msg so user can see what the num_worker is being set for which op
MS_LOG(WARNING) << "num_workers in " << p.first->Name() << " is auto-adjusted from "
<< std::to_string(p.first->num_workers()) + " to " + std::to_string(cur_node_num_worker);
p.first->SetNumWorkers(cur_node_num_worker);
}
return Status::OK();
}

Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<MapNode> node, bool *modified) {
auto itr = weight_profile_.find(node->Name());
CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(), node->Name() + "'s weight doesn't exist.");
int32_t weight = itr->second;
weight_sum_ += weight;
parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight));
return Status::OK();
}

Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<BatchNode> node, bool *modified) {
auto itr = weight_profile_.find(node->Name());
CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(), node->Name() + "'s weight doesn't exist.");
int32_t weight = itr->second;
weight_sum_ += weight;
parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight));
return Status::OK();
}

Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<MappableSourceNode> node, bool *modified) {
RETURN_OK_IF_TRUE(node->Name() == kGeneratorNode); // generator is pipeline op, skip this
auto itr = weight_profile_.find("MappableSource");
CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(),
"LeafSourceNode::" + node->Name() + "'s weight doesn't exist.");
int32_t weight = itr->second;
weight_sum_ += weight;
parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight));
return Status::OK();
}

Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *modified) {
auto itr = weight_profile_.find("NonMappableSourceNode");
CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(),
"NonLeafSource::" + node->Name() + "'s weight doesn't exist.");
int32_t weight = itr->second;
weight_sum_ += weight;
parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight));
return Status::OK();
}

Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<DatasetNode> node, bool *modified) {
weight_sum_ += GetNodeWeightFromProfile(node);
return Status::OK();
}

float AutoWorkerPass::OpWeightPass::GetNodeWeightFromProfile(std::shared_ptr<DatasetNode> node) {
auto itr = weight_profile_.find(node->Name());
// returns 0 if name doesn't exist in the weight profile
return itr == weight_profile_.end() ? 0 : itr->second;
}

} // namespace dataset
} // namespace mindspore

+ 82
- 0
mindspore/ccsrc/minddata/dataset/engine/opt/post/auto_worker_pass.h View File

@@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_
#define DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_

#include <map>
#include <memory>
#include <string>
#include <thread>
#include <utility>
#include <vector>

#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/opt/pass.h"

namespace mindspore {
namespace dataset {

class AutoWorkerPass : public IRTreePass {
public:
// this map will contain weight for the basic pipeline ops. Pipeline op takes up 1 thread but doesn't have workers
const std::vector<std::map<std::string, float>> kOpWeightConfigs = {
{{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 8}, {kMapNode, 8}}, // config1 leaf:batch:map=1:1:1
{{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 4}, {kMapNode, 4}}, // config2 leaf:batch:map=2:1:1
{{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 8}, {kMapNode, 4}}, // config3 leaf:batch:map=1:2:1
{{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 4}, {kMapNode, 8}}, // config4 leaf:batch:map=1:1:2
{{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 8}, {kMapNode, 4}}, // config5 leaf:batch:map=2:2:1
{{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 4}, {kMapNode, 8}}, // config6 leaf:batch:map=2:1:2
{{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 8}, {kMapNode, 8}}, // config7 leaf:batch:map=1:2:2
};
AutoWorkerPass()
: min_num_workers_(1),
max_num_workers_(8),
thread_cnt_(GlobalContext::Instance()->config_manager()->num_cpu_threads()) {}

Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *) override;

private:
class OpWeightPass : public IRNodePass {
public:
explicit OpWeightPass(const std::map<std::string, float> &weight_profile)
: IRNodePass(), weight_sum_(0), weight_profile_(weight_profile) {}
// this is the base class function which contains the logic to handle most of the pipeline ops
// pipeline ops although can't config num_workers it still runs 1 thread they need to be factored into weight
Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override;
// these functions calculate the weights of more complex Nodes which may depend on its input arg. these functions
// will also push these nodes to a vector whose num_workers will be set int the Tree Pass
Status Visit(std::shared_ptr<BatchNode> node, bool *modified) override;
Status Visit(std::shared_ptr<MapNode> node, bool *modified) override;
Status Visit(std::shared_ptr<MappableSourceNode> node, bool *modified) override;
Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *modified) override;

// helper function to look up weight according to the name of this Op.
float GetNodeWeightFromProfile(std::shared_ptr<DatasetNode> node);

int32_t weight_sum_; // sum of all weights in the pipeline
const std::map<std::string, float> weight_profile_; // key: name of ir node, val: weight of this node
std::vector<std::pair<std::shared_ptr<DatasetNode>, float>> parallel_ops_; // first: node second: weight
};

const int32_t min_num_workers_; // minimum number of threads allowed for each op
const int32_t max_num_workers_; // maximum number of threads allowed for each op
const int32_t thread_cnt_; // thread cnt of current CPU, obtained through config manager
};
} // namespace dataset
} // namespace mindspore

#endif // DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_

+ 8
- 2
mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc View File

@@ -19,6 +19,7 @@
#include "minddata/dataset/core/client.h"
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/opt/post/auto_worker_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h"
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
@@ -30,7 +31,7 @@ namespace dataset {

TreeAdapter::TreeAdapter() {
tree_state_ = kCompileStateInit;
optimize_ = common::GetEnv("OPTIMIZE") == "true" ? true : false;
optimize_ = common::GetEnv("OPTIMIZE") == "true";
}

Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
@@ -79,6 +80,11 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
std::vector<std::unique_ptr<IRPass>> actions;
MS_LOG(INFO) << "Running post pass loops.";

// AutoWorkerPass should ideally precede CacheTransForm Pass to avoid complications of the setting
if (GlobalContext::config_manager()->auto_num_workers()) {
actions.emplace_back(std::make_unique<AutoWorkerPass>());
}

// We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here.

// Vector of flags for each action
@@ -235,7 +241,7 @@ Status TreeAdapter::GetNext(TensorRow *row) {
// Record profiling info
if (tracing_ != nullptr) {
cur_batch_num_++;
tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_);
RETURN_IF_NOT_OK(tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_));
}
return Status::OK();
}


+ 60
- 1
mindspore/dataset/core/config.py View File

@@ -22,7 +22,7 @@ import mindspore._c_dataengine as cde

__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load',
'get_callback_timeout']
'get_callback_timeout', 'set_auto_num_workers', 'get_auto_num_workers']

INT32_MAX = 2147483647
UINT32_MAX = 4294967295
@@ -165,6 +165,65 @@ def get_monitor_sampling_interval():
return _config.get_monitor_sampling_interval()


def set_auto_num_workers(enable):
"""
Set the default automatic number of workers. (This feature is turned off by default)
This will adjust the number of workers in each op automatically, overriding the preset user value.
For now, this function is only optimized for Yolo3 dataset with per_batch_map (running map in batch).
It aims to provide a baseline for optimized num_workers assignment. The adjusted value will be logged.

Args:
enable (bool): Whether to enable auto num_workers.

Raises:
ValueError: If enable is not of boolean type.

Examples:
>>> import mindspore.dataset as ds
>>>
>>> # Enable the auto_num_worker, will override user's preset num_worker values
>>> ds.config.set_auto_num_workers(True)
"""
if not isinstance(enable, bool):
raise ValueError("enable isn't of type bool.")
_config.set_auto_num_workers(enable)


def _set_auto_workers_config(option):
"""
INTERNAL USE ONLY!
Select the weight profile of auto_num_workers. currently these 7 options are supported.
Option #0 leaf_num_workers:batch_num_workers:map_num_workers=1:1:1
Option #1 leaf_num_workers:batch_num_workers:map_num_workers=2:1:1
Option #2 leaf_num_workers:batch_num_workers:map_num_workers=1:2:1
Option #3 leaf_num_workers:batch_num_workers:map_num_workers=1:1:2
Option #4 leaf_num_workers:batch_num_workers:map_num_workers=2:2:1
Option #5 leaf_num_workers:batch_num_workers:map_num_workers=2:1:2
Option #6 leaf_num_workers:batch_num_workers:map_num_workers=1:2:2
Args:
option (int): The id of the profile to use.
Raises:
ValueError: If option is not int or not within the range of [0, 6]
"""
if not isinstance(option, int):
raise ValueError("option isn't of type int.")
if option < 0 or option > 6:
raise ValueError("option isn't within the required range of [0, 6].")
_config.set_auto_worker_config(option)


def get_auto_num_workers():
"""
Get the setting (turned on or off) automatic number of workers.

Returns:
Bool, whether auto num worker feature is turned on
Examples:
>>> ds.config.get_auto_num_workers()
"""
return _config.get_auto_num_workers()


def set_callback_timeout(timeout):
"""
Set the default timeout (in seconds) for DSWaitedCallback.


+ 0
- 22
tests/ut/cpp/dataset/c_api_dataset_ops_test.cc View File

@@ -1771,25 +1771,3 @@ TEST_F(MindDataTestPipeline, TestZipSuccess2) {
// Manually terminate the pipeline
iter->Stop();
}

#if !defined(_WIN32) && !defined(_WIN64)
#ifndef ENABLE_ANDROID
TEST_F(MindDataTestPipeline, TestNumWorkersValidate) {
// Testing the static zip() function
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNumWorkersValidate.";

// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 9));
EXPECT_NE(ds, nullptr);

// test if set num_workers=-1
std::shared_ptr<Dataset> ds1 = ds->SetNumWorkers(-1);
EXPECT_EQ(ds1, nullptr);

// test if set num_workers>cpu_count
std::shared_ptr<Dataset> ds2 = ds->SetNumWorkers(UINT32_MAX);
EXPECT_EQ(ds2, nullptr);
}
#endif
#endif

+ 33
- 0
tests/ut/cpp/dataset/optimization_pass_test.cc View File

@@ -22,6 +22,7 @@

#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/opt/post/auto_worker_pass.h"
#include "minddata/dataset/engine/opt/pre/getter_pass.h"

using namespace mindspore::dataset;
@@ -136,3 +137,35 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) {
EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos);
EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos);
}

TEST_F(MindDataTestOptimizationPass, MindDataTestAutoWorkerPass) {
MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestAutoWorkerPass.";

std::shared_ptr<SchemaObj> schema = std::make_shared<SchemaObj>();
ASSERT_TRUE(schema->add_column("label", "uint32", {}));
std::shared_ptr<Dataset> map_leaf = ImageFolder("dir")->SetNumWorkers(0);
std::shared_ptr<Dataset> nonmap_leaf = RandomData(44, schema)->SetNumWorkers(0);
std::shared_ptr<Dataset> batch = Zip({map_leaf, nonmap_leaf})->Batch(1)->SetNumWorkers(0);
std::shared_ptr<Dataset> map = batch->Map({})->SetNumWorkers(0);
// {ImageFolder, RandomData} -> zip -> batch
EXPECT_EQ(map_leaf->IRNode()->num_workers(), 0);
EXPECT_EQ(nonmap_leaf->IRNode()->num_workers(), 0);
EXPECT_EQ(batch->IRNode()->num_workers(), 0);
EXPECT_EQ(map->IRNode()->num_workers(), 0);

std::unique_ptr<IRPass> pass = std::make_unique<AutoWorkerPass>();
bool m = false;
ASSERT_OK(pass->Run(map->IRNode(), &m));

// checking that after this pass, num_workers are set correctly (aka a positive number)
// It is hard to test a exact value because num_threads are different for different machine
// however, this will for sure succeed bc regardless of the total threads on cpu, this would always be >= 1
EXPECT_NE(map_leaf->IRNode()->num_workers(), 0);
EXPECT_NE(nonmap_leaf->IRNode()->num_workers(), 0);
EXPECT_NE(batch->IRNode()->num_workers(), 0);
EXPECT_NE(map->IRNode()->num_workers(), 0);
MS_LOG(DEBUG) << map_leaf->IRNode()->Name() << ": num_worker=" << map_leaf->IRNode()->num_workers();
MS_LOG(DEBUG) << nonmap_leaf->IRNode()->Name() << ": num_worker=" << nonmap_leaf->IRNode()->num_workers();
MS_LOG(DEBUG) << batch->IRNode()->Name() << ": num_worker=" << batch->IRNode()->num_workers();
MS_LOG(DEBUG) << map->IRNode()->Name() << ": num_worker=" << map->IRNode()->num_workers();
}

+ 31
- 0
tests/ut/python/dataset/test_config.py View File

@@ -357,6 +357,35 @@ def test_deterministic_python_seed_multi_thread():
ds.config.set_seed(seed_original)


def test_auto_num_workers_error():
"""
Test auto_num_workers error
"""
err_msg = ""
try:
ds.config.set_auto_num_workers([1, 2])
except ValueError as e:
err_msg = str(e)

assert "isn't of type bool" in err_msg


def test_auto_num_workers():
"""
Test auto_num_workers can be set.
"""

saved_config = ds.config.get_auto_num_workers()
assert isinstance(saved_config, bool)
# change to a different config
flipped_config = not saved_config
ds.config.set_auto_num_workers(flipped_config)
assert flipped_config == ds.config.get_auto_num_workers()
# now flip this back
ds.config.set_auto_num_workers(saved_config)
assert saved_config == ds.config.get_auto_num_workers()


if __name__ == '__main__':
test_basic()
test_get_seed()
@@ -367,3 +396,5 @@ if __name__ == '__main__':
test_deterministic_run_distribution()
test_deterministic_python_seed()
test_deterministic_python_seed_multi_thread()
test_auto_num_workers_error()
test_auto_num_workers()

Loading…
Cancel
Save