Merge pull request !3212 from anzhengqi/epochs-readytags/v0.6.0-beta
| @@ -25,6 +25,8 @@ | |||
| #include "minddata/dataset/engine/dataset_iterator.h" | |||
| #include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | |||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||
| #include "minddata/dataset/engine/datasetops/filter_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||
| @@ -84,7 +86,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = { | |||
| {kRandomData, &DEPipeline::ParseRandomDataOp}, | |||
| {kTextFile, &DEPipeline::ParseTextFileOp}, | |||
| {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, | |||
| {kClue, &DEPipeline::ParseClueOp}}; | |||
| {kClue, &DEPipeline::ParseClueOp}, | |||
| {kEpochCtrl, &DEPipeline::ParseEpochCtrlOp}}; | |||
| DEPipeline::DEPipeline() : iterator_(nullptr) { | |||
| try { | |||
| @@ -166,8 +169,8 @@ Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr & | |||
| Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); } | |||
| // Function to launch the tree execution. | |||
| Status DEPipeline::LaunchTreeExec() { | |||
| RETURN_IF_NOT_OK(tree_->Prepare()); | |||
| Status DEPipeline::LaunchTreeExec(const int32_t num_epochs) { | |||
| RETURN_IF_NOT_OK(tree_->Prepare(num_epochs)); | |||
| RETURN_IF_NOT_OK(tree_->Launch()); | |||
| iterator_ = std::make_unique<DatasetIterator>(tree_); | |||
| if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator."); | |||
| @@ -252,6 +255,16 @@ int DEPipeline::GetRepeatCount() const { return repeat_num_; } | |||
| float ToFloat(const py::handle &handle) { return py::reinterpret_borrow<py::float_>(handle); } | |||
| Status DEPipeline::StopSend() { | |||
| // tree_.root() must be DeviceQueueOp | |||
| DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_->root().get()); | |||
| if (op == nullptr) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "StopSend only supported by DeviceQueueOp"); | |||
| } | |||
| op->StopSend(); | |||
| return Status::OK(); | |||
| } | |||
| int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); } | |||
| bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); } | |||
| @@ -804,6 +817,18 @@ Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | |||
| std::shared_ptr<DatasetOp> *bottom) { | |||
| if (args["count"].is_none()) { | |||
| std::string err_msg = "Error: count is invalid or not set."; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| std::shared_ptr<EpochCtrlOp> op; | |||
| RETURN_IF_NOT_OK(EpochCtrlOp::Builder(ToInt(args["count"])).Build(&op)); | |||
| *top = op; | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | |||
| std::shared_ptr<DatasetOp> *bottom) { | |||
| std::shared_ptr<GeneratorOp::Builder> builder = std::make_shared<GeneratorOp::Builder>(); | |||
| @@ -973,8 +998,8 @@ Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<Data | |||
| (void)builder->SetDeviceType(ToString(value)); | |||
| } else if (key == "device_id") { | |||
| (void)builder->SetDeviceId(ToInt(value)); | |||
| } else if (key == "num_batch") { | |||
| (void)builder->SetNumBatch(ToInt(value)); | |||
| } else if (key == "send_epoch_end") { | |||
| (void)builder->SetSendEpochEnd(ToBool(value)); | |||
| } | |||
| } | |||
| } | |||
| @@ -70,7 +70,8 @@ enum OpName { | |||
| kRandomData, | |||
| kTextFile, | |||
| kBuildVocab, | |||
| kClue | |||
| kClue, | |||
| kEpochCtrl | |||
| }; | |||
| // The C++ binder class that we expose to the python script. | |||
| @@ -90,7 +91,7 @@ class DEPipeline { | |||
| Status AssignRootNode(const DsOpPtr &dataset_op); | |||
| // Function to launch the tree execution. | |||
| Status LaunchTreeExec(); | |||
| Status LaunchTreeExec(int32_t num_epochs); | |||
| // Get a row of data as dictionary of column name to the value. | |||
| Status GetNextAsMap(py::dict *output); | |||
| @@ -143,6 +144,10 @@ class DEPipeline { | |||
| Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | |||
| std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| @@ -189,6 +194,8 @@ class DEPipeline { | |||
| Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status StopSend(); | |||
| Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| private: | |||
| @@ -159,7 +159,7 @@ void bindDEPipeline(py::module *m) { | |||
| [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); }) | |||
| .def("SetBatchParameters", | |||
| [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); }) | |||
| .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); }) | |||
| .def("LaunchTreeExec", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.LaunchTreeExec(num_epochs)); }) | |||
| .def("GetNextAsMap", | |||
| [](DEPipeline &de) { | |||
| py::dict out; | |||
| @@ -188,6 +188,7 @@ void bindDEPipeline(py::module *m) { | |||
| .def("GetBatchSize", &DEPipeline::GetBatchSize) | |||
| .def("GetNumClasses", &DEPipeline::GetNumClasses) | |||
| .def("GetRepeatCount", &DEPipeline::GetRepeatCount) | |||
| .def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); }) | |||
| .def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) { | |||
| THROW_IF_ERROR(de.SaveDataset(file_names, file_type)); | |||
| return true; | |||
| @@ -999,7 +1000,8 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||
| .value("BUILDVOCAB", OpName::kBuildVocab) | |||
| .value("CELEBA", OpName::kCelebA) | |||
| .value("TEXTFILE", OpName::kTextFile) | |||
| .value("CLUE", OpName::kClue); | |||
| .value("CLUE", OpName::kClue) | |||
| .value("EPOCHCTRL", OpName::kEpochCtrl); | |||
| (void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic()) | |||
| .value("DE_JIEBA_MIX", JiebaMode::kMix) | |||
| @@ -40,7 +40,9 @@ Status IteratorBase::GetNextAsMap(TensorMap *out_map) { | |||
| out_map->clear(); | |||
| TensorRow curr_row; | |||
| MS_LOG(INFO) << "get next as map start."; | |||
| RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row)); | |||
| MS_LOG(INFO) << "fetchNextTensor success."; | |||
| // Return empty map if there's no data | |||
| if (curr_row.empty()) { | |||
| @@ -105,7 +107,8 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) { | |||
| // Once eof is handled, always return empty row. Class must be destroyed and recreated if you | |||
| // want to iterate again. | |||
| if (eof_handled_) { | |||
| return Status::OK(); | |||
| std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."; | |||
| RETURN_STATUS_UNEXPECTED(err); | |||
| } | |||
| // Check if we need to get a new DataBuffer to iterate. | |||
| @@ -119,36 +122,22 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) { | |||
| // Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually | |||
| // handle eoe and eof messages here. | |||
| // | |||
| // An eoe buffer means we have iterated fully to the end of the tree. | |||
| // An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of | |||
| // all operators. | |||
| // An eoe buffer means we have iterated an epoch. | |||
| // The next buffer in the pipeline might be an EOF or a databuffer for next epoch | |||
| if (curr_buffer_->eoe()) { | |||
| MS_LOG(DEBUG) << "End of data iteration. Fetch eof and then return empty row."; | |||
| // Before returning the last empty vector, fetch the eof buffer which should be the last | |||
| // buffer, and then free it. | |||
| RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_)); | |||
| if (!curr_buffer_->eof()) { | |||
| RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!"); | |||
| } | |||
| eof_handled_ = true; | |||
| curr_buffer_.reset(); // explicitly free the eof buffer | |||
| // Set tree to Finished state | |||
| root_->Tree()->SetFinished(); | |||
| MS_LOG(INFO) << "End of data iteration."; | |||
| curr_buffer_.reset(); // explicitly free the eoe buffer | |||
| return Status::OK(); | |||
| } | |||
| // An eof buffer means it is the end of execution and all operators are shutting down. | |||
| // Because there is no more data to return to the caller, this will change `eof_handled_` state and | |||
| // returns status unexpected error. | |||
| if (curr_buffer_->eof()) { | |||
| // An eof by itself, without being preceded by an eoe, is possible if a repeat operator | |||
| // exists below us in the stack. Repeat operator eats eoe's but eventually allows the | |||
| // flow of an eof up the pipeline by itself. | |||
| eof_handled_ = true; | |||
| curr_buffer_.reset(); // explicitly free the eof buffer | |||
| // Set tree to Finished state | |||
| root_->Tree()->SetFinished(); | |||
| return Status::OK(); | |||
| std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."; | |||
| RETURN_STATUS_UNEXPECTED(err); | |||
| } | |||
| } | |||
| @@ -208,20 +197,24 @@ Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) { | |||
| // Once eof is handled, always return empty row. Class must be destroyed and recreated if you | |||
| // want to iterate again. | |||
| if (eof_handled_) { | |||
| return Status::OK(); | |||
| std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."; | |||
| RETURN_STATUS_UNEXPECTED(err); | |||
| } | |||
| // Check if we need to get a new DataBuffer to iterate. | |||
| if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) { | |||
| // GetNextInput() depends on current_op's EoeReceived. So, EOE buffer might be already be handled and | |||
| // this child iterator might not see EOE buffer. | |||
| RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_)); | |||
| // Unlike the DatasetIterator, this child iterator does not quit after eoe. | |||
| // Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the | |||
| // If an eoe is picked up here, we simply return an empty vector and it's up to the | |||
| // caller to decide what it wants to do next. | |||
| if (curr_buffer_->eoe()) { | |||
| MS_LOG(DEBUG) << "Child iterator picked up EOE."; | |||
| end_epoch_ = true; | |||
| return Status::OK(); | |||
| } else { | |||
| end_epoch_ = false; | |||
| } | |||
| if (curr_buffer_->eof()) { | |||
| @@ -144,6 +144,9 @@ class ChildIterator : public IteratorBase { | |||
| // @return The string to column id mapping. | |||
| std::unordered_map<std::string, int32_t> GetColumnNameMap() const override; | |||
| // Return T/F if end of epoch | |||
| bool end_of_epoch() { return end_epoch_; } | |||
| private: | |||
| DatasetOp *current_op_; // The parent operator. We consume from it's children. | |||
| int32_t child_idx_; // The specific child this iterator will fetch from. | |||
| @@ -18,6 +18,7 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES | |||
| shuffle_op.cc | |||
| zip_op.cc | |||
| concat_op.cc | |||
| epoch_ctrl_op.cc | |||
| cache_base_op.cc | |||
| cache_lookup_op.cc | |||
| cache_op.cc | |||
| @@ -17,11 +17,13 @@ | |||
| #include "minddata/dataset/engine/datasetops/build_vocab_op.h" | |||
| #include <algorithm> | |||
| #include <iomanip> | |||
| #include <limits> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -202,5 +204,29 @@ BuildVocabOp::Builder::Builder() | |||
| builder_num_workers_ = cfg->num_parallel_workers(); | |||
| builder_connector_size_ = cfg->op_connector_size(); | |||
| } | |||
| // A print method typically used for debugging | |||
| void BuildVocabOp::Print(std::ostream &out, bool show_all) const { | |||
| // Always show the id and name as first line regardless if this summary or detailed print | |||
| out << "(" << std::setw(2) << operator_id_ << ") <BuildVocabOp>:"; | |||
| if (!show_all) { | |||
| // Call the super class for displaying any common 1-liner info | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal 1-liner info for this op | |||
| out << "\n"; | |||
| } else { | |||
| // Call the super class for displaying any common detailed info | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nCode is needed here to show more info about the op." | |||
| << "\n\n"; | |||
| } | |||
| } | |||
| // Pre-Visitor accept method for NodePass | |||
| Status BuildVocabOp::PreAccept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call the pre-visitation | |||
| return p->PreRunOnNode(shared_from_base<BuildVocabOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -131,6 +131,21 @@ class BuildVocabOp : public ParallelOp { | |||
| ~BuildVocabOp() = default; | |||
| /// \brief A print method typically used for debugging | |||
| /// \param[out] out The output stream to write output to | |||
| /// \param[in] show_all A bool to control if you want to show all info or just a summary | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| /// \briefStream output operator overload | |||
| /// \notes This allows you to write the debug print info using stream operators | |||
| /// \param[out] out Reference to the output stream being overloaded | |||
| /// \param[in] vop - reference to the BuildVocabOp to display | |||
| /// \return - the output stream must be returned | |||
| friend std::ostream &operator<<(std::ostream &out, const BuildVocabOp &vop) { | |||
| vop.Print(out, false); | |||
| return out; | |||
| } | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // collect the work product from each worker | |||
| @@ -152,6 +167,12 @@ class BuildVocabOp : public ParallelOp { | |||
| Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); } | |||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status PreAccept(NodePass *p, bool *modified) override; | |||
| private: | |||
| const int32_t interval_; | |||
| bool special_first_; | |||
| @@ -96,7 +96,7 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) { | |||
| RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr))); | |||
| RETURN_IF_NOT_OK(EofReceived(worker_id)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { | |||
| @@ -298,5 +298,19 @@ Status CacheMergeOp::EoeReceived(int32_t worker_id) { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Base-class override for handling cases when an eof is received. | |||
| Status CacheMergeOp::EofReceived(int32_t worker_id) { | |||
| // If we are not in a repeated path, then the merge op gets a eof by itself, without first | |||
| // getting an eoe. However, the logic demands that all epochs close with an eoe first before eof. | |||
| // Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class | |||
| // provides that for us. | |||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated)) { | |||
| MS_LOG(DEBUG) << "Cache merge sending eoe"; | |||
| RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id)); | |||
| } | |||
| MS_LOG(DEBUG) << "Cache merge sending eof"; | |||
| return DatasetOp::EofReceived(worker_id); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -176,6 +176,11 @@ class CacheMergeOp : public ParallelOp { | |||
| /// \return Status object | |||
| Status EoeReceived(int32_t worker_id) override; | |||
| /// \brief Base-class override for handling cases when an eof is received. | |||
| /// \param worker_id - The worker id | |||
| /// \return Status - The error code return | |||
| Status EofReceived(int32_t worker_id) override; | |||
| protected: | |||
| Status ComputeColMap() override; | |||
| @@ -26,6 +26,7 @@ | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/db_connector.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| @@ -102,6 +103,15 @@ Status DatasetOp::InsertAsParent(std::shared_ptr<DatasetOp> to_add) { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Removes child operator in this operator. | |||
| Status DatasetOp::RemoveChildren() { | |||
| for (const auto &child : child_) { | |||
| child->RemoveParent(this); | |||
| } | |||
| child_.clear(); | |||
| return Status::OK(); | |||
| } | |||
| // Adds a parent operator to this operator | |||
| void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); } | |||
| @@ -185,6 +195,12 @@ void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const { | |||
| } | |||
| } | |||
| // Getter function to get all of our children. | |||
| std::vector<std::shared_ptr<DatasetOp>> DatasetOp::children() const { return child_; } | |||
| // Getter function to get all of our parents. | |||
| std::vector<DatasetOp *> DatasetOp::parents() const { return parent_; } | |||
| // Creates the connector within this operator | |||
| void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) { | |||
| MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers | |||
| @@ -76,6 +76,9 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// \return Status eerror code returned | |||
| Status Remove(); | |||
| // Removes child operator in this operator. | |||
| Status RemoveChildren(); | |||
| /// \brief Getter function to get a shared pointer to our child | |||
| /// \param[in] child_index An operator can have n children. Indicates which child to return. | |||
| /// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index | |||
| @@ -86,6 +89,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// \param[in] parent_index An operator can have n parents. Indicates which parent to return. | |||
| void Parent(DatasetOp **parent, int32_t parent_index) const; | |||
| // Getter function to get all of our children. | |||
| std::vector<std::shared_ptr<DatasetOp>> children() const; | |||
| // Getter function to get all of our parents. | |||
| std::vector<DatasetOp *> parents() const; | |||
| // Inserts a operator as the parent current op. | |||
| // Inserted op will become the sole parent of the current op. | |||
| // The existing parent of the current op will be transferred to the inserted op. | |||
| @@ -25,19 +25,21 @@ | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/perf/profiling.h" | |||
| #include "minddata/dataset/engine/perf/device_queue_tracing.h" | |||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, | |||
| int32_t op_connector_size, int64_t num_batch) | |||
| int32_t op_connector_size, bool send_epoch_end) | |||
| : PipelineOp(op_connector_size), | |||
| channel_name_(channel_name), | |||
| device_type_(device_type), | |||
| device_id_(device_id), | |||
| prefetch_size_(prefetch_size), | |||
| num_batch_(num_batch) {} | |||
| send_epoch_end_(send_epoch_end), | |||
| stop_send_(false) {} | |||
| DeviceQueueOp::~DeviceQueueOp() {} | |||
| @@ -53,8 +55,7 @@ DeviceQueueOp::Builder::Builder(int32_t prefetch_size) | |||
| : builder_prefetch_size_(prefetch_size), | |||
| builder_device_id_(0), | |||
| builder_device_type_(DeviceType::CPU), | |||
| builder_channel_name_(""), | |||
| builder_num_batch_(0) { | |||
| builder_channel_name_("") { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| builder_op_connector_size_ = cfg->op_connector_size(); | |||
| } | |||
| @@ -64,6 +65,18 @@ Status DeviceQueueOp::EoeReceived(int32_t worker_id) { | |||
| return Status::OK(); | |||
| } | |||
| Status DeviceQueueOp::CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const { | |||
| // this method checks if the buffer meets the conditions to be sent to TDT | |||
| if (buffer->NumRows() != 0) { | |||
| TensorRow row; | |||
| buffer->GetRow(0, &row); | |||
| for (const auto &item : row) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device."); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status DeviceQueueOp::operator()() { | |||
| TaskManager::FindMe()->Post(); | |||
| @@ -82,23 +95,10 @@ Status DeviceQueueOp::operator()() { | |||
| return Status::OK(); | |||
| } | |||
| Status DeviceQueueOp::CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const { | |||
| // this method checks if the buffer meets the conditions to be sent to TDT | |||
| if (buffer->NumRows() != 0) { | |||
| TensorRow row; | |||
| buffer->GetRow(0, &row); | |||
| for (const auto &item : row) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device."); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| #ifdef ENABLE_TDTQUE | |||
| Status DeviceQueueOp::SendDataToAscend() { | |||
| MS_LOG(INFO) << "Device queue, sending data to Ascend."; | |||
| int64_t total_batch = 0; | |||
| bool is_break_loop = false; | |||
| double batch_start_time, end_time; | |||
| int32_t batch_cost, tdt_cost; | |||
| int32_t connector_size = 0; | |||
| @@ -115,15 +115,20 @@ Status DeviceQueueOp::SendDataToAscend() { | |||
| std::unique_ptr<DataBuffer> current_buffer; | |||
| RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); | |||
| while (!current_buffer->eof() && !is_break_loop) { | |||
| while (!current_buffer->eoe() && !is_break_loop) { | |||
| while (!current_buffer->eof()) { | |||
| while (!current_buffer->eoe()) { | |||
| RETURN_IF_NOT_OK(CheckExceptions(current_buffer)); | |||
| TensorRow currRow; | |||
| for (int row_id = 0; row_id < current_buffer->NumRows() && !is_break_loop; row_id++) { | |||
| for (int row_id = 0; row_id < current_buffer->NumRows(); row_id++) { | |||
| RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow)); | |||
| auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost); | |||
| if (status == TdtStatus::FAILED) { | |||
| return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); | |||
| if (stop_send_) { | |||
| MS_LOG(INFO) << "stop_send received"; | |||
| return Status::OK(); | |||
| } else { | |||
| return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); | |||
| } | |||
| } | |||
| if (isProfilingEnable) { | |||
| @@ -140,9 +145,6 @@ Status DeviceQueueOp::SendDataToAscend() { | |||
| profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, total_batch + 1, connector_size); | |||
| } | |||
| total_batch++; | |||
| if (num_batch_ > 0 && total_batch == num_batch_) { | |||
| is_break_loop = true; | |||
| } | |||
| } | |||
| if (isProfilingEnable) { | |||
| connector_size = ChildOpConnectorSize(); | |||
| @@ -150,6 +152,19 @@ Status DeviceQueueOp::SendDataToAscend() { | |||
| } | |||
| RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); | |||
| } | |||
| if (current_buffer->eoe() && send_epoch_end_) { | |||
| TensorRow currRow; | |||
| auto status = | |||
| tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE); | |||
| if (status == TdtStatus::FAILED) { | |||
| if (stop_send_) { | |||
| MS_LOG(INFO) << "stop_send received"; | |||
| return Status::OK(); | |||
| } else { | |||
| return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); | |||
| } | |||
| } | |||
| } | |||
| if (isProfilingEnable) { | |||
| connector_size = ChildOpConnectorSize(); | |||
| connector_capacity = ChildOpConnectorCapacity(); | |||
| @@ -158,7 +173,7 @@ Status DeviceQueueOp::SendDataToAscend() { | |||
| } | |||
| tree_->SetFinished(); | |||
| MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; | |||
| MS_LOG(INFO) << "Device queue total batch is " << total_batch; | |||
| return Status::OK(); | |||
| } | |||
| @@ -196,9 +211,6 @@ Status DeviceQueueOp::SendDataToGPU() { | |||
| } | |||
| RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle)); | |||
| total_batch++; | |||
| if (num_batch_ > 0 && total_batch == num_batch_) { | |||
| is_break_loop = true; | |||
| } | |||
| } | |||
| if (!TaskManager::FindMe()->Interrupted()) | |||
| RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); | |||
| @@ -211,12 +223,10 @@ Status DeviceQueueOp::SendDataToGPU() { | |||
| is_break_loop = true; | |||
| } | |||
| MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; | |||
| MS_LOG(INFO) << "Device queue total batch is " << total_batch << "."; | |||
| GpuBufferMgr::GetInstance().Close(handle); | |||
| GpuBufferMgr::GetInstance().CloseConfirm(); | |||
| return Status::OK(); | |||
| } | |||
| @@ -240,8 +250,11 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con | |||
| if (ret == BlockQueueStatus_T::ERROR_INPUT) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it."); | |||
| } else { | |||
| MS_LOG(WARNING) << "Retry pushing data..."; | |||
| continue; | |||
| if (!stop_send_) { | |||
| MS_LOG(WARNING) << "Retry pushing data..."; | |||
| continue; | |||
| } | |||
| break; | |||
| } | |||
| } else { | |||
| break; | |||
| @@ -283,13 +296,11 @@ Status DeviceQueueOp::SendDataToCPU() { | |||
| MS_LOG(DEBUG) << "Feature size is " << curr_row[0]->SizeInBytes() << "."; | |||
| MS_LOG(DEBUG) << "Label size is " << curr_row[1]->SizeInBytes() << "."; | |||
| total_batch++; | |||
| if (num_batch_ > 0 && total_batch == num_batch_) { | |||
| break; | |||
| } | |||
| if (stop_send_) break; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; | |||
| MS_LOG(INFO) << "Device queue total batch is " << total_batch << "."; | |||
| return Status::OK(); | |||
| } | |||
| @@ -21,6 +21,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/pipeline_op.h" | |||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #ifdef ENABLE_TDTQUE | |||
| @@ -84,8 +85,8 @@ class DeviceQueueOp : public PipelineOp { | |||
| return *this; | |||
| } | |||
| Builder &SetNumBatch(int64_t num_batch) { | |||
| builder_num_batch_ = num_batch; | |||
| Builder &SetSendEpochEnd(bool send_epoch_end) { | |||
| builder_send_epoch_end_ = send_epoch_end; | |||
| return *this; | |||
| } | |||
| @@ -94,8 +95,9 @@ class DeviceQueueOp : public PipelineOp { | |||
| // to call this Build() method. It will instantiate the DeviceQueueOp | |||
| // and return it to caller as a shared pointer. | |||
| Status Build(std::shared_ptr<DeviceQueueOp> *ptr) { | |||
| *ptr = std::make_shared<DeviceQueueOp>(builder_channel_name_, builder_device_type_, builder_device_id_, | |||
| builder_prefetch_size_, builder_op_connector_size_, builder_num_batch_); | |||
| *ptr = | |||
| std::make_shared<DeviceQueueOp>(builder_channel_name_, builder_device_type_, builder_device_id_, | |||
| builder_prefetch_size_, builder_op_connector_size_, builder_send_epoch_end_); | |||
| return Status::OK(); | |||
| } | |||
| @@ -104,14 +106,14 @@ class DeviceQueueOp : public PipelineOp { | |||
| int32_t builder_device_id_; | |||
| DeviceType builder_device_type_; | |||
| std::string builder_channel_name_; | |||
| int64_t builder_num_batch_; | |||
| int32_t builder_op_connector_size_; | |||
| bool builder_send_epoch_end_; | |||
| }; | |||
| // Name: constructor | |||
| // Description | |||
| DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, | |||
| int32_t op_connector_size, int64_t num_batch); | |||
| int32_t op_connector_size, bool send_epoch_end); | |||
| // Name: destructor | |||
| // Description | |||
| @@ -121,6 +123,8 @@ class DeviceQueueOp : public PipelineOp { | |||
| const int32_t get_prefetch_size() { return prefetch_size_; } | |||
| void StopSend() { stop_send_ = true; } | |||
| // Name: Print() | |||
| // Description: A function that prints info about the node | |||
| void Print(std::ostream &out, // In: The output stream to print to | |||
| @@ -149,6 +153,7 @@ class DeviceQueueOp : public PipelineOp { | |||
| // Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp | |||
| Status CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const; | |||
| private: | |||
| #ifdef ENABLE_TDTQUE | |||
| Status SendDataToAscend(); | |||
| #endif | |||
| @@ -164,7 +169,8 @@ class DeviceQueueOp : public PipelineOp { | |||
| DeviceType device_type_; | |||
| const int32_t device_id_; | |||
| const int32_t prefetch_size_; | |||
| const int64_t num_batch_; | |||
| const bool send_epoch_end_; | |||
| bool stop_send_; | |||
| #ifdef ENABLE_TDTQUE | |||
| std::shared_ptr<TdtPlugin> tdtInstancePtr; | |||
| @@ -0,0 +1,130 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iomanip> | |||
| #include <iostream> | |||
| #include <utility> | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/db_connector.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // The builder "build" method creates the final object. | |||
| Status EpochCtrlOp::Builder::Build(std::shared_ptr<EpochCtrlOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<EpochCtrlOp>(build_max_repeats_); | |||
| return Status::OK(); | |||
| } | |||
| // Constructor | |||
| EpochCtrlOp::EpochCtrlOp(int32_t num_epoch) : RepeatOp(num_epoch) { MS_LOG(INFO) << "Welcome to Epoch Ctrl Op."; } | |||
| // Destructor | |||
| EpochCtrlOp::~EpochCtrlOp() {} | |||
| // A print method typically used for debugging | |||
| void EpochCtrlOp::Print(std::ostream &out, bool show_all) const { | |||
| // Always show the id and name as first line regardless if this summary or detailed print | |||
| out << "(" << std::setw(2) << operator_id_ << ") <EpochCtrlOp>:"; | |||
| if (!show_all) { | |||
| // Call the super class for displaying any common 1-liner info | |||
| PipelineOp::Print(out, show_all); | |||
| // Then show any custom derived-internal 1-liner info for this op | |||
| out << " [epochs: " << max_repeats_ << "]\n"; | |||
| } else { | |||
| // Call the super class for displaying any common detailed info | |||
| PipelineOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << max_repeats_ | |||
| << "\nLeaf Nodes in execution path:"; | |||
| if (!eoe_ops_.empty()) { | |||
| for (size_t i = 0; i < eoe_ops_.size(); i++) { | |||
| out << "\n Operator: " << eoe_ops_[i]->id(); | |||
| } | |||
| } else { | |||
| out << " None."; | |||
| } | |||
| out << "\n\n"; | |||
| } | |||
| } | |||
| Status EpochCtrlOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) { | |||
| if (child_.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("EpochCtrlOp can't be the leaf node."); | |||
| } | |||
| std::unique_ptr<DataBuffer> buf; | |||
| // `retry_if_eoe` is false because EpochCtrlOp does not eat EOE. | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, false)); | |||
| // Only intercept EOE for EoeReceived processing, after that the EOE is forwarded to next op. | |||
| // Other databuffers containing data or EOF will simply be forwarded. | |||
| // EOF can simply be forwarded because this op does not spawn any thread, thus does not require clean up. | |||
| if (buf->eoe()) { | |||
| RETURN_IF_NOT_OK(EoeReceived(worker_id)); | |||
| } | |||
| *p_buffer = std::move(buf); | |||
| return Status::OK(); | |||
| } | |||
| Status EpochCtrlOp::EoeReceived(int32_t worker_id) { | |||
| repeat_count_++; | |||
| MS_LOG(DEBUG) << "Epoch Control operator received end of epoch. Epoch count is now: " << repeat_count_ | |||
| << ". Repeated: " << BitTest(op_ctrl_flags_, kDeOpRepeated) << ". Max epochs: " << max_repeats_; | |||
| // If we've reached the requested epoch count, then flag the leaf nodes | |||
| // to tell them they've got one more epoch to perform. When they reach the end | |||
| // of the last epoch, they quit rather than loop again. | |||
| if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) { | |||
| for (auto &eoe_op : eoe_ops_) { | |||
| MS_LOG(DEBUG) << "EpochCtrl setting last repeat for eoe_op: " << eoe_op->id(); | |||
| eoe_op->set_control_flag(kDeOpLastRepeat); | |||
| } | |||
| } | |||
| // This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it. | |||
| state_ = OpState::kDeOpIdle; | |||
| if (repeat_count_ != max_repeats_) { | |||
| for (auto &eoe_op : eoe_ops_) { | |||
| MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id(); | |||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Pre-Visitor accept method for NodePass | |||
| Status EpochCtrlOp::PreAccept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call the pre-visitation | |||
| return p->PreRunOnNode(shared_from_base<EpochCtrlOp>(), modified); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status EpochCtrlOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call the pre-visitation | |||
| return p->RunOnNode(shared_from_base<EpochCtrlOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,82 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | |||
| #include "minddata/dataset/engine/datasetops/pipeline_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class EpochCtrlOp : public RepeatOp { | |||
| public: | |||
| class Builder : public RepeatOp::Builder { | |||
| public: | |||
| // Builder constructor. Creates the builder object. | |||
| // @note No default args | |||
| // @param count - The number of repeats to do | |||
| // @return This is a constructor. | |||
| explicit Builder(int32_t count) : RepeatOp::Builder(count) {} | |||
| // Default destructor | |||
| ~Builder() = default; | |||
| // The builder "build" method creates the final object. | |||
| // @return shared_ptr to the new EpochCtrlOp object | |||
| Status Build(std::shared_ptr<EpochCtrlOp> *); | |||
| }; | |||
| // Contructor | |||
| explicit EpochCtrlOp(int32_t num_epoch); | |||
| // Destructor | |||
| ~EpochCtrlOp(); | |||
| // A print method typically used for debugging | |||
| // @param out - The output stream to write output to | |||
| // @param show_all - A bool to control if you want to show all info or just a summary | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // This function returns the buffer that is at the top of our output connector. The caller is | |||
| // typically our parent node, when the parent is asking us to provide the next buffer of data. | |||
| // Since EpochCtrlOp is derived from RepeatOp which is an inlined op, getting a buffer from us | |||
| // will simply bounce you to get a buffer from our child. | |||
| // Epoch Control Op does not eat the EOE, it will pass the EOE to the next op. | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override; | |||
| // Base-class override for handling cases when an eoe is received. | |||
| // @param worker_id - The worker id | |||
| Status EoeReceived(int32_t worker_id) override; | |||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status PreAccept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for NodePass visitor acceptor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_ | |||
| @@ -132,6 +132,7 @@ Status RepeatOp::EoeReceived(int32_t worker_id) { | |||
| // Invoke a reset against the eoe nodes only. | |||
| for (auto &eoe_op : eoe_ops_) { | |||
| MS_LOG(DEBUG) << "Repeat operator sending reset to operator: " << eoe_op->id(); | |||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | |||
| } | |||
| @@ -167,8 +168,9 @@ int32_t RepeatOp::num_consumers() const { | |||
| Status RepeatOp::Reset() { | |||
| // If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op. | |||
| // In that case, we now have to bounce the reset down to our own eoe ops. | |||
| MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") reset."; | |||
| MS_LOG(DEBUG) << "Repeat operator " << operator_id_ << " got reset."; | |||
| for (auto &eoe_op : eoe_ops_) { | |||
| MS_LOG(DEBUG) << "Nested repeat operator bouncing a reset to operator: " << eoe_op->id(); | |||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | |||
| } | |||
| state_ = OpState::kDeOpRunning; | |||
| @@ -46,7 +46,7 @@ class RepeatOp : public PipelineOp { | |||
| // @return shared_ptr to the new RepeatOp object | |||
| Status Build(std::shared_ptr<RepeatOp> *); | |||
| private: | |||
| protected: | |||
| int32_t build_max_repeats_; | |||
| Status SanityCheck() const; | |||
| @@ -131,11 +131,11 @@ class RepeatOp : public PipelineOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "RepeatOp"; } | |||
| /// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes | |||
| /// \param[in] eoe_op The input leaf/eoe operator to add to the list | |||
| // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes | |||
| // \param[in] eoe_op The input leaf/eoe operator to add to the list | |||
| void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } | |||
| private: | |||
| protected: | |||
| int32_t max_repeats_; // The number of repeats that the user requested | |||
| int32_t repeat_count_; // A counter for the current number of executed repeats | |||
| std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat. | |||
| @@ -132,8 +132,9 @@ Status ZipOp::prepare(TensorQTable *const table) { | |||
| if (eof_) { | |||
| return Status::OK(); | |||
| } | |||
| // One of our child iterators encounter EOE. Returns and proceed with draining phase. | |||
| if (new_row.empty()) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase got empty row!"); | |||
| return Status::OK(); | |||
| } | |||
| // Pack this first row into our tensor table | |||
| @@ -23,6 +23,7 @@ | |||
| #include "minddata/dataset/engine/opt/pre/removal_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | |||
| #include "minddata/dataset/engine/opt/post/repeat_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/injection_pass.h" | |||
| #include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | |||
| #include "minddata/dataset/engine/perf/profiling.h" | |||
| #include "minddata/dataset/engine/perf/monitor.h" | |||
| @@ -50,11 +51,11 @@ Status ExecutionTree::AssociateNode(const std::shared_ptr<DatasetOp> &op) { | |||
| if (op->tree_ == this) { | |||
| return Status::OK(); | |||
| } | |||
| if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) { | |||
| if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding && tree_state_ != kDeTStatePrepare) { | |||
| std::string err_msg = | |||
| "Invalid tree state for adding a node. Current state: " + std::to_string(static_cast<int>(tree_state_)) + | |||
| " Expected states: " + std::to_string(static_cast<int>(kDeTStateInit)) + " or " + | |||
| std::to_string(static_cast<int>(kDeTStateBuilding)); | |||
| std::to_string(static_cast<int>(kDeTStateBuilding)) + " or " + std::to_string(static_cast<int>(kDeTStatePrepare)); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| @@ -200,7 +201,9 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui | |||
| // For example, repeatOp inlining | |||
| // | |||
| // @return Status - The error code return | |||
| Status ExecutionTree::Prepare() { | |||
| Status ExecutionTree::Prepare(int32_t num_epochs) { | |||
| num_epochs_ = num_epochs; | |||
| // Pre optimization compulsory transformation | |||
| RETURN_IF_NOT_OK(this->PrepareTreePreAction()); | |||
| @@ -222,6 +225,7 @@ Status ExecutionTree::PrepareTreePreAction() { | |||
| std::vector<std::unique_ptr<Pass>> pre_actions; | |||
| // Construct pre actions | |||
| MS_LOG(INFO) << "Running pre pass loops."; | |||
| pre_actions.push_back(std::make_unique<InjectionPass>()); | |||
| pre_actions.push_back(std::make_unique<RemovalPass>()); | |||
| pre_actions.push_back(std::make_unique<CacheTransformPass>()); | |||
| // Apply pre action passes | |||
| @@ -278,6 +282,11 @@ Status ExecutionTree::PrepareDeprecated() { | |||
| " Expected state: " + std::to_string(static_cast<int>(kDeTStatePrepare)); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| if (root_ == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("Please assign one operator as the root of this tree."); | |||
| } | |||
| // Start the recursive prepare | |||
| RETURN_IF_NOT_OK(this->PrepareNode(root_)); | |||
| tree_state_ = kDeTStateReady; | |||
| @@ -176,7 +176,7 @@ class ExecutionTree { | |||
| // For example, repeatOp inlining | |||
| // | |||
| // @return Status - The error code return | |||
| Status Prepare(); | |||
| Status Prepare(int num_epochs = -1); | |||
| // Compulsory transformation/action pre optimization. | |||
| // @return Status - The error code return | |||
| @@ -193,6 +193,7 @@ class ExecutionTree { | |||
| // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively | |||
| // walk the tree to perform modifications to the tree or specific nodes within the tree to get | |||
| // it ready for execution. | |||
| // @param Total number of epochs that will be run on this tree | |||
| // @return Status - The error code return | |||
| Status PrepareDeprecated(); | |||
| @@ -231,6 +232,10 @@ class ExecutionTree { | |||
| // Optional optimizations status | |||
| bool OptimizationEnabled() const { return optimize_; } | |||
| // Getter function to get the total number of epochs to be run on this tree. | |||
| // @return total number of epochs | |||
| int32_t num_epochs() { return num_epochs_; } | |||
| private: | |||
| // A helper functions for doing the recursive printing | |||
| // @param dataset_op - The dataset op to print | |||
| @@ -245,6 +250,7 @@ class ExecutionTree { | |||
| int32_t id_count_; // Counter for generating operator id's | |||
| uint32_t prepare_flags_; // Flags used during tree prepare | |||
| TreeState tree_state_; // Tracking the current tree state | |||
| int32_t num_epochs_; // Total number of epochs to run for this tree | |||
| std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor | |||
| std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager | |||
| bool optimize_; // Flag to enable optional optimizations | |||
| @@ -5,6 +5,7 @@ add_library(engine-opt OBJECT | |||
| post/repeat_pass.cc | |||
| pre/cache_pass.cc | |||
| pre/cache_transform_pass.cc | |||
| pre/injection_pass.cc | |||
| pre/removal_nodes.cc | |||
| pre/removal_pass.cc | |||
| optional/tensor_op_fusion_pass.cc | |||
| @@ -16,11 +16,13 @@ | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/datasetops/batch_op.h" | |||
| #include "minddata/dataset/engine/datasetops/build_vocab_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" | |||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | |||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | |||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||
| #include "minddata/dataset/engine/datasetops/map_op.h" | |||
| #include "minddata/dataset/engine/datasetops/project_op.h" | |||
| #include "minddata/dataset/engine/datasetops/rename_op.h" | |||
| @@ -230,6 +232,11 @@ Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| @@ -244,5 +251,15 @@ Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -77,6 +77,10 @@ class CacheMergeOp; | |||
| class CacheLookupOp; | |||
| class EpochCtrlOp; | |||
| class BuildVocabOp; | |||
| // The base class Pass is the basic unit of tree transformation. | |||
| // The actual implementation of the passes will be derived from here. | |||
| class Pass : public std::enable_shared_from_this<Pass> { | |||
| @@ -190,12 +194,18 @@ class NodePass : public Pass { | |||
| virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified); | |||
| private: | |||
| // Helper function to perform DFS visit | |||
| Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified); | |||
| @@ -20,6 +20,7 @@ | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | |||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -28,6 +29,9 @@ RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(fa | |||
| // Identifies the subtree below this node as being in a repeated path of the tree. | |||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // Create a new stack for eoe operators and push onto our stack of stacks. | |||
| std::unique_ptr<eoe_op_stack> new_stack = std::make_unique<eoe_op_stack>(); | |||
| eoe_op_stacks_.push(std::move(new_stack)); | |||
| // If we are already repeated, then this is a nested repeat. | |||
| if (is_repeated_) { | |||
| nested_repeats_++; | |||
| @@ -36,6 +40,18 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) | |||
| return Status::OK(); | |||
| } | |||
| // Identifies the subtree below this node as being in a repeated path of the tree. | |||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | |||
| // EpochCtrl is derived from RepeatOp. Generally it should do the identical setup | |||
| // that RepeatOp does. However, epoch control is actually simpler because it can | |||
| // only exist as the root node so it doesn't need all the nested code. | |||
| // Create a new stack for eoe operators and push onto our stack of stacks. | |||
| std::unique_ptr<eoe_op_stack> new_stack = std::make_unique<eoe_op_stack>(); | |||
| eoe_op_stacks_.push(std::move(new_stack)); | |||
| is_repeated_ = true; | |||
| return Status::OK(); | |||
| } | |||
| // Identifies the subtree below this node as being in a cache merge path | |||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | |||
| // Turn on the flag that we're under a merge op | |||
| @@ -47,13 +63,24 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modifi | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking | |||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | |||
| while (leaf_op != nullptr) { | |||
| node->AddToEoeList(leaf_op); | |||
| leaf_op = PopFromEOEOpStack(); | |||
| } | |||
| // At this point, we are done with the save area stack. It's a unique pointer to an empty stack | |||
| // at this time, so we can pop it to get rid of it. | |||
| eoe_op_stack *current_stack = eoe_op_stacks_.top().get(); | |||
| if (!current_stack->empty()) { | |||
| RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!"); | |||
| } | |||
| eoe_op_stacks_.pop(); | |||
| // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up | |||
| // and add it to the list of eoe/leaf ops for the repeat, removing it from the save area. | |||
| // and add it to the list of eoe/leaf ops for the repeat. It is important that the op is removed | |||
| // from the save area, because the merge op above us may also take action on it later for a different | |||
| // case when there is no repeat in the merge leg. | |||
| if (is_merge_ && cache_lookup_) { | |||
| cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated); | |||
| node->AddToEoeList(std::move(cache_lookup_)); | |||
| @@ -65,16 +92,29 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| node->set_control_flag(DatasetOp::kDeOpRepeated); | |||
| AddToEOEOpStack(node); | |||
| nested_repeats_--; | |||
| } | |||
| // If we are not nested, or we were the top-most repeat, now we clear the flag | |||
| if (nested_repeats_ == 0) { | |||
| } else { | |||
| // If we are not nested, or we were the top-most repeat, now we clear the flag | |||
| if (nested_repeats_ != 0) { | |||
| RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!"); | |||
| } | |||
| is_repeated_ = false; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Hooks up any identified eoe nodes under this repeat. | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | |||
| // Pop the leaf ops from the save-area stack and add them to the eoe node tracking | |||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | |||
| while (leaf_op != nullptr) { | |||
| node->AddToEoeList(leaf_op); | |||
| leaf_op = PopFromEOEOpStack(); | |||
| } | |||
| is_repeated_ = false; | |||
| return Status::OK(); | |||
| } | |||
| // CacheOp removes previous leaf ops and replaces them with itself | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| if (is_repeated_) { | |||
| @@ -118,9 +158,16 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | |||
| // Turns off the tracking for operations under merge op | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | |||
| // Setting the flag is needed since we didn't call the base class DatasetOp version | |||
| if (is_repeated_) node->set_control_flag(DatasetOp::kDeOpRepeated); | |||
| if (is_repeated_) { | |||
| node->set_control_flag(DatasetOp::kDeOpRepeated); | |||
| // If there was not any repeat in the merge cache miss leg, then the cache_lookup | |||
| // would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack | |||
| if (cache_lookup_) { | |||
| AddToEOEOpStack(std::move(cache_lookup_)); | |||
| } | |||
| } | |||
| cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used | |||
| is_merge_ = false; | |||
| cache_lookup_.reset(); // If a repeat op did not consume this then it's no longer needed | |||
| return Status::OK(); | |||
| } | |||
| @@ -135,25 +182,32 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified | |||
| // In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here. | |||
| if (is_repeated_) { | |||
| node->set_control_flag(DatasetOp::kDeOpRepeated); | |||
| AddToEOEOpStack(node); | |||
| } else { | |||
| // save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we | |||
| // may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself | |||
| // into the pass so that the decision can be made during the processing of the cache miss leg of the merge. | |||
| cache_lookup_ = std::static_pointer_cast<DatasetOp>(node); | |||
| // Delay the assigment of this leap to the eoe stack and allow the merge op processing to handle that. | |||
| } | |||
| // save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we | |||
| // may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself | |||
| // into the pass so that the decision can be made during the processing of the cache miss leg of the merge. | |||
| // Further, if there's a repeat above the merge but no repeat in the cache miss leg, then the merge op will | |||
| // add the lookup to the eoe stack | |||
| cache_lookup_ = std::static_pointer_cast<DatasetOp>(node); | |||
| return Status::OK(); | |||
| } | |||
| // Adds an operator to the eoe operator stack save area | |||
| void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); } | |||
| void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { | |||
| eoe_op_stack *current_stack = eoe_op_stacks_.top().get(); | |||
| current_stack->push(dataset_op); | |||
| } | |||
| // Pops an operator from the eoe operator stack save area | |||
| std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() { | |||
| std::shared_ptr<DatasetOp> top_op = nullptr; | |||
| if (!eoe_stack_.empty()) { | |||
| top_op = eoe_stack_.top(); | |||
| eoe_stack_.pop(); | |||
| eoe_op_stack *current_stack = eoe_op_stacks_.top().get(); | |||
| if (current_stack != nullptr && !current_stack->empty()) { | |||
| top_op = current_stack->top(); | |||
| current_stack->pop(); | |||
| } | |||
| return top_op; | |||
| } | |||
| @@ -30,6 +30,8 @@ namespace dataset { | |||
| /// to the eoe-producing (typically leaf) nodes underneath it. | |||
| class RepeatPass : public NodePass { | |||
| public: | |||
| using eoe_op_stack = std::stack<std::shared_ptr<DatasetOp>>; | |||
| /// \brief Constructor | |||
| RepeatPass(); | |||
| @@ -39,6 +41,12 @@ class RepeatPass : public NodePass { | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | |||
| /// \brief Identifies the subtree below this node as being in a repeated path of the tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override; | |||
| /// \brief Identifies the subtree below this node as being in a cache merge path | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| @@ -51,6 +59,12 @@ class RepeatPass : public NodePass { | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | |||
| /// \brief Hooks up any identified eoe nodes under this repeat. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override; | |||
| /// \brief CacheOp removes previous leaf ops and replaces them with itself | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| @@ -86,11 +100,11 @@ class RepeatPass : public NodePass { | |||
| /// \return shared_ptr to the popped operator | |||
| std::shared_ptr<DatasetOp> PopFromEOEOpStack(); | |||
| bool is_repeated_; // T/F if we are processing under a repeat | |||
| bool is_merge_; // T/F if we are processing under a cache merge op | |||
| int32_t nested_repeats_; // A counter for nested repeats | |||
| std::stack<std::shared_ptr<DatasetOp>> eoe_stack_; // A save area for leaf/eoe ops | |||
| std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op | |||
| bool is_repeated_; // T/F if we are processing under a repeat | |||
| bool is_merge_; // T/F if we are processing under a cache merge op | |||
| int32_t nested_repeats_; // A counter for nested repeats | |||
| std::stack<std::unique_ptr<eoe_op_stack>> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting) | |||
| std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,82 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "minddata/dataset/engine/opt/pre/injection_pass.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // constructor | |||
| InjectionPass::InjectionFinder::InjectionFinder(InjectionPass *injection_pass) : injection_pass_(injection_pass) {} | |||
| // Performs finder work for BuildVocabOp that has special rules about epoch control injection | |||
| Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) { | |||
| if (injection_pass_) { | |||
| injection_pass_->epoch_ctrl_bypass_ = true; | |||
| return Status::OK(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!"); | |||
| } | |||
| } | |||
| // Temporary code to prevent the injection of epoch control when cache op is present | |||
| // Remove this code in cache op phase 2 | |||
| Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| if (injection_pass_) { | |||
| injection_pass_->epoch_ctrl_bypass_ = true; | |||
| return Status::OK(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!"); | |||
| } | |||
| } | |||
| // constructor | |||
| InjectionPass::InjectionPass() : epoch_ctrl_bypass_(false) {} | |||
| // Runs an injection pass to inject in operators needed at the pre pass stage | |||
| Status InjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||
| MS_LOG(INFO) << "Pre pass: Injection pass started."; | |||
| // First, run the finder to perform any injection info before we can go ahead to drive the op injection work. | |||
| // The finder can make updates to the InjectionPass object. | |||
| InjectionPass::InjectionFinder finder(this); | |||
| finder.Run(tree, modified); | |||
| // The first injection logic is to check if we should inject the epoch control op as the root node. | |||
| // Do not inject the op if the number of epochs is 1. | |||
| int32_t num_epochs = tree->num_epochs(); | |||
| if (num_epochs != 1 && !epoch_ctrl_bypass_) { | |||
| std::shared_ptr<EpochCtrlOp> epoch_ctrl_op; | |||
| RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op)); | |||
| RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op)); | |||
| std::shared_ptr<DatasetOp> node = tree->root(); | |||
| if (std::dynamic_pointer_cast<DeviceQueueOp>(node) == nullptr) { | |||
| tree->root()->InsertAsParent(epoch_ctrl_op); | |||
| } else { | |||
| tree->root()->child(0)->InsertAsParent(epoch_ctrl_op); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Pre pass: Injection pass complete."; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,75 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_ | |||
| #define DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class DatasetOp; | |||
| /// \class InjectionPass injection_pass.h | |||
| /// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api | |||
| /// parsing. | |||
| class InjectionPass : public TreePass { | |||
| /// \class InjectionFinder | |||
| /// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for | |||
| /// operators that need to be injected. It is run first by the main injection pass to find out what operators | |||
| /// it may need to inject. | |||
| class InjectionFinder : public NodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| explicit InjectionFinder(InjectionPass *injection_pass); | |||
| /// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) override; | |||
| /// \brief Temporary code to prevent the injection of epoch control when cache op is present. | |||
| /// Remove this code in cache op phase 2 | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| private: | |||
| InjectionPass *injection_pass_; | |||
| }; | |||
| public: | |||
| /// \brief Constructor | |||
| InjectionPass(); | |||
| /// \brief Runs an injection pass to inject in operators needed at the pre pass stage | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | |||
| private: | |||
| bool epoch_ctrl_bypass_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_ | |||
| @@ -29,20 +29,27 @@ std::shared_ptr<TdtPlugin> TdtPlugin::GetInstance() { | |||
| return instance_ptr_; | |||
| } | |||
| TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time) { | |||
| TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time, | |||
| tdt::TdtDataType tdt_type) { | |||
| MS_LOG(DEBUG) << "TDT channel name is " << channel_name << "."; | |||
| std::vector<DataItem> items; | |||
| double start_time; | |||
| auto ret = translate(ts_row, items); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "TDT converting tensor failed!"; | |||
| return FAILED; | |||
| if (tdt_type == tdt::TDT_TENSOR) { | |||
| auto ret = translate(ts_row, items); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "TDT converting tensor failed!"; | |||
| return FAILED; | |||
| } | |||
| } else if (tdt_type == tdt::TDT_END_OF_SEQUENCE) { | |||
| DataItem data_item; | |||
| data_item.dataType_ = tdt::TDT_END_OF_SEQUENCE; | |||
| items.emplace_back(data_item); | |||
| MS_LOG(INFO) << "TDT data type is TDT_END_OF_SEQUENCE"; | |||
| } | |||
| if (profiling) { | |||
| start_time = ProfilingTime::GetCurMilliSecond(); | |||
| } | |||
| if (tdt::TdtHostPushData(channel_name, items) != 0) { | |||
| MS_LOG(ERROR) << "TDT pushing data failed!"; | |||
| return FAILED; | |||
| } | |||
| if (profiling) { | |||
| @@ -122,8 +129,8 @@ TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector<DataItem> &i | |||
| data_item.dataPtr_ = | |||
| std::shared_ptr<void>(reinterpret_cast<uchar *>(&(*ts->begin<uint8_t>())), [](const void *elem) {}); | |||
| items.emplace_back(data_item); | |||
| MS_LOG(DEBUG) << "TDT data type is " << datatype << ", data shape is " << dataShapes << ", data length is " | |||
| << ts->Size() << "."; | |||
| MS_LOG(INFO) << "TDT data type is TDT_TENSOR, tensor type is " << datatype << ", tensor shape is " << dataShapes | |||
| << ", data length is " << ts->Size() << "."; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -38,7 +38,8 @@ class TdtPlugin { | |||
| public: | |||
| static std::shared_ptr<TdtPlugin> GetInstance(); | |||
| TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time); | |||
| TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time, | |||
| tdt::TdtDataType tdt_type = tdt::TDT_TENSOR); | |||
| private: | |||
| TdtPlugin() {} | |||
| @@ -797,6 +797,9 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba | |||
| (void)InitBackend(); | |||
| } | |||
| #endif | |||
| if (iter_num == -1) { | |||
| iter_num = INT32_MAX; | |||
| } | |||
| if (name == kMsConvert || name == kMsVm) { | |||
| return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run); | |||
| } | |||
| @@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che | |||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ | |||
| check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ | |||
| check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32, check_save | |||
| check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save | |||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | |||
| try: | |||
| @@ -946,14 +946,14 @@ class Dataset: | |||
| raise TypeError("apply_func must return a dataset.") | |||
| return dataset | |||
| @check_positive_int32 | |||
| def device_que(self, prefetch_size=None): | |||
| def device_que(self, prefetch_size=None, send_epoch_end=True): | |||
| """ | |||
| Return a transferredDataset that transfer data through device. | |||
| Args: | |||
| prefetch_size (int, optional): prefetch number of records ahead of the | |||
| user's request (default=None). | |||
| send_epoch_end (bool, optional): whether send end of sequence to device or not.(default=True) | |||
| Note: | |||
| If device is Ascend, features of data will be transferred one by one. The limitation | |||
| @@ -962,15 +962,14 @@ class Dataset: | |||
| Return: | |||
| TransferDataset, dataset for transferring. | |||
| """ | |||
| return self.to_device() | |||
| return self.to_device(send_epoch_end=send_epoch_end) | |||
| @check_positive_int32 | |||
| def to_device(self, num_batch=None): | |||
| def to_device(self, send_epoch_end=True): | |||
| """ | |||
| Transfer data through CPU, GPU or Ascend devices. | |||
| Args: | |||
| num_batch (int, optional): limit the number of batch to be sent to device (default=None). | |||
| send_epoch_end (bool, optional): whether send end of sequence to device or not.(default=True) | |||
| Note: | |||
| If device is Ascend, features of data will be transferred one by one. The limitation | |||
| @@ -982,19 +981,9 @@ class Dataset: | |||
| Raises: | |||
| TypeError: If device_type is empty. | |||
| ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'. | |||
| ValueError: If num_batch is not positive or larger than int_max. | |||
| ValueError: If dataset size is None or 0. | |||
| RuntimeError: If dataset is unknown. | |||
| RuntimeError: If distribution file path is given but failed to read. | |||
| """ | |||
| if self.get_dataset_size() is None or 0: | |||
| raise ValueError("dataset size is None or 0.") | |||
| if num_batch is None: | |||
| num_batch = self.get_dataset_size() | |||
| repeat_count = self.get_repeat_count() | |||
| num_batch = num_batch * repeat_count | |||
| queue_name = str(uuid.uuid1()) | |||
| if context: | |||
| @@ -1008,9 +997,6 @@ class Dataset: | |||
| if device_type not in ('Ascend', 'GPU', 'CPU'): | |||
| raise ValueError("Only support CPU, Ascend, GPU") | |||
| if num_batch == 0: | |||
| raise ValueError("num_batch is 0.") | |||
| def get_distribution(output_dataset): | |||
| dev_id = 0 | |||
| if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2, | |||
| @@ -1032,7 +1018,7 @@ class Dataset: | |||
| distribution_path, device_id = get_distribution(self) | |||
| if distribution_path == "": | |||
| return TransferDataset(self, queue_name, device_id, device_type, num_batch) | |||
| return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end) | |||
| try: | |||
| with open(distribution_path, 'r') as distribution_f: | |||
| dist = json.load(distribution_f) | |||
| @@ -1042,7 +1028,7 @@ class Dataset: | |||
| except Exception: | |||
| raise RuntimeError("Distribution file failed to read") | |||
| return TransferDataset(self, queue_name, device_id, device_type, num_batch) | |||
| return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end) | |||
| @check_save | |||
| def save(self, file_name, num_files=1, file_type='mindrecord'): | |||
| @@ -1072,7 +1058,7 @@ class Dataset: | |||
| return SaveOp(self).save(file_names, file_type) | |||
| def create_tuple_iterator(self, columns=None): | |||
| def create_tuple_iterator(self, columns=None, num_epochs=-1): | |||
| """ | |||
| Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data. | |||
| @@ -1098,9 +1084,9 @@ class Dataset: | |||
| """ | |||
| if self._noop_mode(): | |||
| return DummyIterator(self, 'tuple') | |||
| return TupleIterator(self, columns) | |||
| return TupleIterator(self, columns, num_epochs) | |||
| def create_dict_iterator(self): | |||
| def create_dict_iterator(self, num_epochs=-1): | |||
| """ | |||
| Create an Iterator over the dataset. | |||
| @@ -1123,7 +1109,7 @@ class Dataset: | |||
| """ | |||
| if self._noop_mode(): | |||
| return DummyIterator(self, 'dict') | |||
| return DictIterator(self) | |||
| return DictIterator(self, num_epochs) | |||
| def __iter__(self): | |||
| """Create an Iterator over the dataset.""" | |||
| @@ -1149,7 +1135,7 @@ class Dataset: | |||
| self._batch_size = device_iter.get_batch_size() | |||
| self._num_classes = device_iter.num_classes() | |||
| self._repeat_count = device_iter.get_repeat_count() | |||
| device_iter.release() | |||
| device_iter.stop() | |||
| def output_shapes(self): | |||
| """ | |||
| @@ -2085,7 +2071,7 @@ class RepeatDataset(DatasetOp): | |||
| """ | |||
| child_size = self.children[0].get_dataset_size() | |||
| if child_size is not None: | |||
| return child_size | |||
| return child_size * self.count | |||
| return None | |||
| def get_repeat_count(self): | |||
| @@ -2097,7 +2083,6 @@ class RepeatDataset(DatasetOp): | |||
| """ | |||
| return self.count | |||
| class SkipDataset(DatasetOp): | |||
| """ | |||
| The result of applying Skip operator to the input Dataset. | |||
| @@ -2317,10 +2302,10 @@ class TransferDataset(DatasetOp): | |||
| queue_name (str): Name of device queue. | |||
| device_id (int): Id of device. | |||
| device_type (str): Type of device, including "CPU", "GPU", and "Ascend". | |||
| num_batch (int): limit the number of batch to be sent to device (default=None). | |||
| send_epoch_end (bool, optional): Whether send end of sequence to device or not.(default=True) | |||
| """ | |||
| def __init__(self, input_dataset, queue_name, device_id, device_type, num_batch=None): | |||
| def __init__(self, input_dataset, queue_name, device_id, device_type, send_epoch_end=True): | |||
| super().__init__() | |||
| self.children.append(input_dataset) | |||
| input_dataset.parent.append(self) | |||
| @@ -2328,7 +2313,7 @@ class TransferDataset(DatasetOp): | |||
| self._input_indexs = input_dataset.input_indexs | |||
| self._device_type = device_type | |||
| self._device_id = device_id | |||
| self.__num_batch = num_batch | |||
| self._send_epoch_end = send_epoch_end | |||
| self.iterator = None | |||
| def get_args(self): | |||
| @@ -2336,13 +2321,13 @@ class TransferDataset(DatasetOp): | |||
| args["queue_name"] = self.queue_name | |||
| args["device_type"] = self._device_type | |||
| args["device_id"] = self._device_id | |||
| args["num_batch"] = self.__num_batch | |||
| args["send_epoch_end"] = self._send_epoch_end | |||
| return args | |||
| def create_dict_iterator(self): | |||
| def create_dict_iterator(self, num_epochs=-1): | |||
| raise RuntimeError("TransferDataset is not iterable") | |||
| def create_tuple_iterator(self, columns=None): | |||
| def create_tuple_iterator(self, columns=None, num_epochs=-1): | |||
| raise RuntimeError("TransferDataset is not iterable") | |||
| def __iter__(self): | |||
| @@ -2354,12 +2339,14 @@ class TransferDataset(DatasetOp): | |||
| def output_types(self): | |||
| raise RuntimeError("TransferDataset does not support output_types") | |||
| def send(self): | |||
| def send(self, num_epochs=-1): | |||
| # need to keep iterator alive so the executionTree is not destroyed | |||
| if self._noop_mode(): | |||
| return | |||
| self.iterator = TupleIterator(self) | |||
| self.iterator = TupleIterator(self, num_epochs=-1) | |||
| def stop_send(self): | |||
| self.iterator.depipeline.StopSend() | |||
| class RangeDataset(MappableDataset): | |||
| """ | |||
| @@ -29,7 +29,6 @@ from . import datasets as de | |||
| ITERATORS_LIST = list() | |||
| def _cleanup(): | |||
| """Release all the Iterator.""" | |||
| for itr_ref in ITERATORS_LIST: | |||
| @@ -60,7 +59,6 @@ def _alter_node(node): | |||
| node.iterator_bootstrap() | |||
| return node | |||
| class Iterator: | |||
| """ | |||
| General Iterator over a dataset. | |||
| @@ -69,10 +67,21 @@ class Iterator: | |||
| dataset: Dataset to be iterated over | |||
| """ | |||
| def __init__(self, dataset): | |||
| def __init__(self, dataset, num_epochs=-1): | |||
| self.num_epochs = num_epochs | |||
| ITERATORS_LIST.append(weakref.ref(self)) | |||
| # create a copy of tree and work on it. | |||
| self.dataset = copy.deepcopy(dataset) | |||
| self.parent_subtree = [] | |||
| # The dataset passed into the iterator is not the root of the tree. | |||
| # Trim the tree by saving the parent subtree into self.parent_subtree and | |||
| # restore it after launching our c++ pipeline. | |||
| if self.dataset.parent: | |||
| logger.warning("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.") | |||
| self.parent_subtree = self.dataset.parent | |||
| self.dataset.parent = [] | |||
| self.dataset = alter_tree(self.dataset) | |||
| if not self.__is_tree(): | |||
| raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") | |||
| @@ -83,9 +92,17 @@ class Iterator: | |||
| root = self.__convert_node_postorder(self.dataset) | |||
| self.depipeline.AssignRootNode(root) | |||
| self.depipeline.LaunchTreeExec() | |||
| self.depipeline.LaunchTreeExec(self.num_epochs) | |||
| self._index = 0 | |||
| def stop(self): | |||
| """ | |||
| Manually terminate python iterator instead of relying on out of scope destruction. | |||
| """ | |||
| logger.info("terminating python iterator. This will also terminate c++ pipeline.") | |||
| if hasattr(self, 'depipeline') and self.depipeline: | |||
| del self.depipeline | |||
| def __is_tree_node(self, node): | |||
| """Check if a node is tree node.""" | |||
| if not node.children: | |||
| @@ -214,9 +231,14 @@ class Iterator: | |||
| @abstractmethod | |||
| def get_next(self): | |||
| pass | |||
| raise RuntimeError("Calling base class Iterator's get_next is invalid.") | |||
| def __next__(self): | |||
| if not self.depipeline: | |||
| logger.warning("Iterator does not have a running c++ pipeline." + | |||
| "It can be because Iterator stop() had been called, or c++ pipeline crashed silently.") | |||
| raise RuntimeError("Iterator does not have a running c++ pipeline.") | |||
| data = self.get_next() | |||
| if not data: | |||
| if self._index == 0: | |||
| @@ -293,12 +315,12 @@ class TupleIterator(Iterator): | |||
| def check_node_type(self, node): | |||
| pass | |||
| def __init__(self, dataset, columns=None): | |||
| def __init__(self, dataset, columns=None, num_epochs=-1): | |||
| if columns is not None: | |||
| if not isinstance(columns, list): | |||
| columns = [columns] | |||
| dataset = dataset.project(columns) | |||
| super().__init__(dataset) | |||
| super().__init__(dataset, num_epochs) | |||
| def __iter__(self): | |||
| return self | |||
| @@ -57,7 +57,8 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): | |||
| # transform data format | |||
| dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) | |||
| exec_dataset = exec_dataset.device_que() | |||
| send_epoch_end = bool(dataset_size == -1) | |||
| exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end) | |||
| _executor.init_dataset(exec_dataset.queue_name, | |||
| dataset_size, | |||
| @@ -126,7 +127,7 @@ def _construct_tensor_list(types, shapes, batch_expand_num=1): | |||
| def _to_tensor(elem, scaling_sens=None): | |||
| """Conver numpy to tensor, adapt to minddata feed solution.""" | |||
| """Convert numpy to tensor, adapt to feed the data from host solution.""" | |||
| lst = [] | |||
| if not isinstance(elem, (tuple, list)): | |||
| elem = [elem] | |||
| @@ -145,7 +146,8 @@ def _to_tensor(elem, scaling_sens=None): | |||
| def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None): | |||
| """Conver numpy to tensor, expanding batch dimension according to device_num, adapt to minddata feed solution.""" | |||
| """Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data | |||
| from host solution.""" | |||
| lst = [] | |||
| if not isinstance(elem, (tuple, list)): | |||
| elem = [elem] | |||
| @@ -16,7 +16,7 @@ | |||
| import math | |||
| import os | |||
| from mindspore._checkparam import check_bool | |||
| from mindspore._checkparam import check_bool, check_int | |||
| from .. import context | |||
| from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \ | |||
| _construct_tensor_list, _to_full_shapes, _to_full_tensor | |||
| @@ -42,17 +42,23 @@ class DatasetHelper: | |||
| The iter of DatasetHelper will give one epoch data. | |||
| Args: | |||
| dataset (DataSet): The dataset. | |||
| dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. | |||
| Default: True. | |||
| dataset (DataSet): The training dataset iterator. | |||
| dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True. | |||
| sink_size (int): Control the amount of data each sink. | |||
| If sink_size=-1, sink the complete dataset each epoch. | |||
| If sink_size>0, sink sink_size data each epoch. Default: -1. | |||
| Examples: | |||
| >>> dataset_helper = DatasetHelper(dataset) | |||
| >>> for inputs in dataset_helper: | |||
| >>> outputs = network(*inputs) | |||
| """ | |||
| def __init__(self, dataset, dataset_sink_mode=True): | |||
| def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1): | |||
| check_bool(dataset_sink_mode) | |||
| check_int(sink_size) | |||
| if sink_size < -1 or sink_size == 0: | |||
| raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) | |||
| if dataset_sink_mode: | |||
| if context.get_context("enable_ge"): | |||
| @@ -68,9 +74,10 @@ class DatasetHelper: | |||
| iterclass = _DatasetIterMS | |||
| elif context.get_context("device_target") == "CPU": | |||
| raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.") | |||
| self.iter = iterclass(dataset, sink_size) | |||
| else: | |||
| iterclass = _DatasetIterFeed | |||
| self.iter = iterclass(dataset) | |||
| iterclass = _DatasetIterNormal | |||
| self.iter = iterclass(dataset) | |||
| def __iter__(self): | |||
| return self.iter.__iter__() | |||
| @@ -80,21 +87,26 @@ class DatasetHelper: | |||
| """Get the types and shapes from dataset on current config.""" | |||
| return self.iter.types_shapes() | |||
| def loop_size(self): | |||
| """Get loop_size for every iteration.""" | |||
| return self.iter.loop_size | |||
| def sink_size(self): | |||
| """Get sink_size for every iteration.""" | |||
| return self.iter.get_sink_size() | |||
| def stop_send(self): | |||
| """Free up resources about data sink.""" | |||
| self.iter.stop_send() | |||
| class _DatasetIter: | |||
| """Base iter for dataset help""" | |||
| def __init__(self, dataset): | |||
| if not hasattr(dataset, '__loop_size__'): | |||
| self.loop_size = dataset.get_dataset_size() | |||
| else: | |||
| self.loop_size = dataset.__loop_size__ | |||
| """Base iter for dataset helper""" | |||
| def __init__(self, dataset, sink_size): | |||
| self.dataset = dataset | |||
| self.sink_size = sink_size | |||
| self.sink_count = 1 | |||
| if not hasattr(dataset, '__ME_INITED__'): | |||
| dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) | |||
| if not hasattr(dataset, '__TRANSFER_DATASET__'): | |||
| if hasattr(dataset, '__loop_size__'): | |||
| self.sink_size = dataset.__loop_size__ | |||
| dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) | |||
| dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name | |||
| if not hasattr(dataset, '__no_send__'): | |||
| @@ -102,43 +114,70 @@ class _DatasetIter: | |||
| else: | |||
| _send_data(dataset) | |||
| self.ind = 0 | |||
| self.dataset = dataset | |||
| dataset_types, dataset_shapes = _get_types_and_shapes(dataset) | |||
| self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes | |||
| self.stop_send = dataset.__TRANSFER_DATASET__.stop_send | |||
| self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) | |||
| def __iter__(self): | |||
| self.ind = 0 | |||
| self.index = 0 | |||
| return self | |||
| def __next__(self): | |||
| if self.ind >= self.loop_count: | |||
| if self.index >= self.sink_count: | |||
| raise StopIteration() | |||
| self.ind += 1 | |||
| self.index += 1 | |||
| return self.op() | |||
| def types_shapes(self): | |||
| return self.dataset_types, self.dataset_shapes | |||
| def get_loop_count(self, dataset): | |||
| loop_count = 1 | |||
| def get_sink_count(self, dataset): | |||
| sink_count = 1 | |||
| if hasattr(dataset, '__loop_size__'): | |||
| loop_size = dataset.__loop_size__ | |||
| if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0: | |||
| raise ValueError(f'Dataset size {dataset.get_dataset_size()} and ' | |||
| f'loop_size {loop_size} are not matched.') | |||
| loop_count = math.ceil(dataset.get_dataset_size() / loop_size) | |||
| return loop_count | |||
| f'sink_size {loop_size} are not matched.') | |||
| sink_count = math.ceil(dataset.get_dataset_size() / loop_size) | |||
| return sink_count | |||
| def get_sink_size(self): | |||
| """get sink_size to device""" | |||
| sink_size = 1 | |||
| if hasattr(self.dataset, '__loop_size__'): | |||
| sink_size = self.dataset.__loop_size__ | |||
| else: | |||
| if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend": | |||
| if self.sink_size > 0: | |||
| sink_size = self.sink_size | |||
| else: | |||
| sink_size = self.dataset.get_dataset_size() | |||
| return sink_size | |||
| class _DatasetIterGE(_DatasetIter): | |||
| """Iter for GE.""" | |||
| def __init__(self, dataset, sink_size): | |||
| super().__init__(dataset, sink_size) | |||
| self.sink_count = self.get_sink_count(dataset) | |||
| batch_expand_num = 1 | |||
| if _need_to_full(): | |||
| batch_expand_num = _get_device_num() | |||
| tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num) | |||
| def op(): | |||
| return tensor_list_run | |||
| self.op = op | |||
| class _DatasetIterMSLoopSink(_DatasetIter): | |||
| """Iter for context (device_target=Ascend)""" | |||
| def __init__(self, dataset): | |||
| super(_DatasetIterMSLoopSink, self).__init__(dataset) | |||
| self.loop_count = self.get_loop_count(dataset) | |||
| def __init__(self, dataset, sink_size): | |||
| super().__init__(dataset, sink_size) | |||
| self.sink_count = self.get_sink_count(dataset) | |||
| ms_role = os.getenv("MS_ROLE") | |||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | |||
| self.loop_count = 1 | |||
| self.sink_count = 1 | |||
| # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch, | |||
| # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for | |||
| # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink. | |||
| @@ -153,66 +192,42 @@ class _DatasetIterMSLoopSink(_DatasetIter): | |||
| class _DatasetIterMS(_DatasetIter): | |||
| """Iter for context (device_target=GPU)""" | |||
| def __init__(self, dataset): | |||
| super(_DatasetIterMS, self).__init__(dataset) | |||
| self.loop_count = dataset.get_dataset_size() | |||
| self.loop_size = 1 | |||
| """Iter for MS(enable_loop_sink=False).""" | |||
| def __init__(self, dataset, sink_size): | |||
| super().__init__(dataset, sink_size) | |||
| if sink_size > 0: | |||
| self.sink_count = sink_size | |||
| else: | |||
| self.sink_count = dataset.get_dataset_size() | |||
| queue_name = dataset.__ME_INITED__ | |||
| self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) | |||
| class _DatasetIterPSLite(_DatasetIter): | |||
| """Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED""" | |||
| def __init__(self, dataset): | |||
| super(_DatasetIterPSLite, self).__init__(dataset) | |||
| self.loop_count = 1 | |||
| self.loop_size = 1 | |||
| def __init__(self, dataset, sink_size): | |||
| super().__init__(dataset, sink_size) | |||
| self.sink_count = 1 | |||
| self.sink_size = 1 | |||
| self.op = None | |||
| def op(): | |||
| return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1) | |||
| self.op = op | |||
| class _DatasetIterGE(_DatasetIter): | |||
| """Iter for ge""" | |||
| def __init__(self, dataset): | |||
| super(_DatasetIterGE, self).__init__(dataset) | |||
| self.loop_count = self.get_loop_count(dataset) | |||
| batch_expand_num = 1 | |||
| if _need_to_full(): | |||
| batch_expand_num = _get_device_num() | |||
| tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num) | |||
| def op(): | |||
| return tensor_list_run | |||
| self.op = op | |||
| class _DatasetIterFeed: | |||
| class _DatasetIterNormal: | |||
| """Iter for normal(non sink) mode, feed the data from host.""" | |||
| def __init__(self, dataset): | |||
| self.dataset = dataset | |||
| self.device_num = _get_device_num() | |||
| self.global_rank = _get_global_rank() | |||
| self.repeat_count = dataset.get_repeat_count() | |||
| self.repeat_ind = 0 | |||
| self.loop_count = dataset.get_dataset_size() | |||
| self.ind = 0 | |||
| def __iter__(self): | |||
| if self.repeat_ind % self.repeat_count == 0: | |||
| self.iter = self.dataset.__iter__() | |||
| self.repeat_ind += 1 | |||
| self.ind = 0 | |||
| self.iter = self.dataset.create_tuple_iterator() | |||
| return self | |||
| def __next__(self): | |||
| if self.ind >= self.loop_count: | |||
| raise StopIteration() | |||
| self.ind += 1 | |||
| data = self.iter.__next__() | |||
| if _need_to_full(): | |||
| return _to_full_tensor(data, self.device_num, self.global_rank) | |||
| @@ -21,7 +21,7 @@ import numpy as np | |||
| from mindspore import log as logger | |||
| from ..common.tensor import Tensor | |||
| from ..nn.metrics import get_metrics | |||
| from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool | |||
| from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int | |||
| from .callback import _InternalCallbackParam, RunContext, _CallbackManager | |||
| from .. import context | |||
| from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | |||
| @@ -225,7 +225,7 @@ class Model: | |||
| scaling_sens /= self._device_number | |||
| return scaling_sens | |||
| def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode): | |||
| def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1): | |||
| """Initializes dataset.""" | |||
| need_wrap = False | |||
| if dataset_sink_mode: | |||
| @@ -237,7 +237,7 @@ class Model: | |||
| if not is_train: | |||
| dataset.__loop_size__ = 1 | |||
| dataset_helper = DatasetHelper(dataset, dataset_sink_mode) | |||
| dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size) | |||
| # remove later to deal with loop sink | |||
| if need_wrap: | |||
| @@ -317,7 +317,7 @@ class Model: | |||
| self._eval_network.compile(*inputs) | |||
| break | |||
| def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): | |||
| def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1): | |||
| """ | |||
| Training. | |||
| @@ -332,6 +332,7 @@ class Model: | |||
| dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. | |||
| Configure pynative mode, the training process will be performed with | |||
| dataset not sink. | |||
| sink_size (int): Control the amount of data each sink. Default: -1. | |||
| """ | |||
| epoch = check_int_positive(epoch) | |||
| self._train_network.set_train() | |||
| @@ -342,7 +343,10 @@ class Model: | |||
| cb_params = _InternalCallbackParam() | |||
| cb_params.train_network = self._train_network | |||
| cb_params.epoch_num = epoch | |||
| cb_params.batch_num = train_dataset.get_dataset_size() | |||
| if dataset_sink_mode and sink_size > 0: | |||
| cb_params.batch_num = sink_size | |||
| else: | |||
| cb_params.batch_num = train_dataset.get_dataset_size() | |||
| cb_params.mode = "train" | |||
| cb_params.loss_fn = self._loss_fn | |||
| cb_params.optimizer = self._optimizer | |||
| @@ -364,7 +368,7 @@ class Model: | |||
| "So the training process will be performed with dataset not sink.") | |||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||
| else: | |||
| self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) | |||
| self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size) | |||
| @staticmethod | |||
| def _transform_callbacks(callbacks): | |||
| @@ -377,7 +381,7 @@ class Model: | |||
| return [callbacks] | |||
| def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): | |||
| def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1): | |||
| """ | |||
| Training process. The data would be passed to network through dataset channel. | |||
| @@ -390,17 +394,18 @@ class Model: | |||
| function respectively. | |||
| list_callback (Callback): Executor of callback list. Default: None. | |||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | |||
| sink_size (int): Control the amount of data each sink. Default: -1. | |||
| """ | |||
| dataset_helper, train_network = self._exec_preprocess(self._train_network, | |||
| is_train=True, | |||
| phase='train', | |||
| dataset=train_dataset, | |||
| dataset_sink_mode=True) | |||
| dataset_sink_mode=True, | |||
| sink_size=sink_size) | |||
| self._train_network = train_network | |||
| cb_params.train_network = self._train_network | |||
| cb_params.cur_step_num = 0 | |||
| loop_size = dataset_helper.loop_size() | |||
| run_context = RunContext(cb_params) | |||
| list_callback.begin(run_context) | |||
| @@ -412,9 +417,9 @@ class Model: | |||
| # for data sink dataset_helper only iter once, other wise iter epoch_size times. | |||
| for inputs in dataset_helper: | |||
| cb_params.cur_step_num += loop_size | |||
| list_callback.step_begin(run_context) | |||
| outputs = self._train_network(*inputs) | |||
| cb_params.cur_step_num += dataset_helper.sink_size() | |||
| cb_params.net_outputs = outputs | |||
| list_callback.step_end(run_context) | |||
| @@ -422,6 +427,7 @@ class Model: | |||
| should_stop = should_stop or run_context.get_stop_requested() | |||
| if should_stop: | |||
| break | |||
| dataset_helper.stop_send() | |||
| list_callback.end(run_context) | |||
| @@ -490,7 +496,7 @@ class Model: | |||
| list_callback.end(run_context) | |||
| def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): | |||
| def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1): | |||
| """ | |||
| Training API where the iteration is controlled by python front-end. | |||
| @@ -515,7 +521,10 @@ class Model: | |||
| dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. | |||
| Configure pynative mode, the training process will be performed with | |||
| dataset not sink. | |||
| sink_size (int): Control the amount of data each sink. | |||
| If sink_size=-1, sink the complete dataset each epoch. | |||
| If sink_size>0, sink sink_size data each epoch. | |||
| If dataset_sink_mode is False, set sink_size invalid. Default: -1. | |||
| Examples: | |||
| >>> dataset = get_dataset() | |||
| @@ -526,17 +535,19 @@ class Model: | |||
| >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) | |||
| >>> model.train(2, dataset) | |||
| """ | |||
| repeat_count = train_dataset.get_repeat_count() | |||
| if epoch != repeat_count and dataset_sink_mode is True: | |||
| logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}") | |||
| check_bool(dataset_sink_mode) | |||
| check_int(sink_size) | |||
| if sink_size < -1 or sink_size == 0: | |||
| raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) | |||
| _device_number_check(self._parallel_mode, self._device_number) | |||
| _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) | |||
| self._train(epoch, | |||
| train_dataset, | |||
| callbacks=callbacks, | |||
| dataset_sink_mode=dataset_sink_mode) | |||
| dataset_sink_mode=dataset_sink_mode, | |||
| sink_size=sink_size) | |||
| def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None): | |||
| """ | |||
| @@ -43,7 +43,7 @@ if __name__ == "__main__": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, cfg.epoch_size) | |||
| ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, 1) | |||
| network = AlexNet(cfg.num_classes) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size())) | |||
| @@ -57,7 +57,7 @@ if __name__ == '__main__': | |||
| ds_train = create_dataset(args_opt.dataset_path, | |||
| train_mode=True, | |||
| epochs=train_config.train_epochs, | |||
| epochs=1, | |||
| batch_size=train_config.batch_size, | |||
| data_type=DataType(data_config.data_format), | |||
| rank_size=rank_size, | |||
| @@ -82,7 +82,7 @@ if __name__ == '__main__': | |||
| if args_opt.do_eval: | |||
| ds_eval = create_dataset(args_opt.dataset_path, train_mode=False, | |||
| epochs=train_config.train_epochs, | |||
| epochs=1, | |||
| batch_size=train_config.batch_size, | |||
| data_type=DataType(data_config.data_format)) | |||
| eval_callback = EvalCallBack(model, ds_eval, auc_metric, | |||
| @@ -66,7 +66,7 @@ if __name__ == "__main__": | |||
| init() | |||
| args_opt.base_size = config.crop_size | |||
| args_opt.crop_size = config.crop_size | |||
| train_dataset = create_dataset(args_opt, args_opt.data_url, config.epoch_size, config.batch_size, usage="train") | |||
| train_dataset = create_dataset(args_opt, args_opt.data_url, 1, config.batch_size, usage="train") | |||
| dataset_size = train_dataset.get_dataset_size() | |||
| time_cb = TimeMonitor(data_size=dataset_size) | |||
| callback = [time_cb, LossCallBack()] | |||
| @@ -94,7 +94,7 @@ if __name__ == '__main__': | |||
| loss_scale = float(config.loss_scale) | |||
| # When create MindDataset, using the fitst mindrecord file, such as FasterRcnn.mindrecord0. | |||
| dataset = create_fasterrcnn_dataset(mindrecord_file, repeat_num=config.epoch_size, | |||
| dataset = create_fasterrcnn_dataset(mindrecord_file, repeat_num=1, | |||
| batch_size=config.batch_size, device_num=device_num, rank_id=rank) | |||
| dataset_size = dataset.get_dataset_size() | |||
| @@ -78,7 +78,7 @@ if __name__ == '__main__': | |||
| mirror_mean=True) | |||
| init() | |||
| dataset = create_dataset(cfg.data_path, cfg.epoch_size) | |||
| dataset = create_dataset(cfg.data_path, 1) | |||
| batch_num = dataset.get_dataset_size() | |||
| net = GoogleNet(num_classes=cfg.num_classes) | |||
| @@ -45,8 +45,7 @@ if __name__ == "__main__": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), | |||
| cfg.batch_size, | |||
| cfg.epoch_size) | |||
| cfg.batch_size) | |||
| network = LeNet5(cfg.num_classes) | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| @@ -44,7 +44,7 @@ args = parser.parse_args() | |||
| if __name__ == "__main__": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size) | |||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1) | |||
| step_size = ds_train.get_dataset_size() | |||
| # define fusion network | |||
| @@ -77,7 +77,7 @@ if __name__ == '__main__': | |||
| model = Model(network, loss, opt, {'acc': Accuracy()}) | |||
| print("============== Starting Training ==============") | |||
| ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs) | |||
| ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, 1) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) | |||
| @@ -249,7 +249,7 @@ def train_parallel(config: TransformerConfig): | |||
| pre_train_dataset = load_dataset( | |||
| data_files=config.pre_train_dataset, | |||
| batch_size=config.batch_size, epoch_count=config.epochs, | |||
| batch_size=config.batch_size, epoch_count=1, | |||
| sink_mode=config.dataset_sink_mode, | |||
| sink_step=config.dataset_sink_step, | |||
| rank_size=MultiAscend.get_group_size(), | |||
| @@ -257,7 +257,7 @@ def train_parallel(config: TransformerConfig): | |||
| ) if config.pre_train_dataset else None | |||
| fine_tune_dataset = load_dataset( | |||
| data_files=config.fine_tune_dataset, | |||
| batch_size=config.batch_size, epoch_count=config.epochs, | |||
| batch_size=config.batch_size, epoch_count=1, | |||
| sink_mode=config.dataset_sink_mode, | |||
| sink_step=config.dataset_sink_step, | |||
| rank_size=MultiAscend.get_group_size(), | |||
| @@ -265,7 +265,7 @@ def train_parallel(config: TransformerConfig): | |||
| ) if config.fine_tune_dataset else None | |||
| test_dataset = load_dataset( | |||
| data_files=config.test_dataset, | |||
| batch_size=config.batch_size, epoch_count=config.epochs, | |||
| batch_size=config.batch_size, epoch_count=1, | |||
| sink_mode=config.dataset_sink_mode, | |||
| sink_step=config.dataset_sink_step, | |||
| rank_size=MultiAscend.get_group_size(), | |||
| @@ -288,17 +288,17 @@ def train_single(config: TransformerConfig): | |||
| print(" | Starting training on single device.") | |||
| pre_train_dataset = load_dataset(data_files=config.pre_train_dataset, | |||
| batch_size=config.batch_size, | |||
| epoch_count=config.epochs, | |||
| epoch_count=1, | |||
| sink_mode=config.dataset_sink_mode, | |||
| sink_step=config.dataset_sink_step) if config.pre_train_dataset else None | |||
| fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset, | |||
| batch_size=config.batch_size, | |||
| epoch_count=config.epochs, | |||
| epoch_count=1, | |||
| sink_mode=config.dataset_sink_mode, | |||
| sink_step=config.dataset_sink_step) if config.fine_tune_dataset else None | |||
| test_dataset = load_dataset(data_files=config.test_dataset, | |||
| batch_size=config.batch_size, | |||
| epoch_count=config.epochs, | |||
| epoch_count=1, | |||
| sink_mode=config.dataset_sink_mode, | |||
| sink_step=config.dataset_sink_step) if config.test_dataset else None | |||
| @@ -180,7 +180,7 @@ if __name__ == '__main__': | |||
| do_train=True, | |||
| config=config_gpu, | |||
| platform=args_opt.platform, | |||
| repeat_num=epoch_size, | |||
| repeat_num=1, | |||
| batch_size=config_gpu.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| # resume | |||
| @@ -239,7 +239,7 @@ if __name__ == '__main__': | |||
| do_train=True, | |||
| config=config_ascend, | |||
| platform=args_opt.platform, | |||
| repeat_num=epoch_size, | |||
| repeat_num=1, | |||
| batch_size=config_ascend.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| if args_opt.pre_trained: | |||
| @@ -86,7 +86,7 @@ if __name__ == '__main__': | |||
| do_train=True, | |||
| config=config, | |||
| device_target=args_opt.device_target, | |||
| repeat_num=epoch_size, | |||
| repeat_num=1, | |||
| batch_size=config.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| # load pre trained ckpt | |||
| @@ -181,7 +181,7 @@ if __name__ == '__main__': | |||
| do_train=True, | |||
| config=config_gpu, | |||
| platform=args_opt.platform, | |||
| repeat_num=epoch_size, | |||
| repeat_num=1, | |||
| batch_size=config_gpu.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| # resume | |||
| @@ -240,7 +240,7 @@ if __name__ == '__main__': | |||
| do_train=True, | |||
| config=config_ascend, | |||
| platform=args_opt.platform, | |||
| repeat_num=epoch_size, | |||
| repeat_num=1, | |||
| batch_size=config_ascend.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| if args_opt.pre_trained: | |||
| @@ -36,12 +36,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| _cur_dir = os.getcwd() | |||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): | |||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||
| """ do train """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||
| steps_per_epoch = dataset.get_dataset_size() | |||
| epoch_num = dataset.get_repeat_count() | |||
| # optimizer | |||
| if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': | |||
| optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), | |||
| @@ -176,11 +175,11 @@ def run_classifier(): | |||
| assessment_method=assessment_method) | |||
| if args_opt.do_train.lower() == "true": | |||
| ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, | |||
| ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, | |||
| assessment_method=assessment_method, | |||
| data_file_path=args_opt.train_data_file_path, | |||
| schema_file_path=args_opt.schema_file_path) | |||
| do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) | |||
| do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) | |||
| if args_opt.do_eval.lower() == "true": | |||
| if save_finetune_checkpoint_path == "": | |||
| @@ -191,7 +190,7 @@ def run_classifier(): | |||
| ds.get_dataset_size(), epoch_num, "classifier") | |||
| if args_opt.do_eval.lower() == "true": | |||
| ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, | |||
| ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, | |||
| assessment_method=assessment_method, | |||
| data_file_path=args_opt.eval_data_file_path, | |||
| schema_file_path=args_opt.schema_file_path) | |||
| @@ -38,12 +38,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| _cur_dir = os.getcwd() | |||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): | |||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||
| """ do train """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||
| steps_per_epoch = dataset.get_dataset_size() | |||
| epoch_num = dataset.get_repeat_count() | |||
| # optimizer | |||
| if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': | |||
| optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), | |||
| @@ -204,10 +203,10 @@ def run_ner(): | |||
| use_crf=(args_opt.use_crf.lower() == "true"), | |||
| tag_to_index=tag_to_index, dropout_prob=0.1) | |||
| if args_opt.do_train.lower() == "true": | |||
| ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, | |||
| ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, | |||
| assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, | |||
| schema_file_path=args_opt.schema_file_path) | |||
| do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) | |||
| do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) | |||
| if args_opt.do_eval.lower() == "true": | |||
| if save_finetune_checkpoint_path == "": | |||
| @@ -218,7 +217,7 @@ def run_ner(): | |||
| ds.get_dataset_size(), epoch_num, "ner") | |||
| if args_opt.do_eval.lower() == "true": | |||
| ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, | |||
| ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, | |||
| assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path, | |||
| schema_file_path=args_opt.schema_file_path) | |||
| do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path, | |||
| @@ -100,11 +100,12 @@ def run_pretrain(): | |||
| bert_net_cfg.compute_type = mstype.float32 | |||
| ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, | |||
| args_opt.enable_data_sink, args_opt.data_sink_steps, | |||
| args_opt.data_dir, args_opt.schema_dir) | |||
| ds = create_bert_dataset(1, device_num, rank, args_opt.do_shuffle, | |||
| args_opt.enable_data_sink, args_opt.data_sink_steps, | |||
| args_opt.data_dir, args_opt.schema_dir) | |||
| new_repeat_count = args_opt.epoch_size | |||
| if args_opt.train_steps > 0: | |||
| new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) | |||
| new_repeat_count = min(args_opt.epoch_size, args_opt.train_steps // args_opt.data_sink_steps) | |||
| netwithloss = BertNetworkWithLoss(bert_net_cfg, True) | |||
| if cfg.optimizer == 'Lamb': | |||
| @@ -38,12 +38,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| _cur_dir = os.getcwd() | |||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): | |||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||
| """ do train """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||
| steps_per_epoch = dataset.get_dataset_size() | |||
| epoch_num = dataset.get_repeat_count() | |||
| # optimizer | |||
| if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': | |||
| optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), | |||
| @@ -181,10 +180,10 @@ def run_squad(): | |||
| netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1) | |||
| if args_opt.do_train.lower() == "true": | |||
| ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, | |||
| ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, | |||
| data_file_path=args_opt.train_data_file_path, | |||
| schema_file_path=args_opt.schema_file_path) | |||
| do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) | |||
| do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) | |||
| if args_opt.do_eval.lower() == "true": | |||
| if save_finetune_checkpoint_path == "": | |||
| load_finetune_checkpoint_dir = _cur_dir | |||
| @@ -194,7 +193,7 @@ def run_squad(): | |||
| ds.get_dataset_size(), epoch_num, "squad") | |||
| if args_opt.do_eval.lower() == "true": | |||
| ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, | |||
| ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, | |||
| data_file_path=args_opt.eval_data_file_path, | |||
| schema_file_path=args_opt.schema_file_path, is_training=False) | |||
| do_eval(ds, args_opt.vocab_file_path, args_opt.eval_json_path, | |||
| @@ -54,7 +54,6 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e | |||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||
| # apply batch operations | |||
| ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) | |||
| ds = ds.repeat(max(new_repeat_count, repeat_count)) | |||
| logger.info("data size: {}".format(ds.get_dataset_size())) | |||
| logger.info("repeatcount: {}".format(ds.get_repeat_count())) | |||
| return ds, new_repeat_count | |||
| @@ -17,7 +17,6 @@ | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset.engine.datasets as de | |||
| import mindspore.dataset.transforms.c_transforms as deC | |||
| from mindspore import log as logger | |||
| from .config import transformer_net_cfg | |||
| def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", enable_data_sink="true", | |||
| @@ -42,7 +41,4 @@ def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle | |||
| ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True) | |||
| ds = ds.repeat(repeat_count) | |||
| ds.channel_name = 'transformer' | |||
| logger.info("data size: {}".format(ds.get_dataset_size())) | |||
| logger.info("repeatcount: {}".format(ds.get_repeat_count())) | |||
| return ds, repeat_count | |||
| return ds | |||
| @@ -125,10 +125,10 @@ def run_transformer_train(): | |||
| else: | |||
| device_num = 1 | |||
| rank_id = 0 | |||
| dataset, repeat_count = create_transformer_dataset(epoch_count=args.epoch_size, rank_size=device_num, | |||
| rank_id=rank_id, do_shuffle=args.do_shuffle, | |||
| enable_data_sink=args.enable_data_sink, | |||
| dataset_path=args.data_path) | |||
| dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num, | |||
| rank_id=rank_id, do_shuffle=args.do_shuffle, | |||
| enable_data_sink=args.enable_data_sink, | |||
| dataset_path=args.data_path) | |||
| netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True) | |||
| @@ -165,7 +165,7 @@ def run_transformer_train(): | |||
| netwithgrads.set_train(True) | |||
| model = Model(netwithgrads) | |||
| model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true")) | |||
| model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true")) | |||
| if __name__ == '__main__': | |||
| run_transformer_train() | |||
| @@ -88,10 +88,10 @@ if __name__ == '__main__': | |||
| # create dataset | |||
| if args_opt.net == "resnet50": | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=config.epoch_size, | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1, | |||
| batch_size=config.batch_size, target=target) | |||
| else: | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=config.epoch_size, | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1, | |||
| batch_size=config.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| @@ -105,7 +105,7 @@ if __name__ == '__main__': | |||
| loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) | |||
| if args_opt.do_train: | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, | |||
| repeat_num=epoch_size, batch_size=config.batch_size) | |||
| batch_size=config.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | |||
| @@ -91,7 +91,7 @@ def main(): | |||
| loss_scale = float(args_opt.loss_scale) | |||
| # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. | |||
| dataset = create_ssd_dataset(mindrecord_file, repeat_num=args_opt.epoch_size, | |||
| dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, | |||
| batch_size=args_opt.batch_size, device_num=device_num, rank=rank) | |||
| dataset_size = dataset.get_dataset_size() | |||
| @@ -83,7 +83,7 @@ if __name__ == '__main__': | |||
| mirror_mean=True) | |||
| init() | |||
| dataset = vgg_create_dataset(args_opt.data_path, cfg.epoch_size) | |||
| dataset = vgg_create_dataset(args_opt.data_path, 1) | |||
| batch_num = dataset.get_dataset_size() | |||
| net = vgg16(num_classes=cfg.num_classes) | |||
| @@ -63,7 +63,7 @@ def test_train(configure): | |||
| data_path = configure.data_path | |||
| batch_size = configure.batch_size | |||
| epochs = configure.epochs | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size) | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size) | |||
| print("ds_train.size: {}".format(ds_train.get_dataset_size())) | |||
| net_builder = ModelBuilder() | |||
| @@ -67,8 +67,8 @@ def test_train_eval(config): | |||
| data_path = config.data_path | |||
| batch_size = config.batch_size | |||
| epochs = config.epochs | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size) | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size) | |||
| print("ds_train.size: {}".format(ds_train.get_dataset_size())) | |||
| print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | |||
| @@ -85,14 +85,14 @@ def train_and_eval(config): | |||
| if config.full_batch: | |||
| context.set_auto_parallel_context(full_batch=True) | |||
| de.config.set_seed(1) | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | |||
| batch_size=batch_size*get_group_size()) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | |||
| batch_size=batch_size*get_group_size()) | |||
| else: | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | |||
| batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | |||
| batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) | |||
| print("ds_train.size: {}".format(ds_train.get_dataset_size())) | |||
| print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | |||
| @@ -74,9 +74,9 @@ def train_and_eval(config): | |||
| batch_size = config.batch_size | |||
| epochs = config.epochs | |||
| print("epochs is {}".format(epochs)) | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | |||
| batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | |||
| batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) | |||
| print("ds_train.size: {}".format(ds_train.get_dataset_size())) | |||
| print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | |||
| @@ -121,7 +121,7 @@ def main(): | |||
| loss_scale = float(args_opt.loss_scale) | |||
| # When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0. | |||
| dataset = create_yolo_dataset(mindrecord_file, repeat_num=args_opt.epoch_size, | |||
| dataset = create_yolo_dataset(mindrecord_file, | |||
| batch_size=args_opt.batch_size, device_num=device_num, rank=rank) | |||
| dataset_size = dataset.get_dataset_size() | |||
| print("Create dataset done!") | |||
| @@ -50,13 +50,20 @@ class MindData: | |||
| def input_indexs(self): | |||
| return self._input_indexs | |||
| def device_que(self): | |||
| def device_que(self, send_epoch_end=True): | |||
| self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736' | |||
| self.send_epoch_end = send_epoch_end | |||
| return self | |||
| def create_tuple_iterator(self): | |||
| return self.__iter__() | |||
| def send(self): | |||
| pass | |||
| def stop_send(self): | |||
| pass | |||
| def __len__(self): | |||
| return self._size | |||
| @@ -73,7 +73,7 @@ if __name__ == "__main__": | |||
| epoch_size = 3 | |||
| args_opt.base_size = config.crop_size | |||
| args_opt.crop_size = config.crop_size | |||
| train_dataset = create_dataset(args_opt, args_opt.data_url, epoch_size, config.batch_size, | |||
| train_dataset = create_dataset(args_opt, args_opt.data_url, 1, config.batch_size, | |||
| usage="train", shuffle=False) | |||
| dataset_size = train_dataset.get_dataset_size() | |||
| callback = LossCallBack(dataset_size) | |||
| @@ -120,10 +120,10 @@ def test_transformer(): | |||
| batch_size = 96 | |||
| epoch_size = 3 | |||
| config = get_config(version=version, batch_size=batch_size) | |||
| dataset, repeat_count = create_transformer_dataset(epoch_count=epoch_size, | |||
| do_shuffle="false", | |||
| enable_data_sink="false", | |||
| dataset_path=DATA_DIR) | |||
| dataset = create_transformer_dataset(epoch_count=1, | |||
| do_shuffle="false", | |||
| enable_data_sink="false", | |||
| dataset_path=DATA_DIR) | |||
| netwithloss = TransformerNetworkWithLoss(config, True) | |||
| @@ -146,7 +146,7 @@ def test_transformer(): | |||
| netwithgrads.set_train(True) | |||
| time_monitor_callback = TimeMonitor(dataset.get_dataset_size()) | |||
| model = Model(netwithgrads) | |||
| model.train(repeat_count, dataset, callbacks=[time_monitor_callback, callback], dataset_sink_mode=False) | |||
| model.train(epoch_size, dataset, callbacks=[time_monitor_callback, callback], dataset_sink_mode=False) | |||
| # assertion occurs while the loss value, overflow state or loss_scale value is wrong | |||
| loss_value = np.array(callback.loss_list) | |||
| @@ -79,9 +79,9 @@ def test_train_eval(): | |||
| batch_size = config.batch_size | |||
| epochs = config.epochs | |||
| print("epochs is {}".format(epochs)) | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size, | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size, | |||
| data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size()) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size, | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size, | |||
| data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size()) | |||
| print("ds_train.size: {}".format(ds_train.get_dataset_size())) | |||
| print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | |||
| @@ -76,9 +76,9 @@ def test_train_eval(): | |||
| batch_size = config.batch_size | |||
| epochs = config.epochs | |||
| print("epochs is {}".format(epochs)) | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | |||
| batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | |||
| batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) | |||
| print("ds_train.size: {}".format(ds_train.get_dataset_size())) | |||
| print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | |||
| @@ -113,7 +113,7 @@ def test_yolov3(): | |||
| loss_scale = float(loss_scale) | |||
| # When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0. | |||
| dataset = create_yolo_dataset(mindrecord_file, repeat_num=epoch_size, | |||
| dataset = create_yolo_dataset(mindrecord_file, repeat_num=1, | |||
| batch_size=batch_size, device_num=device_num, rank=rank) | |||
| dataset_size = dataset.get_dataset_size() | |||
| print("Create dataset done!") | |||
| @@ -146,12 +146,12 @@ def test_yolov3(): | |||
| assert loss_value[2] < expect_loss_value[2] | |||
| epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2] | |||
| expect_epoch_mseconds = 950 | |||
| expect_epoch_mseconds = 2000 | |||
| print("epoch mseconds: {}".format(epoch_mseconds)) | |||
| assert epoch_mseconds <= expect_epoch_mseconds | |||
| per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2] | |||
| expect_per_step_mseconds = 110 | |||
| expect_per_step_mseconds = 220 | |||
| print("per step mseconds: {}".format(per_step_mseconds)) | |||
| assert per_step_mseconds <= expect_per_step_mseconds | |||
| print("yolov3 test case passed.") | |||
| @@ -91,6 +91,7 @@ def me_de_train_dataset(sink_mode=False): | |||
| """test me de train dataset""" | |||
| # apply repeat operations | |||
| repeat_count = 1 | |||
| sink_size = -1 | |||
| batch_size = 16 | |||
| ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", | |||
| "next_sentence_labels", "masked_lm_positions", | |||
| @@ -99,9 +100,9 @@ def me_de_train_dataset(sink_mode=False): | |||
| new_repeat_count = repeat_count | |||
| if sink_mode: | |||
| repeat_count = 30 | |||
| sink_steps = 100 | |||
| sink_size = 100 | |||
| ori_dataaet_size = ds.get_dataset_size() | |||
| new_size = sink_steps * batch_size | |||
| new_size = sink_size * batch_size | |||
| ds.set_dataset_size(new_size) | |||
| new_repeat_count = int(repeat_count * ori_dataaet_size // ds.get_dataset_size()) | |||
| ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) | |||
| @@ -112,10 +113,9 @@ def me_de_train_dataset(sink_mode=False): | |||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||
| # apply batch operations | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| ds = ds.repeat(repeat_count) | |||
| logger.info("data size: {}".format(ds.get_dataset_size())) | |||
| logger.info("repeat_count: {}".format(ds.get_repeat_count())) | |||
| return ds, new_repeat_count | |||
| return ds, new_repeat_count, sink_size | |||
| def weight_variable(shape): | |||
| @@ -157,7 +157,7 @@ class TimeMonitor(Callback): | |||
| def test_bert_percision(): | |||
| """test bert percision""" | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) | |||
| ds, new_repeat_count = me_de_train_dataset() | |||
| ds, new_repeat_count, _ = me_de_train_dataset() | |||
| version = os.getenv('VERSION', 'large') | |||
| batch_size = 16 | |||
| config = get_config(version=version, batch_size=batch_size) | |||
| @@ -215,7 +215,7 @@ def test_bert_percision(): | |||
| def test_bert_performance(): | |||
| """test bert performance""" | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) | |||
| ds, new_repeat_count = me_de_train_dataset(sink_mode=True) | |||
| ds, new_repeat_count, sink_size = me_de_train_dataset(sink_mode=True) | |||
| version = os.getenv('VERSION', 'large') | |||
| batch_size = 16 | |||
| config = get_config(version=version, batch_size=batch_size) | |||
| @@ -251,7 +251,7 @@ def test_bert_performance(): | |||
| param.default_input = weight_variable(value.asnumpy().shape) | |||
| time_monitor_callback = TimeMonitor(ds.get_dataset_size()) | |||
| model.train(new_repeat_count, ds, callbacks=[time_monitor_callback, callback], | |||
| dataset_sink_mode=True) | |||
| dataset_sink_mode=True, sink_size=sink_size) | |||
| # assertion occurs while the loss value, overflow state or loss_scale value is wrong | |||
| loss_value = np.array(callback.loss_list) | |||
| @@ -79,7 +79,7 @@ def test_deeplabv3_1p(): | |||
| args_opt.base_size = config.crop_size | |||
| args_opt.crop_size = config.crop_size | |||
| args_opt.batch_size = config.batch_size | |||
| train_dataset = create_dataset(args_opt, data_url, epoch_size, config.batch_size, | |||
| train_dataset = create_dataset(args_opt, data_url, 1, config.batch_size, | |||
| usage="eval") | |||
| dataset_size = train_dataset.get_dataset_size() | |||
| callback = LossCallBack(dataset_size) | |||
| @@ -155,7 +155,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): | |||
| # train dataset | |||
| dataset = create_dataset(dataset_path=dataset_path, do_train=True, | |||
| repeat_num=epoch_size, batch_size=config.batch_size) | |||
| repeat_num=1, batch_size=config.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| eval_interval = config.eval_interval | |||
| @@ -163,7 +163,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): | |||
| # evalutation dataset | |||
| eval_dataset = create_dataset(dataset_path=eval_path, do_train=False, | |||
| repeat_num=epoch_size, batch_size=config.eval_batch_size) | |||
| repeat_num=1, batch_size=config.eval_batch_size) | |||
| # loss scale | |||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | |||
| @@ -260,14 +260,14 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl): | |||
| # train dataset | |||
| dataset = create_dataset(dataset_path=dataset_path, do_train=True, | |||
| repeat_num=epoch_size, batch_size=thor_config.batch_size) | |||
| repeat_num=1, batch_size=thor_config.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| eval_interval = thor_config.eval_interval | |||
| # evalutation dataset | |||
| eval_dataset = create_dataset(dataset_path=eval_path, do_train=False, | |||
| repeat_num=epoch_size, batch_size=thor_config.eval_batch_size) | |||
| repeat_num=1, batch_size=thor_config.eval_batch_size) | |||
| # loss scale | |||
| loss_scale = FixedLossScaleManager(thor_config.loss_scale, drop_overflow_update=False) | |||
| @@ -136,7 +136,7 @@ if __name__ == '__main__': | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| if args_opt.do_train: | |||
| dataset = create_dataset(epoch_size) | |||
| dataset = create_dataset(1) | |||
| batch_num = dataset.get_dataset_size() | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=10) | |||
| ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck) | |||
| @@ -140,7 +140,7 @@ def train_process(epoch_size, num_classes, batch_size): | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| dataset = create_dataset(epoch_size, training=True, batch_size=batch_size) | |||
| dataset = create_dataset(1, training=True, batch_size=batch_size) | |||
| loss_cb = LossGet() | |||
| model.train(epoch_size, dataset, callbacks=[loss_cb]) | |||
| @@ -164,7 +164,7 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size, | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| dataset = create_dataset(epoch_size, training=True, | |||
| dataset = create_dataset(1, training=True, | |||
| batch_size=batch_size, rank_id=device_id, rank_size=device_num, | |||
| enable_hccl=enable_hccl) | |||
| @@ -91,8 +91,9 @@ SET(DE_UT_SRCS | |||
| cyclic_array_test.cc | |||
| perf_data_test.cc | |||
| c_api_test.cc | |||
| tensor_op_fusion_pass_test.cc | |||
| tensor_op_fusion_pass_test.cc | |||
| sliding_window_op_test.cc | |||
| epoch_ctrl_op_test.cc | |||
| ) | |||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | |||
| @@ -397,23 +397,21 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { | |||
| std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true); | |||
| std::shared_ptr<CacheMergeOp> myMergeOp; | |||
| rc = CacheMergeOp::Builder().SetNumWorkers(3).SetOpConnectorSize(3).SetNumCleaner(2).SetClient(myClient).Build( | |||
| &myMergeOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op. | |||
| // Rather than manually build this, the way to do it is to choose the position of the cache in the tree by | |||
| // adding a CacheOp. Then, the tree prepare code will drive a transform that will remove the CacheOp and | |||
| // replace it with the required tree structures for cache lookup op and cache merge op. | |||
| std::shared_ptr<CacheLookupOp> myLookupOp; | |||
| rc = CacheLookupOp::Builder() | |||
| .SetNumWorkers(3) | |||
| .SetOpConnectorSize(3) | |||
| std::shared_ptr<CacheOp> myCacheOp; | |||
| rc = CacheOp::Builder() | |||
| .SetNumWorkers(4) | |||
| .SetClient(myClient) | |||
| .SetSampler(seq_sampler) | |||
| .Build(&myLookupOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| .SetRowsPerBuffer(3) | |||
| .Build(&myCacheOp); | |||
| std::shared_ptr<ImageFolderOp> so; | |||
| ImageFolderOp::Builder builder; | |||
| builder.SetSampler(myLookupOp) | |||
| builder.SetSampler(std::move(seq_sampler)) | |||
| .SetOpConnectorSize(3) | |||
| .SetNumWorkers(3) | |||
| .SetRowsPerBuffer(2) | |||
| @@ -432,20 +430,18 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { | |||
| auto myTree = std::make_shared<ExecutionTree>(); | |||
| rc = myTree->AssociateNode(so); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myLookupOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myMergeOp); | |||
| rc = myTree->AssociateNode(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssignRoot(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myRepeatOp->AddChild(myMergeOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myMergeOp->AddChild(myLookupOp); | |||
| rc = myRepeatOp->AddChild(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myMergeOp->AddChild(so); | |||
| rc = myCacheOp->AddChild(so); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->Prepare(); | |||
| @@ -0,0 +1,639 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "common/common.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| #include <memory> | |||
| using namespace mindspore::dataset; | |||
| using mindspore::MsLogLevel::INFO; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path, | |||
| bool shuf = false, std::shared_ptr<Sampler> sampler = nullptr, | |||
| std::map<std::string, int32_t> map = {}, bool decode = false); | |||
| std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops); | |||
| class MindDataTestEpochCtrlOp : public UT::DatasetOpTesting { | |||
| public: | |||
| void SetUp() override { | |||
| DatasetOpTesting::SetUp(); | |||
| folder_path = datasets_root_path_ + "/testPK/data"; | |||
| GlobalInit(); | |||
| // Start with an empty execution tree | |||
| my_tree_ = std::make_shared<ExecutionTree>(); | |||
| my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)}); | |||
| rc = my_tree_->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree_->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(my_tree_); | |||
| TensorMap tensor_map; | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| int32_t i = 0; | |||
| while (tensor_map.size() != 0) { | |||
| tensor_map["label"]->GetItemAt<int32_t>(&label, {}); | |||
| EXPECT_TRUE(img_class[(i % 44) / 11] == label); | |||
| // Dump all the image into string, to be used as a comparison later. | |||
| golden_imgs.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size()); | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| i++; | |||
| } | |||
| } | |||
| std::shared_ptr<ExecutionTree> my_tree_; | |||
| Status rc; | |||
| std::string golden_imgs; | |||
| std::string folder_path; | |||
| int32_t label = 0; | |||
| std::string result; | |||
| int32_t img_class[4] = {0, 1, 2, 3}; | |||
| }; | |||
| TEST_F(MindDataTestEpochCtrlOp, ImageFolder_AutoInjectEpoch) { | |||
| MS_LOG(WARNING) << "Doing ImageFolder_AutoInjectEpoch."; | |||
| int32_t num_epoch = 2 + std::rand() % 5; | |||
| my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)}); | |||
| rc = my_tree_->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree_->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| MS_LOG(DEBUG) << "num_epoch: " << num_epoch; | |||
| std::string golden = golden_imgs; | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(my_tree_); | |||
| TensorMap tensor_map; | |||
| uint64_t i = 0; | |||
| for (int epoch = 0; epoch < num_epoch; epoch++) { | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while (tensor_map.size() != 0) { | |||
| tensor_map["label"]->GetItemAt<int32_t>(&label, {}); | |||
| MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; | |||
| EXPECT_TRUE(img_class[(i % 44) / 11] == label); | |||
| // Dump all the image into string, to be used as a comparison later. | |||
| result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size()); | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| i++; | |||
| } | |||
| EXPECT_TRUE(result == golden); | |||
| result.clear(); | |||
| MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; | |||
| } | |||
| EXPECT_TRUE(i == 44 * num_epoch); | |||
| // Try to fetch data beyond the specified number of epochs. | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch) { | |||
| MS_LOG(WARNING) << "Doing ImageFolder_Epoch."; | |||
| int32_t num_epoch = 2 + std::rand() % 5; | |||
| my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)}); | |||
| rc = my_tree_->Prepare(num_epoch); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree_->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| MS_LOG(DEBUG) << "num_epoch: " << num_epoch; | |||
| std::string golden = golden_imgs; | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(my_tree_); | |||
| TensorMap tensor_map; | |||
| uint64_t i = 0; | |||
| for (int epoch = 0; epoch < num_epoch; epoch++) { | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while (tensor_map.size() != 0) { | |||
| tensor_map["label"]->GetItemAt<int32_t>(&label, {}); | |||
| MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; | |||
| EXPECT_TRUE(img_class[(i % 44) / 11] == label); | |||
| // Dump all the image into string, to be used as a comparison later. | |||
| result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size()); | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| i++; | |||
| } | |||
| EXPECT_TRUE(result == golden); | |||
| result.clear(); | |||
| MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; | |||
| } | |||
| EXPECT_TRUE(i == 44 * num_epoch); | |||
| // Try to fetch data beyond the specified number of epochs. | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_FALSE(rc.IsOk()); | |||
| } | |||
| TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Epoch) { | |||
| MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch."; | |||
| int32_t num_epoch = 2 + std::rand() % 5; | |||
| int32_t num_repeats = 2; | |||
| std::shared_ptr<RepeatOp> repeat_op; | |||
| rc = RepeatOp::Builder(num_repeats).Build(&repeat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op}); | |||
| rc = my_tree_->Prepare(num_epoch); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree_->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats; | |||
| std::string golden = golden_imgs; | |||
| for (int i = 1; i < num_repeats; i++) { | |||
| golden += golden_imgs; | |||
| } | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(my_tree_); | |||
| TensorMap tensor_map; | |||
| uint64_t i = 0; | |||
| for (int epoch = 0; epoch < num_epoch; epoch++) { | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while (tensor_map.size() != 0) { | |||
| tensor_map["label"]->GetItemAt<int32_t>(&label, {}); | |||
| MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; | |||
| EXPECT_TRUE(img_class[(i % 44) / 11] == label); | |||
| // Dump all the image into string, to be used as a comparison later. | |||
| result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size()); | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| i++; | |||
| } | |||
| EXPECT_TRUE(result == golden); | |||
| result.clear(); | |||
| MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; | |||
| } | |||
| EXPECT_TRUE(i == 44 * num_repeats * num_epoch); | |||
| // Try to fetch data beyond the specified number of epochs. | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_FALSE(rc.IsOk()); | |||
| } | |||
| TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Repeat_Epoch) { | |||
| MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Repeat_Epoch."; | |||
| int32_t num_epoch = 2 + std::rand() % 5; | |||
| int32_t num_repeats = 2; | |||
| std::shared_ptr<RepeatOp> repeat_op; | |||
| rc = RepeatOp::Builder(num_repeats).Build(&repeat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| int32_t num_repeats_2 = 3; | |||
| std::shared_ptr<RepeatOp> repeat_op_2; | |||
| rc = RepeatOp::Builder(num_repeats_2).Build(&repeat_op_2); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op, repeat_op_2}); | |||
| rc = my_tree_->Prepare(num_epoch); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree_->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats << ". num_repeat_2: " << num_repeats_2; | |||
| std::string golden; | |||
| for (int j = 0; j < num_repeats_2; j++) { | |||
| for (int i = 0; i < num_repeats; i++) { | |||
| golden += golden_imgs; | |||
| } | |||
| } | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(my_tree_); | |||
| TensorMap tensor_map; | |||
| uint64_t i = 0; | |||
| for (int epoch = 0; epoch < num_epoch; epoch++) { | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while (tensor_map.size() != 0) { | |||
| tensor_map["label"]->GetItemAt<int32_t>(&label, {}); | |||
| MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; | |||
| EXPECT_TRUE(img_class[(i % 44) / 11] == label); | |||
| // Dump all the image into string, to be used as a comparison later. | |||
| result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size()); | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| i++; | |||
| } | |||
| EXPECT_EQ(result.size(), golden.size()); | |||
| EXPECT_TRUE(result == golden); | |||
| result.clear(); | |||
| MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; | |||
| } | |||
| EXPECT_EQ(i, 44 * num_epoch * num_repeats * num_repeats_2); | |||
| // Try to fetch data beyond the specified number of epochs. | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_FALSE(rc.IsOk()); | |||
| } | |||
| TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch_Inf) { | |||
| MS_LOG(WARNING) << "Doing ImageFolder_Epoch_Inf."; | |||
| // if num_epoch == -1, it means infinity. | |||
| int32_t num_epoch = -1; | |||
| my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)}); | |||
| rc = my_tree_->Prepare(num_epoch); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree_->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(my_tree_); | |||
| TensorMap tensor_map; | |||
| uint64_t i = 0; | |||
| // For this test, we stop at stop_at_epoch number. | |||
| int32_t stop_at_epoch = 2 + std::rand() % 6; | |||
| MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch; | |||
| for (int epoch = 0; epoch < stop_at_epoch; epoch++) { | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while (tensor_map.size() != 0) { | |||
| tensor_map["label"]->GetItemAt<int32_t>(&label, {}); | |||
| MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; | |||
| EXPECT_TRUE(img_class[(i % 44) / 11] == label); | |||
| // Dump all the image into string, to be used as a comparison later. | |||
| result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size()); | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| i++; | |||
| } | |||
| EXPECT_EQ(result, golden_imgs); | |||
| result.clear(); | |||
| MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; | |||
| } | |||
| EXPECT_TRUE(i == 44 * stop_at_epoch); | |||
| } | |||
| TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Repeat_Epoch_Inf) { | |||
| MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch_Inf."; | |||
| // if num_epoch == -1, it means infinity. | |||
| int32_t num_epoch = -1; | |||
| int32_t num_repeats = 2; | |||
| std::shared_ptr<RepeatOp> repeat_op; | |||
| rc = RepeatOp::Builder(num_repeats).Build(&repeat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| int32_t num_repeats_2 = 3; | |||
| std::shared_ptr<RepeatOp> repeat_op_2; | |||
| rc = RepeatOp::Builder(num_repeats_2).Build(&repeat_op_2); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op, repeat_op_2}); | |||
| rc = my_tree_->Prepare(num_epoch); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree_->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats << ". num_repeat_2: " << num_repeats_2; | |||
| std::string golden; | |||
| for (int j = 0; j < num_repeats_2; j++) { | |||
| for (int i = 0; i < num_repeats; i++) { | |||
| golden += golden_imgs; | |||
| } | |||
| } | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(my_tree_); | |||
| TensorMap tensor_map; | |||
| uint64_t i = 0; | |||
| // For this test, we stop at stop_at_epoch number. | |||
| int32_t stop_at_epoch = 2 + std::rand() % 6; | |||
| MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch; | |||
| for (int epoch = 0; epoch < stop_at_epoch; epoch++) { | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while (tensor_map.size() != 0) { | |||
| tensor_map["label"]->GetItemAt<int32_t>(&label, {}); | |||
| MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; | |||
| EXPECT_TRUE(img_class[(i % 44) / 11] == label); | |||
| // Dump all the image into string, to be used as a comparison later. | |||
| result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size()); | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| i++; | |||
| } | |||
| EXPECT_EQ(result, golden); | |||
| result.clear(); | |||
| MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; | |||
| } | |||
| EXPECT_TRUE(i == 44 * stop_at_epoch * num_repeats * num_repeats_2); | |||
| } | |||
| TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch_ChildItr) { | |||
| MS_LOG(WARNING) << "Doing ImageFolder_Epoch_ChildItr."; | |||
| int32_t num_epoch = 2 + std::rand() % 5; | |||
| my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)}); | |||
| rc = my_tree_->Prepare(num_epoch); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree_->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "num_epoch: " << num_epoch; | |||
| // Start the loop of reading tensors from our pipeline | |||
| ChildIterator ci(my_tree_->root().get(), 0, 0); | |||
| TensorRow tensor_row; | |||
| uint64_t total_sample = 0; | |||
| uint64_t i = 0; | |||
| uint32_t epoch = 0; | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while(!ci.eof_handled()) { | |||
| i = 0; | |||
| while (tensor_row.size() != 0) { | |||
| tensor_row[1]->GetItemAt<int32_t>(&label, {}); | |||
| MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; | |||
| EXPECT_TRUE(img_class[(i % 44) / 11] == label); | |||
| // Dump all the image into string, to be used as a comparison later. | |||
| result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size()); | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| i++; | |||
| } | |||
| epoch++; | |||
| MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; | |||
| EXPECT_TRUE(result == golden_imgs); | |||
| result.clear(); | |||
| EXPECT_TRUE(i == 44); | |||
| total_sample += i; | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| EXPECT_TRUE(total_sample == 44 * num_epoch); | |||
| // Try to fetch data after last epoch ends. | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(tensor_row.empty()); | |||
| EXPECT_FALSE(rc.IsOk()); | |||
| } | |||
| TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Epoch_ChildItr) { | |||
| MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch_ChildItr."; | |||
| int32_t num_epoch = 2 + std::rand() % 5; | |||
| int32_t num_repeats = 2; | |||
| std::shared_ptr<RepeatOp> repeat_op; | |||
| rc = RepeatOp::Builder(num_repeats).Build(&repeat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op}); | |||
| rc = my_tree_->Prepare(num_epoch); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree_->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats; | |||
| std::string golden; | |||
| for (int i = 0; i < num_repeats; i++) { | |||
| golden += golden_imgs; | |||
| } | |||
| // Start the loop of reading tensors from our pipeline | |||
| ChildIterator ci(my_tree_->root().get(), 0, 0); | |||
| TensorRow tensor_row; | |||
| uint64_t total_sample = 0; | |||
| uint64_t i = 0; | |||
| uint32_t epoch = 0; | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while(!ci.eof_handled()) { | |||
| i = 0; | |||
| while (tensor_row.size() != 0) { | |||
| tensor_row[1]->GetItemAt<int32_t>(&label, {}); | |||
| MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; | |||
| EXPECT_TRUE(img_class[(i % 44) / 11] == label); | |||
| // Dump all the image into string, to be used as a comparison later. | |||
| result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size()); | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| i++; | |||
| } | |||
| epoch++; | |||
| MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; | |||
| EXPECT_TRUE(result == golden); | |||
| result.clear(); | |||
| EXPECT_TRUE(i == 44 * num_repeats); | |||
| total_sample += i; | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| EXPECT_TRUE(total_sample == 44 * num_epoch * num_repeats); | |||
| // Try to fetch data after last epoch ends. | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(tensor_row.empty()); | |||
| EXPECT_FALSE(rc.IsOk()); | |||
| } | |||
| TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Repeat_Epoch_ChildItr) { | |||
| MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Repeat_Epoch_ChildItr."; | |||
| int32_t num_epoch = 2 + std::rand() % 5; | |||
| int32_t num_repeats = 2; | |||
| std::shared_ptr<RepeatOp> repeat_op; | |||
| rc = RepeatOp::Builder(num_repeats).Build(&repeat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| int32_t num_repeats_2 = 3; | |||
| std::shared_ptr<RepeatOp> repeat_op_2; | |||
| rc = RepeatOp::Builder(num_repeats_2).Build(&repeat_op_2); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op, repeat_op_2}); | |||
| rc = my_tree_->Prepare(num_epoch); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree_->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats << ". num_repeat_2: " << num_repeats_2; | |||
| std::string golden; | |||
| for (int j = 0; j < num_repeats_2; j++) { | |||
| for (int i = 0; i < num_repeats; i++) { | |||
| golden += golden_imgs; | |||
| } | |||
| } | |||
| // Start the loop of reading tensors from our pipeline | |||
| ChildIterator ci(my_tree_->root().get(), 0, 0); | |||
| TensorRow tensor_row; | |||
| uint64_t total_sample = 0; | |||
| uint64_t i = 0; | |||
| uint32_t epoch = 0; | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while(!ci.eof_handled()) { | |||
| i = 0; | |||
| while (tensor_row.size() != 0) { | |||
| tensor_row[1]->GetItemAt<int32_t>(&label, {}); | |||
| MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; | |||
| EXPECT_TRUE(img_class[(i % 44) / 11] == label); | |||
| // Dump all the image into string, to be used as a comparison later. | |||
| result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size()); | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| i++; | |||
| } | |||
| epoch++; | |||
| MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; | |||
| EXPECT_TRUE(result == golden); | |||
| result.clear(); | |||
| EXPECT_TRUE(i == 44 * num_repeats * num_repeats_2); | |||
| total_sample += i; | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| EXPECT_TRUE(total_sample == 44 * num_epoch * num_repeats * num_repeats_2); | |||
| // Try to fetch data after last epoch ends. | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(tensor_row.empty()); | |||
| EXPECT_FALSE(rc.IsOk()); | |||
| } | |||
| TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch_Inf_ChildItr) { | |||
| MS_LOG(WARNING) << "Doing ImageFolder_Epoch_Inf_ChildItr."; | |||
| // if num_epoch == -1, it means infinity. | |||
| int32_t num_epoch = -1; | |||
| my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)}); | |||
| rc = my_tree_->Prepare(num_epoch); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree_->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| ChildIterator ci(my_tree_->root().get(), 0, 0); | |||
| TensorRow tensor_row; | |||
| uint64_t i = 0; | |||
| // For this test, we stop at a random number between 0 - 100 epochs. | |||
| int32_t stop_at_epoch = 2 + std::rand() % 5; | |||
| MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch; | |||
| for (int epoch = 0; epoch < stop_at_epoch; epoch++) { | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while (tensor_row.size() != 0) { | |||
| tensor_row[1]->GetItemAt<int32_t>(&label, {}); | |||
| MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; | |||
| EXPECT_TRUE(img_class[(i % 44) / 11] == label); | |||
| // Dump all the image into string, to be used as a comparison later. | |||
| result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size()); | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| i++; | |||
| } | |||
| EXPECT_TRUE(result == golden_imgs); | |||
| result.clear(); | |||
| MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; | |||
| } | |||
| EXPECT_TRUE(i == 44 * stop_at_epoch); | |||
| } | |||
| TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Epoch_Inf_ChildItr) { | |||
| MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch_Inf_ChildItr."; | |||
| // if num_epoch == -1, it means infinity. | |||
| int32_t num_epoch = -1; | |||
| int32_t num_repeats = 2; | |||
| std::shared_ptr<RepeatOp> repeat_op; | |||
| rc = RepeatOp::Builder(num_repeats).Build(&repeat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op}); | |||
| rc = my_tree_->Prepare(num_epoch); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree_->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats; | |||
| std::string golden; | |||
| for (int i = 0; i < num_repeats; i++) { | |||
| golden += golden_imgs; | |||
| } | |||
| // Start the loop of reading tensors from our pipeline | |||
| ChildIterator ci(my_tree_->root().get(), 0, 0); | |||
| TensorRow tensor_row; | |||
| uint64_t i = 0; | |||
| // For this test, we stop at a random number between 0 - 100 epochs. | |||
| int32_t stop_at_epoch = 2 + std::rand() % 5; | |||
| MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch; | |||
| for (int epoch = 0; epoch < stop_at_epoch; epoch++) { | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while (tensor_row.size() != 0) { | |||
| tensor_row[1]->GetItemAt<int32_t>(&label, {}); | |||
| MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n"; | |||
| EXPECT_TRUE(img_class[(i % 44) / 11] == label); | |||
| // Dump all the image into string, to be used as a comparison later. | |||
| result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size()); | |||
| rc = ci.FetchNextTensorRow(&tensor_row); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| i++; | |||
| } | |||
| EXPECT_TRUE(result == golden); | |||
| result.clear(); | |||
| MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i; | |||
| } | |||
| EXPECT_TRUE(i == 44 * stop_at_epoch * num_repeats); | |||
| } | |||
| @@ -46,7 +46,8 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) { | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssociateNode(my_tfreader_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| my_tree->AssociateNode(parent_op); | |||
| rc = my_tree->AssociateNode(parent_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| ASSERT_NE(parent_op, nullptr); | |||
| ASSERT_NE(my_tfreader_op, nullptr); | |||
| parent_op->AddChild(std::move(my_tfreader_op)); | |||
| @@ -104,9 +104,11 @@ def test_cache_map_basic3(): | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.repeat(4) | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||
| print("ds1.dataset_size is ", ds1.get_dataset_size()) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| print("get data from dataset") | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| @@ -152,6 +154,10 @@ def test_cache_map_failure1(): | |||
| if __name__ == '__main__': | |||
| test_cache_map_basic1() | |||
| print("test_cache_map_basic1 success.") | |||
| test_cache_map_basic2() | |||
| print("test_cache_map_basic2 success.") | |||
| test_cache_map_basic3() | |||
| print("test_cache_map_basic3 success.") | |||
| test_cache_map_failure1() | |||
| print("test_cache_map_failure1 success.") | |||
| @@ -238,7 +238,7 @@ def test_tfrecord_shard_equal_rows(): | |||
| def test_tfrecord_no_schema_columns_list(): | |||
| logger.info("test_tfrecord_no_schema_columns_list") | |||
| data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"]) | |||
| row = data.create_dict_iterator().get_next() | |||
| row = data.create_dict_iterator().__next__() | |||
| assert row["col_sint16"] == [-32768] | |||
| with pytest.raises(KeyError) as info: | |||
| @@ -258,7 +258,7 @@ def test_tfrecord_schema_columns_list(): | |||
| schema.add_column('col_sint32', de_type=mstype.int64, shape=[1]) | |||
| schema.add_column('col_sint64', de_type=mstype.int64, shape=[1]) | |||
| data = ds.TFRecordDataset(FILES, schema=schema, shuffle=False, columns_list=["col_sint16"]) | |||
| row = data.create_dict_iterator().get_next() | |||
| row = data.create_dict_iterator().__next__() | |||
| assert row["col_sint16"] == [-32768] | |||
| with pytest.raises(KeyError) as info: | |||
| @@ -12,6 +12,8 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import time | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| from mindspore import log as logger | |||
| @@ -35,6 +37,8 @@ def test_case_0(): | |||
| data = data.device_que() | |||
| data.send() | |||
| time.sleep(0.1) | |||
| data.stop_send() | |||
| def test_case_1(): | |||
| @@ -58,6 +62,8 @@ def test_case_1(): | |||
| data = data.device_que() | |||
| data.send() | |||
| time.sleep(0.1) | |||
| data.stop_send() | |||
| def test_case_2(): | |||
| @@ -84,6 +90,8 @@ def test_case_2(): | |||
| data = data.device_que() | |||
| assert data.get_repeat_count() == 2 | |||
| data.send() | |||
| time.sleep(0.1) | |||
| data.stop_send() | |||
| def test_case_3(): | |||
| @@ -109,13 +117,17 @@ def test_case_3(): | |||
| data = data.device_que() | |||
| data.send() | |||
| time.sleep(0.1) | |||
| data.stop_send() | |||
| def test_case_tf_file(): | |||
| data = ds.TFRecordDataset(TF_FILES, TF_SCHEMA_FILE, shuffle=ds.Shuffle.FILES) | |||
| data = data.to_device(num_batch=10) | |||
| data = data.to_device() | |||
| data.send() | |||
| time.sleep(0.1) | |||
| data.stop_send() | |||
| if __name__ == '__main__': | |||
| @@ -0,0 +1,608 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing Epoch Control op in DE | |||
| """ | |||
| import itertools | |||
| import cv2 | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| from mindspore import log as logger | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| def diff_mse(in1, in2): | |||
| """ | |||
| diff_mse | |||
| """ | |||
| mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean() | |||
| return mse * 100 | |||
| def test_cifar10(): | |||
| """ | |||
| dataset parameter | |||
| """ | |||
| logger.info("Test dataset parameter") | |||
| data_dir_10 = "../data/dataset/testCifar10Data" | |||
| num_repeat = 2 | |||
| batch_size = 32 | |||
| limit_dataset = 100 | |||
| # apply dataset operations | |||
| data1 = ds.Cifar10Dataset(data_dir_10, limit_dataset) | |||
| data1 = data1.repeat(num_repeat) | |||
| data1 = data1.batch(batch_size, True) | |||
| num_epoch = 5 | |||
| # iter1 will always assume there is a next epoch and never shutdown. | |||
| iter1 = data1.create_tuple_iterator() | |||
| epoch_count = 0 | |||
| sample_count = 0 | |||
| for _ in range(num_epoch): | |||
| row_count = 0 | |||
| for _ in iter1: | |||
| # in this example, each dictionary has keys "image" and "label" | |||
| row_count += 1 | |||
| assert row_count == int(limit_dataset * num_repeat / batch_size) | |||
| logger.debug("row_count: ", row_count) | |||
| epoch_count += 1 | |||
| sample_count += row_count | |||
| assert epoch_count == num_epoch | |||
| logger.debug("total epochs: ", epoch_count) | |||
| assert sample_count == int(limit_dataset * num_repeat / batch_size) * num_epoch | |||
| logger.debug("total sample: ", sample_count) | |||
| def test_decode_op(): | |||
| """ | |||
| Test Decode op | |||
| """ | |||
| logger.info("test_decode_op") | |||
| # Decode with rgb format set to True | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| # Serialize and Load dataset requires using vision.Decode instead of vision.Decode(). | |||
| data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)]) | |||
| # Second dataset | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| num_epoch = 5 | |||
| # iter1 will always assume there is a next epoch and never shutdown. | |||
| iter1 = data1.create_dict_iterator() | |||
| # iter 2 will stop and shutdown pipeline after num_epoch | |||
| iter2 = data2.create_dict_iterator(num_epoch) | |||
| for _ in range(num_epoch): | |||
| i = 0 | |||
| for item1, item2 in itertools.zip_longest(iter1, iter2): | |||
| actual = item1["image"] | |||
| expected = cv2.imdecode(item2["image"], cv2.IMREAD_COLOR) | |||
| expected = cv2.cvtColor(expected, cv2.COLOR_BGR2RGB) | |||
| assert actual.shape == expected.shape | |||
| diff = actual - expected | |||
| mse = np.sum(np.power(diff, 2)) | |||
| assert mse == 0 | |||
| i = i + 1 | |||
| assert i == 3 | |||
| # Users have the option to manually stop the iterator, or rely on garbage collector. | |||
| iter1.stop() | |||
| # Expect a AttributeError since iter1 has been stopped. | |||
| with pytest.raises(AttributeError) as info: | |||
| iter1.__next__() | |||
| assert "object has no attribute 'depipeline'" in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| iter2.__next__() | |||
| err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | |||
| assert err_msg in str(info.value) | |||
| # Generate 1d int numpy array from 0 - 63 | |||
| def generator_1d(): | |||
| """ | |||
| generator | |||
| """ | |||
| for i in range(64): | |||
| yield (np.array([i]),) | |||
| def test_generator_dict_0(): | |||
| """ | |||
| test generator dict 0 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| i = 0 | |||
| # create the iterator inside the loop declaration | |||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item["data"], golden) | |||
| i = i + 1 | |||
| def test_generator_dict_1(): | |||
| """ | |||
| test generator dict 1 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| for _ in range(10): | |||
| i = 0 | |||
| # BAD. Do not create iterator every time inside. | |||
| # Create iterator outside the epoch for loop. | |||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item["data"], golden) | |||
| i = i + 1 | |||
| assert i == 64 | |||
| def test_generator_dict_2(): | |||
| """ | |||
| test generator dict 2 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| iter1 = data1.create_dict_iterator() | |||
| for _ in range(10): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item["data"], golden) | |||
| i = i + 1 | |||
| assert i == 64 | |||
| # iter1 is still alive and running. | |||
| item1 = iter1.__next__() | |||
| assert item1 | |||
| # rely on garbage collector to destroy iter1 | |||
| def test_generator_dict_3(): | |||
| """ | |||
| test generator dict 3 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| iter1 = data1.create_dict_iterator() | |||
| for _ in range(10): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item["data"], golden) | |||
| i = i + 1 | |||
| assert i == 64 | |||
| # optional | |||
| iter1.stop() | |||
| # Expect a AttributeError since iter1 has been stopped. | |||
| with pytest.raises(AttributeError) as info: | |||
| iter1.__next__() | |||
| assert "object has no attribute 'depipeline'" in str(info.value) | |||
| def test_generator_dict_4(): | |||
| """ | |||
| test generator dict 4 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| iter1 = data1.create_dict_iterator(num_epochs=10) | |||
| for _ in range(10): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item["data"], golden) | |||
| i = i + 1 | |||
| assert i == 64 | |||
| with pytest.raises(RuntimeError) as info: | |||
| iter1.__next__() | |||
| err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | |||
| assert err_msg in str(info.value) | |||
| def test_generator_dict_4_1(): | |||
| """ | |||
| test generator dict 4_1 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| # epoch ctrl op will not be injected if num_epochs is 1. | |||
| iter1 = data1.create_dict_iterator(num_epochs=1) | |||
| for _ in range(1): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item["data"], golden) | |||
| i = i + 1 | |||
| assert i == 64 | |||
| with pytest.raises(RuntimeError) as info: | |||
| iter1.__next__() | |||
| err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | |||
| assert err_msg in str(info.value) | |||
| def test_generator_dict_4_2(): | |||
| """ | |||
| test generator dict 4_2 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| # repeat will not be injected when num repeat is 1. | |||
| data1 = data1.repeat(1) | |||
| # epoch ctrl op will not be injected if num_epochs is 1. | |||
| iter1 = data1.create_dict_iterator(num_epochs=1) | |||
| for _ in range(1): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item["data"], golden) | |||
| i = i + 1 | |||
| assert i == 64 | |||
| with pytest.raises(RuntimeError) as info: | |||
| iter1.__next__() | |||
| err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | |||
| assert err_msg in str(info.value) | |||
| def test_generator_dict_5(): | |||
| """ | |||
| test generator dict 5 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| iter1 = data1.create_dict_iterator(num_epochs=11) | |||
| for _ in range(10): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item["data"], golden) | |||
| i = i + 1 | |||
| assert i == 64 | |||
| # still one more epoch left in the iter1. | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item["data"], golden) | |||
| i = i + 1 | |||
| assert i == 64 | |||
| # now iter1 has been exhausted, c++ pipeline has been shut down. | |||
| with pytest.raises(RuntimeError) as info: | |||
| iter1.__next__() | |||
| err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | |||
| assert err_msg in str(info.value) | |||
| # Test tuple iterator | |||
| def test_generator_tuple_0(): | |||
| """ | |||
| test generator tuple 0 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| i = 0 | |||
| # create the iterator inside the loop declaration | |||
| for item in data1.create_tuple_iterator(): # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| def test_generator_tuple_1(): | |||
| """ | |||
| test generator tuple 1 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| for _ in range(10): | |||
| i = 0 | |||
| # BAD. Do not create iterator every time inside. | |||
| # Create iterator outside the epoch for loop. | |||
| for item in data1.create_tuple_iterator(): # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 | |||
| def test_generator_tuple_2(): | |||
| """ | |||
| test generator tuple 2 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| iter1 = data1.create_tuple_iterator() | |||
| for _ in range(10): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 | |||
| # iter1 is still alive and running. | |||
| item1 = iter1.__next__() | |||
| assert item1 | |||
| # rely on garbage collector to destroy iter1 | |||
| def test_generator_tuple_3(): | |||
| """ | |||
| test generator tuple 3 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| iter1 = data1.create_tuple_iterator() | |||
| for _ in range(10): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 | |||
| # optional | |||
| iter1.stop() | |||
| # Expect a AttributeError since iter1 has been stopped. | |||
| with pytest.raises(AttributeError) as info: | |||
| iter1.__next__() | |||
| assert "object has no attribute 'depipeline'" in str(info.value) | |||
| def test_generator_tuple_4(): | |||
| """ | |||
| test generator tuple 4 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| iter1 = data1.create_tuple_iterator(num_epochs=10) | |||
| for _ in range(10): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 | |||
| with pytest.raises(RuntimeError) as info: | |||
| iter1.__next__() | |||
| err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | |||
| assert err_msg in str(info.value) | |||
| def test_generator_tuple_5(): | |||
| """ | |||
| test generator tuple 5 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| iter1 = data1.create_tuple_iterator(num_epochs=11) | |||
| for _ in range(10): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 | |||
| # still one more epoch left in the iter1. | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 | |||
| # now iter1 has been exhausted, c++ pipeline has been shut down. | |||
| with pytest.raises(RuntimeError) as info: | |||
| iter1.__next__() | |||
| err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | |||
| assert err_msg in str(info.value) | |||
| # Test with repeat | |||
| def test_generator_tuple_repeat_1(): | |||
| """ | |||
| test generator tuple repeat 1 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| data1 = data1.repeat(2) | |||
| iter1 = data1.create_tuple_iterator(num_epochs=11) | |||
| for _ in range(10): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i % 64]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 * 2 | |||
| # still one more epoch left in the iter1. | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i % 64]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 * 2 | |||
| # now iter1 has been exhausted, c++ pipeline has been shut down. | |||
| with pytest.raises(RuntimeError) as info: | |||
| iter1.__next__() | |||
| err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | |||
| assert err_msg in str(info.value) | |||
| # Test with repeat | |||
| def test_generator_tuple_repeat_repeat_1(): | |||
| """ | |||
| test generator tuple repeat repeat 1 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| data1 = data1.repeat(2) | |||
| data1 = data1.repeat(3) | |||
| iter1 = data1.create_tuple_iterator(num_epochs=11) | |||
| for _ in range(10): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i % 64]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 * 2 * 3 | |||
| # still one more epoch left in the iter1. | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i % 64]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 * 2 * 3 | |||
| # now iter1 has been exhausted, c++ pipeline has been shut down. | |||
| with pytest.raises(RuntimeError) as info: | |||
| iter1.__next__() | |||
| err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | |||
| assert err_msg in str(info.value) | |||
| def test_generator_tuple_repeat_repeat_2(): | |||
| """ | |||
| test generator tuple repeat repeat 2 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| data1 = data1.repeat(2) | |||
| data1 = data1.repeat(3) | |||
| iter1 = data1.create_tuple_iterator() | |||
| for _ in range(10): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i % 64]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 * 2 * 3 | |||
| # optional | |||
| iter1.stop() | |||
| # Expect a AttributeError since iter1 has been stopped. | |||
| with pytest.raises(AttributeError) as info: | |||
| iter1.__next__() | |||
| assert "object has no attribute 'depipeline'" in str(info.value) | |||
| def test_generator_tuple_repeat_repeat_3(): | |||
| """ | |||
| test generator tuple repeat repeat 3 | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| data1 = data1.repeat(2) | |||
| data1 = data1.repeat(3) | |||
| iter1 = data1.create_tuple_iterator() | |||
| for _ in range(10): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i % 64]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 * 2 * 3 | |||
| for _ in range(5): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i % 64]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 * 2 * 3 | |||
| # rely on garbage collector to destroy iter1 | |||
| def test_generator_reusedataset(): | |||
| """ | |||
| test generator reusedataset | |||
| """ | |||
| logger.info("Test 1D Generator : 0 - 63") | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| data1 = data1.repeat(2) | |||
| iter1 = data1.create_tuple_iterator() | |||
| for _ in range(10): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i % 64]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 * 2 | |||
| data1 = data1.repeat(3) | |||
| iter1 = data1.create_tuple_iterator() | |||
| for _ in range(5): | |||
| i = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([i % 64]) | |||
| assert np.array_equal(item[0], golden) | |||
| i = i + 1 | |||
| assert i == 64 * 2 * 3 | |||
| data1 = data1.batch(2) | |||
| iter1 = data1.create_dict_iterator() | |||
| for _ in range(5): | |||
| i = 0 | |||
| sample = 0 | |||
| for item in iter1: # each data is a dictionary | |||
| golden = np.array([[i % 64], [(i + 1) % 64]]) | |||
| assert np.array_equal(item["data"], golden) | |||
| i = i + 2 | |||
| sample = sample + 1 | |||
| assert sample == 64 * 3 | |||
| # rely on garbage collector to destroy iter1 | |||
| @@ -87,7 +87,7 @@ def test_five_crop_error_msg(): | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| with pytest.raises(RuntimeError) as info: | |||
| data.create_tuple_iterator().get_next() | |||
| data.create_tuple_iterator().__next__() | |||
| error_msg = "TypeError: img should be PIL Image or Numpy array. Got <class 'tuple'>" | |||
| # error msg comes from ToTensor() | |||
| @@ -41,18 +41,18 @@ def test_case1(): | |||
| assert data.get_batch_size() == 2 | |||
| assert data.get_repeat_count() == 1 | |||
| data = data.repeat(10) | |||
| assert data.get_dataset_size() == 6 | |||
| assert data.get_dataset_size() == 60 | |||
| assert data.get_batch_size() == 2 | |||
| assert data.get_repeat_count() == 10 | |||
| data = data.project(["new_column"]) | |||
| assert data.get_dataset_size() == 6 | |||
| assert data.get_dataset_size() == 60 | |||
| assert data.get_batch_size() == 2 | |||
| assert data.get_repeat_count() == 10 | |||
| data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10) | |||
| data1 = data.zip(data2) | |||
| assert data1.get_dataset_size() == 6 | |||
| assert data1.get_dataset_size() == 60 | |||
| def test_case2(): | |||
| @@ -65,14 +65,14 @@ def test_case2(): | |||
| data = data.rename("col_sint64", "new_column") | |||
| assert data.get_dataset_size() == 3 | |||
| data = data.repeat(10) | |||
| assert data.get_dataset_size() == 3 | |||
| assert data.get_dataset_size() == 30 | |||
| data = data.project(["new_column"]) | |||
| assert data.get_dataset_size() == 3 | |||
| assert data.get_dataset_size() == 30 | |||
| data2 = ds.TFRecordDataset(FILES, num_samples=6).batch(2).repeat(10) | |||
| data1 = data.zip(data2) | |||
| assert data1.get_dataset_size() == 3 | |||
| assert data1.get_dataset_size() == 30 | |||
| def test_case3(): | |||
| @@ -94,11 +94,11 @@ def test_case4(): | |||
| data2 = data2.shuffle(100) | |||
| assert data2.get_dataset_size() == 6 | |||
| data2 = data2.repeat(3) | |||
| assert data2.get_dataset_size() == 6 | |||
| assert data2.get_dataset_size() == 18 | |||
| data3 = ds.zip((data1, data2)) | |||
| assert data3.get_dataset_size() == 6 | |||
| assert data3.get_dataset_size() == 18 | |||
| def test_case5(): | |||
| @@ -73,7 +73,7 @@ def test_iterator_weak_ref(): | |||
| _cleanup() | |||
| with pytest.raises(AttributeError) as info: | |||
| itr2.get_next() | |||
| itr2.__next__() | |||
| assert "object has no attribute 'depipeline'" in str(info.value) | |||
| del itr1 | |||
| @@ -251,6 +251,49 @@ def test_nested_repeat11(): | |||
| assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3 | |||
| def test_repeat_count1(): | |||
| data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) | |||
| data1_size = data1.get_dataset_size() | |||
| logger.info("dataset size is {}".format(data1_size)) | |||
| batch_size = 2 | |||
| repeat_count = 4 | |||
| resize_height, resize_width = 32, 32 | |||
| decode_op = vision.Decode() | |||
| resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR) | |||
| data1 = data1.map(input_columns=["image"], operations=decode_op) | |||
| data1 = data1.map(input_columns=["image"], operations=resize_op) | |||
| data1 = data1.repeat(repeat_count) | |||
| data1 = data1.batch(batch_size, drop_remainder=False) | |||
| dataset_size = data1.get_dataset_size() | |||
| logger.info("dataset repeat then batch's size is {}".format(dataset_size)) | |||
| num1_iter = 0 | |||
| for _ in data1.create_dict_iterator(): | |||
| num1_iter += 1 | |||
| assert data1_size == 3 | |||
| assert dataset_size == num1_iter == 6 | |||
| def test_repeat_count2(): | |||
| data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) | |||
| data1_size = data1.get_dataset_size() | |||
| logger.info("dataset size is {}".format(data1_size)) | |||
| batch_size = 2 | |||
| repeat_count = 4 | |||
| resize_height, resize_width = 32, 32 | |||
| decode_op = vision.Decode() | |||
| resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR) | |||
| data1 = data1.map(input_columns=["image"], operations=decode_op) | |||
| data1 = data1.map(input_columns=["image"], operations=resize_op) | |||
| data1 = data1.batch(batch_size, drop_remainder=False) | |||
| data1 = data1.repeat(repeat_count) | |||
| dataset_size = data1.get_dataset_size() | |||
| logger.info("dataset batch then repeat's size is {}".format(dataset_size)) | |||
| num1_iter = 0 | |||
| for _ in data1.create_dict_iterator(): | |||
| num1_iter += 1 | |||
| assert data1_size == 3 | |||
| assert dataset_size == num1_iter == 8 | |||
| if __name__ == "__main__": | |||
| test_tf_repeat_01() | |||
| @@ -268,3 +311,5 @@ if __name__ == "__main__": | |||
| test_nested_repeat9() | |||
| test_nested_repeat10() | |||
| test_nested_repeat11() | |||
| test_repeat_count1() | |||
| test_repeat_count2() | |||
| @@ -252,14 +252,14 @@ def test_zip_exception_06(): | |||
| if __name__ == '__main__': | |||
| test_zip_01() | |||
| test_zip_02() | |||
| test_zip_03() | |||
| test_zip_04() | |||
| test_zip_05() | |||
| test_zip_06() | |||
| test_zip_exception_01() | |||
| test_zip_exception_02() | |||
| test_zip_exception_03() | |||
| test_zip_exception_04() | |||
| test_zip_exception_05() | |||
| test_zip_exception_06() | |||
| #test_zip_02() | |||
| #test_zip_03() | |||
| #test_zip_04() | |||
| #test_zip_05() | |||
| #test_zip_06() | |||
| #test_zip_exception_01() | |||
| #test_zip_exception_02() | |||
| #test_zip_exception_03() | |||
| #test_zip_exception_04() | |||
| #test_zip_exception_05() | |||
| #test_zip_exception_06() | |||
| @@ -274,6 +274,9 @@ class DatasetLenet(): | |||
| def get_repeat_count(self): | |||
| return 1 | |||
| def create_tuple_iterator(self): | |||
| return self | |||
| def test_train_32k_8p(batch_size=32, num_classes=32768): | |||
| dev_num = 8 | |||
| @@ -61,6 +61,9 @@ class DatasetLenet(): | |||
| def get_repeat_count(self): | |||
| return 1 | |||
| def create_tuple_iterator(self): | |||
| return self | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| @@ -58,6 +58,9 @@ class Dataset(): | |||
| def get_repeat_count(self): | |||
| return 1 | |||
| def create_tuple_iterator(self): | |||
| return self | |||
| class GatherV2(_Loss): | |||
| def __init__(self, index_dim, strategy, index_size=16): | |||
| @@ -0,0 +1,107 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test dataset helper.""" | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| from mindspore.communication.management import init | |||
| from mindspore.train.dataset_helper import DatasetHelper | |||
| from ....dataset_mock import MindData | |||
| def get_dataset(batch_size=1): | |||
| dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32) | |||
| dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1), | |||
| (batch_size, 20), (batch_size, 20), (batch_size, 20)) | |||
| dataset = MindData(size=2, batch_size=batch_size, np_types=dataset_types, | |||
| output_shapes=dataset_shapes, input_indexs=(0, 1)) | |||
| return dataset | |||
| def test_dataset_helper_dataset_sink_mode_str(): | |||
| dataset = get_dataset(32) | |||
| with pytest.raises(TypeError): | |||
| DatasetHelper(dataset, dataset_sink_mode="True") | |||
| def test_dataset_helper_dataset_sink_mode_int(): | |||
| dataset = get_dataset(32) | |||
| with pytest.raises(TypeError): | |||
| DatasetHelper(dataset, dataset_sink_mode=1) | |||
| def test_dataset_helper_sink_size_bool(): | |||
| dataset = get_dataset(32) | |||
| with pytest.raises(TypeError): | |||
| DatasetHelper(dataset, dataset_sink_mode=True, sink_size=True) | |||
| def test_dataset_helper_sink_size_float(): | |||
| dataset = get_dataset(32) | |||
| with pytest.raises(TypeError): | |||
| DatasetHelper(dataset, dataset_sink_mode=True, sink_size=1.0) | |||
| def test_dataset_helper_sink_size_negative(): | |||
| dataset = get_dataset(32) | |||
| with pytest.raises(ValueError): | |||
| DatasetHelper(dataset, dataset_sink_mode=True, sink_size=-2) | |||
| def test_dataset_iter_normal(): | |||
| dataset = get_dataset(32) | |||
| dataset_helper = DatasetHelper(dataset, dataset_sink_mode=False) | |||
| count = 0 | |||
| for _ in range(2): | |||
| for _ in dataset_helper: | |||
| count += 1 | |||
| dataset.reset() | |||
| assert count == 6 | |||
| @pytest.mark.skipif('not context.get_context("enable_ge")') | |||
| def test_dataset_iter_ge(): | |||
| init() | |||
| dataset = get_dataset(32) | |||
| dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | |||
| count = 0 | |||
| for _ in range(2): | |||
| for _ in dataset_helper: | |||
| count += 1 | |||
| assert count == 2 | |||
| @pytest.mark.skipif('context.get_context("enable_ge")') | |||
| def test_dataset_iter_ms_loop_sink(): | |||
| init() | |||
| context.set_context(enable_loop_sink=True) | |||
| dataset = get_dataset(32) | |||
| dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | |||
| count = 0 | |||
| for _ in range(2): | |||
| for inputs in dataset_helper: | |||
| count += 1 | |||
| assert inputs == tuple() | |||
| assert count == 2 | |||
| @pytest.mark.skipif('context.get_context("enable_ge")') | |||
| def test_dataset_iter_ms(): | |||
| init() | |||
| context.set_context(enable_loop_sink=False) | |||
| dataset = get_dataset(32) | |||
| DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | |||