Merge pull request !508 from EricZ/mastertags/v0.2.0-alpha
| @@ -48,6 +48,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D | |||
| {kMap, &DEPipeline::ParseMapOp}, | |||
| {kFilter, &DEPipeline::ParseFilterOp}, | |||
| {kBatch, &DEPipeline::ParseBatchOp}, | |||
| {kBarrier, &DEPipeline::ParseBarrierOp}, | |||
| {kRepeat, &DEPipeline::ParseRepeatOp}, | |||
| {kSkip, &DEPipeline::ParseSkipOp}, | |||
| {kZip, &DEPipeline::ParseZipOp}, | |||
| @@ -627,6 +628,30 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>(); | |||
| // Right now barrier should only take num_rows_per_buffer = 1 | |||
| // The reason for this is because having it otherwise can lead to blocking issues | |||
| // See barrier_op.h for more details | |||
| (void)builder->SetRowsPerBuffer(1); | |||
| for (auto arg : args) { | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "condition_name") { | |||
| (void)builder->SetConditionName(ToString(value)); | |||
| } else if (key == "condition_func") { | |||
| (void)builder->SetConditionFunc(value.cast<py::function>()); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<BarrierOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *ptr = op; | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| int32_t prefetch_size = 0; | |||
| if (args.contains("prefetch_size")) { | |||
| @@ -40,6 +40,7 @@ enum OpName { | |||
| kShuffle, | |||
| kMindrecord, | |||
| kBatch, | |||
| kBarrier, | |||
| kCache, | |||
| kRepeat, | |||
| kSkip, | |||
| @@ -115,6 +116,8 @@ class DEPipeline { | |||
| Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| @@ -481,6 +481,7 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||
| .value("STORAGE", OpName::kStorage) | |||
| .value("SHUFFLE", OpName::kShuffle) | |||
| .value("BATCH", OpName::kBatch) | |||
| .value("BARRIER", OpName::kBarrier) | |||
| .value("MINDRECORD", OpName::kMindrecord) | |||
| .value("CACHE", OpName::kCache) | |||
| .value("REPEAT", OpName::kRepeat) | |||
| @@ -25,6 +25,7 @@ | |||
| #include "dataset/core/tensor_shape.h" | |||
| #include "dataset/engine/data_schema.h" | |||
| #include "dataset/engine/dataset_iterator.h" | |||
| #include "dataset/engine/datasetops/barrier_op.h" | |||
| #include "dataset/engine/datasetops/batch_op.h" | |||
| #include "dataset/engine/datasetops/dataset_op.h" | |||
| #include "dataset/engine/datasetops/device_queue_op.h" | |||
| @@ -4,6 +4,7 @@ add_library(engine-datasetops OBJECT | |||
| dataset_op.cc | |||
| parallel_op.cc | |||
| pipeline_op.cc | |||
| barrier_op.cc | |||
| batch_op.cc | |||
| device_queue_op.cc | |||
| map_op.cc | |||
| @@ -0,0 +1,235 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/datasetops/barrier_op.h" | |||
| #include <utility> | |||
| #include "dataset/core/constants.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/core/global_context.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| BarrierOp::Builder::Builder() { | |||
| // Some arguments to the BarrierOp constructor have a default argument that is taken | |||
| // from the client config. | |||
| // The user may choose to change these values for the construction of the BarrierOp by | |||
| // using the various builder set methods. | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| builder_rows_per_buffer_ = cfg->rows_per_buffer(); | |||
| builder_op_connector_size_ = cfg->op_connector_size(); | |||
| } | |||
| Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); } | |||
| Status BarrierOp::Builder::Build(std::shared_ptr<BarrierOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<BarrierOp>(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_, | |||
| builder_condition_func_); | |||
| return Status::OK(); | |||
| } | |||
| // Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions | |||
| BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, | |||
| py::function condition_func) | |||
| : PipelineOp(op_connector_size), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| buffer_id_(0), | |||
| clean_up_(false), | |||
| eof_(false), | |||
| condition_name_(condition_name), | |||
| condition_function_(condition_func) {} | |||
| // destructor | |||
| BarrierOp::~BarrierOp() {} | |||
| // Entry point for Barrier, called by launch() | |||
| Status BarrierOp::operator()() { | |||
| // The children_num_ parameter needs to be put here | |||
| // Synchronize with TaskManager once the thread is created. | |||
| TaskManager::FindMe()->Post(); | |||
| // create child iterator, right now this barrier is a pipeline operator | |||
| int32_t worker_id = 0; | |||
| int32_t child_idx = 0; | |||
| child_iterator_ = std::make_unique<ChildIterator>(this, worker_id, child_idx); | |||
| // Loop until eof is true | |||
| while (!eof_) { | |||
| // Create new table to put the new tensor rows | |||
| std::unique_ptr<TensorQTable> curr_table = std::make_unique<TensorQTable>(); | |||
| RETURN_IF_NOT_OK(prepare(curr_table.get())); | |||
| // If an eof got picked up during the above prepare, then we're done | |||
| if (eof_) { | |||
| break; | |||
| } | |||
| // we have to output new buffer with possibly different buffer size, possibly one row | |||
| while (!clean_up_) { | |||
| // 1. If a previous loop iteration sent the current table out, then create a new one. | |||
| if (curr_table == nullptr) { | |||
| curr_table = std::make_unique<TensorQTable>(); | |||
| } | |||
| // 2 fill the table. Note: clean_up mode might get turned on if epoch is finished | |||
| RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); | |||
| // 3 create and update buffer and send it to the out connector | |||
| if (!curr_table->empty()) { | |||
| std::unique_ptr<DataBuffer> curr_buffer = std::make_unique<DataBuffer>(buffer_id_, DataBuffer::kDeBFlagNone); | |||
| curr_buffer->set_tensor_table(std::move(curr_table)); | |||
| curr_buffer->set_column_name_map(col_name_id_map_); | |||
| MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " | |||
| << curr_buffer->NumCols() << ", map " << col_name_id_map_.size() << "."; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); | |||
| buffer_id_++; | |||
| } | |||
| } | |||
| // 4 handle drain state. | |||
| if (clean_up_) { | |||
| MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal."; | |||
| // Send the eoe up. | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)))); | |||
| } | |||
| } | |||
| // 5 handle eof | |||
| // propagate eof here. | |||
| MS_LOG(INFO) << "Barrier operator got EOF, propagating."; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)))); | |||
| return Status::OK(); | |||
| } | |||
| // Handles preprocessing of the main loop, used when starting new epoch | |||
| Status BarrierOp::prepare(TensorQTable *const table) { | |||
| MS_LOG(DEBUG) << "Barrier operator prepares for new epoch."; | |||
| clean_up_ = false; | |||
| buffer_id_ = 0; | |||
| if (table == nullptr) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table."); | |||
| } | |||
| // fill initial row | |||
| TensorRow new_row = {}; | |||
| // use iterator to get next row and invoke pyfunc wait | |||
| RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); | |||
| // If the first row fetching resulted in eof, then we are done. | |||
| if (eof_) { | |||
| return Status::OK(); | |||
| } | |||
| if (new_row.empty()) { | |||
| // This epoch is empty | |||
| return Status::OK(); | |||
| } | |||
| // Pack this first row into our tensor table | |||
| // first row we also have to check if we should block | |||
| RETURN_IF_NOT_OK(blockCond()); | |||
| table->push_back(std::move(new_row)); | |||
| // At this point we have 1 row produced, we take the old column map id and use it in the new table | |||
| // Initializing col_name_id_map_ from the first data buffer. | |||
| col_name_id_map_ = child_iterator_->col_name_id_map(); | |||
| // the update code below shouldn't do anything bad if the column name already exists. | |||
| return Status::OK(); | |||
| } | |||
| // fillBuffer always expects a new table to fill | |||
| Status BarrierOp::fillBuffer(TensorQTable *const table) { | |||
| if (table == nullptr) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer."); | |||
| } | |||
| TensorRow new_row = {}; | |||
| while (table->size() < static_cast<size_t>(rows_per_buffer_)) { | |||
| RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); | |||
| // Early exit the loop if we got empty row from any of our child iterations | |||
| if (new_row.empty()) { | |||
| return Status::OK(); | |||
| } | |||
| // else we got a row so pack it into the tensor table. | |||
| RETURN_IF_NOT_OK(blockCond()); | |||
| table->push_back(std::move(new_row)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // function executes a py_func and blocks until condition becomes true. | |||
| Status BarrierOp::blockCond() { | |||
| { | |||
| py::gil_scoped_acquire gil_acquire; | |||
| if (Py_IsInitialized() == 0) { | |||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||
| } | |||
| // we have condition name, however the flexibility is in python today | |||
| try { | |||
| // Invoke python function | |||
| py::object ret_py_obj = condition_function_(); | |||
| // Process the return value | |||
| if (!py::isinstance<py::bool_>(ret_py_obj)) { | |||
| return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false"); | |||
| } | |||
| } catch (const py::error_already_set &e) { | |||
| return Status(StatusCode::kPyFuncException, e.what()); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // fetches next Barrier buffer row | |||
| Status BarrierOp::getNextTensorRow(TensorRow *new_row) { | |||
| // iterate over all iterators and generate a row | |||
| RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row)); | |||
| // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row | |||
| if (new_row->empty()) { | |||
| // If we did not get a row from any of the children, then it's the end of an epoch and we can move | |||
| // to drain state. | |||
| MS_LOG(INFO) << "Barrier operator child iterator produced empty row."; | |||
| clean_up_ = true; | |||
| // If we picked up an eof here, then we are completely done. | |||
| if ((child_iterator_)->eof_handled()) { | |||
| MS_LOG(INFO) << "Barrier operator iterator got EOF."; | |||
| eof_ = true; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // A function that prints info about the Operator | |||
| void BarrierOp::Print(std::ostream &out, bool show_all) const { | |||
| // Call base class printer first | |||
| PipelineOp::Print(out, show_all); | |||
| out << "\nBarrierOp:\n" | |||
| << "\nCondition " << condition_name_ << "\n\n"; | |||
| } | |||
| // overwrite function and handle eof | |||
| Status BarrierOp::EofReceived(int32_t) { | |||
| MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now."; | |||
| return Status::OK(); | |||
| } | |||
| // overwrite function and handle eoe | |||
| Status BarrierOp::EoeReceived(int32_t) { | |||
| state_ = OpState::kDeOpIdle; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,172 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/engine/dataset_iterator.h" | |||
| #include "dataset/engine/datasetops/pipeline_op.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Forward declare | |||
| class DataBuffer; | |||
| class ExecutionTree; | |||
| // BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has | |||
| // been received. This signal is given from python layer. The current barrier design respects the | |||
| // rows per buffer design and will only output a buffer with rows once it has received rows per buffer | |||
| // signals from python. | |||
| class BarrierOp : public PipelineOp { | |||
| public: | |||
| // The nested builder class inside of the BarrierOp is used to help manage all of | |||
| // the arguments for constructing it. Use the builder by setting each argument | |||
| // with the provided set methods, and then finally call the build method to execute | |||
| // the actual construction. | |||
| class Builder { | |||
| public: | |||
| // Builder constructor. Creates the builder object. | |||
| // @note No default args | |||
| // @return This is a constructor. | |||
| Builder(); | |||
| // Default destructor | |||
| ~Builder() = default; | |||
| // Setter method. | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { | |||
| builder_rows_per_buffer_ = rows_per_buffer; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @param int32_t op_connector_size | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetOpConnectorSize(int32_t op_connector_size) { | |||
| builder_op_connector_size_ = op_connector_size; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @param const std::string & condition_name | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetConditionName(const std::string &condition_name) { | |||
| builder_condition_name_ = condition_name; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @param py::function condition_func - blocking condition function | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetConditionFunc(py::function condition_func) { | |||
| builder_condition_func_ = condition_func; | |||
| return *this; | |||
| } | |||
| // The builder "build" method creates the BarrierOp dataset Operator. | |||
| // @return shared_ptr to the new BarrierOp object | |||
| Status Build(std::shared_ptr<BarrierOp> *); | |||
| private: | |||
| int32_t builder_rows_per_buffer_; | |||
| int32_t builder_op_connector_size_; | |||
| std::string builder_condition_name_; | |||
| py::function builder_condition_func_; | |||
| Status SanityCheck() const; | |||
| }; | |||
| // Constructor for BarrierOp | |||
| // @param rows_per_buffer - number of rows in output buffer | |||
| // @param op_connector_size - connector size | |||
| // @param condition_name - the condition name associated with this operator | |||
| // @param condition_func - the blocking condition check per row | |||
| // @note - currently rows_per_buffer should = 1 for barrier. | |||
| // The reason for this is having other values would complicate how the pipeline behaves with other operators | |||
| // One example of such case is having batch after barrier. Batch would be waiting for data and having | |||
| // rows per buffer in this case can result in hanging | |||
| BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, | |||
| py::function condition_func); | |||
| // Destructor | |||
| ~BarrierOp(); | |||
| Status EofReceived(int32_t) override; | |||
| Status EoeReceived(int32_t) override; | |||
| // Print function for Barrier | |||
| // @param out - output stream to print to | |||
| // @param show_all - if it should print everything | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // Provide stream operator for displaying it | |||
| friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) { | |||
| bo.Print(out, false); | |||
| return out; | |||
| } | |||
| // Class functor operator () override. | |||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // Handles preprocessing of the main loop, used when starting new epoch | |||
| // @param table - a table of tensors to be moved into a buffer | |||
| Status prepare(TensorQTable *const table); | |||
| // This function calls takes a table repeatedly adds rows to it. | |||
| // @param table - a table of tensors to be moved into a buffer | |||
| Status fillBuffer(TensorQTable *const table); | |||
| // Gets next tensor row and sets control signals | |||
| Status getNextTensorRow(TensorRow *new_row); | |||
| // This function runs the wait function on condition | |||
| Status blockCond(); | |||
| private: | |||
| // clean up variable to return imcomplete buffer | |||
| bool clean_up_; | |||
| // end of file state, we stop reading data and shut down | |||
| bool eof_; | |||
| // rows per buffer | |||
| int32_t rows_per_buffer_; | |||
| // buffer_id | |||
| int32_t buffer_id_; | |||
| // local variable to keep track of the buffer information | |||
| std::unordered_map<std::string, int32_t> col_name_id_map_; | |||
| // iterator to pull new rows, we only have one child | |||
| std::unique_ptr<ChildIterator> child_iterator_; | |||
| // condition name, to support multiple barriers | |||
| std::string condition_name_; | |||
| // Function pointer of blocking function | |||
| py::function condition_function_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ | |||
| @@ -34,7 +34,7 @@ class DataBuffer; | |||
| class ZipOp : public PipelineOp { | |||
| public: | |||
| // The nested builder class inside of the BatchOp is used to help manage all of | |||
| // The nested builder class inside of the ZipOp is used to help manage all of | |||
| // the arguments for constructing it. Use the builder by setting each argument | |||
| // with the provided set methods, and then finally call the build method to execute | |||
| // the actual construction. | |||
| @@ -76,8 +76,8 @@ class ZipOp : public PipelineOp { | |||
| }; | |||
| // Constructor for ZipOp | |||
| // @param rows_per_buffer number of rows in output buffer | |||
| // @param op_connector_size connector | |||
| // @param rows_per_buffer - number of rows in output buffer | |||
| // @param op_connector_size - connector size | |||
| ZipOp(int32_t rows_per_buffer, int32_t op_connector_size); | |||
| // Destructor | |||
| @@ -88,8 +88,8 @@ class ZipOp : public PipelineOp { | |||
| Status EoeReceived(int32_t) override; | |||
| // Print function for Zip | |||
| // @param out output stream to print to | |||
| // @param show_all if it should print everything | |||
| // @param out - output stream to print to | |||
| // @param show_all - if it should print everything | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // Provide stream operator for displaying it | |||
| @@ -113,14 +113,14 @@ class ZipOp : public PipelineOp { | |||
| Status fillBuffer(TensorQTable *const table); | |||
| // Special handle case where an empty row has been received from child iterator | |||
| // @note we need to drain eoe signals from all children connectors. | |||
| // @details when this function is called, then we encountered eoe at child iterator | |||
| // @note - we need to drain eoe signals from all children connectors. | |||
| // @details - when this function is called, then we encountered eoe at child iterator | |||
| // we have to drain rows from other child iterators until we hit eoe from all other child iterators | |||
| Status drainPipeline(); | |||
| // Merges 1 row from each childIterator together | |||
| // @param new_zip_row input and output, will return a non-empty row if all rows from childConnectors are non-empty | |||
| // @param updateColumnMapping generates a new column name to index mapping (mColNameIdMap) if set to true | |||
| // @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty | |||
| // @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true | |||
| // @details merge rows from iterator together. This is the main functionality for ZipOp | |||
| // this function takes one row and fills it with tensors from rows fetched | |||
| // from childIterators. | |||
| @@ -28,6 +28,7 @@ import multiprocessing | |||
| import queue | |||
| from enum import Enum | |||
| from importlib import import_module | |||
| import threading | |||
| import numpy as np | |||
| from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ | |||
| @@ -40,7 +41,7 @@ from .iterators import DictIterator, TupleIterator | |||
| from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \ | |||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ | |||
| check_zip_dataset, check_add_column, check_textfiledataset | |||
| check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset | |||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | |||
| try: | |||
| @@ -141,6 +142,7 @@ class Dataset: | |||
| self._batch_size = None | |||
| self._num_classes = None | |||
| self._repeat_count = None | |||
| self._sync = False | |||
| def get_args(self): | |||
| """ | |||
| @@ -198,6 +200,30 @@ class Dataset: | |||
| """ | |||
| return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns) | |||
| @check_sync_wait | |||
| def sync_wait(self, condition_name, num_batch=1, callback=None): | |||
| ''' | |||
| Add a blocking condition to the input Dataset | |||
| Args: | |||
| input_dataset (Dataset): Input dataset to apply flow control | |||
| num_batch (int): the number of batches without blocking at the start of each epoch | |||
| condition_name (str): The condition name that is used to toggle sending next row | |||
| callback (function): The callback funciton that will be invoked when sync_update is called | |||
| Raises: | |||
| RuntimeError: If condition name already exists. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> # data is an instance of Dataset object. | |||
| >>> data = data.sync_wait("callback1") | |||
| >>> data = data.batch(batch_size) | |||
| >>> for batch_data in data.create_dict_iterator(): | |||
| >>> data = data.sync_update("callback1") | |||
| ''' | |||
| return SyncWaitDataset(self, condition_name, num_batch, callback) | |||
| @check_shuffle | |||
| def shuffle(self, buffer_size): | |||
| """ | |||
| @@ -220,6 +246,9 @@ class Dataset: | |||
| Returns: | |||
| ShuffleDataset, dataset shuffled. | |||
| Raises: | |||
| RuntimeError: If exist sync operators before shuffle. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> # data is an instance of Dataset object | |||
| @@ -821,6 +850,9 @@ class Dataset: | |||
| self._input_indexs = value | |||
| def _get_pipeline_info(self): | |||
| """ | |||
| Gets pipeline information. | |||
| """ | |||
| device_iter = TupleIterator(self) | |||
| self._output_shapes = device_iter.get_output_shapes() | |||
| self._output_types = device_iter.get_output_types() | |||
| @@ -875,6 +907,30 @@ class Dataset: | |||
| return self.input[0].num_classes() | |||
| return None | |||
| def get_sync_notifiers(self): | |||
| if self.input: | |||
| return self.input[0].get_sync_notifiers() | |||
| return {} | |||
| def is_sync(self): | |||
| if self.input: | |||
| return self.input[0].is_sync() | |||
| return False | |||
| def sync_update(self, condition_name, num_batch=None, data=None): | |||
| """ | |||
| condition_name (str): The condition name that is used to toggle sending next row | |||
| step_size (int or None): The number of steps(rows) that are released | |||
| when pass_rows is None, will update the same number as sync_wait specified | |||
| data (dict or None): The data passed to the callback | |||
| """ | |||
| notifiers_dict = self.get_sync_notifiers() | |||
| if condition_name not in notifiers_dict: | |||
| raise RuntimeError("Condition name not found") | |||
| if num_batch is not None: | |||
| num_batch *= self.get_batch_size() | |||
| notifiers_dict[condition_name](num_batch, data) | |||
| def get_batch_size(self): | |||
| """ | |||
| Get the size of a batch. | |||
| @@ -978,6 +1034,8 @@ class BatchDataset(DatasetOp): | |||
| if BatchDataset._is_ancestor_of_repeat(input_dataset): | |||
| logger.warning("Repeat is located before batch, data from two epochs can be batched together.") | |||
| BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) | |||
| self.batch_size = batch_size | |||
| self.drop_remainder = drop_remainder | |||
| self.per_batch_map = per_batch_map | |||
| @@ -1034,6 +1092,20 @@ class BatchDataset(DatasetOp): | |||
| flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset) | |||
| return flag | |||
| @staticmethod | |||
| def _update_batch_size_for_syncwait(dataset, batch_size): | |||
| """ | |||
| Utility function to notify batch size to sync_wait. | |||
| Args: | |||
| dataset (Dataset): dataset to be checked | |||
| batchsize (int): batch size to notify | |||
| """ | |||
| if isinstance(dataset, SyncWaitDataset): | |||
| dataset.update_sync_batch_size(batch_size) | |||
| for input_dataset in dataset.input: | |||
| BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) | |||
| class BatchInfo(CBatchInfo): | |||
| """ | |||
| @@ -1058,6 +1130,108 @@ class BatchInfo(CBatchInfo): | |||
| """ | |||
| return | |||
| class BlockReleasePair: | |||
| """ | |||
| The blocking condition class used by SyncWaitDataset | |||
| Args: | |||
| init_release_rows (int): Number of lines to allow through the pipeline | |||
| callback (function): The callback funciton that will be called when release is called | |||
| """ | |||
| def __init__(self, init_release_rows, callback=None): | |||
| self.row_count = -init_release_rows | |||
| self.cv = threading.Condition() | |||
| self.callback = callback | |||
| self.default_rows = init_release_rows | |||
| def __deepcopy__(self, memodict): | |||
| if id(self) in memodict: | |||
| return memodict[id(self)] | |||
| memodict[id(self)] = self | |||
| # condition variable and callback are the same, but reset the counter | |||
| self.reset() | |||
| return self | |||
| def reset(self): | |||
| with self.cv: | |||
| self.row_count = -self.default_rows | |||
| self.cv.notify_all() | |||
| def update_batched_size(self, batch_size): | |||
| # should only use before the pipeline creates | |||
| self.row_count *= batch_size | |||
| self.default_rows *= batch_size | |||
| def block_func(self): | |||
| with self.cv: | |||
| self.cv.wait_for(lambda: self.row_count < 0) | |||
| self.row_count += 1 | |||
| return True | |||
| def release_func(self, pass_rows=None, data=None): | |||
| with self.cv: | |||
| if pass_rows is None: | |||
| pass_rows = self.default_rows | |||
| self.row_count -= pass_rows | |||
| if self.callback is not None: | |||
| self.callback(data) | |||
| self.cv.notify_all() | |||
| class SyncWaitDataset(DatasetOp): | |||
| """ | |||
| The result of adding a blocking condition to the input Dataset | |||
| Args: | |||
| input_dataset (Dataset): Input dataset to apply flow control | |||
| num_batch (int): the number of batches without blocking at the start of each epoch | |||
| condition_name (str): The condition name that is used to toggle sending next row | |||
| callback (function): The callback funciton that will be invoked when sync_update is called | |||
| Raises: | |||
| RuntimeError: If condition name already exists. | |||
| """ | |||
| def __init__(self, input_dataset, condition_name, num_batch, callback=None): | |||
| super().__init__() | |||
| self.input.append(input_dataset) | |||
| input_dataset.output.append(self) | |||
| # set to the default value, waiting for the batch to update it | |||
| self._condition_name = condition_name | |||
| self._pair = BlockReleasePair(num_batch, callback) | |||
| if self._condition_name in self.input[0].get_sync_notifiers(): | |||
| raise RuntimeError("Condition name is already in use") | |||
| def get_sync_notifiers(self): | |||
| return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}} | |||
| def is_sync(self): | |||
| return True | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| args["condition_name"] = self._condition_name | |||
| args["condition_func"] = self._pair.block_func | |||
| return args | |||
| def update_sync_batch_size(self, batch_size): | |||
| self._pair.update_batched_size(batch_size) | |||
| @staticmethod | |||
| def _is_ancestor_of_batch(dataset): | |||
| """ | |||
| Utility function to find the case where sync_wait is used before batch. | |||
| Args: | |||
| dataset (Dataset): dataset to be checked | |||
| Return: | |||
| True or False | |||
| """ | |||
| if isinstance(dataset, BatchDataset): | |||
| return True | |||
| flag = False | |||
| for input_dataset in dataset.input: | |||
| flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset) | |||
| return flag | |||
| class ShuffleDataset(DatasetOp): | |||
| """ | |||
| @@ -1066,6 +1240,9 @@ class ShuffleDataset(DatasetOp): | |||
| Args: | |||
| input_dataset (Dataset): Input Dataset to be shuffled. | |||
| buffer_size (int): The size of the buffer. | |||
| Raises: | |||
| RuntimeError: If exist sync operators before shuffle. | |||
| """ | |||
| def __init__(self, input_dataset, buffer_size): | |||
| @@ -1074,6 +1251,8 @@ class ShuffleDataset(DatasetOp): | |||
| self.input.append(input_dataset) | |||
| input_dataset.output.append(self) | |||
| self._input_indexs = input_dataset.input_indexs | |||
| if self.is_sync(): | |||
| raise RuntimeError("No shuffle after sync operators") | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| @@ -1427,6 +1606,9 @@ class ZipDataset(DatasetOp): | |||
| """ | |||
| return None | |||
| def is_sync(self): | |||
| return any([c.is_sync() for c in self.input]) | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| return args | |||
| @@ -129,6 +129,8 @@ class Iterator: | |||
| op_type = OpName.MINDRECORD | |||
| elif isinstance(dataset, de.BatchDataset): | |||
| op_type = OpName.BATCH | |||
| elif isinstance(dataset, de.SyncWaitDataset): | |||
| op_type = OpName.BARRIER | |||
| elif isinstance(dataset, de.ZipDataset): | |||
| op_type = OpName.ZIP | |||
| elif isinstance(dataset, de.MapDataset): | |||
| @@ -652,6 +652,22 @@ def check_batch(method): | |||
| return new_method | |||
| def check_sync_wait(method): | |||
| """check the input arguments of sync_wait.""" | |||
| @wraps(method) | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| nreq_param_str = ['condition_name'] | |||
| nreq_param_int = ['step_size'] | |||
| check_param_type(nreq_param_int, param_dict, int) | |||
| check_param_type(nreq_param_str, param_dict, str) | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| def check_shuffle(method): | |||
| """check the input arguments of shuffle.""" | |||
| @@ -12,8 +12,18 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing configuration manager | |||
| """ | |||
| import filecmp | |||
| import glob | |||
| import os | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| def test_basic(): | |||
| ds.config.load('../data/dataset/declient.cfg') | |||
| @@ -36,6 +46,34 @@ def test_basic(): | |||
| assert ds.config.get_prefetch_size() == 4 | |||
| assert ds.config.get_seed() == 5 | |||
| def test_pipeline(): | |||
| """ | |||
| Test that our configuration pipeline works when we set parameters at dataset interval | |||
| """ | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | |||
| ds.config.set_num_parallel_workers(2) | |||
| data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)]) | |||
| ds.serialize(data1, "testpipeline.json") | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | |||
| ds.config.set_num_parallel_workers(4) | |||
| data2 = data2.map(input_columns=["image"], operations=[vision.Decode(True)]) | |||
| ds.serialize(data2, "testpipeline2.json") | |||
| # check that the generated output is different | |||
| assert (filecmp.cmp('testpipeline.json', 'testpipeline2.json')) | |||
| # this test passes currently because our num_parallel_workers don't get updated. | |||
| # remove generated jason files | |||
| file_list = glob.glob('*.json') | |||
| for f in file_list: | |||
| try: | |||
| os.remove(f) | |||
| except IOError: | |||
| logger.info("Error while deleting: {}".format(f)) | |||
| if __name__ == '__main__': | |||
| test_basic() | |||
| test_pipeline() | |||
| @@ -0,0 +1,182 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| import time | |||
| import numpy as np | |||
| def gen(): | |||
| for i in range(100): | |||
| yield np.array(i), | |||
| class Augment: | |||
| def __init__(self, loss): | |||
| self.loss = loss | |||
| def preprocess(self, input): | |||
| return input | |||
| def update(self, data): | |||
| self.loss = data["loss"] | |||
| def test_simple_sync_wait(): | |||
| """ | |||
| Test simple sync wait: test sync in dataset pipeline | |||
| """ | |||
| logger.info("test_simple_sync_wait") | |||
| batch_size = 4 | |||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | |||
| aug = Augment(0) | |||
| dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) | |||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | |||
| dataset = dataset.batch(batch_size) | |||
| count = 0 | |||
| for data in dataset.create_dict_iterator(): | |||
| assert (data["input"][0] == count) | |||
| count += batch_size | |||
| data = {"loss": count} | |||
| dataset.sync_update(condition_name="policy", data=data) | |||
| def test_simple_shuffle_sync(): | |||
| """ | |||
| Test simple shuffle sync: test shuffle before sync | |||
| """ | |||
| logger.info("test_simple_shuffle_sync") | |||
| shuffle_size = 4 | |||
| batch_size = 10 | |||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | |||
| aug = Augment(0) | |||
| dataset = dataset.shuffle(shuffle_size) | |||
| dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) | |||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | |||
| dataset = dataset.batch(batch_size) | |||
| count = 0 | |||
| for data in dataset.create_dict_iterator(): | |||
| count += 1 | |||
| #time.sleep(0.5) | |||
| data = {"loss": count} | |||
| dataset.sync_update(condition_name="policy", data=data) | |||
| def test_two_sync(): | |||
| """ | |||
| Test two sync: dataset pipeline with with two sync_operators | |||
| """ | |||
| logger.info("test_two_sync") | |||
| batch_size = 6 | |||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | |||
| aug = Augment(0) | |||
| # notice that with our design, we need to have step_size = shuffle size | |||
| dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) | |||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | |||
| dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches") | |||
| dataset = dataset.batch(batch_size) | |||
| count = 0 | |||
| for data in dataset.create_dict_iterator(): | |||
| count += 1 | |||
| data = {"loss": count} | |||
| dataset.sync_update(condition_name="every batch", data=data) | |||
| if count % 2 == 0: | |||
| dataset.sync_update(condition_name="every 2 batches") | |||
| def test_sync_epoch(): | |||
| """ | |||
| Test sync wait with epochs: test sync with epochs in dataset pipeline | |||
| """ | |||
| logger.info("test_sync_epoch") | |||
| batch_size = 30 | |||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | |||
| aug = Augment(0) | |||
| dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) | |||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | |||
| dataset = dataset.batch(batch_size, drop_remainder=True) | |||
| for epochs in range(3): | |||
| aug.update({"loss": 0}) | |||
| count = 0 | |||
| for data in dataset.create_dict_iterator(): | |||
| assert (data["input"][0] == count) | |||
| count += batch_size | |||
| data = {"loss": count} | |||
| dataset.sync_update(condition_name="policy", data=data) | |||
| def test_sync_exception_01(): | |||
| """ | |||
| Test sync: with shuffle in sync mode | |||
| """ | |||
| logger.info("test_sync_exception_01") | |||
| shuffle_size = 4 | |||
| batch_size = 10 | |||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | |||
| aug = Augment(0) | |||
| dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) | |||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | |||
| try: | |||
| dataset = dataset.shuffle(shuffle_size) | |||
| except BaseException as e: | |||
| assert "shuffle" in str(e) | |||
| dataset = dataset.batch(batch_size) | |||
| def test_sync_exception_02(): | |||
| """ | |||
| Test sync: with duplicated condition name | |||
| """ | |||
| logger.info("test_sync_exception_02") | |||
| batch_size = 6 | |||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | |||
| aug = Augment(0) | |||
| # notice that with our design, we need to have step_size = shuffle size | |||
| dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) | |||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | |||
| try: | |||
| dataset = dataset.sync_wait(num_batch=2, condition_name="every batch") | |||
| except BaseException as e: | |||
| assert "name" in str(e) | |||
| dataset = dataset.batch(batch_size) | |||
| if __name__ == "__main__": | |||
| test_simple_sync_wait() | |||
| test_simple_shuffle_sync() | |||
| test_two_sync() | |||
| test_sync_exception_01() | |||
| test_sync_exception_02() | |||
| test_sync_epoch() | |||