Merge pull request !508 from EricZ/mastertags/v0.2.0-alpha
| @@ -48,6 +48,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D | |||||
| {kMap, &DEPipeline::ParseMapOp}, | {kMap, &DEPipeline::ParseMapOp}, | ||||
| {kFilter, &DEPipeline::ParseFilterOp}, | {kFilter, &DEPipeline::ParseFilterOp}, | ||||
| {kBatch, &DEPipeline::ParseBatchOp}, | {kBatch, &DEPipeline::ParseBatchOp}, | ||||
| {kBarrier, &DEPipeline::ParseBarrierOp}, | |||||
| {kRepeat, &DEPipeline::ParseRepeatOp}, | {kRepeat, &DEPipeline::ParseRepeatOp}, | ||||
| {kSkip, &DEPipeline::ParseSkipOp}, | {kSkip, &DEPipeline::ParseSkipOp}, | ||||
| {kZip, &DEPipeline::ParseZipOp}, | {kZip, &DEPipeline::ParseZipOp}, | ||||
| @@ -627,6 +628,30 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||||
| std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>(); | |||||
| // Right now barrier should only take num_rows_per_buffer = 1 | |||||
| // The reason for this is because having it otherwise can lead to blocking issues | |||||
| // See barrier_op.h for more details | |||||
| (void)builder->SetRowsPerBuffer(1); | |||||
| for (auto arg : args) { | |||||
| std::string key = py::str(arg.first); | |||||
| py::handle value = arg.second; | |||||
| if (!value.is_none()) { | |||||
| if (key == "condition_name") { | |||||
| (void)builder->SetConditionName(ToString(value)); | |||||
| } else if (key == "condition_func") { | |||||
| (void)builder->SetConditionFunc(value.cast<py::function>()); | |||||
| } | |||||
| } | |||||
| } | |||||
| std::shared_ptr<BarrierOp> op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||||
| *ptr = op; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | ||||
| int32_t prefetch_size = 0; | int32_t prefetch_size = 0; | ||||
| if (args.contains("prefetch_size")) { | if (args.contains("prefetch_size")) { | ||||
| @@ -40,6 +40,7 @@ enum OpName { | |||||
| kShuffle, | kShuffle, | ||||
| kMindrecord, | kMindrecord, | ||||
| kBatch, | kBatch, | ||||
| kBarrier, | |||||
| kCache, | kCache, | ||||
| kRepeat, | kRepeat, | ||||
| kSkip, | kSkip, | ||||
| @@ -115,6 +116,8 @@ class DEPipeline { | |||||
| Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||||
| Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| @@ -481,6 +481,7 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||||
| .value("STORAGE", OpName::kStorage) | .value("STORAGE", OpName::kStorage) | ||||
| .value("SHUFFLE", OpName::kShuffle) | .value("SHUFFLE", OpName::kShuffle) | ||||
| .value("BATCH", OpName::kBatch) | .value("BATCH", OpName::kBatch) | ||||
| .value("BARRIER", OpName::kBarrier) | |||||
| .value("MINDRECORD", OpName::kMindrecord) | .value("MINDRECORD", OpName::kMindrecord) | ||||
| .value("CACHE", OpName::kCache) | .value("CACHE", OpName::kCache) | ||||
| .value("REPEAT", OpName::kRepeat) | .value("REPEAT", OpName::kRepeat) | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "dataset/core/tensor_shape.h" | #include "dataset/core/tensor_shape.h" | ||||
| #include "dataset/engine/data_schema.h" | #include "dataset/engine/data_schema.h" | ||||
| #include "dataset/engine/dataset_iterator.h" | #include "dataset/engine/dataset_iterator.h" | ||||
| #include "dataset/engine/datasetops/barrier_op.h" | |||||
| #include "dataset/engine/datasetops/batch_op.h" | #include "dataset/engine/datasetops/batch_op.h" | ||||
| #include "dataset/engine/datasetops/dataset_op.h" | #include "dataset/engine/datasetops/dataset_op.h" | ||||
| #include "dataset/engine/datasetops/device_queue_op.h" | #include "dataset/engine/datasetops/device_queue_op.h" | ||||
| @@ -4,6 +4,7 @@ add_library(engine-datasetops OBJECT | |||||
| dataset_op.cc | dataset_op.cc | ||||
| parallel_op.cc | parallel_op.cc | ||||
| pipeline_op.cc | pipeline_op.cc | ||||
| barrier_op.cc | |||||
| batch_op.cc | batch_op.cc | ||||
| device_queue_op.cc | device_queue_op.cc | ||||
| map_op.cc | map_op.cc | ||||
| @@ -0,0 +1,235 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "dataset/engine/datasetops/barrier_op.h" | |||||
| #include <utility> | |||||
| #include "dataset/core/constants.h" | |||||
| #include "dataset/engine/data_buffer.h" | |||||
| #include "dataset/engine/db_connector.h" | |||||
| #include "dataset/core/config_manager.h" | |||||
| #include "dataset/core/global_context.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| BarrierOp::Builder::Builder() { | |||||
| // Some arguments to the BarrierOp constructor have a default argument that is taken | |||||
| // from the client config. | |||||
| // The user may choose to change these values for the construction of the BarrierOp by | |||||
| // using the various builder set methods. | |||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||||
| builder_rows_per_buffer_ = cfg->rows_per_buffer(); | |||||
| builder_op_connector_size_ = cfg->op_connector_size(); | |||||
| } | |||||
| Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); } | |||||
| Status BarrierOp::Builder::Build(std::shared_ptr<BarrierOp> *ptr) { | |||||
| RETURN_IF_NOT_OK(SanityCheck()); | |||||
| *ptr = std::make_shared<BarrierOp>(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_, | |||||
| builder_condition_func_); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions | |||||
| BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, | |||||
| py::function condition_func) | |||||
| : PipelineOp(op_connector_size), | |||||
| rows_per_buffer_(rows_per_buffer), | |||||
| buffer_id_(0), | |||||
| clean_up_(false), | |||||
| eof_(false), | |||||
| condition_name_(condition_name), | |||||
| condition_function_(condition_func) {} | |||||
| // destructor | |||||
| BarrierOp::~BarrierOp() {} | |||||
| // Entry point for Barrier, called by launch() | |||||
| Status BarrierOp::operator()() { | |||||
| // The children_num_ parameter needs to be put here | |||||
| // Synchronize with TaskManager once the thread is created. | |||||
| TaskManager::FindMe()->Post(); | |||||
| // create child iterator, right now this barrier is a pipeline operator | |||||
| int32_t worker_id = 0; | |||||
| int32_t child_idx = 0; | |||||
| child_iterator_ = std::make_unique<ChildIterator>(this, worker_id, child_idx); | |||||
| // Loop until eof is true | |||||
| while (!eof_) { | |||||
| // Create new table to put the new tensor rows | |||||
| std::unique_ptr<TensorQTable> curr_table = std::make_unique<TensorQTable>(); | |||||
| RETURN_IF_NOT_OK(prepare(curr_table.get())); | |||||
| // If an eof got picked up during the above prepare, then we're done | |||||
| if (eof_) { | |||||
| break; | |||||
| } | |||||
| // we have to output new buffer with possibly different buffer size, possibly one row | |||||
| while (!clean_up_) { | |||||
| // 1. If a previous loop iteration sent the current table out, then create a new one. | |||||
| if (curr_table == nullptr) { | |||||
| curr_table = std::make_unique<TensorQTable>(); | |||||
| } | |||||
| // 2 fill the table. Note: clean_up mode might get turned on if epoch is finished | |||||
| RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); | |||||
| // 3 create and update buffer and send it to the out connector | |||||
| if (!curr_table->empty()) { | |||||
| std::unique_ptr<DataBuffer> curr_buffer = std::make_unique<DataBuffer>(buffer_id_, DataBuffer::kDeBFlagNone); | |||||
| curr_buffer->set_tensor_table(std::move(curr_table)); | |||||
| curr_buffer->set_column_name_map(col_name_id_map_); | |||||
| MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " | |||||
| << curr_buffer->NumCols() << ", map " << col_name_id_map_.size() << "."; | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); | |||||
| buffer_id_++; | |||||
| } | |||||
| } | |||||
| // 4 handle drain state. | |||||
| if (clean_up_) { | |||||
| MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal."; | |||||
| // Send the eoe up. | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)))); | |||||
| } | |||||
| } | |||||
| // 5 handle eof | |||||
| // propagate eof here. | |||||
| MS_LOG(INFO) << "Barrier operator got EOF, propagating."; | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)))); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Handles preprocessing of the main loop, used when starting new epoch | |||||
| Status BarrierOp::prepare(TensorQTable *const table) { | |||||
| MS_LOG(DEBUG) << "Barrier operator prepares for new epoch."; | |||||
| clean_up_ = false; | |||||
| buffer_id_ = 0; | |||||
| if (table == nullptr) { | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table."); | |||||
| } | |||||
| // fill initial row | |||||
| TensorRow new_row = {}; | |||||
| // use iterator to get next row and invoke pyfunc wait | |||||
| RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); | |||||
| // If the first row fetching resulted in eof, then we are done. | |||||
| if (eof_) { | |||||
| return Status::OK(); | |||||
| } | |||||
| if (new_row.empty()) { | |||||
| // This epoch is empty | |||||
| return Status::OK(); | |||||
| } | |||||
| // Pack this first row into our tensor table | |||||
| // first row we also have to check if we should block | |||||
| RETURN_IF_NOT_OK(blockCond()); | |||||
| table->push_back(std::move(new_row)); | |||||
| // At this point we have 1 row produced, we take the old column map id and use it in the new table | |||||
| // Initializing col_name_id_map_ from the first data buffer. | |||||
| col_name_id_map_ = child_iterator_->col_name_id_map(); | |||||
| // the update code below shouldn't do anything bad if the column name already exists. | |||||
| return Status::OK(); | |||||
| } | |||||
| // fillBuffer always expects a new table to fill | |||||
| Status BarrierOp::fillBuffer(TensorQTable *const table) { | |||||
| if (table == nullptr) { | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer."); | |||||
| } | |||||
| TensorRow new_row = {}; | |||||
| while (table->size() < static_cast<size_t>(rows_per_buffer_)) { | |||||
| RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); | |||||
| // Early exit the loop if we got empty row from any of our child iterations | |||||
| if (new_row.empty()) { | |||||
| return Status::OK(); | |||||
| } | |||||
| // else we got a row so pack it into the tensor table. | |||||
| RETURN_IF_NOT_OK(blockCond()); | |||||
| table->push_back(std::move(new_row)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // function executes a py_func and blocks until condition becomes true. | |||||
| Status BarrierOp::blockCond() { | |||||
| { | |||||
| py::gil_scoped_acquire gil_acquire; | |||||
| if (Py_IsInitialized() == 0) { | |||||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||||
| } | |||||
| // we have condition name, however the flexibility is in python today | |||||
| try { | |||||
| // Invoke python function | |||||
| py::object ret_py_obj = condition_function_(); | |||||
| // Process the return value | |||||
| if (!py::isinstance<py::bool_>(ret_py_obj)) { | |||||
| return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false"); | |||||
| } | |||||
| } catch (const py::error_already_set &e) { | |||||
| return Status(StatusCode::kPyFuncException, e.what()); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // fetches next Barrier buffer row | |||||
| Status BarrierOp::getNextTensorRow(TensorRow *new_row) { | |||||
| // iterate over all iterators and generate a row | |||||
| RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row)); | |||||
| // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row | |||||
| if (new_row->empty()) { | |||||
| // If we did not get a row from any of the children, then it's the end of an epoch and we can move | |||||
| // to drain state. | |||||
| MS_LOG(INFO) << "Barrier operator child iterator produced empty row."; | |||||
| clean_up_ = true; | |||||
| // If we picked up an eof here, then we are completely done. | |||||
| if ((child_iterator_)->eof_handled()) { | |||||
| MS_LOG(INFO) << "Barrier operator iterator got EOF."; | |||||
| eof_ = true; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // A function that prints info about the Operator | |||||
| void BarrierOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Call base class printer first | |||||
| PipelineOp::Print(out, show_all); | |||||
| out << "\nBarrierOp:\n" | |||||
| << "\nCondition " << condition_name_ << "\n\n"; | |||||
| } | |||||
| // overwrite function and handle eof | |||||
| Status BarrierOp::EofReceived(int32_t) { | |||||
| MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now."; | |||||
| return Status::OK(); | |||||
| } | |||||
| // overwrite function and handle eoe | |||||
| Status BarrierOp::EoeReceived(int32_t) { | |||||
| state_ = OpState::kDeOpIdle; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,172 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ | |||||
| #define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "dataset/core/tensor.h" | |||||
| #include "dataset/engine/dataset_iterator.h" | |||||
| #include "dataset/engine/datasetops/pipeline_op.h" | |||||
| #include "dataset/kernels/tensor_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // Forward declare | |||||
| class DataBuffer; | |||||
| class ExecutionTree; | |||||
| // BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has | |||||
| // been received. This signal is given from python layer. The current barrier design respects the | |||||
| // rows per buffer design and will only output a buffer with rows once it has received rows per buffer | |||||
| // signals from python. | |||||
| class BarrierOp : public PipelineOp { | |||||
| public: | |||||
| // The nested builder class inside of the BarrierOp is used to help manage all of | |||||
| // the arguments for constructing it. Use the builder by setting each argument | |||||
| // with the provided set methods, and then finally call the build method to execute | |||||
| // the actual construction. | |||||
| class Builder { | |||||
| public: | |||||
| // Builder constructor. Creates the builder object. | |||||
| // @note No default args | |||||
| // @return This is a constructor. | |||||
| Builder(); | |||||
| // Default destructor | |||||
| ~Builder() = default; | |||||
| // Setter method. | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { | |||||
| builder_rows_per_buffer_ = rows_per_buffer; | |||||
| return *this; | |||||
| } | |||||
| // Setter method. | |||||
| // @param int32_t op_connector_size | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetOpConnectorSize(int32_t op_connector_size) { | |||||
| builder_op_connector_size_ = op_connector_size; | |||||
| return *this; | |||||
| } | |||||
| // Setter method. | |||||
| // @param const std::string & condition_name | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetConditionName(const std::string &condition_name) { | |||||
| builder_condition_name_ = condition_name; | |||||
| return *this; | |||||
| } | |||||
| // Setter method. | |||||
| // @param py::function condition_func - blocking condition function | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetConditionFunc(py::function condition_func) { | |||||
| builder_condition_func_ = condition_func; | |||||
| return *this; | |||||
| } | |||||
| // The builder "build" method creates the BarrierOp dataset Operator. | |||||
| // @return shared_ptr to the new BarrierOp object | |||||
| Status Build(std::shared_ptr<BarrierOp> *); | |||||
| private: | |||||
| int32_t builder_rows_per_buffer_; | |||||
| int32_t builder_op_connector_size_; | |||||
| std::string builder_condition_name_; | |||||
| py::function builder_condition_func_; | |||||
| Status SanityCheck() const; | |||||
| }; | |||||
| // Constructor for BarrierOp | |||||
| // @param rows_per_buffer - number of rows in output buffer | |||||
| // @param op_connector_size - connector size | |||||
| // @param condition_name - the condition name associated with this operator | |||||
| // @param condition_func - the blocking condition check per row | |||||
| // @note - currently rows_per_buffer should = 1 for barrier. | |||||
| // The reason for this is having other values would complicate how the pipeline behaves with other operators | |||||
| // One example of such case is having batch after barrier. Batch would be waiting for data and having | |||||
| // rows per buffer in this case can result in hanging | |||||
| BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, | |||||
| py::function condition_func); | |||||
| // Destructor | |||||
| ~BarrierOp(); | |||||
| Status EofReceived(int32_t) override; | |||||
| Status EoeReceived(int32_t) override; | |||||
| // Print function for Barrier | |||||
| // @param out - output stream to print to | |||||
| // @param show_all - if it should print everything | |||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| // Provide stream operator for displaying it | |||||
| friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) { | |||||
| bo.Print(out, false); | |||||
| return out; | |||||
| } | |||||
| // Class functor operator () override. | |||||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||||
| // provide the master loop that drives the logic for performing the work | |||||
| // @return Status - The error code return | |||||
| Status operator()() override; | |||||
| // Handles preprocessing of the main loop, used when starting new epoch | |||||
| // @param table - a table of tensors to be moved into a buffer | |||||
| Status prepare(TensorQTable *const table); | |||||
| // This function calls takes a table repeatedly adds rows to it. | |||||
| // @param table - a table of tensors to be moved into a buffer | |||||
| Status fillBuffer(TensorQTable *const table); | |||||
| // Gets next tensor row and sets control signals | |||||
| Status getNextTensorRow(TensorRow *new_row); | |||||
| // This function runs the wait function on condition | |||||
| Status blockCond(); | |||||
| private: | |||||
| // clean up variable to return imcomplete buffer | |||||
| bool clean_up_; | |||||
| // end of file state, we stop reading data and shut down | |||||
| bool eof_; | |||||
| // rows per buffer | |||||
| int32_t rows_per_buffer_; | |||||
| // buffer_id | |||||
| int32_t buffer_id_; | |||||
| // local variable to keep track of the buffer information | |||||
| std::unordered_map<std::string, int32_t> col_name_id_map_; | |||||
| // iterator to pull new rows, we only have one child | |||||
| std::unique_ptr<ChildIterator> child_iterator_; | |||||
| // condition name, to support multiple barriers | |||||
| std::string condition_name_; | |||||
| // Function pointer of blocking function | |||||
| py::function condition_function_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ | |||||
| @@ -34,7 +34,7 @@ class DataBuffer; | |||||
| class ZipOp : public PipelineOp { | class ZipOp : public PipelineOp { | ||||
| public: | public: | ||||
| // The nested builder class inside of the BatchOp is used to help manage all of | |||||
| // The nested builder class inside of the ZipOp is used to help manage all of | |||||
| // the arguments for constructing it. Use the builder by setting each argument | // the arguments for constructing it. Use the builder by setting each argument | ||||
| // with the provided set methods, and then finally call the build method to execute | // with the provided set methods, and then finally call the build method to execute | ||||
| // the actual construction. | // the actual construction. | ||||
| @@ -76,8 +76,8 @@ class ZipOp : public PipelineOp { | |||||
| }; | }; | ||||
| // Constructor for ZipOp | // Constructor for ZipOp | ||||
| // @param rows_per_buffer number of rows in output buffer | |||||
| // @param op_connector_size connector | |||||
| // @param rows_per_buffer - number of rows in output buffer | |||||
| // @param op_connector_size - connector size | |||||
| ZipOp(int32_t rows_per_buffer, int32_t op_connector_size); | ZipOp(int32_t rows_per_buffer, int32_t op_connector_size); | ||||
| // Destructor | // Destructor | ||||
| @@ -88,8 +88,8 @@ class ZipOp : public PipelineOp { | |||||
| Status EoeReceived(int32_t) override; | Status EoeReceived(int32_t) override; | ||||
| // Print function for Zip | // Print function for Zip | ||||
| // @param out output stream to print to | |||||
| // @param show_all if it should print everything | |||||
| // @param out - output stream to print to | |||||
| // @param show_all - if it should print everything | |||||
| void Print(std::ostream &out, bool show_all) const override; | void Print(std::ostream &out, bool show_all) const override; | ||||
| // Provide stream operator for displaying it | // Provide stream operator for displaying it | ||||
| @@ -113,14 +113,14 @@ class ZipOp : public PipelineOp { | |||||
| Status fillBuffer(TensorQTable *const table); | Status fillBuffer(TensorQTable *const table); | ||||
| // Special handle case where an empty row has been received from child iterator | // Special handle case where an empty row has been received from child iterator | ||||
| // @note we need to drain eoe signals from all children connectors. | |||||
| // @details when this function is called, then we encountered eoe at child iterator | |||||
| // @note - we need to drain eoe signals from all children connectors. | |||||
| // @details - when this function is called, then we encountered eoe at child iterator | |||||
| // we have to drain rows from other child iterators until we hit eoe from all other child iterators | // we have to drain rows from other child iterators until we hit eoe from all other child iterators | ||||
| Status drainPipeline(); | Status drainPipeline(); | ||||
| // Merges 1 row from each childIterator together | // Merges 1 row from each childIterator together | ||||
| // @param new_zip_row input and output, will return a non-empty row if all rows from childConnectors are non-empty | |||||
| // @param updateColumnMapping generates a new column name to index mapping (mColNameIdMap) if set to true | |||||
| // @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty | |||||
| // @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true | |||||
| // @details merge rows from iterator together. This is the main functionality for ZipOp | // @details merge rows from iterator together. This is the main functionality for ZipOp | ||||
| // this function takes one row and fills it with tensors from rows fetched | // this function takes one row and fills it with tensors from rows fetched | ||||
| // from childIterators. | // from childIterators. | ||||
| @@ -28,6 +28,7 @@ import multiprocessing | |||||
| import queue | import queue | ||||
| from enum import Enum | from enum import Enum | ||||
| from importlib import import_module | from importlib import import_module | ||||
| import threading | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ | from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ | ||||
| @@ -40,7 +41,7 @@ from .iterators import DictIterator, TupleIterator | |||||
| from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \ | from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \ | ||||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | ||||
| check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ | check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ | ||||
| check_zip_dataset, check_add_column, check_textfiledataset | |||||
| check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset | |||||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | ||||
| try: | try: | ||||
| @@ -141,6 +142,7 @@ class Dataset: | |||||
| self._batch_size = None | self._batch_size = None | ||||
| self._num_classes = None | self._num_classes = None | ||||
| self._repeat_count = None | self._repeat_count = None | ||||
| self._sync = False | |||||
| def get_args(self): | def get_args(self): | ||||
| """ | """ | ||||
| @@ -198,6 +200,30 @@ class Dataset: | |||||
| """ | """ | ||||
| return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns) | return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns) | ||||
| @check_sync_wait | |||||
| def sync_wait(self, condition_name, num_batch=1, callback=None): | |||||
| ''' | |||||
| Add a blocking condition to the input Dataset | |||||
| Args: | |||||
| input_dataset (Dataset): Input dataset to apply flow control | |||||
| num_batch (int): the number of batches without blocking at the start of each epoch | |||||
| condition_name (str): The condition name that is used to toggle sending next row | |||||
| callback (function): The callback funciton that will be invoked when sync_update is called | |||||
| Raises: | |||||
| RuntimeError: If condition name already exists. | |||||
| Examples: | |||||
| >>> import mindspore.dataset as ds | |||||
| >>> # data is an instance of Dataset object. | |||||
| >>> data = data.sync_wait("callback1") | |||||
| >>> data = data.batch(batch_size) | |||||
| >>> for batch_data in data.create_dict_iterator(): | |||||
| >>> data = data.sync_update("callback1") | |||||
| ''' | |||||
| return SyncWaitDataset(self, condition_name, num_batch, callback) | |||||
| @check_shuffle | @check_shuffle | ||||
| def shuffle(self, buffer_size): | def shuffle(self, buffer_size): | ||||
| """ | """ | ||||
| @@ -220,6 +246,9 @@ class Dataset: | |||||
| Returns: | Returns: | ||||
| ShuffleDataset, dataset shuffled. | ShuffleDataset, dataset shuffled. | ||||
| Raises: | |||||
| RuntimeError: If exist sync operators before shuffle. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| >>> # data is an instance of Dataset object | >>> # data is an instance of Dataset object | ||||
| @@ -821,6 +850,9 @@ class Dataset: | |||||
| self._input_indexs = value | self._input_indexs = value | ||||
| def _get_pipeline_info(self): | def _get_pipeline_info(self): | ||||
| """ | |||||
| Gets pipeline information. | |||||
| """ | |||||
| device_iter = TupleIterator(self) | device_iter = TupleIterator(self) | ||||
| self._output_shapes = device_iter.get_output_shapes() | self._output_shapes = device_iter.get_output_shapes() | ||||
| self._output_types = device_iter.get_output_types() | self._output_types = device_iter.get_output_types() | ||||
| @@ -875,6 +907,30 @@ class Dataset: | |||||
| return self.input[0].num_classes() | return self.input[0].num_classes() | ||||
| return None | return None | ||||
| def get_sync_notifiers(self): | |||||
| if self.input: | |||||
| return self.input[0].get_sync_notifiers() | |||||
| return {} | |||||
| def is_sync(self): | |||||
| if self.input: | |||||
| return self.input[0].is_sync() | |||||
| return False | |||||
| def sync_update(self, condition_name, num_batch=None, data=None): | |||||
| """ | |||||
| condition_name (str): The condition name that is used to toggle sending next row | |||||
| step_size (int or None): The number of steps(rows) that are released | |||||
| when pass_rows is None, will update the same number as sync_wait specified | |||||
| data (dict or None): The data passed to the callback | |||||
| """ | |||||
| notifiers_dict = self.get_sync_notifiers() | |||||
| if condition_name not in notifiers_dict: | |||||
| raise RuntimeError("Condition name not found") | |||||
| if num_batch is not None: | |||||
| num_batch *= self.get_batch_size() | |||||
| notifiers_dict[condition_name](num_batch, data) | |||||
| def get_batch_size(self): | def get_batch_size(self): | ||||
| """ | """ | ||||
| Get the size of a batch. | Get the size of a batch. | ||||
| @@ -978,6 +1034,8 @@ class BatchDataset(DatasetOp): | |||||
| if BatchDataset._is_ancestor_of_repeat(input_dataset): | if BatchDataset._is_ancestor_of_repeat(input_dataset): | ||||
| logger.warning("Repeat is located before batch, data from two epochs can be batched together.") | logger.warning("Repeat is located before batch, data from two epochs can be batched together.") | ||||
| BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) | |||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| self.drop_remainder = drop_remainder | self.drop_remainder = drop_remainder | ||||
| self.per_batch_map = per_batch_map | self.per_batch_map = per_batch_map | ||||
| @@ -1034,6 +1092,20 @@ class BatchDataset(DatasetOp): | |||||
| flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset) | flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset) | ||||
| return flag | return flag | ||||
| @staticmethod | |||||
| def _update_batch_size_for_syncwait(dataset, batch_size): | |||||
| """ | |||||
| Utility function to notify batch size to sync_wait. | |||||
| Args: | |||||
| dataset (Dataset): dataset to be checked | |||||
| batchsize (int): batch size to notify | |||||
| """ | |||||
| if isinstance(dataset, SyncWaitDataset): | |||||
| dataset.update_sync_batch_size(batch_size) | |||||
| for input_dataset in dataset.input: | |||||
| BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) | |||||
| class BatchInfo(CBatchInfo): | class BatchInfo(CBatchInfo): | ||||
| """ | """ | ||||
| @@ -1058,6 +1130,108 @@ class BatchInfo(CBatchInfo): | |||||
| """ | """ | ||||
| return | return | ||||
| class BlockReleasePair: | |||||
| """ | |||||
| The blocking condition class used by SyncWaitDataset | |||||
| Args: | |||||
| init_release_rows (int): Number of lines to allow through the pipeline | |||||
| callback (function): The callback funciton that will be called when release is called | |||||
| """ | |||||
| def __init__(self, init_release_rows, callback=None): | |||||
| self.row_count = -init_release_rows | |||||
| self.cv = threading.Condition() | |||||
| self.callback = callback | |||||
| self.default_rows = init_release_rows | |||||
| def __deepcopy__(self, memodict): | |||||
| if id(self) in memodict: | |||||
| return memodict[id(self)] | |||||
| memodict[id(self)] = self | |||||
| # condition variable and callback are the same, but reset the counter | |||||
| self.reset() | |||||
| return self | |||||
| def reset(self): | |||||
| with self.cv: | |||||
| self.row_count = -self.default_rows | |||||
| self.cv.notify_all() | |||||
| def update_batched_size(self, batch_size): | |||||
| # should only use before the pipeline creates | |||||
| self.row_count *= batch_size | |||||
| self.default_rows *= batch_size | |||||
| def block_func(self): | |||||
| with self.cv: | |||||
| self.cv.wait_for(lambda: self.row_count < 0) | |||||
| self.row_count += 1 | |||||
| return True | |||||
| def release_func(self, pass_rows=None, data=None): | |||||
| with self.cv: | |||||
| if pass_rows is None: | |||||
| pass_rows = self.default_rows | |||||
| self.row_count -= pass_rows | |||||
| if self.callback is not None: | |||||
| self.callback(data) | |||||
| self.cv.notify_all() | |||||
| class SyncWaitDataset(DatasetOp): | |||||
| """ | |||||
| The result of adding a blocking condition to the input Dataset | |||||
| Args: | |||||
| input_dataset (Dataset): Input dataset to apply flow control | |||||
| num_batch (int): the number of batches without blocking at the start of each epoch | |||||
| condition_name (str): The condition name that is used to toggle sending next row | |||||
| callback (function): The callback funciton that will be invoked when sync_update is called | |||||
| Raises: | |||||
| RuntimeError: If condition name already exists. | |||||
| """ | |||||
| def __init__(self, input_dataset, condition_name, num_batch, callback=None): | |||||
| super().__init__() | |||||
| self.input.append(input_dataset) | |||||
| input_dataset.output.append(self) | |||||
| # set to the default value, waiting for the batch to update it | |||||
| self._condition_name = condition_name | |||||
| self._pair = BlockReleasePair(num_batch, callback) | |||||
| if self._condition_name in self.input[0].get_sync_notifiers(): | |||||
| raise RuntimeError("Condition name is already in use") | |||||
| def get_sync_notifiers(self): | |||||
| return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}} | |||||
| def is_sync(self): | |||||
| return True | |||||
| def get_args(self): | |||||
| args = super().get_args() | |||||
| args["condition_name"] = self._condition_name | |||||
| args["condition_func"] = self._pair.block_func | |||||
| return args | |||||
| def update_sync_batch_size(self, batch_size): | |||||
| self._pair.update_batched_size(batch_size) | |||||
| @staticmethod | |||||
| def _is_ancestor_of_batch(dataset): | |||||
| """ | |||||
| Utility function to find the case where sync_wait is used before batch. | |||||
| Args: | |||||
| dataset (Dataset): dataset to be checked | |||||
| Return: | |||||
| True or False | |||||
| """ | |||||
| if isinstance(dataset, BatchDataset): | |||||
| return True | |||||
| flag = False | |||||
| for input_dataset in dataset.input: | |||||
| flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset) | |||||
| return flag | |||||
| class ShuffleDataset(DatasetOp): | class ShuffleDataset(DatasetOp): | ||||
| """ | """ | ||||
| @@ -1066,6 +1240,9 @@ class ShuffleDataset(DatasetOp): | |||||
| Args: | Args: | ||||
| input_dataset (Dataset): Input Dataset to be shuffled. | input_dataset (Dataset): Input Dataset to be shuffled. | ||||
| buffer_size (int): The size of the buffer. | buffer_size (int): The size of the buffer. | ||||
| Raises: | |||||
| RuntimeError: If exist sync operators before shuffle. | |||||
| """ | """ | ||||
| def __init__(self, input_dataset, buffer_size): | def __init__(self, input_dataset, buffer_size): | ||||
| @@ -1074,6 +1251,8 @@ class ShuffleDataset(DatasetOp): | |||||
| self.input.append(input_dataset) | self.input.append(input_dataset) | ||||
| input_dataset.output.append(self) | input_dataset.output.append(self) | ||||
| self._input_indexs = input_dataset.input_indexs | self._input_indexs = input_dataset.input_indexs | ||||
| if self.is_sync(): | |||||
| raise RuntimeError("No shuffle after sync operators") | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| @@ -1427,6 +1606,9 @@ class ZipDataset(DatasetOp): | |||||
| """ | """ | ||||
| return None | return None | ||||
| def is_sync(self): | |||||
| return any([c.is_sync() for c in self.input]) | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| return args | return args | ||||
| @@ -129,6 +129,8 @@ class Iterator: | |||||
| op_type = OpName.MINDRECORD | op_type = OpName.MINDRECORD | ||||
| elif isinstance(dataset, de.BatchDataset): | elif isinstance(dataset, de.BatchDataset): | ||||
| op_type = OpName.BATCH | op_type = OpName.BATCH | ||||
| elif isinstance(dataset, de.SyncWaitDataset): | |||||
| op_type = OpName.BARRIER | |||||
| elif isinstance(dataset, de.ZipDataset): | elif isinstance(dataset, de.ZipDataset): | ||||
| op_type = OpName.ZIP | op_type = OpName.ZIP | ||||
| elif isinstance(dataset, de.MapDataset): | elif isinstance(dataset, de.MapDataset): | ||||
| @@ -652,6 +652,22 @@ def check_batch(method): | |||||
| return new_method | return new_method | ||||
| def check_sync_wait(method): | |||||
| """check the input arguments of sync_wait.""" | |||||
| @wraps(method) | |||||
| def new_method(*args, **kwargs): | |||||
| param_dict = make_param_dict(method, args, kwargs) | |||||
| nreq_param_str = ['condition_name'] | |||||
| nreq_param_int = ['step_size'] | |||||
| check_param_type(nreq_param_int, param_dict, int) | |||||
| check_param_type(nreq_param_str, param_dict, str) | |||||
| return method(*args, **kwargs) | |||||
| return new_method | |||||
| def check_shuffle(method): | def check_shuffle(method): | ||||
| """check the input arguments of shuffle.""" | """check the input arguments of shuffle.""" | ||||
| @@ -12,8 +12,18 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """ | |||||
| Testing configuration manager | |||||
| """ | |||||
| import filecmp | |||||
| import glob | |||||
| import os | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||||
| def test_basic(): | def test_basic(): | ||||
| ds.config.load('../data/dataset/declient.cfg') | ds.config.load('../data/dataset/declient.cfg') | ||||
| @@ -36,6 +46,34 @@ def test_basic(): | |||||
| assert ds.config.get_prefetch_size() == 4 | assert ds.config.get_prefetch_size() == 4 | ||||
| assert ds.config.get_seed() == 5 | assert ds.config.get_seed() == 5 | ||||
| def test_pipeline(): | |||||
| """ | |||||
| Test that our configuration pipeline works when we set parameters at dataset interval | |||||
| """ | |||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | |||||
| ds.config.set_num_parallel_workers(2) | |||||
| data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)]) | |||||
| ds.serialize(data1, "testpipeline.json") | |||||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | |||||
| ds.config.set_num_parallel_workers(4) | |||||
| data2 = data2.map(input_columns=["image"], operations=[vision.Decode(True)]) | |||||
| ds.serialize(data2, "testpipeline2.json") | |||||
| # check that the generated output is different | |||||
| assert (filecmp.cmp('testpipeline.json', 'testpipeline2.json')) | |||||
| # this test passes currently because our num_parallel_workers don't get updated. | |||||
| # remove generated jason files | |||||
| file_list = glob.glob('*.json') | |||||
| for f in file_list: | |||||
| try: | |||||
| os.remove(f) | |||||
| except IOError: | |||||
| logger.info("Error while deleting: {}".format(f)) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_basic() | test_basic() | ||||
| test_pipeline() | |||||
| @@ -0,0 +1,182 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import mindspore.dataset as ds | |||||
| from mindspore import log as logger | |||||
| import time | |||||
| import numpy as np | |||||
| def gen(): | |||||
| for i in range(100): | |||||
| yield np.array(i), | |||||
| class Augment: | |||||
| def __init__(self, loss): | |||||
| self.loss = loss | |||||
| def preprocess(self, input): | |||||
| return input | |||||
| def update(self, data): | |||||
| self.loss = data["loss"] | |||||
| def test_simple_sync_wait(): | |||||
| """ | |||||
| Test simple sync wait: test sync in dataset pipeline | |||||
| """ | |||||
| logger.info("test_simple_sync_wait") | |||||
| batch_size = 4 | |||||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | |||||
| aug = Augment(0) | |||||
| dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) | |||||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | |||||
| dataset = dataset.batch(batch_size) | |||||
| count = 0 | |||||
| for data in dataset.create_dict_iterator(): | |||||
| assert (data["input"][0] == count) | |||||
| count += batch_size | |||||
| data = {"loss": count} | |||||
| dataset.sync_update(condition_name="policy", data=data) | |||||
| def test_simple_shuffle_sync(): | |||||
| """ | |||||
| Test simple shuffle sync: test shuffle before sync | |||||
| """ | |||||
| logger.info("test_simple_shuffle_sync") | |||||
| shuffle_size = 4 | |||||
| batch_size = 10 | |||||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | |||||
| aug = Augment(0) | |||||
| dataset = dataset.shuffle(shuffle_size) | |||||
| dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) | |||||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | |||||
| dataset = dataset.batch(batch_size) | |||||
| count = 0 | |||||
| for data in dataset.create_dict_iterator(): | |||||
| count += 1 | |||||
| #time.sleep(0.5) | |||||
| data = {"loss": count} | |||||
| dataset.sync_update(condition_name="policy", data=data) | |||||
| def test_two_sync(): | |||||
| """ | |||||
| Test two sync: dataset pipeline with with two sync_operators | |||||
| """ | |||||
| logger.info("test_two_sync") | |||||
| batch_size = 6 | |||||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | |||||
| aug = Augment(0) | |||||
| # notice that with our design, we need to have step_size = shuffle size | |||||
| dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) | |||||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | |||||
| dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches") | |||||
| dataset = dataset.batch(batch_size) | |||||
| count = 0 | |||||
| for data in dataset.create_dict_iterator(): | |||||
| count += 1 | |||||
| data = {"loss": count} | |||||
| dataset.sync_update(condition_name="every batch", data=data) | |||||
| if count % 2 == 0: | |||||
| dataset.sync_update(condition_name="every 2 batches") | |||||
| def test_sync_epoch(): | |||||
| """ | |||||
| Test sync wait with epochs: test sync with epochs in dataset pipeline | |||||
| """ | |||||
| logger.info("test_sync_epoch") | |||||
| batch_size = 30 | |||||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | |||||
| aug = Augment(0) | |||||
| dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) | |||||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | |||||
| dataset = dataset.batch(batch_size, drop_remainder=True) | |||||
| for epochs in range(3): | |||||
| aug.update({"loss": 0}) | |||||
| count = 0 | |||||
| for data in dataset.create_dict_iterator(): | |||||
| assert (data["input"][0] == count) | |||||
| count += batch_size | |||||
| data = {"loss": count} | |||||
| dataset.sync_update(condition_name="policy", data=data) | |||||
| def test_sync_exception_01(): | |||||
| """ | |||||
| Test sync: with shuffle in sync mode | |||||
| """ | |||||
| logger.info("test_sync_exception_01") | |||||
| shuffle_size = 4 | |||||
| batch_size = 10 | |||||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | |||||
| aug = Augment(0) | |||||
| dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) | |||||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | |||||
| try: | |||||
| dataset = dataset.shuffle(shuffle_size) | |||||
| except BaseException as e: | |||||
| assert "shuffle" in str(e) | |||||
| dataset = dataset.batch(batch_size) | |||||
| def test_sync_exception_02(): | |||||
| """ | |||||
| Test sync: with duplicated condition name | |||||
| """ | |||||
| logger.info("test_sync_exception_02") | |||||
| batch_size = 6 | |||||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | |||||
| aug = Augment(0) | |||||
| # notice that with our design, we need to have step_size = shuffle size | |||||
| dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) | |||||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | |||||
| try: | |||||
| dataset = dataset.sync_wait(num_batch=2, condition_name="every batch") | |||||
| except BaseException as e: | |||||
| assert "name" in str(e) | |||||
| dataset = dataset.batch(batch_size) | |||||
| if __name__ == "__main__": | |||||
| test_simple_sync_wait() | |||||
| test_simple_shuffle_sync() | |||||
| test_two_sync() | |||||
| test_sync_exception_01() | |||||
| test_sync_exception_02() | |||||
| test_sync_epoch() | |||||