From 212eccbbae64290945eda1e42e3bcef51562c1cf Mon Sep 17 00:00:00 2001 From: Eric Date: Sun, 29 Nov 2020 22:09:25 -0500 Subject: [PATCH] Added fix for sync wait input check Fix review 1 Address comments 2 Fix review 3 --- mindspore/ccsrc/minddata/dataset/core/tensor.h | 1 - .../minddata/dataset/engine/datasetops/batch_op.cc | 4 ++-- .../dataset/engine/datasetops/parallel_op.cc | 12 ++---------- .../minddata/dataset/engine/datasetops/shuffle_op.cc | 4 ++-- .../minddata/dataset/kernels/data/data_utils.cc | 4 +++- .../dataset/kernels/data/random_choice_op.cc | 3 ++- mindspore/dataset/engine/datasets.py | 7 ++++--- 7 files changed, 15 insertions(+), 20 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor.h b/mindspore/ccsrc/minddata/dataset/core/tensor.h index a5d0c5b59d..12bebf8d8a 100644 --- a/mindspore/ccsrc/minddata/dataset/core/tensor.h +++ b/mindspore/ccsrc/minddata/dataset/core/tensor.h @@ -781,7 +781,6 @@ inline Status Tensor::CreateFromVector(const std::vectordata_end_ = (*out)->data_ + offset_arr[i]; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc index 9689f75ee7..fac8fb6f29 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -546,9 +546,9 @@ Status BatchOp::GetDatasetSize(int64_t *dataset_size) { RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); if (num_rows > 0 && start_batch_size_ > 0) { if (drop_) { - num_rows = floor(num_rows / (1.0 * start_batch_size_)); + num_rows = static_cast(floor(num_rows / (1.0 * start_batch_size_))); } else { - num_rows = ceil(num_rows / (1.0 * start_batch_size_)); + num_rows = static_cast(ceil(num_rows / (1.0 * start_batch_size_))); } } *dataset_size = num_rows; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc index 63921aec69..a92690c097 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc @@ -52,16 +52,8 @@ Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) { // A print method typically used for debugging void ParallelOp::Print(std::ostream &out, bool show_all) const { - // Summary 1-liner print - if (!show_all) { - // Call super class printer - DatasetOp::Print(out, show_all); - out << " [workers: " << num_workers_ << "]"; - } else { - // Detailed print - DatasetOp::Print(out, show_all); - out << "\nNum workers: " << num_workers_; - } + DatasetOp::Print(out, show_all); + out << " [workers: " << num_workers_ << "]"; } // Override base class reset to provide reset actions specific to the ParallelOp class. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc index fd4c81c2e7..e3139f5680 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc @@ -220,7 +220,7 @@ Status ShuffleOp::operator()() { } // Since we overloaded eoeReceived function, we are responsible to flow the EOE up the - // pipepline manually now that we are done draining the shuffle buffer + // pipeline manually now that we are done draining the shuffle buffer MS_LOG(DEBUG) << "Shuffle operator sending EOE."; auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); @@ -283,7 +283,7 @@ Status ShuffleOp::InitShuffleBuffer() { shuffle_buffer_state_ = kShuffleStateDrain; } - MS_LOG(DEBUG) << "Shuffle operator finished intializing the shuffle buffer."; + MS_LOG(DEBUG) << "Shuffle operator finished initializing the shuffle buffer."; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc index e9250ea219..920903620e 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc @@ -111,7 +111,9 @@ Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *ou *output = out; return Status::OK(); } catch (const std::exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in OneHotOp"); + std::string err_msg = "Unexpected error in OneHotOp: "; + err_msg += e.what(); + RETURN_STATUS_UNEXPECTED(err_msg); } } diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.cc index b9444a3298..bcb7855aea 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.cc @@ -89,7 +89,8 @@ RandomChoiceOp::RandomChoiceOp(const std::vector> &ops : ops_(ops), gen_(GetSeed()), rand_int_(0, ops.size() - 1) { if (ops_.empty()) { MS_LOG(ERROR) << "op_list in RandomChoiceOp is empty."; - } else if (ops_.size() == 1) { + } + if (ops_.size() == 1) { MS_LOG(WARNING) << "op_list has only 1 op, this op would be picked every time."; } is_deterministic_ = false; diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 2a9e4a4bc7..483f62e9ec 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -343,8 +343,8 @@ class Dataset: Add a blocking condition to the input Dataset. Args: - num_batch (int): the number of batches without blocking at the start of each epoch. condition_name (str): The condition name that is used to toggle sending next row. + num_batch (int): the number of batches without blocking at the start of each epoch. callback (function): The callback funciton that will be invoked when sync_update is called. Raises: @@ -1452,9 +1452,10 @@ class Dataset: num_batch (Union[int, None]): The number of batches (rows) that are released. When num_batch is None, it will default to the number specified by the sync_wait operator (default=None). - data (Union[dict, None]): The data passed to the callback (default=None). + data (Any): The data passed to the callback, user defined (default=None). """ - if isinstance(num_batch, int) and num_batch <= 0: + if (not isinstance(num_batch, int) and num_batch is not None) or \ + (isinstance(num_batch, int) and num_batch <= 0): # throwing exception, disable all sync_wait in pipeline self.disable_sync() raise RuntimeError("Sync_update batch size can only be positive, got : {}.".format(num_batch))