diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc index 5ec1270fd1..e30e1bbaa0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -389,8 +389,8 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat TensorRow output_batch; // If user returns a type that is neither a list nor an array, issue a error msg. if (py::isinstance(ret_tuple[i])) { - MS_LOG(WARNING) << "column: " << out_col_names_[i] - << " returned by per_batch_map is a np.array. Please use list instead."; + MS_LOG(INFO) << "column: " << out_col_names_[i] + << " returned by per_batch_map is a np.array. Please use list instead."; } else if (!py::isinstance(ret_tuple[i])) { MS_LOG(ERROR) << "column: " << out_col_names_[i] << " returned by per_batch_map is not a list, this could lead to conversion failure."; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index dbb8989619..8a6cab1487 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -420,10 +420,13 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr &siz } } Status DatasetNode::ValidateParams() { + int32_t num_threads = GlobalContext::config_manager()->num_cpu_threads(); + // in case std::thread::hardware_concurrency returns 0, use an artificial upper limit + num_threads = num_threads > 0 ? num_threads : std::numeric_limits::max(); CHECK_FAIL_RETURN_UNEXPECTED( - num_workers_ > 0 && num_workers_ < std::numeric_limits::max(), - Name() + "'s num_workers=" + std::to_string(num_workers_) + ", this value is less than 1 or too large."); - + num_workers_ > 0 && num_workers_ <= num_threads, + Name() + "'s num_workers=" + std::to_string(num_workers_) + + ", this value is not within the required range of [1, cpu_thread_cnt=" + std::to_string(num_threads) + "]."); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc b/mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc index 2c6f62a5a2..8b4d7bd588 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc @@ -23,16 +23,16 @@ void RuntimeContext::AssignConsumer(std::shared_ptr tree_consumer) tree_consumer_ = std::move(tree_consumer); } Status NativeRuntimeContext::Terminate() { - MS_LOG(INFO) << "Terminating a NativeRuntime"; + MS_LOG(INFO) << "Terminating a NativeRuntime."; if (tree_consumer_ != nullptr) { return TerminateImpl(); } - MS_LOG(WARNING) << "TreeConsumer was not initialized"; + MS_LOG(WARNING) << "TreeConsumer was not initialized."; return Status::OK(); } Status NativeRuntimeContext::TerminateImpl() { - CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized"); + CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " TreeConsumer is not initialized."); return tree_consumer_->Terminate(); } diff --git a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc index 880f715cd9..6b74674a2f 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc @@ -1839,7 +1839,7 @@ TEST_F(MindDataTestPipeline, TestNumWorkersValidate) { // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; - std::shared_ptr ds = ImageFolder(folder_path); + std::shared_ptr ds = ImageFolder(folder_path, false, SequentialSampler(0, 1)); // ds needs to be non nullptr otherwise, the subsequent logic will core dump ASSERT_NE(ds, nullptr); @@ -1849,4 +1849,14 @@ TEST_F(MindDataTestPipeline, TestNumWorkersValidate) { // test if set num_workers can be very large EXPECT_EQ(ds->SetNumWorkers(INT32_MAX)->CreateIterator(), nullptr); -} \ No newline at end of file + + int32_t cpu_core_cnt = GlobalContext::config_manager()->num_cpu_threads(); + + // only do this test if cpu_core_cnt can be successfully obtained + if (cpu_core_cnt > 0) { + EXPECT_EQ(ds->SetNumWorkers(cpu_core_cnt + 1)->CreateIterator(), nullptr); + // verify setting num_worker to 1 or cpu_core_cnt is allowed + ASSERT_OK(ds->SetNumWorkers(cpu_core_cnt)->IRNode()->ValidateParams()); + ASSERT_OK(ds->SetNumWorkers(1)->IRNode()->ValidateParams()); + } +}