From cd94518769ba56638a02ccfa8fe1a56ae399dd8d Mon Sep 17 00:00:00 2001 From: eric Date: Thu, 2 Apr 2020 18:58:41 -0400 Subject: [PATCH] X# This is a combination of 2 commits. Initial commit for dataset op python Added signature to barrier Adde compiling barrier code Rebasing, fixed new compile errors Final fix for make_unique Added pybind API for barrier Fixed pyfunc invocation python interface - sync_wait !1 sync_wait python interface * python interface - sync_wait fix test update test update test Added new test case add test case test for shuffle + batch Added two-sync test case Restrited that no shuffle after sync Added sync to pipeline info block first databuffer as well Intelligently get batch size Fix default case Lock Pair shares among all iterators Added fix for empty character Fixed up test case formatting Fix end of epoch in sync_wait Fixing CI --- mindspore/ccsrc/dataset/api/de_pipeline.cc | 25 ++ mindspore/ccsrc/dataset/api/de_pipeline.h | 3 + .../ccsrc/dataset/api/python_bindings.cc | 1 + mindspore/ccsrc/dataset/core/client.h | 1 + .../dataset/engine/datasetops/CMakeLists.txt | 1 + .../dataset/engine/datasetops/barrier_op.cc | 235 ++++++++++++++++++ .../dataset/engine/datasetops/barrier_op.h | 172 +++++++++++++ .../ccsrc/dataset/engine/datasetops/zip_op.h | 18 +- mindspore/dataset/engine/datasets.py | 184 +++++++++++++- mindspore/dataset/engine/iterators.py | 2 + mindspore/dataset/engine/validators.py | 16 ++ tests/ut/python/dataset/test_config.py | 38 +++ tests/ut/python/dataset/test_sync_wait.py | 182 ++++++++++++++ 13 files changed, 868 insertions(+), 10 deletions(-) create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h create mode 100644 tests/ut/python/dataset/test_sync_wait.py diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index a02d995147..c3dfeafe48 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -48,6 +48,7 @@ static std::unordered_map g_parse_op_func_ = {{kStorage, &D {kMap, &DEPipeline::ParseMapOp}, {kFilter, &DEPipeline::ParseFilterOp}, {kBatch, &DEPipeline::ParseBatchOp}, + {kBarrier, &DEPipeline::ParseBarrierOp}, {kRepeat, &DEPipeline::ParseRepeatOp}, {kSkip, &DEPipeline::ParseSkipOp}, {kZip, &DEPipeline::ParseZipOp}, @@ -627,6 +628,30 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr return Status::OK(); } +Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr *ptr) { + std::shared_ptr builder = std::make_shared(); + // Right now barrier should only take num_rows_per_buffer = 1 + // The reason for this is because having it otherwise can lead to blocking issues + // See barrier_op.h for more details + (void)builder->SetRowsPerBuffer(1); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "condition_name") { + (void)builder->SetConditionName(ToString(value)); + } else if (key == "condition_func") { + (void)builder->SetConditionFunc(value.cast()); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} + Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *ptr) { int32_t prefetch_size = 0; if (args.contains("prefetch_size")) { diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index 25919afe58..7f9c6c459a 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -40,6 +40,7 @@ enum OpName { kShuffle, kMindrecord, kBatch, + kBarrier, kCache, kRepeat, kSkip, @@ -115,6 +116,8 @@ class DEPipeline { Status ParseBatchOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseBarrierOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseGeneratorOp(const py::dict &args, std::shared_ptr *ptr); Status ParseRenameOp(const py::dict &args, std::shared_ptr *ptr); diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 9865396a7d..2b8ce4e896 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -476,6 +476,7 @@ PYBIND11_MODULE(_c_dataengine, m) { .value("STORAGE", OpName::kStorage) .value("SHUFFLE", OpName::kShuffle) .value("BATCH", OpName::kBatch) + .value("BARRIER", OpName::kBarrier) .value("MINDRECORD", OpName::kMindrecord) .value("CACHE", OpName::kCache) .value("REPEAT", OpName::kRepeat) diff --git a/mindspore/ccsrc/dataset/core/client.h b/mindspore/ccsrc/dataset/core/client.h index 15064dee6b..40de887aea 100644 --- a/mindspore/ccsrc/dataset/core/client.h +++ b/mindspore/ccsrc/dataset/core/client.h @@ -25,6 +25,7 @@ #include "dataset/core/tensor_shape.h" #include "dataset/engine/data_schema.h" #include "dataset/engine/dataset_iterator.h" +#include "dataset/engine/datasetops/barrier_op.h" #include "dataset/engine/datasetops/batch_op.h" #include "dataset/engine/datasetops/dataset_op.h" #include "dataset/engine/datasetops/device_queue_op.h" diff --git a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt index 7de62d9d11..9e8272d513 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt @@ -4,6 +4,7 @@ add_library(engine-datasetops OBJECT dataset_op.cc parallel_op.cc pipeline_op.cc + barrier_op.cc batch_op.cc device_queue_op.cc map_op.cc diff --git a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc new file mode 100644 index 0000000000..b0ea7dbd07 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc @@ -0,0 +1,235 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/barrier_op.h" +#include +#include "dataset/core/constants.h" +#include "dataset/engine/data_buffer.h" +#include "dataset/engine/db_connector.h" +#include "dataset/core/config_manager.h" +#include "dataset/core/global_context.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +BarrierOp::Builder::Builder() { + // Some arguments to the BarrierOp constructor have a default argument that is taken + // from the client config. + // The user may choose to change these values for the construction of the BarrierOp by + // using the various builder set methods. + + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); } + +Status BarrierOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_, + builder_condition_func_); + return Status::OK(); +} + +// Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions +BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, + py::function condition_func) + : PipelineOp(op_connector_size), + rows_per_buffer_(rows_per_buffer), + buffer_id_(0), + clean_up_(false), + eof_(false), + condition_name_(condition_name), + condition_function_(condition_func) {} + +// destructor +BarrierOp::~BarrierOp() {} + +// Entry point for Barrier, called by launch() +Status BarrierOp::operator()() { + // The children_num_ parameter needs to be put here + // Synchronize with TaskManager once the thread is created. + TaskManager::FindMe()->Post(); + + // create child iterator, right now this barrier is a pipeline operator + int32_t worker_id = 0; + int32_t child_idx = 0; + child_iterator_ = std::make_unique(this, worker_id, child_idx); + + // Loop until eof is true + while (!eof_) { + // Create new table to put the new tensor rows + std::unique_ptr curr_table = std::make_unique(); + RETURN_IF_NOT_OK(prepare(curr_table.get())); + + // If an eof got picked up during the above prepare, then we're done + if (eof_) { + break; + } + + // we have to output new buffer with possibly different buffer size, possibly one row + while (!clean_up_) { + // 1. If a previous loop iteration sent the current table out, then create a new one. + + if (curr_table == nullptr) { + curr_table = std::make_unique(); + } + + // 2 fill the table. Note: clean_up mode might get turned on if epoch is finished + RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); + + // 3 create and update buffer and send it to the out connector + if (!curr_table->empty()) { + std::unique_ptr curr_buffer = std::make_unique(buffer_id_, DataBuffer::kDeBFlagNone); + curr_buffer->set_tensor_table(std::move(curr_table)); + curr_buffer->set_column_name_map(col_name_id_map_); + MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " + << curr_buffer->NumCols() << ", map " << col_name_id_map_.size() << "."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + buffer_id_++; + } + } + + // 4 handle drain state. + if (clean_up_) { + MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal."; + // Send the eoe up. + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + } + } + // 5 handle eof + // propagate eof here. + MS_LOG(INFO) << "Barrier operator got EOF, propagating."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + return Status::OK(); +} + +// Handles preprocessing of the main loop, used when starting new epoch +Status BarrierOp::prepare(TensorQTable *const table) { + MS_LOG(DEBUG) << "Barrier operator prepares for new epoch."; + clean_up_ = false; + buffer_id_ = 0; + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table."); + } + // fill initial row + TensorRow new_row = {}; + // use iterator to get next row and invoke pyfunc wait + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + + // If the first row fetching resulted in eof, then we are done. + if (eof_) { + return Status::OK(); + } + if (new_row.empty()) { + // This epoch is empty + return Status::OK(); + } + // Pack this first row into our tensor table + // first row we also have to check if we should block + RETURN_IF_NOT_OK(blockCond()); + + table->push_back(std::move(new_row)); + // At this point we have 1 row produced, we take the old column map id and use it in the new table + // Initializing col_name_id_map_ from the first data buffer. + col_name_id_map_ = child_iterator_->col_name_id_map(); + // the update code below shouldn't do anything bad if the column name already exists. + return Status::OK(); +} + +// fillBuffer always expects a new table to fill +Status BarrierOp::fillBuffer(TensorQTable *const table) { + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer."); + } + TensorRow new_row = {}; + while (table->size() < static_cast(rows_per_buffer_)) { + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + // Early exit the loop if we got empty row from any of our child iterations + if (new_row.empty()) { + return Status::OK(); + } + // else we got a row so pack it into the tensor table. + RETURN_IF_NOT_OK(blockCond()); + + table->push_back(std::move(new_row)); + } + return Status::OK(); +} + +// function executes a py_func and blocks until condition becomes true. +Status BarrierOp::blockCond() { + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + // we have condition name, however the flexibility is in python today + try { + // Invoke python function + py::object ret_py_obj = condition_function_(); + // Process the return value + if (!py::isinstance(ret_py_obj)) { + return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false"); + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + return Status::OK(); +} + +// fetches next Barrier buffer row +Status BarrierOp::getNextTensorRow(TensorRow *new_row) { + // iterate over all iterators and generate a row + RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row)); + // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row + if (new_row->empty()) { + // If we did not get a row from any of the children, then it's the end of an epoch and we can move + // to drain state. + MS_LOG(INFO) << "Barrier operator child iterator produced empty row."; + clean_up_ = true; + // If we picked up an eof here, then we are completely done. + if ((child_iterator_)->eof_handled()) { + MS_LOG(INFO) << "Barrier operator iterator got EOF."; + eof_ = true; + } + return Status::OK(); + } + return Status::OK(); +} + +// A function that prints info about the Operator +void BarrierOp::Print(std::ostream &out, bool show_all) const { + // Call base class printer first + PipelineOp::Print(out, show_all); + out << "\nBarrierOp:\n" + << "\nCondition " << condition_name_ << "\n\n"; +} + +// overwrite function and handle eof +Status BarrierOp::EofReceived(int32_t) { + MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now."; + return Status::OK(); +} + +// overwrite function and handle eoe +Status BarrierOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h new file mode 100644 index 0000000000..8be55fba7e --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h @@ -0,0 +1,172 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ + +#include +#include +#include +#include +#include +#include "dataset/core/tensor.h" +#include "dataset/engine/dataset_iterator.h" +#include "dataset/engine/datasetops/pipeline_op.h" +#include "dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +// Forward declare +class DataBuffer; +class ExecutionTree; + +// BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has +// been received. This signal is given from python layer. The current barrier design respects the +// rows per buffer design and will only output a buffer with rows once it has received rows per buffer +// signals from python. + +class BarrierOp : public PipelineOp { + public: + // The nested builder class inside of the BarrierOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @param int32_t op_connector_size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @param const std::string & condition_name + // @return Builder setter method returns reference to the builder. + Builder &SetConditionName(const std::string &condition_name) { + builder_condition_name_ = condition_name; + return *this; + } + + // Setter method. + // @param py::function condition_func - blocking condition function + // @return Builder setter method returns reference to the builder. + Builder &SetConditionFunc(py::function condition_func) { + builder_condition_func_ = condition_func; + return *this; + } + + // The builder "build" method creates the BarrierOp dataset Operator. + // @return shared_ptr to the new BarrierOp object + Status Build(std::shared_ptr *); + + private: + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + std::string builder_condition_name_; + py::function builder_condition_func_; + + Status SanityCheck() const; + }; + + // Constructor for BarrierOp + // @param rows_per_buffer - number of rows in output buffer + // @param op_connector_size - connector size + // @param condition_name - the condition name associated with this operator + // @param condition_func - the blocking condition check per row + // @note - currently rows_per_buffer should = 1 for barrier. + // The reason for this is having other values would complicate how the pipeline behaves with other operators + // One example of such case is having batch after barrier. Batch would be waiting for data and having + // rows per buffer in this case can result in hanging + BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, + py::function condition_func); + + // Destructor + ~BarrierOp(); + + Status EofReceived(int32_t) override; + + Status EoeReceived(int32_t) override; + + // Print function for Barrier + // @param out - output stream to print to + // @param show_all - if it should print everything + void Print(std::ostream &out, bool show_all) const override; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) { + bo.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Handles preprocessing of the main loop, used when starting new epoch + // @param table - a table of tensors to be moved into a buffer + Status prepare(TensorQTable *const table); + + // This function calls takes a table repeatedly adds rows to it. + // @param table - a table of tensors to be moved into a buffer + Status fillBuffer(TensorQTable *const table); + + // Gets next tensor row and sets control signals + Status getNextTensorRow(TensorRow *new_row); + + // This function runs the wait function on condition + Status blockCond(); + + private: + // clean up variable to return imcomplete buffer + bool clean_up_; + // end of file state, we stop reading data and shut down + bool eof_; + // rows per buffer + int32_t rows_per_buffer_; + // buffer_id + int32_t buffer_id_; + // local variable to keep track of the buffer information + std::unordered_map col_name_id_map_; + // iterator to pull new rows, we only have one child + std::unique_ptr child_iterator_; + // condition name, to support multiple barriers + std::string condition_name_; + // Function pointer of blocking function + py::function condition_function_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h index f14ecba733..04d8ab0121 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h @@ -34,7 +34,7 @@ class DataBuffer; class ZipOp : public PipelineOp { public: - // The nested builder class inside of the BatchOp is used to help manage all of + // The nested builder class inside of the ZipOp is used to help manage all of // the arguments for constructing it. Use the builder by setting each argument // with the provided set methods, and then finally call the build method to execute // the actual construction. @@ -76,8 +76,8 @@ class ZipOp : public PipelineOp { }; // Constructor for ZipOp - // @param rows_per_buffer number of rows in output buffer - // @param op_connector_size connector + // @param rows_per_buffer - number of rows in output buffer + // @param op_connector_size - connector size ZipOp(int32_t rows_per_buffer, int32_t op_connector_size); // Destructor @@ -88,8 +88,8 @@ class ZipOp : public PipelineOp { Status EoeReceived(int32_t) override; // Print function for Zip - // @param out output stream to print to - // @param show_all if it should print everything + // @param out - output stream to print to + // @param show_all - if it should print everything void Print(std::ostream &out, bool show_all) const override; // Provide stream operator for displaying it @@ -113,14 +113,14 @@ class ZipOp : public PipelineOp { Status fillBuffer(TensorQTable *const table); // Special handle case where an empty row has been received from child iterator - // @note we need to drain eoe signals from all children connectors. - // @details when this function is called, then we encountered eoe at child iterator + // @note - we need to drain eoe signals from all children connectors. + // @details - when this function is called, then we encountered eoe at child iterator // we have to drain rows from other child iterators until we hit eoe from all other child iterators Status drainPipeline(); // Merges 1 row from each childIterator together - // @param new_zip_row input and output, will return a non-empty row if all rows from childConnectors are non-empty - // @param updateColumnMapping generates a new column name to index mapping (mColNameIdMap) if set to true + // @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty + // @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true // @details merge rows from iterator together. This is the main functionality for ZipOp // this function takes one row and fills it with tensors from rows fetched // from childIterators. diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 855e4609bb..f67461eee3 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -26,6 +26,7 @@ import random import uuid from enum import Enum from importlib import import_module +import threading import numpy as np from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ @@ -38,7 +39,7 @@ from .iterators import DictIterator, TupleIterator from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ - check_zip_dataset, check_add_column, check_textfiledataset + check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -139,6 +140,7 @@ class Dataset: self._batch_size = None self._num_classes = None self._repeat_count = None + self._sync = False def get_args(self): """ @@ -196,6 +198,30 @@ class Dataset: """ return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns) + @check_sync_wait + def sync_wait(self, condition_name, num_batch=1, callback=None): + ''' + Add a blocking condition to the input Dataset + + Args: + input_dataset (Dataset): Input dataset to apply flow control + num_batch (int): the number of batches without blocking at the start of each epoch + condition_name (str): The condition name that is used to toggle sending next row + callback (function): The callback funciton that will be invoked when sync_update is called + + Raises: + RuntimeError: If condition name already exists. + + Examples: + >>> import mindspore.dataset as ds + >>> # data is an instance of Dataset object. + >>> data = data.sync_wait("callback1") + >>> data = data.batch(batch_size) + >>> for batch_data in data.create_dict_iterator(): + >>> data = data.sync_update("callback1") + ''' + return SyncWaitDataset(self, condition_name, num_batch, callback) + @check_shuffle def shuffle(self, buffer_size): """ @@ -218,6 +244,9 @@ class Dataset: Returns: ShuffleDataset, dataset shuffled. + Raises: + RuntimeError: If exist sync operators before shuffle. + Examples: >>> import mindspore.dataset as ds >>> # data is an instance of Dataset object @@ -816,6 +845,9 @@ class Dataset: self._input_indexs = value def _get_pipeline_info(self): + """ + Gets pipeline information. + """ device_iter = TupleIterator(self) self._output_shapes = device_iter.get_output_shapes() self._output_types = device_iter.get_output_types() @@ -870,6 +902,30 @@ class Dataset: return self.input[0].num_classes() return None + def get_sync_notifiers(self): + if self.input: + return self.input[0].get_sync_notifiers() + return {} + + def is_sync(self): + if self.input: + return self.input[0].is_sync() + return False + + def sync_update(self, condition_name, num_batch=None, data=None): + """ + condition_name (str): The condition name that is used to toggle sending next row + step_size (int or None): The number of steps(rows) that are released + when pass_rows is None, will update the same number as sync_wait specified + data (dict or None): The data passed to the callback + """ + notifiers_dict = self.get_sync_notifiers() + if condition_name not in notifiers_dict: + raise RuntimeError("Condition name not found") + if num_batch is not None: + num_batch *= self.get_batch_size() + notifiers_dict[condition_name](num_batch, data) + def get_batch_size(self): """ Get the size of a batch. @@ -973,6 +1029,8 @@ class BatchDataset(DatasetOp): if BatchDataset._is_ancestor_of_repeat(input_dataset): logger.warning("Repeat is located before batch, data from two epochs can be batched together.") + BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) + self.batch_size = batch_size self.drop_remainder = drop_remainder self.per_batch_map = per_batch_map @@ -1029,6 +1087,20 @@ class BatchDataset(DatasetOp): flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset) return flag + @staticmethod + def _update_batch_size_for_syncwait(dataset, batch_size): + """ + Utility function to notify batch size to sync_wait. + + Args: + dataset (Dataset): dataset to be checked + batchsize (int): batch size to notify + """ + if isinstance(dataset, SyncWaitDataset): + dataset.update_sync_batch_size(batch_size) + for input_dataset in dataset.input: + BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) + class BatchInfo(CBatchInfo): """ @@ -1053,6 +1125,108 @@ class BatchInfo(CBatchInfo): """ return +class BlockReleasePair: + """ + The blocking condition class used by SyncWaitDataset + + Args: + init_release_rows (int): Number of lines to allow through the pipeline + callback (function): The callback funciton that will be called when release is called + """ + def __init__(self, init_release_rows, callback=None): + self.row_count = -init_release_rows + self.cv = threading.Condition() + self.callback = callback + self.default_rows = init_release_rows + + def __deepcopy__(self, memodict): + if id(self) in memodict: + return memodict[id(self)] + memodict[id(self)] = self + # condition variable and callback are the same, but reset the counter + self.reset() + return self + + def reset(self): + with self.cv: + self.row_count = -self.default_rows + self.cv.notify_all() + + def update_batched_size(self, batch_size): + # should only use before the pipeline creates + self.row_count *= batch_size + self.default_rows *= batch_size + + def block_func(self): + with self.cv: + self.cv.wait_for(lambda: self.row_count < 0) + self.row_count += 1 + return True + + def release_func(self, pass_rows=None, data=None): + with self.cv: + if pass_rows is None: + pass_rows = self.default_rows + self.row_count -= pass_rows + if self.callback is not None: + self.callback(data) + self.cv.notify_all() + +class SyncWaitDataset(DatasetOp): + """ + The result of adding a blocking condition to the input Dataset + + Args: + input_dataset (Dataset): Input dataset to apply flow control + num_batch (int): the number of batches without blocking at the start of each epoch + condition_name (str): The condition name that is used to toggle sending next row + callback (function): The callback funciton that will be invoked when sync_update is called + + Raises: + RuntimeError: If condition name already exists. + """ + + def __init__(self, input_dataset, condition_name, num_batch, callback=None): + super().__init__() + self.input.append(input_dataset) + input_dataset.output.append(self) + # set to the default value, waiting for the batch to update it + self._condition_name = condition_name + self._pair = BlockReleasePair(num_batch, callback) + if self._condition_name in self.input[0].get_sync_notifiers(): + raise RuntimeError("Condition name is already in use") + + def get_sync_notifiers(self): + return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}} + + def is_sync(self): + return True + + def get_args(self): + args = super().get_args() + args["condition_name"] = self._condition_name + args["condition_func"] = self._pair.block_func + return args + + def update_sync_batch_size(self, batch_size): + self._pair.update_batched_size(batch_size) + + @staticmethod + def _is_ancestor_of_batch(dataset): + """ + Utility function to find the case where sync_wait is used before batch. + + Args: + dataset (Dataset): dataset to be checked + Return: + True or False + """ + if isinstance(dataset, BatchDataset): + return True + flag = False + for input_dataset in dataset.input: + flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset) + return flag class ShuffleDataset(DatasetOp): """ @@ -1061,6 +1235,9 @@ class ShuffleDataset(DatasetOp): Args: input_dataset (Dataset): Input Dataset to be shuffled. buffer_size (int): The size of the buffer. + + Raises: + RuntimeError: If exist sync operators before shuffle. """ def __init__(self, input_dataset, buffer_size): @@ -1069,6 +1246,8 @@ class ShuffleDataset(DatasetOp): self.input.append(input_dataset) input_dataset.output.append(self) self._input_indexs = input_dataset.input_indexs + if self.is_sync(): + raise RuntimeError("No shuffle after sync operators") def get_args(self): args = super().get_args() @@ -1335,6 +1514,9 @@ class ZipDataset(DatasetOp): """ return None + def is_sync(self): + return any([c.is_sync() for c in self.input]) + def get_args(self): args = super().get_args() return args diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 6af6c7dba8..a8d20df5f3 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -125,6 +125,8 @@ class Iterator: op_type = OpName.MINDRECORD elif isinstance(dataset, de.BatchDataset): op_type = OpName.BATCH + elif isinstance(dataset, de.SyncWaitDataset): + op_type = OpName.BARRIER elif isinstance(dataset, de.ZipDataset): op_type = OpName.ZIP elif isinstance(dataset, de.MapDataset): diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index a68d723f1d..a8d18ab2c1 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -652,6 +652,22 @@ def check_batch(method): return new_method +def check_sync_wait(method): + """check the input arguments of sync_wait.""" + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + nreq_param_str = ['condition_name'] + nreq_param_int = ['step_size'] + + check_param_type(nreq_param_int, param_dict, int) + + check_param_type(nreq_param_str, param_dict, str) + + return method(*args, **kwargs) + + return new_method def check_shuffle(method): """check the input arguments of shuffle.""" diff --git a/tests/ut/python/dataset/test_config.py b/tests/ut/python/dataset/test_config.py index 8cabe81aaa..0c1e0073af 100644 --- a/tests/ut/python/dataset/test_config.py +++ b/tests/ut/python/dataset/test_config.py @@ -12,8 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +""" +Testing configuration manager +""" +import filecmp +import glob +import os + import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as vision +DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" def test_basic(): ds.config.load('../data/dataset/declient.cfg') @@ -36,6 +46,34 @@ def test_basic(): assert ds.config.get_prefetch_size() == 4 assert ds.config.get_seed() == 5 +def test_pipeline(): + """ + Test that our configuration pipeline works when we set parameters at dataset interval + """ + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + ds.config.set_num_parallel_workers(2) + data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)]) + ds.serialize(data1, "testpipeline.json") + + data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + ds.config.set_num_parallel_workers(4) + data2 = data2.map(input_columns=["image"], operations=[vision.Decode(True)]) + ds.serialize(data2, "testpipeline2.json") + + # check that the generated output is different + assert (filecmp.cmp('testpipeline.json', 'testpipeline2.json')) + + # this test passes currently because our num_parallel_workers don't get updated. + + # remove generated jason files + file_list = glob.glob('*.json') + for f in file_list: + try: + os.remove(f) + except IOError: + logger.info("Error while deleting: {}".format(f)) + if __name__ == '__main__': test_basic() + test_pipeline() diff --git a/tests/ut/python/dataset/test_sync_wait.py b/tests/ut/python/dataset/test_sync_wait.py new file mode 100644 index 0000000000..277499d9ae --- /dev/null +++ b/tests/ut/python/dataset/test_sync_wait.py @@ -0,0 +1,182 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import mindspore.dataset as ds +from mindspore import log as logger +import time +import numpy as np + + +def gen(): + for i in range(100): + yield np.array(i), + + +class Augment: + def __init__(self, loss): + self.loss = loss + + def preprocess(self, input): + return input + + def update(self, data): + self.loss = data["loss"] + + +def test_simple_sync_wait(): + """ + Test simple sync wait: test sync in dataset pipeline + """ + logger.info("test_simple_sync_wait") + batch_size = 4 + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + dataset = dataset.batch(batch_size) + + count = 0 + for data in dataset.create_dict_iterator(): + assert (data["input"][0] == count) + count += batch_size + data = {"loss": count} + dataset.sync_update(condition_name="policy", data=data) + + +def test_simple_shuffle_sync(): + """ + Test simple shuffle sync: test shuffle before sync + """ + logger.info("test_simple_shuffle_sync") + shuffle_size = 4 + batch_size = 10 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.shuffle(shuffle_size) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + dataset = dataset.batch(batch_size) + + count = 0 + for data in dataset.create_dict_iterator(): + count += 1 + #time.sleep(0.5) + data = {"loss": count} + dataset.sync_update(condition_name="policy", data=data) + + +def test_two_sync(): + """ + Test two sync: dataset pipeline with with two sync_operators + """ + logger.info("test_two_sync") + batch_size = 6 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + # notice that with our design, we need to have step_size = shuffle size + dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) + + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + + dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches") + + dataset = dataset.batch(batch_size) + + count = 0 + for data in dataset.create_dict_iterator(): + count += 1 + data = {"loss": count} + dataset.sync_update(condition_name="every batch", data=data) + if count % 2 == 0: + dataset.sync_update(condition_name="every 2 batches") + +def test_sync_epoch(): + """ + Test sync wait with epochs: test sync with epochs in dataset pipeline + """ + logger.info("test_sync_epoch") + batch_size = 30 + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + dataset = dataset.batch(batch_size, drop_remainder=True) + + for epochs in range(3): + aug.update({"loss": 0}) + count = 0 + for data in dataset.create_dict_iterator(): + assert (data["input"][0] == count) + count += batch_size + data = {"loss": count} + dataset.sync_update(condition_name="policy", data=data) + + +def test_sync_exception_01(): + """ + Test sync: with shuffle in sync mode + """ + logger.info("test_sync_exception_01") + shuffle_size = 4 + batch_size = 10 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + + try: + dataset = dataset.shuffle(shuffle_size) + except BaseException as e: + assert "shuffle" in str(e) + dataset = dataset.batch(batch_size) + + +def test_sync_exception_02(): + """ + Test sync: with duplicated condition name + """ + logger.info("test_sync_exception_02") + batch_size = 6 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + # notice that with our design, we need to have step_size = shuffle size + dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) + + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + + try: + dataset = dataset.sync_wait(num_batch=2, condition_name="every batch") + except BaseException as e: + assert "name" in str(e) + dataset = dataset.batch(batch_size) + + +if __name__ == "__main__": + test_simple_sync_wait() + test_simple_shuffle_sync() + test_two_sync() + test_sync_exception_01() + test_sync_exception_02() + test_sync_epoch() \ No newline at end of file