diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h index bf754e352e..7f7dabc925 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -258,10 +258,6 @@ class DatasetOp : public std::enable_shared_from_this { /// \return The number of required repeats for the operator int32_t op_total_repeats() { return op_total_repeats_; } - /// \brief Getter function - /// \return The number of required epochs for the operator - int32_t op_total_epochs() { return op_total_repeats_ / op_num_repeats_per_epoch_; } - /// \brief Getter function /// \return The number of repeats per epoch for the operator int32_t op_num_repeats_per_epoch() const { return op_num_repeats_per_epoch_; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc index 4a59ecfdff..5efa38b683 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,10 +17,8 @@ #include #include -#include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" #include "minddata/dataset/engine/data_buffer.h" -#include "minddata/dataset/engine/db_connector.h" #include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/log_adapter.h" diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h index 891a0ca96f..af52d91a29 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -304,6 +304,10 @@ class CsvOp : public ParallelOp { // @return - Status of the node visit. Status Accept(NodePass *p, bool *const modified) override; + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "CsvOp"; } + private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc index 27543bc6c2..83ac5a495f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,20 +16,9 @@ #include "minddata/dataset/engine/execution_tree.h" #include #include -#include #include #include "minddata/dataset/engine/datasetops/dataset_op.h" -#include "minddata/dataset/engine/datasetops/shuffle_op.h" #include "minddata/dataset/engine/datasetops/device_queue_op.h" -#include "minddata/dataset/engine/opt/pass.h" -#include "minddata/dataset/engine/opt/pre/removal_pass.h" -#ifndef ENABLE_ANDROID -#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" -#include "minddata/dataset/engine/opt/post/repeat_pass.h" -#include "minddata/dataset/engine/opt/pre/cache_error_pass.h" -#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" -#endif -#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" #include "minddata/dataset/engine/perf/profiling.h" #include "minddata/dataset/engine/perf/monitor.h" #if defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE) @@ -255,97 +244,13 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::functionPreAction()); - - // Post optimization compulsory transformation - RETURN_IF_NOT_OK(this->PostAction()); - - // The tree is ready to be prepared. - tree_state_ = kDeTStatePrepare; - - // Existing transformation implementation, will be removed later - RETURN_IF_NOT_OK(this->PrepareDeprecated()); - return Status::OK(); -} - -Status ExecutionTree::PreAction() { - bool modified = false; - std::vector> pre_actions; - // Construct pre actions - if (!partially_prepare_) { -#ifndef ENABLE_ANDROID - pre_actions.push_back(std::make_unique()); -#endif - pre_actions.push_back(std::make_unique()); - pre_actions.push_back(std::make_unique()); - } - - MS_LOG(INFO) << "Running " << pre_actions.size() << " pre pass loops."; - - // Apply pre action passes - for (auto &pass : pre_actions) { - RETURN_IF_NOT_OK(pass->Run(this, &modified)); - } - MS_LOG(INFO) << "Pre passes complete."; - return Status::OK(); -} - -Status ExecutionTree::PostAction() { - bool modified = false; - OptPass post_actions; - // Construct pre actions - MS_LOG(INFO) << "Running post pass loops."; -#ifndef ENABLE_ANDROID - // Calling CacheErrorPass again. This is a temporary fix until the TensorOperation is properly done in Pybind. - // The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API. - // This is because Python API binding to TensorOperation is still in progress. - post_actions.push_back(std::make_unique()); - post_actions.push_back(std::make_unique()); -#endif - - // Apply post action passes - for (auto &pass : post_actions) { - RETURN_IF_NOT_OK(pass->Run(this, &modified)); - } - MS_LOG(INFO) << "Post passes complete."; - - return Status::OK(); -} - -// The driver of the prepare phase of the execution tree. The prepare phase will recursively // walk the tree to perform modifications to the tree or specific nodes within the tree to get // it ready for execution. // // This driver is deprecated. -Status ExecutionTree::PrepareDeprecated() { - // Tree must be in pending prepare state before we can assign root to it - if (tree_state_ != kDeTStatePrepare) { - std::string err_msg = - "Invalid tree state for preparing the tree. Current state: " + std::to_string(static_cast(tree_state_)) + - " Expected state: " + std::to_string(static_cast(kDeTStatePrepare)); - RETURN_STATUS_UNEXPECTED(err_msg); - } +Status ExecutionTree::Prepare() { + // The tree is ready to be prepared. + tree_state_ = kDeTStatePrepare; if (root_ == nullptr) { RETURN_STATUS_UNEXPECTED("Please assign one operator as the root of this tree."); diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h index 87a6bd069f..059a53aadd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -169,24 +169,6 @@ class ExecutionTree { // @return the prepare flags uint32_t PrepareFlags() const { return prepare_flags_; } - // The driver of the prepare phase of the execution tree. - // Prepare phase consists of three sub phases - // - // 1. PreAction() - // Compulsory transformation/action pre optimization. - // For example, CacheOp Insertion - // - // 2. Optimize() - // Optimization transformation/action, optional - // For example, MapOp Fusion - // - // 3. PostAction() - // Compulsory transformation/action post optimization. - // For example, repeatOp inlining - // - // @return Status The status code returned - Status Prepare(int num_epochs = -1, bool partial = false); - // Compulsory transformation/action pre optimization. // @return Status The status code returned Status PreAction(); @@ -200,7 +182,7 @@ class ExecutionTree { // it ready for execution. // @param Total number of epochs that will be run on this tree // @return Status The status code returned - Status PrepareDeprecated(); + Status Prepare(); // Recursive function used during prepare phase to visit a node and drive any pre- and post- // node actions during a tree walk. @@ -239,10 +221,6 @@ class ExecutionTree { // Getter for profiling manager, no ownership ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); } - // Getter function to get the total number of epochs to be run on this tree. - // @return total number of epochs - int32_t num_epochs() { return num_epochs_; } - private: // A helper functions for doing the recursive printing // @param dataset_op - The dataset op to print @@ -257,9 +235,7 @@ class ExecutionTree { int32_t id_count_; // Counter for generating operator id's uint32_t prepare_flags_; // Flags used during tree prepare TreeState tree_state_; // Tracking the current tree state - int32_t num_epochs_; // Total number of epochs to run for this tree std::unique_ptr profiling_manager_; // Profiling manager - bool partially_prepare_; // Temp: during migration to IR, if true, run remaining passes. #if defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE) // This rank_id is for numa and device_queue, one process work with only one rank_id, // for standalone scenario, this rank_id may come from env 'CUDA_VISIBLE_DEVICES', diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc index 9d254de7d2..9856b6332c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -102,9 +102,11 @@ Status BatchNode::Build(std::vector> *const node_ops) node_ops->push_back(project_op); } - node_ops->push_back(std::make_shared(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, - in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, - pad_map_)); + auto op = std::make_shared(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, + in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, pad_map_); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); #else node_ops->push_back(std::make_shared(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, in_col_names_, pad_map_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc index c72faf666c..99e0a16708 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -84,9 +84,12 @@ void BucketBatchByLengthNode::Print(std::ostream &out) const { Status BucketBatchByLengthNode::Build(std::vector> *const node_ops) { bucket_boundaries_.insert(bucket_boundaries_.begin(), 0); - node_ops->push_back(std::make_shared( - column_names_, bucket_boundaries_, bucket_batch_sizes_, element_length_function_, pad_info_, - pad_to_bucket_boundary_, drop_remainder_, connector_que_size_)); + auto op = std::make_shared(column_names_, bucket_boundaries_, bucket_batch_sizes_, + element_length_function_, pad_info_, pad_to_bucket_boundary_, + drop_remainder_, connector_que_size_); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); if (bucket_boundaries_[0] == 0) { bucket_boundaries_.erase(bucket_boundaries_.begin()); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc index 86b1f91a69..c7f4b11a64 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,10 +55,11 @@ void BuildSentenceVocabNode::Print(std::ostream &out) const { // Function to build BuildSentenceVocabNode Status BuildSentenceVocabNode::Build(std::vector> *const node_ops) { - std::shared_ptr build_sentence_piece_vocab_op; - build_sentence_piece_vocab_op = std::make_shared( - vocab_, col_names_, vocab_size_, character_coverage_, model_type_, params_, connector_que_size_); - node_ops->push_back(build_sentence_piece_vocab_op); + auto op = std::make_shared(vocab_, col_names_, vocab_size_, character_coverage_, + model_type_, params_, connector_que_size_); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc index 7886a0b1b4..0e743f015e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,6 +54,8 @@ Status BuildVocabNode::Build(std::vector> *const node std::shared_ptr build_vocab_op; build_vocab_op = std::make_shared(vocab_, columns_, freq_range_, top_k_, special_tokens_, special_first_, num_workers_, connector_que_size_); + build_vocab_op->set_total_repeats(GetTotalRepeats()); + build_vocab_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(build_vocab_op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc index 63217c381d..0497b91c1d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc @@ -51,10 +51,24 @@ Status CacheLookupNode::Build(std::vector> *node_ops) "Internal error. Attempt to create a cache lookup node without cache client."); RETURN_IF_NOT_OK(cache_->Build()); RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_)); + lookup_op_->set_total_repeats(GetTotalRepeats()); + lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(lookup_op_); return Status::OK(); } +// Visitor accepting method for IRNodePass +Status CacheLookupNode::Accept(IRNodePass *const p, bool *const modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for IRNodePass +Status CacheLookupNode::AcceptAfter(IRNodePass *const p, bool *const modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} + std::shared_ptr CacheLookupNode::SamplerCopy() { // CacheLookupNode should already been copied, so we just return it here return std::static_pointer_cast(lookup_node_copy_); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.h index 04a510d34d..eda1e1f560 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.h @@ -64,6 +64,18 @@ class CacheLookupNode : public DatasetNode, public SamplerObj { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(IRNodePass *const p, bool *const modified) override; + + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(IRNodePass *const p, bool *const modified) override; + private: std::shared_ptr sampler_; std::shared_ptr lookup_op_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.cc index 66faa99890..2a08a5e2c9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.cc @@ -48,9 +48,23 @@ Status CacheMergeNode::Build(std::vector> *node_ops) RETURN_IF_NOT_OK(cache_->Build()); std::shared_ptr merge_op = nullptr; RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op)); + merge_op->set_total_repeats(GetTotalRepeats()); + merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(merge_op); return Status::OK(); } +// Visitor accepting method for IRNodePass +Status CacheMergeNode::Accept(IRNodePass *const p, bool *const modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for IRNodePass +Status CacheMergeNode::AcceptAfter(IRNodePass *const p, bool *const modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.h index 0afcbc1922..d49a19e37e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.h @@ -53,6 +53,18 @@ class CacheMergeNode : public DatasetNode { /// \brief Parameters validation /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(IRNodePass *const p, bool *const modified) override; + + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(IRNodePass *const p, bool *const modified) override; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.cc index 38edbeb468..9b02ff2651 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.cc @@ -53,9 +53,23 @@ Status CacheNode::Build(std::vector> *node_ops) { std::shared_ptr cache_op = nullptr; RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op)); cache_op->SetSampler(sampler_->SamplerBuild()); + cache_op->set_total_repeats(GetTotalRepeats()); + cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(cache_op); return Status::OK(); } +// Visitor accepting method for IRNodePass +Status CacheNode::Accept(IRNodePass *const p, bool *const modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for IRNodePass +Status CacheNode::AcceptAfter(IRNodePass *const p, bool *const modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.h index 25d969a23f..1951017752 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.h @@ -55,6 +55,18 @@ class CacheNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(IRNodePass *const p, bool *const modified) override; + + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(IRNodePass *const p, bool *const modified) override; + private: std::shared_ptr sampler_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc index 318a496791..58f6bcdb18 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -119,12 +119,16 @@ Status ConcatNode::GetDatasetSize(const std::shared_ptr &size } Status ConcatNode::Build(std::vector> *const node_ops) { + std::shared_ptr op; if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) { - node_ops->push_back(std::make_shared(connector_que_size_)); + op = std::make_shared(connector_que_size_); } else { - node_ops->push_back(std::make_shared(connector_que_size_, sampler_->SamplerBuild(), - children_flag_and_nums_, children_start_end_index_)); + op = std::make_shared(connector_que_size_, sampler_->SamplerBuild(), children_flag_and_nums_, + children_start_end_index_); } + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index 174a995859..52086c0077 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -219,7 +219,9 @@ DatasetNode::DatasetNode() dataset_size_(-1), mappable_(kNotADataSource), nary_op_(false), - descendant_of_cache_(false) { + descendant_of_cache_(false), + total_repeats_(-1), + num_epochs_(1) { // Fetch some default value from config manager std::shared_ptr cfg = GlobalContext::config_manager(); num_workers_ = cfg->num_parallel_workers(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index b96f89523a..3f9e47a8f5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/engine/consumers/tree_consumer.h" #include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" #include "minddata/dataset/engine/datasetops/filter_op.h" #include "minddata/dataset/engine/datasetops/map_op/map_op.h" #include "minddata/dataset/engine/datasetops/project_op.h" @@ -292,6 +293,24 @@ class DatasetNode : public std::enable_shared_from_this { /// \return Status of the function virtual Status to_json(nlohmann::json *out_json); + /// \brief Setter function, set the number of total repeats for the operator + void SetTotalRepeats(int32_t total_repeats) { total_repeats_ = total_repeats; } + + /// \brief Setter function, set the number of epochs for the operator + void SetNumEpochs(int32_t num_epochs) { num_epochs_ = num_epochs; } + + /// \brief Getter function + /// \return The number of required repeats for the operator + int32_t GetTotalRepeats() const { return total_repeats_; } + + /// \brief Getter function + /// \return The number of epochs for the operator + int32_t GetNumEpochs() const { return num_epochs_; } + + /// \brief Getter function + /// \return The number of repeats per epoch for the operator + int32_t GetNumRepeatsPerEpoch() const { return total_repeats_ / num_epochs_; } + protected: std::vector> children_; DatasetNode *parent_; // used to record the only one parent of an IR node after parsing phase @@ -301,6 +320,8 @@ class DatasetNode : public std::enable_shared_from_this { int32_t rows_per_buffer_; int32_t connector_que_size_; int32_t worker_connector_size_; + int32_t total_repeats_; // Number of times required to run this operator + int32_t num_epochs_; // Number of epochs // Establish a parent-child relationship between this node and the input node. // Used only in the constructor of the class and its derived classes. void AddChild(std::shared_ptr child); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc index 08622327cc..5e29313049 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,6 +44,8 @@ void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" + // Function to build the EpochCtrlOp Status EpochCtrlNode::Build(std::vector> *const node_ops) { auto new_op_ = std::make_shared(repeat_count_); + new_op_->set_total_repeats(GetTotalRepeats()); + new_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(new_op_); op_ = new_op_; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc index 24be067e73..2be8738d6b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,7 +44,10 @@ void FilterNode::Print(std::ostream &out) const { } Status FilterNode::Build(std::vector> *const node_ops) { - node_ops->push_back(std::make_shared(input_columns_, num_workers_, connector_que_size_, predicate_)); + auto op = std::make_shared(input_columns_, num_workers_, connector_que_size_, predicate_); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc index d325f9d099..4097753173 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,7 +38,8 @@ MapNode::MapNode(std::shared_ptr child, std::vectorAddChild(child); } @@ -64,6 +65,17 @@ Status MapNode::Build(std::vector> *const node_ops) { operations_.begin(), operations_.end(), std::back_inserter(tensor_ops), [](std::shared_ptr operation) -> std::shared_ptr { return operation->Build(); }); + // This is temporary code. + // Because the randomness of its tensor operations is not known in TensorOperation form until we convert them + // to TensorOp, we need to check the randomness here. + // When TensorOperation captures the randomness behaviour, remove this code and the member "under_a_cache_" + // and the temporary code in CacheValidation pre pass in IR optimizer. + if (under_a_cache_) { + auto itr = std::find_if(tensor_ops.begin(), tensor_ops.end(), [](const auto &it) { return !it->Deterministic(); }); + if (itr != tensor_ops.end()) { + RETURN_STATUS_UNEXPECTED("MapNode containing random operation is not supported as a descendant of cache."); + } + } // This parameter will be removed with next rebase std::vector col_orders; auto map_op = std::make_shared(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_); @@ -74,9 +86,12 @@ Status MapNode::Build(std::vector> *const node_ops) { if (!project_columns_.empty()) { auto project_op = std::make_shared(project_columns_); + project_op->set_total_repeats(GetTotalRepeats()); + project_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(project_op); } - + map_op->set_total_repeats(GetTotalRepeats()); + map_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(map_op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h index b6628b057a..4589ee4087 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -79,6 +79,9 @@ class MapNode : public DatasetNode { /// \brief setter to set all tensor operations void setOperations(const std::vector> &operations); + /// \brief indicate this Map will be cached + void Cached() { under_a_cache_ = true; } + /// \brief Getter functions /// \brief Getter of tensor operations /// \return Vector of operations the Map node will process @@ -95,12 +98,11 @@ class MapNode : public DatasetNode { private: std::vector> operations_; - - private: std::vector input_columns_; std::vector output_columns_; std::vector project_columns_; std::vector> callbacks_; + bool under_a_cache_; }; } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc index e0fb878579..647999ad5d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,7 +53,10 @@ Status ProjectNode::ValidateParams() { } Status ProjectNode::Build(std::vector> *const node_ops) { - node_ops->push_back(std::make_shared(columns_)); + auto op = std::make_shared(columns_); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc index e9f3ca6ddc..496bbaebeb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,7 +58,10 @@ Status RenameNode::ValidateParams() { } Status RenameNode::Build(std::vector> *const node_ops) { - node_ops->push_back(std::make_shared(input_columns_, output_columns_, connector_que_size_)); + auto op = std::make_shared(input_columns_, output_columns_, connector_que_size_); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc index 63452cd69d..8c95f26a00 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,6 +40,8 @@ void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + st Status RepeatNode::Build(std::vector> *const node_ops) { auto new_op = std::make_shared(repeat_count_); + new_op->set_total_repeats(GetTotalRepeats()); + new_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(new_op); op_ = new_op; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc index d496b94895..85267c0785 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,8 +44,11 @@ void ShuffleNode::Print(std::ostream &out) const { // Function to build the ShuffleOp Status ShuffleNode::Build(std::vector> *const node_ops) { - node_ops->push_back(std::make_shared(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_, - rows_per_buffer_)); + auto op = std::make_shared(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_, + rows_per_buffer_); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc index eb29c30391..4ab2751ff4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,7 +39,10 @@ void SkipNode::Print(std::ostream &out) const { out << Name() + "(skip_count:" + // Function to build the SkipOp Status SkipNode::Build(std::vector> *const node_ops) { - node_ops->push_back(std::make_shared(skip_count_, connector_que_size_)); + auto op = std::make_shared(skip_count_, connector_que_size_); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc index 1e2d20428d..1a9359155a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -72,9 +72,11 @@ Status AlbumNode::Build(std::vector> *const node_ops) // Argument that is not exposed to user in the API. std::set extensions = {}; - node_ops->push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, - decode_, extensions, std::move(schema), - std::move(sampler_->SamplerBuild()))); + auto album_op = std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_, + extensions, std::move(schema), std::move(sampler_->SamplerBuild())); + album_op->set_total_repeats(GetTotalRepeats()); + album_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(album_op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc index 386afe33ee..369db7079a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -67,9 +67,12 @@ Status CelebANode::Build(std::vector> *const node_ops // label is like this:0 1 0 0 1...... RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - node_ops->push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, - decode_, usage_, extensions_, std::move(schema), - std::move(sampler_->SamplerBuild()))); + auto celeba_op = + std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_, usage_, + extensions_, std::move(schema), std::move(sampler_->SamplerBuild())); + celeba_op->set_total_repeats(GetTotalRepeats()); + celeba_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(celeba_op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc index 4e78fe5dcd..b718490e48 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -64,9 +64,12 @@ Status Cifar100Node::Build(std::vector> *const node_o RETURN_IF_NOT_OK( schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - node_ops->push_back(std::make_shared(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, - dataset_dir_, connector_que_size_, std::move(schema), - std::move(sampler_->SamplerBuild()))); + auto cifar_op = + std::make_shared(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, dataset_dir_, + connector_que_size_, std::move(schema), std::move(sampler_->SamplerBuild())); + cifar_op->set_total_repeats(GetTotalRepeats()); + cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(cifar_op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc index 7d188331c8..73e843fe27 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,9 +62,12 @@ Status Cifar10Node::Build(std::vector> *const node_op RETURN_IF_NOT_OK( schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - node_ops->push_back(std::make_shared(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, - dataset_dir_, connector_que_size_, std::move(schema), - std::move(sampler_->SamplerBuild()))); + auto cifar_op = + std::make_shared(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, dataset_dir_, + connector_que_size_, std::move(schema), std::move(sampler_->SamplerBuild())); + cifar_op->set_total_repeats(GetTotalRepeats()); + cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(cifar_op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc index f37b62e656..a1b10b0f5c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -193,9 +193,12 @@ Status CLUENode::Build(std::vector> *const node_ops) // Add the shuffle op after this op RETURN_IF_NOT_OK(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, rows_per_buffer_, &shuffle_op)); + shuffle_op->set_total_repeats(GetTotalRepeats()); + shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(shuffle_op); } - + clue_op->set_total_repeats(GetTotalRepeats()); + clue_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(clue_op); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc index 309619737d..81762c6e24 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -122,7 +122,8 @@ Status CocoNode::Build(std::vector> *const node_ops) std::shared_ptr op = std::make_shared(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild())); - + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(op); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc index b9cc83e2c2..fd6471c455 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -130,10 +130,12 @@ Status CSVNode::Build(std::vector> *const node_ops) { // Add the shuffle op after this op RETURN_IF_NOT_OK(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, rows_per_buffer_, &shuffle_op)); - + shuffle_op->set_total_repeats(GetTotalRepeats()); + shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(shuffle_op); } - + csv_op->set_total_repeats(GetTotalRepeats()); + csv_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(csv_op); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc index 09b9830f99..4ab4350e2a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -91,7 +91,8 @@ Status GeneratorNode::Build(std::vector> *const node_ if (reset_ancestor_ != nullptr) { reset_ancestor_->op_->AddToEoeList(op); } - + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc index b2c852764f..647225fbec 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -70,9 +70,12 @@ Status ImageFolderNode::Build(std::vector> *const nod RETURN_IF_NOT_OK( schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); - node_ops->push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, - recursive_, decode_, exts_, class_indexing_, std::move(schema), - std::move(sampler_->SamplerBuild()))); + auto op = std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, + recursive_, decode_, exts_, class_indexing_, std::move(schema), + std::move(sampler_->SamplerBuild())); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc index 16077144b3..6bd1a5ce15 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -94,7 +94,8 @@ Status ManifestNode::Build(std::vector> *const node_o manifest_op = std::make_shared(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, class_index_, std::move(schema), std::move(sampler_->SamplerBuild()), usage_); - + manifest_op->set_total_repeats(GetTotalRepeats()); + manifest_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(manifest_op); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc index 804abb0763..0c8ecf7479 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -169,6 +169,8 @@ Status MindDataNode::Build(std::vector> *const node_o } RETURN_IF_NOT_OK(mindrecord_op->Init()); + mindrecord_op->set_total_repeats(GetTotalRepeats()); + mindrecord_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(mindrecord_op); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc index 5a3efcfba1..0d43fbeb96 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,9 +58,11 @@ Status MnistNode::Build(std::vector> *const node_ops) RETURN_IF_NOT_OK( schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - node_ops->push_back(std::make_shared(usage_, num_workers_, rows_per_buffer_, dataset_dir_, - connector_que_size_, std::move(schema), - std::move(sampler_->SamplerBuild()))); + auto op = std::make_shared(usage_, num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, + std::move(schema), std::move(sampler_->SamplerBuild())); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc index 4c6cf420c7..b4a312fc8b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -109,7 +109,8 @@ Status RandomNode::Build(std::vector> *const node_ops std::shared_ptr op; op = std::make_shared(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, std::move(data_schema_)); - + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(op); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc index e836f88c7c..315d220949 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -98,9 +98,12 @@ Status TextFileNode::Build(std::vector> *const node_o // Add the shuffle op after this op RETURN_IF_NOT_OK(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, rows_per_buffer_, &shuffle_op)); + shuffle_op->set_total_repeats(GetTotalRepeats()); + shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(shuffle_op); } - + text_file_op->set_total_repeats(GetTotalRepeats()); + text_file_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); // Add TextFileOp node_ops->push_back(text_file_op); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc index b9fbc77e1c..bd1e3e6176 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -140,9 +140,12 @@ Status TFRecordNode::Build(std::vector> *const node_o // Add the shuffle op after this op RETURN_IF_NOT_OK(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, rows_per_buffer_, &shuffle_op)); + shuffle_op->set_total_repeats(GetTotalRepeats()); + shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(shuffle_op); } - + tf_reader_op->set_total_repeats(GetTotalRepeats()); + tf_reader_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); // Add TFReaderOp node_ops->push_back(tf_reader_op); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc index f08d40e493..a8c071d9f6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -113,7 +113,8 @@ Status VOCNode::Build(std::vector> *const node_ops) { voc_op = std::make_shared(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild())); - + voc_op->set_total_repeats(GetTotalRepeats()); + voc_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(voc_op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc index fdb7ae45be..6b9d19b0ed 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,7 +47,10 @@ Status SyncWaitNode::Build(std::vector> *const node_o // The reason for this is because having it otherwise can lead to blocking issues // See barrier_op.h for more details int32_t rows_per_buffer = 1; - node_ops->push_back(std::make_shared(rows_per_buffer, connector_que_size_, condition_name_, callback_)); + auto op = std::make_shared(rows_per_buffer, connector_que_size_, condition_name_, callback_); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc index 94efe21b5d..ebc0809d08 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 20202-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,7 +40,10 @@ void TakeNode::Print(std::ostream &out) const { out << Name() + "(num_rows:" + s // Function to build the TakeOp Status TakeNode::Build(std::vector> *const node_ops) { - node_ops->push_back(std::make_shared(take_count_, connector_que_size_)); + auto op = std::make_shared(take_count_, connector_que_size_); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc index 79732c516f..c85f555e0d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -100,8 +100,11 @@ Status TransferNode::Build(std::vector> *const node_o device_id_ = 0; RETURN_IF_NOT_OK(this->GetShardId(&device_id_)); - node_ops->push_back(std::make_shared(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, - total_batch_, create_data_info_queue_)); + auto op = std::make_shared(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, + total_batch_, create_data_info_queue_); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc index 4c69f8b747..9aef58823f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,7 +58,10 @@ Status ZipNode::ValidateParams() { } Status ZipNode::Build(std::vector> *const node_ops) { - node_ops->push_back(std::make_shared(rows_per_buffer_, connector_que_size_)); + auto op = std::make_shared(rows_per_buffer_, connector_que_size_); + op->set_total_repeats(GetTotalRepeats()); + op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt index 50df9e2c91..78bf0fc4a3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt @@ -2,29 +2,19 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) set(DATASET_ENGINE_OPT_SRC_FILES + optional/tensor_op_fusion_pass.cc pass.cc post/auto_worker_pass.cc + post/repeat_pass.cc + pre/cache_transform_pass.cc pre/cache_validation_pass.cc pre/deep_copy_pass.cc + pre/epoch_ctrl_pass.cc pre/getter_pass.cc pre/input_validation_pass.cc - pre/epoch_ctrl_pass.cc pre/node_removal_pass.cc ) -# This set of files is for ExecTree's optimizer. It is being migrated to IR's optimizer. -# When the migration is complete, we will remove these files. -set(DATASET_ENGINE_OPT_SRC_FILES - ${DATASET_ENGINE_OPT_SRC_FILES} - optional/tensor_op_fusion_pass.cc - pre/cache_error_pass.cc - post/repeat_pass.cc - pre/cache_transform_pass.cc - pre/epoch_injection_pass.cc - util/printer_pass.cc - pre/removal_pass.cc - ) - if(ENABLE_PYTHON) set(DATASET_ENGINE_OPT_SRC_FILES ${DATASET_ENGINE_OPT_SRC_FILES} diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc index e0cbbc6930..b5981e6442 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,11 @@ #include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h" #endif #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/ir/datasetops/cache_node.h" +#include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h" +#include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h" +#endif #include "minddata/dataset/engine/ir/datasetops/concat_node.h" #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h" #include "minddata/dataset/engine/ir/datasetops/filter_node.h" @@ -187,6 +192,26 @@ Status IRNodePass::Visit(std::shared_ptr node, bool *const modified) Status IRNodePass::VisitAfter(std::shared_ptr node, bool *const modified) { return VisitAfter(std::static_pointer_cast(node), modified); } +#ifndef ENABLE_ANDROID +Status IRNodePass::Visit(std::shared_ptr node, bool *const modified) { + return Visit(std::static_pointer_cast(node), modified); +} +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *const modified) { + return VisitAfter(std::static_pointer_cast(node), modified); +} +Status IRNodePass::Visit(std::shared_ptr node, bool *const modified) { + return Visit(std::static_pointer_cast(node), modified); +} +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *const modified) { + return VisitAfter(std::static_pointer_cast(node), modified); +} +Status IRNodePass::Visit(std::shared_ptr node, bool *const modified) { + return Visit(std::static_pointer_cast(node), modified); +} +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *const modified) { + return VisitAfter(std::static_pointer_cast(node), modified); +} +#endif Status IRNodePass::Visit(std::shared_ptr node, bool *const modified) { return Visit(std::static_pointer_cast(node), modified); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h index 2e3de51947..19c9a4989b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,11 @@ namespace dataset { class BatchNode; class BucketBatchByLengthNode; class BuildVocabNode; +#ifndef ENABLE_ANDROID +class CacheLookupNode; +class CacheMergeNode; +class CacheNode; +#endif class ConcatNode; class EpochCtrlNode; class FilterNode; @@ -199,6 +204,14 @@ class IRNodePass : public IRPass { virtual Status VisitAfter(std::shared_ptr node, bool *const modified); virtual Status Visit(std::shared_ptr node, bool *const modified); virtual Status VisitAfter(std::shared_ptr node, bool *const modified); +#ifndef ENABLE_ANDROID + virtual Status Visit(std::shared_ptr node, bool *const modified); + virtual Status VisitAfter(std::shared_ptr node, bool *const modified); + virtual Status Visit(std::shared_ptr node, bool *const modified); + virtual Status VisitAfter(std::shared_ptr node, bool *const modified); + virtual Status Visit(std::shared_ptr node, bool *const modified); + virtual Status VisitAfter(std::shared_ptr node, bool *const modified); +#endif virtual Status Visit(std::shared_ptr node, bool *const modified); virtual Status VisitAfter(std::shared_ptr node, bool *const modified); virtual Status Visit(std::shared_ptr node, bool *const modified); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc index 4a1bfef512..8eb9b5599f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,15 +14,16 @@ * limitations under the License. */ -#include #include "minddata/dataset/engine/opt/post/repeat_pass.h" -#include "minddata/dataset/engine/datasetops/repeat_op.h" -#include "minddata/dataset/engine/datasetops/cache_op.h" -#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" -#include "minddata/dataset/engine/datasetops/cache_merge_op.h" -#include "minddata/dataset/engine/datasetops/device_queue_op.h" -#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" -#include "minddata/dataset/engine/datasetops/source/generator_op.h" + +#include + +#include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h" +#include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h" +#include "minddata/dataset/engine/ir/datasetops/cache_node.h" +#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h" +#include "minddata/dataset/engine/ir/datasetops/repeat_node.h" +#include "minddata/dataset/engine/ir/datasetops/transfer_node.h" namespace mindspore { namespace dataset { @@ -31,10 +32,10 @@ RepeatPass::RepeatPass() : num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {} // Identifies the subtree below this node as being in a repeated path of the tree. -Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *const modified) { +Status RepeatPass::Visit(std::shared_ptr node, bool *const modified) { // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_. // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely. - if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { + if (node->Count() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { num_repeats_ = -num_repeats_; } // This RepeatOp and its descendent nodes should be repeated for another num_repeats() times. @@ -49,14 +50,14 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *const modi // num_repeats_ is originally 4, after repeat2(2), num_repeats_ becomes 8 (2*4), // meaning repeat2 and map op should be set to read 8 times (2*4). // Then, after repeat1(3), num_repeats_ becomes 24 (3*2*4), meaning repeat1 and tfreader op should repeat 24 times. - num_repeats_ *= node->num_repeats(); + num_repeats_ *= node->Count(); return Status::OK(); } // Identifies the subtree below this node as being in a repeated path of the tree. -Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *const modified) { +Status RepeatPass::Visit(std::shared_ptr node, bool *const modified) { // Get the total number of epochs from the EpochCtrlOp parameter - num_epochs_ = node->num_repeats(); + num_epochs_ = node->Count(); // Every node below this EpochCtrlOp should be repeated for num_epochs_ times. // For example: tfreader --> epoch ctrl(3) // num_repeats_ is originally 1 (default initialization), after this epoch ctrl(3), num_repeats_ becomes 3 (1*3), @@ -65,115 +66,108 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *const m return Status::OK(); } +#ifndef ENABLE_ANDROID // Identifies the subtree below this node as being in a cache merge path -Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *const modified) { +Status RepeatPass::Visit(std::shared_ptr node, bool *const modified) { // Turn on the flag that we're under a merge op is_merge_ = true; return Status::OK(); } // Identifies the subtree below this node as being cached -Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *const modified) { +Status RepeatPass::Visit(std::shared_ptr node, bool *const modified) { // Turn on the flag that we're under a merge op is_cached_ = true; return Status::OK(); } +#endif // Hooks up any identified eoe nodes under this repeat. -Status RepeatPass::RunOnNode(std::shared_ptr node, bool *const modified) { +Status RepeatPass::VisitAfter(std::shared_ptr node, bool *const modified) { // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up // and set its total repeats. It is important that the op is removed from the save area, // because the merge op above us may also take action on it later for a different case when // there is no repeat in the merge leg. if (is_merge_ && cache_lookup_) { - cache_lookup_->set_total_repeats(num_repeats_); - cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); + cache_lookup_->SetTotalRepeats(num_repeats_); + cache_lookup_->SetNumEpochs(num_epochs_); cache_lookup_.reset(); } if (is_cached_) { - AddToCachedOpStack(node); + AddToCachedNodeStack(node); } - node->set_total_repeats(num_repeats_); - node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); + node->SetTotalRepeats(num_repeats_); + node->SetNumEpochs(num_epochs_); // We finish the walk of this RepeatOp's descendent nodes. // The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n. // But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode, // so we divide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp. - num_repeats_ /= node->num_repeats(); + num_repeats_ /= node->Count(); return Status::OK(); } // Hooks up any identified eoe nodes under this repeat. -Status RepeatPass::RunOnNode(std::shared_ptr node, bool *const modified) { - node->set_total_repeats(num_repeats_); - node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); +Status RepeatPass::VisitAfter(std::shared_ptr node, bool *const modified) { + node->SetTotalRepeats(num_repeats_); + node->SetNumEpochs(num_epochs_); // We finish the walk of this EpochCtrl's descendent nodes. - num_repeats_ /= node->num_repeats(); + num_repeats_ /= node->Count(); + return Status::OK(); +} + +// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up +// for use with a controlling repeat above it. +Status RepeatPass::VisitAfter(std::shared_ptr node, bool *const modified) { + // If we are under a cache op, then save ourselves to the cached op stack. + if (is_cached_) { + AddToCachedNodeStack(node); + } + // Set total repeats and total epochs for the node + node->SetTotalRepeats(num_repeats_); + node->SetNumEpochs(num_epochs_); return Status::OK(); } +#ifndef ENABLE_ANDROID // CacheOp removes previous leaf ops and replaces them with itself -Status RepeatPass::RunOnNode(std::shared_ptr node, bool *const modified) { +Status RepeatPass::VisitAfter(std::shared_ptr node, bool *const modified) { is_cached_ = false; // if we are a cache within a repeat path of the tree, then adjust the total repeats and total epochs for cached ops. // So that those cached nodes become 1-time use (up to eoe), never repeated. Instead // the repeating behaviours shall be invoked against the cache op. - std::shared_ptr cached_op = PopFromCachedOpStack(); - while (cached_op != nullptr) { - int32_t cached_op_total_repeats = cached_op->op_total_repeats() / num_repeats_; - cached_op->set_total_repeats(cached_op_total_repeats); + std::shared_ptr cached_node = PopFromCachedNodeStack(); + while (cached_node != nullptr) { + int32_t cached_op_total_repeats = cached_node->GetTotalRepeats() / num_repeats_; + cached_node->SetTotalRepeats(cached_op_total_repeats); // Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1 - cached_op->set_num_repeats_per_epoch(cached_op_total_repeats); - cached_op = PopFromCachedOpStack(); + cached_node->SetNumEpochs(1); + cached_node = PopFromCachedNodeStack(); } - node->set_total_repeats(num_repeats_); - node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); - return Status::OK(); -} - -Status RepeatPass::RunOnNode(std::shared_ptr node, bool *const modified) { - // If we are under a cache op, then save ourselves to the cached op stack. - if (is_cached_) { - AddToCachedOpStack(node); - } - // Set total repeats and total epochs for the node - node->set_total_repeats(num_repeats_); - node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); - return Status::OK(); -} -// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up -// for use with a controlling repeat above it. -Status RepeatPass::RunOnNode(std::shared_ptr node, bool *const modified) { - // If we are under a cache op, then save ourselves to the cached op stack. - if (is_cached_) { - AddToCachedOpStack(node); - } - // Set total repeats and total epochs for the node - node->set_total_repeats(num_repeats_); - node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); + node->SetTotalRepeats(num_repeats_); + node->SetNumEpochs(num_epochs_); return Status::OK(); } // Turns off the tracking for operations under merge op -Status RepeatPass::RunOnNode(std::shared_ptr node, bool *const modified) { +Status RepeatPass::VisitAfter(std::shared_ptr node, bool *const modified) { // If there was not any repeat in the merge cache miss leg, then the cache_lookup // would not have been consumed yet. In that case, we need to set its total repeats for it. if (cache_lookup_) { - cache_lookup_->set_total_repeats(num_repeats_); - cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); + cache_lookup_->SetTotalRepeats(num_repeats_); + cache_lookup_->SetNumEpochs(num_epochs_); } - node->set_total_repeats(num_repeats_); - node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); + node->SetTotalRepeats(num_repeats_); + node->SetNumEpochs(num_epochs_); cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used is_merge_ = false; return Status::OK(); } // Saves the lookup up in case it needs to be referenced by a repeat -Status RepeatPass::RunOnNode(std::shared_ptr node, bool *const modified) { +Status RepeatPass::VisitAfter(std::shared_ptr node, bool *const modified) { if (!node->IsLeaf()) { // By definition, the CacheLookup must be a leaf op. Make that clear here. RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!"); @@ -184,29 +178,30 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *const mo // into the pass so that the decision can be made during the processing of the cache miss leg of the merge. // Further, if there's a repeat above the merge but no repeat in the cache miss leg, then the merge op will // add the lookup to the eoe stack - cache_lookup_ = std::static_pointer_cast(node); + cache_lookup_ = std::static_pointer_cast(node); return Status::OK(); } +#endif -Status RepeatPass::RunOnNode(std::shared_ptr node, bool *const modified) { - // Set total repeats and total epochs for the DeviceQueueOp - node->set_total_repeats(num_epochs_); - node->set_num_repeats_per_epoch(1); +Status RepeatPass::VisitAfter(std::shared_ptr node, bool *const modified) { + // Set total repeats and total epochs for the TransferNode + node->SetTotalRepeats(num_epochs_); + node->SetNumEpochs(num_epochs_); return Status::OK(); } // Adds an operator to the cached operator stack save area -void RepeatPass::AddToCachedOpStack(std::shared_ptr dataset_op) { cached_op_stacks_.push(dataset_op); } +void RepeatPass::AddToCachedNodeStack(std::shared_ptr node) { cached_node_stacks_.push(node); } // Pops an operator from the cached operator stack save area -std::shared_ptr RepeatPass::PopFromCachedOpStack() { - std::shared_ptr top_op = nullptr; - if (!cached_op_stacks_.empty()) { - top_op = cached_op_stacks_.top(); - cached_op_stacks_.pop(); +std::shared_ptr RepeatPass::PopFromCachedNodeStack() { + std::shared_ptr top_node = nullptr; + if (!cached_node_stacks_.empty()) { + top_node = cached_node_stacks_.top(); + cached_node_stacks_.pop(); } - return top_op; + return top_node; } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h index 8eb1ce13c1..6c9f257bd0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,12 +25,11 @@ namespace mindspore { namespace dataset { -/// \class RepeatPass repeat_pass.h -/// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references -/// to the eoe-producing (typically leaf) nodes underneath it. -class RepeatPass : public NodePass { +/// \class RepeatPass +/// \brief This is a post pass that calculate the number of repeats the pipeline needs to fetch the data. +class RepeatPass : public IRNodePass { public: - using op_stack = std::stack>; + using op_stack = std::stack>; /// \brief Constructor RepeatPass(); @@ -40,93 +39,91 @@ class RepeatPass : public NodePass { /// \brief Identifies the subtree below this node as being in a repeated path of the tree. /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; + Status Visit(std::shared_ptr node, bool *const modified) override; /// \brief Identifies the subtree below this node as being in a repeated path of the tree. /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; + Status Visit(std::shared_ptr node, bool *const modified) override; +#ifndef ENABLE_ANDROID /// \brief Identifies the subtree below this node as being in a cache merge path /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; + Status Visit(std::shared_ptr node, bool *const modified) override; /// \brief Identifies the subtree below this node as being cached /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; + Status Visit(std::shared_ptr node, bool *const modified) override; +#endif /// \brief Hooks up any identified eoe nodes under this repeat. /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; + Status VisitAfter(std::shared_ptr node, bool *const modified) override; /// \brief Hooks up any identified eoe nodes under this repeat. /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; + Status VisitAfter(std::shared_ptr node, bool *const modified) override; - /// \brief CacheOp removes previous leaf ops and replaces them with itself +#ifndef ENABLE_ANDROID + /// \brief CacheNode removes previous leaf ops and replaces them with itself /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; + Status VisitAfter(std::shared_ptr node, bool *const modified) override; - /// \brief Turns of the tracking for operations under merge op + /// \brief Turns off the tracking for operations under merge op /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; + Status VisitAfter(std::shared_ptr node, bool *const modified) override; /// \brief Saves the lookup up in case it needs to be referenced by a repeat /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; + Status VisitAfter(std::shared_ptr node, bool *const modified) override; +#endif - /// \brief Set the epoch count for DeviceQueue + /// \brief Sets the epoch count for TransferNode /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Special case for GeneratorOp - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; + Status VisitAfter(std::shared_ptr node, bool *const modified) override; /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up /// for use with a controlling repeat above it. /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; + Status VisitAfter(std::shared_ptr node, bool *const modified) override; private: - /// \brief Adds an operator to the cached operator stack save area - /// \param op - The dataset op to work add to cached stack + /// \brief Adds an operator to the cached stack save area + /// \param node - The dataset node to add to cached stack /// \return Status The status code returned - void AddToCachedOpStack(std::shared_ptr dataset_op); - - /// \brief Pops an operator from the cached operator stack save area - /// \return shared_ptr to the popped operator - std::shared_ptr PopFromCachedOpStack(); - - bool is_merge_; // T/F if we are processing under a cache merge op - bool is_cached_; // T/F is we are processing under a cache op - int32_t num_repeats_; // A multiplier to the total number of repeats - int32_t num_epochs_; // To save the total number of epochs - op_stack cached_op_stacks_; // A save area for ops under a cache op - std::shared_ptr cache_lookup_; // A save area for a cache lookup op + void AddToCachedNodeStack(std::shared_ptr node); + + /// \brief Pops an operator from the cached stack save area + /// \return shared_ptr to the popped dataset node + std::shared_ptr PopFromCachedNodeStack(); + + bool is_merge_; // T/F if we are processing under a cache merge node + bool is_cached_; // T/F is we are processing under a cache node + int32_t num_repeats_; // A multiplier to the total number of repeats + int32_t num_epochs_; // To save the total number of epochs + op_stack cached_node_stacks_; // A save area for operators under a cache node + std::shared_ptr cache_lookup_; // A save area for a cache lookup node }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc deleted file mode 100644 index 2ab4a5ca21..0000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc +++ /dev/null @@ -1,189 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "minddata/dataset/engine/datasetops/cache_op.h" -#include "minddata/dataset/engine/datasetops/zip_op.h" -#include "minddata/dataset/engine/datasetops/map_op/map_op.h" -#include "minddata/dataset/engine/opt/pre/cache_error_pass.h" - -namespace mindspore { -namespace dataset { - -// Constructor -CacheErrorPass::CacheErrorPass() : is_cached_(false), is_mappable_(false) {} - -// Identifies the subtree below this node as being cached -Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *const modified) { - // Turn on the flag that we're under a merge op - is_cached_ = true; - return Status::OK(); -} - -// Returns an error if ZipOp exists under a cache -Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *const modified) { - if (is_cached_) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, - "ZipOp is currently not supported as a descendant operator under a cache."); - } - - return Status::OK(); -} - -// Returns an error if MapOp with non-deterministic TensorOps exists under a cache -Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *const modified) { - if (is_cached_) { - auto tfuncs = node->TFuncs(); - for (size_t i = 0; i < tfuncs.size(); i++) { - if (!tfuncs[i]->Deterministic()) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, - "MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache."); - } - } - } - return Status::OK(); -} - -// Returns an error if ConcatOp exists under a cache -Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *const modified) { - if (is_cached_) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, - "ConcatOp is currently not supported as a descendant operator under a cache."); - } - - return Status::OK(); -} - -// Returns an error if TakeOp exists under a cache -Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *const modified) { - if (is_cached_) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, - "TakeOp/SplitOp is currently not supported as a descendant operator under a cache."); - } - - return Status::OK(); -} - -// Returns an error if SkipOp exists under a cache -Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *const modified) { - if (is_cached_) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, - "SkipOp is currently not supported as a descendant operator under a cache."); - } - - return Status::OK(); -} - -// Returns an error if SkipOp exists under a cache -Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *const modified) { - if (is_cached_) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, - "BatchOp is currently not supported as a descendant operator under a cache."); - } - - return Status::OK(); -} - -#ifdef ENABLE_PYTHON -// Returns an error if FilterOp exists under a cache -Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *const modified) { - if (is_cached_) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, - "FilterOp is currently not supported as a descendant operator under a cache."); - } - - return Status::OK(); -} -#endif - -Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *const modified) { - // Turn on the flag that this is a tree with mappable leaf dataset - is_mappable_ = true; - return Status::OK(); -} - -Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *const modified) { - // Turn on the flag that this is a tree with mappable leaf dataset - is_mappable_ = true; - return Status::OK(); -} - -Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *const modified) { - // Turn on the flag that this is a tree with mappable leaf dataset - is_mappable_ = true; - return Status::OK(); -} - -Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *const modified) { - // Turn on the flag that this is a tree with mappable leaf dataset - is_mappable_ = true; - return Status::OK(); -} - -Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *const modified) { - // Turn on the flag that this is a tree with mappable leaf dataset - is_mappable_ = true; - return Status::OK(); -} - -Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *const modified) { - // Turn on the flag that this is a tree with mappable leaf dataset - is_mappable_ = true; - return Status::OK(); -} - -Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *const modified) { - // Turn on the flag that this is a tree with mappable leaf dataset - is_mappable_ = true; - return Status::OK(); -} - -Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *const modified) { - // Turn on the flag that this is a tree with mappable leaf dataset - is_mappable_ = true; - return Status::OK(); -} - -Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *const modified) { - // Turn on the flag that this is a tree with mappable leaf dataset - is_mappable_ = true; - return Status::OK(); -} - -Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *const modified) { - // Turn on the flag that this is a tree with mappable leaf dataset - is_mappable_ = true; - return Status::OK(); -} - -Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *const modified) { - // Turn off the flag that we're under a merge op - is_cached_ = false; - return Status::OK(); -} - -// Currently, returns an error if RepeatOp exists under a cache -// Because there is no operator in the cache hit stream to consume eoes, caching above repeat causes problem. -Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *const modified) { - if (is_cached_ && is_mappable_) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, - "Repeat is not supported as a descendant operator under a mappable cache."); - } - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h deleted file mode 100644 index a3a5d502ac..0000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h +++ /dev/null @@ -1,167 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_ - -#include -#include -#include -#include "minddata/dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -/// \class CacheErrorPass cache_error_pass.h -/// \brief This is a NodePass who's job is to catch invalid tree configurations related to cache and generate failures. -class CacheErrorPass : public NodePass { - public: - /// \brief Constructor - CacheErrorPass(); - - /// \brief Destructor - ~CacheErrorPass() = default; - - /// \brief Identifies the subtree below this node as being cached - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Returns an error if ZipOp exists under a cache - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Returns an error if ConcatOp exists under a cache - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Returns an error if TakeOp exists under a cache - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Returns an error if SkipOp exists under a cache - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Returns an error if SkipOp exists under a cache - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; - -#ifdef ENABLE_PYTHON - /// \brief Returns an error if FilterOp exists under a cache - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; -#endif - - /// \brief Identifies the leaf dataset as being mappable - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Identifies the leaf dataset as being mappable - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Identifies the leaf dataset as being mappable - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Identifies the leaf dataset as being mappable - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Identifies the leaf dataset as being mappable - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Identifies the leaf dataset as being mappable - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Identifies the leaf dataset as being mappable - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Identifies the leaf dataset as being mappable - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Identifies the leaf dataset as being mappable - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Identifies the leaf dataset as being mappable - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Identifies the subtree above this node as not being cached - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Identifies and block repeat under cache scenarios - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - private: - bool is_cached_; - bool is_mappable_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_POST_CACHE_ERROR_PASS_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.cc index 981334e0d0..5ae54a99a9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,7 +25,6 @@ #include "minddata/dataset/engine/ir/datasetops/skip_node.h" #include "minddata/dataset/engine/ir/datasetops/take_node.h" #include "minddata/dataset/engine/ir/datasetops/zip_node.h" -#include "minddata/dataset/include/transforms.h" namespace mindspore { namespace dataset { @@ -114,11 +113,18 @@ Status CacheValidationPass::Visit(std::shared_ptr node, bool *const mod } // If Map is created to be cached, set the flag indicating we found an operation with a cache. is_cached_ = true; + + // This is temporary code. + // Because the randomness of its tensor operations is not known in TensorOperation form until we convert them + // to TensorOp, we need to check the randomness in MapNode::Build(). + // By setting this MapNode is under a cache, we will check the randomness of its tensor operations without the need + // to walk the IR tree again. + node->Cached(); + auto tfuncs = node->TensorOperations(); for (size_t i = 0; i < tfuncs.size(); i++) { if (tfuncs[i]->IsRandomOp()) { - RETURN_STATUS_UNEXPECTED( - "MapNode with non-deterministic operations is not supported as a descendant of cache."); + RETURN_STATUS_UNEXPECTED("MapNode containing random operation is not supported as a descendant of cache."); } } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc deleted file mode 100644 index 2d279ea1fe..0000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" -#include "minddata/dataset/engine/execution_tree.h" -#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" -#include "minddata/dataset/engine/datasetops/device_queue_op.h" - -namespace mindspore { -namespace dataset { - -// constructor -EpochInjectionPass::InjectionFinder::InjectionFinder(std::shared_ptr node) : injection_point_(node) {} - -#ifndef ENABLE_ANDROID -// Performs finder work for BuildVocabOp that has special rules about epoch control injection -Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr node, bool *const modified) { - injection_point_ = nullptr; - return Status::OK(); -} - -// Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection -Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr node, - bool *const modified) { - injection_point_ = nullptr; - return Status::OK(); -} -#endif - -Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr node, bool *const modified) { - // Assumption: There is only one DeviceQueueOp in a pipeline. This assumption is not validated here. - injection_point_ = node->child(0); - return Status::OK(); -} - -// constructor -EpochInjectionPass::EpochInjectionPass() {} - -// Runs an injection pass to inject in operators needed at the pre pass stage -Status EpochInjectionPass::RunOnTree(ExecutionTree *tree, bool *const modified) { - MS_LOG(INFO) << "Pre pass: Injection pass started."; - - // First, run the finder to perform any injection info before we can go ahead to drive the op injection work. - // The finder can make updates to the EpochInjectionPass object. - EpochInjectionPass::InjectionFinder finder(tree->root()); - RETURN_IF_NOT_OK(finder.Run(tree, modified)); - - // The first injection logic is to check if we should inject the epoch control op as the root node. - // Do not inject the op if the number of epochs is 1. - int32_t num_epochs = tree->num_epochs(); - std::shared_ptr epoch_inject_node = finder.injection_point(); - if (num_epochs != 1 && epoch_inject_node != nullptr) { - std::shared_ptr epoch_ctrl_op; - RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op)); - RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op)); - RETURN_IF_NOT_OK(epoch_inject_node->InsertAsParent(epoch_ctrl_op)); - } - - MS_LOG(INFO) << "Pre pass: Injection pass complete."; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h deleted file mode 100644 index d5c3b281b2..0000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ -#define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ - -#include -#include -#include "minddata/dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -class DatasetOp; - -/// \class EpochInjectionPass epoch_injection_pass.h -/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api -/// parsing. -class EpochInjectionPass : public TreePass { - /// \class InjectionFinder - /// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for - /// operators that need to be injected. It is run first by the main injection pass to find out what operators - /// it may need to inject. - class InjectionFinder : public NodePass { - public: - /// \brief Constructor - explicit InjectionFinder(std::shared_ptr node); - - /// \brief Destructor - ~InjectionFinder() = default; - -#ifndef ENABLE_ANDROID - /// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection. - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection. - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; -#endif - - /// \brief Register the DeviceQueueOp for further action. - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Getter - std::shared_ptr injection_point() { return injection_point_; } - - private: - std::shared_ptr injection_point_; - }; - - public: - /// \brief Constructor - EpochInjectionPass(); - - /// \brief Destructor - ~EpochInjectionPass() = default; - - /// \brief Runs an injection pass to inject in operators needed at the pre pass stage - /// \param[inout] tree The tree to operate on. - /// \param[inout] Indicate of the tree was modified. - /// \return Status The status code returned - Status RunOnTree(ExecutionTree *tree, bool *const modified) override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.cc index d4019c3c17..3da5b15076 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ */ #include -#include #include "minddata/dataset/include/datasets.h" #include "minddata/dataset/engine/opt/pre/input_validation_pass.h" diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc deleted file mode 100644 index 688ad8adf1..0000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "minddata/dataset/engine/opt/pre/removal_pass.h" -#include "minddata/dataset/engine/datasetops/shuffle_op.h" -#include "minddata/dataset/engine/execution_tree.h" - -namespace mindspore { -namespace dataset { - -RemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {} - -#ifndef ENABLE_ANDROID -// Identifies the subtree below this node as a cached descendant tree. -Status RemovalPass::RemovalNodes::PreRunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; - is_caching_ = true; - return Status::OK(); -} - -// Resets the tracking of the cache within the tree -Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; - is_caching_ = false; - return Status::OK(); -} -#endif - -// Perform ShuffleOp removal check. -Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - // If we are in a cache descendant tree, then this shuffle op needs to be removed - if (is_caching_) { - MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; - nodes_to_remove_.push_back(std::static_pointer_cast(node)); - } - return Status::OK(); -} - -// constructor -RemovalPass::RemovalPass() {} - -// Walk the tree to collect the nodes to remove, then removes them. -Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *const modified) { - MS_LOG(INFO) << "Pre pass: removal pass started."; - // Create the removal node pass which can identify which nodes need to be removed. - std::unique_ptr removal_nodes = std::make_unique(); - RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); - - // Then, execute the removal of any nodes that were set up for removal - for (auto node : removal_nodes->nodes_to_remove()) { - RETURN_IF_NOT_OK(node->Remove()); - } - MS_LOG(INFO) << "Pre pass: removal pass complete."; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h deleted file mode 100644 index 849ee23174..0000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h +++ /dev/null @@ -1,90 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ - -#include -#include -#include "minddata/dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -class DatasetOp; - -/// \class RemovalPass removal_pass.h -/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which -/// nodes should be removed, and then removes them. -class RemovalPass : public TreePass { - /// \class RemovalNodes - /// \brief This is a NodePass who's job is to identify which nodes should be removed. - /// It works in conjunction with the removal_pass. - class RemovalNodes : public NodePass { - public: - /// \brief Constructor - /// \param[in] removal_pass Raw pointer back to controlling tree pass - RemovalNodes(); - - /// \brief Destructor - ~RemovalNodes() = default; - -#ifndef ENABLE_ANDROID - /// \brief Identifies the subtree below this node as a cached descendant tree. - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Resets the tracking of the cache within the tree - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; -#endif - - /// \brief Perform ShuffleOp removal check - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Getter - /// \return All the nodes to be removed - std::vector> nodes_to_remove() { return nodes_to_remove_; } - - private: - bool is_caching_; - std::vector> nodes_to_remove_; - }; - - public: - /// \brief Constructor - RemovalPass(); - - /// \brief Destructor - ~RemovalPass() = default; - - /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. - /// \param[inout] tree The tree to operate on. - /// \param[inout] Indicate of the tree was modified. - /// \return Status The status code returned - Status RunOnTree(ExecutionTree *tree, bool *const modified) override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.cc deleted file mode 100644 index 7ff2c39794..0000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.cc +++ /dev/null @@ -1,121 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "minddata/dataset/engine/opt/util/printer_pass.h" - -namespace mindspore { -namespace dataset { - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting DatasetOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting BatchOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting MapOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting ProjectOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting RenameOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting SkipOp" << '\n'; - return Status::OK(); -} -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting ShuffleOp" << '\n'; - return Status::OK(); -} -#ifndef ENABLE_ANDROID -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting MindRecordOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting TFReaderOp" << '\n'; - return Status::OK(); -} -#endif - -#ifdef ENABLE_PYTHON -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting FilterOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting GeneratorOp" << '\n'; - return Status::OK(); -} -#endif - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting TakeOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting ZipOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting DeviceQueueOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting ImageFolderOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *const modified) { - *modified = false; - std::cout << "Visiting ImageFolderOp" << '\n'; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.h deleted file mode 100644 index 62463a694e..0000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H - -#include -#include "minddata/dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -class PrinterPass : public NodePass { - public: - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - -#ifndef ENABLE_ANDROID - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - Status RunOnNode(std::shared_ptr node, bool *const modified) override; -#endif - -#ifdef ENABLE_PYTHON - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - Status RunOnNode(std::shared_ptr node, bool *const modified) override; -#endif - - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - Status RunOnNode(std::shared_ptr node, bool *const modified) override; -}; - -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc index a2726d2c61..bc0adc221f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ #include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" #include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/opt/post/auto_worker_pass.h" +#include "minddata/dataset/engine/opt/post/repeat_pass.h" #ifdef ENABLE_PYTHON #include "minddata/dataset/engine/opt/post/generator_node_pass.h" #endif @@ -94,6 +95,7 @@ Status TreeAdapter::PostPass(std::shared_ptr ir) { #ifdef ENABLE_PYTHON actions.emplace_back(std::make_unique()); #endif + actions.emplace_back(std::make_unique()); // We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here. @@ -133,7 +135,7 @@ Status TreeAdapter::BuildExecutionTreeRecur(std::shared_ptr ir, std return Status::OK(); } -Status TreeAdapter::Build(std::shared_ptr root_ir, int32_t num_epochs) { +Status TreeAdapter::Build(std::shared_ptr root_ir) { // This will evolve in the long run tree_ = std::make_unique(); // disable profiling if this is only a getter pass @@ -146,7 +148,7 @@ Status TreeAdapter::Build(std::shared_ptr root_ir, int32_t num_epoc // Note: We will gradually move the pre pass, optimizer pass, and post pass // on ExecutionTree to perform on IR tree. // Prepare the tree - RETURN_IF_NOT_OK(tree_->Prepare(num_epochs, true)); + RETURN_IF_NOT_OK(tree_->Prepare()); // After the tree is prepared, the col_name_id_map can safely be obtained column_name_map_ = tree_->root()->column_name_id_map(); @@ -192,7 +194,7 @@ Status TreeAdapter::Compile(std::shared_ptr input_ir, int32_t num_e // Remember the root node root_ir_ = root_ir; - RETURN_IF_NOT_OK(Build(root_ir_, num_epochs)); + RETURN_IF_NOT_OK(Build(root_ir_)); tree_state_ = kCompileStateReady; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h index 71644f0f99..b2d46550d9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -83,7 +83,7 @@ class TreeAdapter { Status PostPass(std::shared_ptr ir); // Build an Execution tree - Status Build(std::shared_ptr root_ir, int32_t num_epochs); + Status Build(std::shared_ptr root_ir); // This RECURSIVE function walks the (optimized) IR tree in DFS to build its corresponding Execution tree. Status BuildExecutionTreeRecur(std::shared_ptr ir, std::shared_ptr *op); diff --git a/tests/ut/cpp/dataset/album_op_test.cc b/tests/ut/cpp/dataset/album_op_test.cc index 49af1a63d0..d347dcb3f6 100644 --- a/tests/ut/cpp/dataset/album_op_test.cc +++ b/tests/ut/cpp/dataset/album_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,22 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include #include #include #include "common/common.h" #include "minddata/dataset/core/client.h" -#include "minddata/dataset/core/global_context.h" #include "minddata/dataset/engine/datasetops/source/album_op.h" -#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "minddata/dataset/util/path.h" #include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" @@ -89,7 +79,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchema) { std::string folder_path = datasets_root_path_ + "/testAlbum/images"; std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json"; std::vector column_names = {"image", "label", "id"}; - auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file, column_names, false), Repeat(2)}); + auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file, column_names, false); + auto op2 = Repeat(2); + op1->set_total_repeats(2); + op1->set_num_repeats_per_epoch(2); + auto tree = Build({op1, op2}); ASSERT_OK(tree->Prepare()); ASSERT_OK(tree->Launch()); DatasetIterator di(tree); @@ -111,7 +105,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchema) { TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchemaNoOrder) { std::string folder_path = datasets_root_path_ + "/testAlbum/images"; std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json"; - auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)}); + auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file); + auto op2 = Repeat(2); + op1->set_total_repeats(2); + op1->set_num_repeats_per_epoch(2); + auto tree = Build({op1, op2}); ASSERT_OK(tree->Prepare()); ASSERT_OK(tree->Launch()); DatasetIterator di(tree); @@ -134,7 +132,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchemaFloat) { std::string folder_path = datasets_root_path_ + "/testAlbum/images"; // add the priority column std::string schema_file = datasets_root_path_ + "/testAlbum/floatSchema.json"; - auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)}); + auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file); + auto op2 = Repeat(2); + op1->set_total_repeats(2); + op1->set_num_repeats_per_epoch(2); + auto tree = Build({op1, op2}); tree->Prepare(); ASSERT_OK(tree->Launch()); DatasetIterator di(tree); @@ -159,7 +161,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithFullSchema) { std::string folder_path = datasets_root_path_ + "/testAlbum/images"; // add the priority column std::string schema_file = datasets_root_path_ + "/testAlbum/fullSchema.json"; - auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)}); + auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file); + auto op2 = Repeat(2); + op1->set_total_repeats(2); + op1->set_num_repeats_per_epoch(2); + auto tree = Build({op1, op2}); ASSERT_OK(tree->Prepare()); ASSERT_OK(tree->Launch()); DatasetIterator di(tree); diff --git a/tests/ut/cpp/dataset/batch_op_test.cc b/tests/ut/cpp/dataset/batch_op_test.cc index 69090ec6b6..2b65e8b3a8 100644 --- a/tests/ut/cpp/dataset/batch_op_test.cc +++ b/tests/ut/cpp/dataset/batch_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,14 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include #include #include #include "minddata/dataset/core/client.h" #include "common/common.h" -#include "utils/ms_utils.h" #include "gtest/gtest.h" -#include "minddata/dataset/core/global_context.h" #include "utils/log_adapter.h" #include "securec.h" #include "minddata/dataset/util/status.h" @@ -112,7 +109,12 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatch) { TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) { std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data"; bool success = false; - auto tree = Build({TFReader(schema_file), Repeat(2), Batch(7, true, 99)}); + auto op1 = TFReader(schema_file); + auto op2 = Repeat(2); + auto op3 = Batch(7, true, 99); + op1->set_total_repeats(2); + op1->set_num_repeats_per_epoch(2); + auto tree = Build({op1, op2, op3}); tree->Prepare(); Status rc = tree->Launch(); if (rc.IsError()) { @@ -157,7 +159,12 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) { TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) { std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data"; bool success = false; - auto tree = Build({TFReader(schema_file), Repeat(2), Batch(7, false, 99)}); + auto op1 = TFReader(schema_file); + auto op2 = Repeat(2); + auto op3 = Batch(7, false, 99); + op1->set_total_repeats(2); + op1->set_num_repeats_per_epoch(2); + auto tree = Build({op1, op2, op3}); tree->Prepare(); Status rc = tree->Launch(); if (rc.IsError()) { @@ -209,7 +216,14 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) { TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) { std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data"; bool success = false; - auto tree = Build({TFReader(schema_file), Batch(7, false, 99), Repeat(2)}); + auto op1 = TFReader(schema_file); + auto op2 = Batch(7, false, 99); + auto op3 = Repeat(2); + op1->set_total_repeats(2); + op1->set_num_repeats_per_epoch(2); + op2->set_total_repeats(2); + op2->set_num_repeats_per_epoch(2); + auto tree = Build({op1, op2, op3}); tree->Prepare(); Status rc = tree->Launch(); if (rc.IsError()) { @@ -255,7 +269,14 @@ TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) { TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) { std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data"; bool success = false; - auto tree = Build({TFReader(schema_file), Batch(5, true, 99), Repeat(2)}); + auto op1 = TFReader(schema_file); + auto op2 = Batch(5, true, 99); + auto op3 = Repeat(2); + op1->set_total_repeats(2); + op1->set_num_repeats_per_epoch(2); + op2->set_total_repeats(2); + op2->set_num_repeats_per_epoch(2); + auto tree = Build({op1, op2, op3}); tree->Prepare(); Status rc = tree->Launch(); if (rc.IsError()) { diff --git a/tests/ut/cpp/dataset/cache_op_test.cc b/tests/ut/cpp/dataset/cache_op_test.cc index 5f01261ad9..a85e6a6c33 100644 --- a/tests/ut/cpp/dataset/cache_op_test.cc +++ b/tests/ut/cpp/dataset/cache_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -293,15 +293,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { ASSERT_TRUE(rc.IsOk()); // Assign tree relations and root + myCacheOp->set_total_repeats(numRepeats); + myCacheOp->set_num_repeats_per_epoch(numRepeats); rc = myRepeatOp->AddChild(myCacheOp); ASSERT_TRUE(rc.IsOk()); + // Always set to 1 under a CacheOp because we read from it only once. The CacheOp is the one that repeats. + myRandomDataOp->set_total_repeats(1); + myRandomDataOp->set_num_repeats_per_epoch(1); rc = myCacheOp->AddChild(myRandomDataOp); ASSERT_TRUE(rc.IsOk()); rc = myTree->AssignRoot(myRepeatOp); ASSERT_TRUE(rc.IsOk()); MS_LOG(INFO) << "Launching tree and begin iteration"; - rc = myTree->Prepare(1); + rc = myTree->Prepare(); ASSERT_TRUE(rc.IsOk()); // quick check to see what tree looks like @@ -412,15 +417,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { ASSERT_TRUE(rc.IsOk()); // Assign tree relations and root + myCacheOp->set_total_repeats(numRepeats); + myCacheOp->set_num_repeats_per_epoch(numRepeats); rc = myRepeatOp->AddChild(myCacheOp); ASSERT_TRUE(rc.IsOk()); + // Always set to 1 under a CacheOp because we read from it only once. The CacheOp is the one that repeats. + myRandomDataOp->set_total_repeats(1); + myRandomDataOp->set_num_repeats_per_epoch(1); rc = myCacheOp->AddChild(myRandomDataOp); ASSERT_TRUE(rc.IsOk()); rc = myTree->AssignRoot(myRepeatOp); ASSERT_TRUE(rc.IsOk()); MS_LOG(INFO) << "Launching tree and begin iteration"; - rc = myTree->Prepare(1); + rc = myTree->Prepare(); ASSERT_TRUE(rc.IsOk()); std::cout << *myClient << std::endl; @@ -502,14 +512,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { rc = myTree->AssignRoot(myRepeatOp); ASSERT_TRUE(rc.IsOk()); + myMergeOp->set_total_repeats(numRepeats); + myMergeOp->set_num_repeats_per_epoch(numRepeats); rc = myRepeatOp->AddChild(myMergeOp); ASSERT_TRUE(rc.IsOk()); + myLookupOp->set_total_repeats(numRepeats); + myLookupOp->set_num_repeats_per_epoch(numRepeats); rc = myMergeOp->AddChild(myLookupOp); ASSERT_TRUE(rc.IsOk()); + so->set_total_repeats(numRepeats); + so->set_num_repeats_per_epoch(numRepeats); rc = myMergeOp->AddChild(so); ASSERT_TRUE(rc.IsOk()); - rc = myTree->Prepare(1); + rc = myTree->Prepare(); ASSERT_TRUE(rc.IsOk()); rc = myTree->Launch(); ASSERT_TRUE(rc.IsOk()); diff --git a/tests/ut/cpp/dataset/celeba_op_test.cc b/tests/ut/cpp/dataset/celeba_op_test.cc index 988ad82f74..202a6a8c95 100644 --- a/tests/ut/cpp/dataset/celeba_op_test.cc +++ b/tests/ut/cpp/dataset/celeba_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,14 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include #include #include #include "common/common.h" #include "minddata/dataset/core/client.h" -#include "minddata/dataset/core/global_context.h" #include "minddata/dataset/engine/datasetops/source/celeba_op.h" #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" #include "minddata/dataset/util/status.h" @@ -98,7 +95,11 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) { {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, {0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}}; uint32_t count = 0; - auto tree = Build({Celeba(16, 2, 32, dir), Repeat(2)}); + auto op1 = Celeba(16, 2, 32, dir); + auto op2 = Repeat(2); + auto tree = Build({op1, op2}); + op1->set_total_repeats(2); + op1->set_num_repeats_per_epoch(2); tree->Prepare(); Status rc = tree->Launch(); if (rc.IsError()) { diff --git a/tests/ut/cpp/dataset/cifar_op_test.cc b/tests/ut/cpp/dataset/cifar_op_test.cc index dc26e8d64f..caa77a3aed 100644 --- a/tests/ut/cpp/dataset/cifar_op_test.cc +++ b/tests/ut/cpp/dataset/cifar_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,8 +39,6 @@ using mindspore::MsLogLevel::ERROR; using mindspore::ExceptionType::NoExceptionType; using mindspore::LogStream; -std::shared_ptr Repeat(int repeatCnt); - std::shared_ptr Build(std::vector> ops); std::shared_ptr Cifarop(uint64_t num_works, uint64_t rows, uint64_t conns, std::string path, diff --git a/tests/ut/cpp/dataset/coco_op_test.cc b/tests/ut/cpp/dataset/coco_op_test.cc index 3c786e6f81..4b2f1c7bcd 100644 --- a/tests/ut/cpp/dataset/coco_op_test.cc +++ b/tests/ut/cpp/dataset/coco_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,8 +45,6 @@ using mindspore::LogStream; std::shared_ptr Batch(int batch_size = 1, bool drop = false, int rows_per_buf = 2); -std::shared_ptr Repeat(int repeat_cnt); - std::shared_ptr Build(std::vector> ops); class MindDataTestCocoOp : public UT::DatasetOpTesting { @@ -261,4 +259,4 @@ TEST_F(MindDataTestCocoOp, TestCocoPanoptic) { } ASSERT_EQ(row_count, 2); -} \ No newline at end of file +} diff --git a/tests/ut/cpp/dataset/image_folder_op_test.cc b/tests/ut/cpp/dataset/image_folder_op_test.cc index 109536ebc6..8383563440 100644 --- a/tests/ut/cpp/dataset/image_folder_op_test.cc +++ b/tests/ut/cpp/dataset/image_folder_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include #include #include #include @@ -29,7 +28,6 @@ #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "minddata/dataset/util/path.h" #include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" @@ -82,7 +80,11 @@ class MindDataTestImageFolderSampler : public UT::DatasetOpTesting { TEST_F(MindDataTestImageFolderSampler, TestSequentialImageFolderWithRepeat) { std::string folder_path = datasets_root_path_ + "/testPK/data"; - auto tree = Build({ImageFolder(16, 2, 32, folder_path, false), Repeat(2)}); + auto op1 = ImageFolder(16, 2, 32, folder_path, false); + auto op2 = Repeat(2); + op1->set_total_repeats(2); + op1->set_num_repeats_per_epoch(2); + auto tree = Build({op1, op2}); tree->Prepare(); int32_t res[] = {0, 1, 2, 3}; Status rc = tree->Launch(); @@ -166,7 +168,12 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) { TEST_F(MindDataTestImageFolderSampler, TestSequentialImageFolderWithRepeatBatch) { std::string folder_path = datasets_root_path_ + "/testPK/data"; - auto tree = Build({ImageFolder(16, 2, 32, folder_path, false), Repeat(2), Batch(11)}); + auto op1 = ImageFolder(16, 2, 32, folder_path, false); + auto op2 = Repeat(2); + auto op3 = Batch(11); + op1->set_total_repeats(2); + op1->set_num_repeats_per_epoch(2); + auto tree = Build({op1, op2, op3}); tree->Prepare(); int32_t res[4][11] = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, @@ -297,7 +304,11 @@ TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) { int64_t num_samples = 0; std::shared_ptr sampler = std::make_shared(num_samples, 11, 10, false); std::string folder_path = datasets_root_path_ + "/testPK/data"; - auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler)), Repeat(4)}); + auto op1 = ImageFolder(16, 2, 32, folder_path, false, std::move(sampler)); + auto op2 = Repeat(4); + op1->set_total_repeats(4); + op1->set_num_repeats_per_epoch(4); + auto tree = Build({op1, op2}); tree->Prepare(); Status rc = tree->Launch(); if (rc.IsError()) { diff --git a/tests/ut/cpp/dataset/ir_callback_test.cc b/tests/ut/cpp/dataset/ir_callback_test.cc index 69359518f6..b4ee54b5f9 100644 --- a/tests/ut/cpp/dataset/ir_callback_test.cc +++ b/tests/ut/cpp/dataset/ir_callback_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #include "common/common.h" #include "minddata/dataset/callback/ds_callback.h" #include "minddata/dataset/core/client.h" +#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" #include "minddata/dataset/engine/datasetops/source/random_data_op.h" #include "minddata/dataset/engine/tree_adapter.h" #include "minddata/dataset/include/datasets.h" @@ -166,6 +167,10 @@ TEST_F(MindDataTestCallback, TestBasicCallback) { std::shared_ptr repeat_op; rc = RepeatOp::Builder(2).Build(&repeat_op); // start build then launch tree + leaf->set_total_repeats(2); + leaf->set_num_repeats_per_epoch(2); + map_op->set_total_repeats(2); + map_op->set_num_repeats_per_epoch(2); std::shared_ptr tree = test::BuildTree({leaf, map_op, repeat_op}); rc = tree->Prepare(); EXPECT_TRUE(rc.IsOk()); @@ -213,8 +218,15 @@ TEST_F(MindDataTestCallback, TestMultiEpochCallback) { // config RepeatOp std::shared_ptr repeat_op; rc = RepeatOp::Builder(2).Build(&repeat_op); + // config EpochCtrlOp + std::shared_ptr epoch_ctrl_op; + rc = EpochCtrlOp::Builder(-1).Build(&epoch_ctrl_op); // start build then launch tree - std::shared_ptr tree = test::BuildTree({leaf, map_op, repeat_op}); + leaf->set_total_repeats(-2); + leaf->set_num_repeats_per_epoch(2); + map_op->set_total_repeats(-2); + map_op->set_num_repeats_per_epoch(2); + std::shared_ptr tree = test::BuildTree({leaf, map_op, repeat_op, epoch_ctrl_op}); rc = tree->Prepare(); EXPECT_TRUE(rc.IsOk()); rc = tree->Launch(); @@ -271,8 +283,15 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) { // config RepeatOp std::shared_ptr repeat_op; rc = RepeatOp::Builder(2).Build(&repeat_op); + // config EpochCtrlOp + std::shared_ptr epoch_ctrl_op; + rc = EpochCtrlOp::Builder(-1).Build(&epoch_ctrl_op); // start build then launch tree - std::shared_ptr tree = test::BuildTree({leaf, map_op, repeat_op}); + leaf->set_total_repeats(-2); + leaf->set_num_repeats_per_epoch(2); + map_op->set_total_repeats(-2); + map_op->set_num_repeats_per_epoch(2); + std::shared_ptr tree = test::BuildTree({leaf, map_op, repeat_op, epoch_ctrl_op}); rc = tree->Prepare(); EXPECT_TRUE(rc.IsOk()); rc = tree->Launch(); diff --git a/tests/ut/cpp/dataset/manifest_op_test.cc b/tests/ut/cpp/dataset/manifest_op_test.cc index af84b2cbcd..dbc45345ab 100644 --- a/tests/ut/cpp/dataset/manifest_op_test.cc +++ b/tests/ut/cpp/dataset/manifest_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,7 +58,11 @@ class MindDataTestManifest : public UT::DatasetOpTesting { TEST_F(MindDataTestManifest, TestSequentialManifestWithRepeat) { std::string file = datasets_root_path_ + "/testManifestData/cpp.json"; - auto tree = Build({Manifest(16, 2, 32, file), Repeat(2)}); + auto op1 = Manifest(16, 2, 32, file); + auto op2 = Repeat(2); + op1->set_total_repeats(2); + op1->set_num_repeats_per_epoch(2); + auto tree = Build({op1, op2}); tree->Prepare(); uint32_t res[] = {0, 1, 0, 1}; Status rc = tree->Launch(); @@ -148,7 +152,11 @@ TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) { int64_t num_samples = 1; int64_t start_index = 0; auto seq_sampler = std::make_shared(num_samples, start_index); - auto tree = Build({Manifest(16, 2, 32, file, "train", std::move(seq_sampler), {}), Repeat(4)}); + auto op1 = Manifest(16, 2, 32, file, "train", std::move(seq_sampler), {}); + auto op2 = Repeat(4); + op1->set_total_repeats(4); + op1->set_num_repeats_per_epoch(4); + auto tree = Build({op1, op2}); tree->Prepare(); Status rc = tree->Launch(); if (rc.IsError()) { diff --git a/tests/ut/cpp/dataset/map_op_test.cc b/tests/ut/cpp/dataset/map_op_test.cc index aa1bd29add..64dd8bc15a 100644 --- a/tests/ut/cpp/dataset/map_op_test.cc +++ b/tests/ut/cpp/dataset/map_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ #include #include - #include "common/common.h" #include "minddata/dataset/core/client.h" #include "minddata/dataset/core/tensor.h" @@ -416,6 +415,8 @@ TEST_F(MindDataTestMapOp, TestTFReaderRepeatMap) { rc = my_map_op->AddChild(my_repeat_op); EXPECT_TRUE(rc.IsOk()); + my_tfreader_op->set_total_repeats(num_repeats); + my_tfreader_op->set_num_repeats_per_epoch(num_repeats); rc = my_repeat_op->AddChild(my_tfreader_op); EXPECT_TRUE(rc.IsOk()); @@ -471,9 +472,13 @@ TEST_F(MindDataTestMapOp, TestTFReaderMapRepeat) { rc = my_tree_->AssociateNode(my_map_op); EXPECT_TRUE(rc.IsOk()); + my_map_op->set_total_repeats(num_repeats); + my_map_op->set_num_repeats_per_epoch(num_repeats); rc = my_repeat_op->AddChild(my_map_op); EXPECT_TRUE(rc.IsOk()); + my_tfreader_op->set_total_repeats(num_repeats); + my_tfreader_op->set_num_repeats_per_epoch(num_repeats); rc = my_map_op->AddChild(my_tfreader_op); EXPECT_TRUE(rc.IsOk()); @@ -548,9 +553,13 @@ TEST_F(MindDataTestMapOp, TFReader_Decode_Repeat_Resize) { rc = my_tree_->AssociateNode(my_map_resize_op); EXPECT_TRUE(rc.IsOk()); + my_tfreader_op->set_total_repeats(num_repeats); + my_tfreader_op->set_num_repeats_per_epoch(num_repeats); rc = my_map_decode_op->AddChild(my_tfreader_op); EXPECT_TRUE(rc.IsOk()); + my_map_decode_op->set_total_repeats(num_repeats); + my_map_decode_op->set_num_repeats_per_epoch(num_repeats); rc = my_repeat_op->AddChild(my_map_decode_op); EXPECT_TRUE(rc.IsOk()); @@ -611,7 +620,12 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) { rc = map_resize_builder.Build(&map_resize_op); EXPECT_TRUE(rc.IsOk()); - my_tree_ = Build({ImageFolder(16, 2, 32, folder_path, false), map_decode_map, repeat_op, map_resize_op}); + auto image_folder_op = ImageFolder(16, 2, 32, folder_path, false); + image_folder_op->set_total_repeats(num_repeats); + image_folder_op->set_num_repeats_per_epoch(num_repeats); + map_decode_map->set_total_repeats(num_repeats); + map_decode_map->set_num_repeats_per_epoch(num_repeats); + my_tree_ = Build({image_folder_op, map_decode_map, repeat_op, map_resize_op}); rc = my_tree_->Prepare(); EXPECT_TRUE(rc.IsOk()); rc = my_tree_->Launch(); @@ -656,7 +670,12 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) { rc = map_resize_builder.Build(&map_resize_op); EXPECT_TRUE(rc.IsOk()); - auto my_tree_2 = Build({ImageFolder(16, 2, 32, folder_path, false), map_decode_map, repeat_op, map_resize_op}); + image_folder_op = ImageFolder(16, 2, 32, folder_path, false); + image_folder_op->set_total_repeats(num_repeats); + image_folder_op->set_num_repeats_per_epoch(num_repeats); + map_decode_map->set_total_repeats(num_repeats); + map_decode_map->set_num_repeats_per_epoch(num_repeats); + auto my_tree_2 = Build({image_folder_op, map_decode_map, repeat_op, map_resize_op}); rc = my_tree_2->Prepare(); EXPECT_TRUE(rc.IsOk()); @@ -714,7 +733,12 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize_NoInputColumns) { rc = map_resize_builder.Build(&map_resize_op); EXPECT_TRUE(rc.IsOk()); - my_tree_ = Build({ImageFolder(16, 2, 32, folder_path, false), map_decode_map, repeat_op, map_resize_op}); + auto image_folder_op = ImageFolder(16, 2, 32, folder_path, false); + image_folder_op->set_total_repeats(num_repeats); + image_folder_op->set_num_repeats_per_epoch(num_repeats); + map_decode_map->set_total_repeats(num_repeats); + map_decode_map->set_num_repeats_per_epoch(num_repeats); + my_tree_ = Build({image_folder_op, map_decode_map, repeat_op, map_resize_op}); rc = my_tree_->Prepare(); EXPECT_TRUE(rc.IsOk()); rc = my_tree_->Launch(); diff --git a/tests/ut/cpp/dataset/mind_record_op_test.cc b/tests/ut/cpp/dataset/mind_record_op_test.cc index bed97a740d..81da99453f 100644 --- a/tests/ut/cpp/dataset/mind_record_op_test.cc +++ b/tests/ut/cpp/dataset/mind_record_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -370,9 +370,12 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordRepeat) { rc = my_tree->AssociateNode(my_repeat_op); EXPECT_TRUE(rc.IsOk()); + my_mindrecord_op->set_total_repeats(num_repeats); + my_mindrecord_op->set_num_repeats_per_epoch(num_repeats); rc = my_repeat_op->AddChild(my_mindrecord_op); EXPECT_TRUE(rc.IsOk()); + // Set children/root layout. rc = my_tree->AssignRoot(my_repeat_op); EXPECT_TRUE(rc.IsOk()); @@ -452,6 +455,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) { rc = my_tree->AssociateNode(my_repeat_op); EXPECT_TRUE(rc.IsOk()); + my_mindrecord_op->set_total_repeats(num_repeats); + my_mindrecord_op->set_num_repeats_per_epoch(num_repeats); rc = my_repeat_op->AddChild(my_mindrecord_op); EXPECT_TRUE(rc.IsOk()); diff --git a/tests/ut/cpp/dataset/mnist_op_test.cc b/tests/ut/cpp/dataset/mnist_op_test.cc index a6b03c288b..814bc687f5 100644 --- a/tests/ut/cpp/dataset/mnist_op_test.cc +++ b/tests/ut/cpp/dataset/mnist_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -78,7 +78,11 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) { int64_t num_samples = 10; int64_t start_index = 0; auto seq_sampler = std::make_shared(num_samples, start_index); - auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2)}); + auto op1 = CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)); + auto op2 = Repeat(2); + op1->set_total_repeats(2); + op1->set_num_repeats_per_epoch(2); + auto tree = Build({op1, op2}); tree->Prepare(); uint32_t res[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; Status rc = tree->Launch(); @@ -108,7 +112,12 @@ TEST_F(MindDataTestMnistSampler, TestSequentialImageFolderWithRepeatBatch) { int64_t num_samples = 10; int64_t start_index = 0; auto seq_sampler = std::make_shared(num_samples, start_index); - auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2), Batch(5)}); + auto op1 = CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)); + auto op2 = Repeat(2); + auto op3 = Batch(5); + op1->set_total_repeats(2); + op1->set_num_repeats_per_epoch(2); + auto tree = Build({op1, op2, op3}); tree->Prepare(); uint32_t res[4][5] = { {0, 0, 0, 0, 0 }, {0, 0, 0, 0, 0 }, diff --git a/tests/ut/cpp/dataset/random_data_op_test.cc b/tests/ut/cpp/dataset/random_data_op_test.cc index 3cb7b57ad6..ac1a5013fe 100644 --- a/tests/ut/cpp/dataset/random_data_op_test.cc +++ b/tests/ut/cpp/dataset/random_data_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ class MindDataTestRandomDataOp : public UT::DatasetOpTesting { // Test info: // - Simple test with a user-provided schema generated purely from DataSchema C API -// - has an interation loop +// - has an interaction loop // // Tree: single node tree with RandomDataOp // @@ -213,7 +213,7 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic3) { // Test info: // - json schema input it's a fairly simple one -// - has an interation loop +// - has an interaction loop // // Tree: RepeatOp over RandomDataOp // @@ -253,6 +253,8 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic4) { rc = myTree->AssociateNode(myRepeatOp); EXPECT_TRUE(rc.IsOk()); + myRandomDataOp->set_total_repeats(numRepeats); + myRandomDataOp->set_num_repeats_per_epoch(numRepeats); rc = myRepeatOp->AddChild(myRandomDataOp); EXPECT_TRUE(rc.IsOk()); @@ -290,7 +292,7 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic4) { // Test info: // - json schema input it's a fairly simple one -// - has an interation loop +// - has an interaction loop // - same as MindDataTestRandomDataOpBasic4 except that this one will have parallel workers // // Tree: RepeatOp over RandomDataOp @@ -331,6 +333,8 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic5) { rc = myTree->AssociateNode(myRepeatOp); EXPECT_TRUE(rc.IsOk()); + myRandomDataOp->set_total_repeats(numRepeats); + myRandomDataOp->set_num_repeats_per_epoch(numRepeats); rc = myRepeatOp->AddChild(myRandomDataOp); EXPECT_TRUE(rc.IsOk()); @@ -418,9 +422,13 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpTree1) { rc = myTree->AssociateNode(myRepeatOp); EXPECT_TRUE(rc.IsOk()); + myShuffleOp->set_total_repeats(numRepeats); + myShuffleOp->set_num_repeats_per_epoch(numRepeats); rc = myRepeatOp->AddChild(myShuffleOp); EXPECT_TRUE(rc.IsOk()); - + + myRandomDataOp->set_total_repeats(numRepeats); + myRandomDataOp->set_num_repeats_per_epoch(numRepeats); rc = myShuffleOp->AddChild(myRandomDataOp); EXPECT_TRUE(rc.IsOk()); diff --git a/tests/ut/cpp/dataset/sentence_piece_vocab_op_test.cc b/tests/ut/cpp/dataset/sentence_piece_vocab_op_test.cc index 19f7291079..6db89eb7b7 100644 --- a/tests/ut/cpp/dataset/sentence_piece_vocab_op_test.cc +++ b/tests/ut/cpp/dataset/sentence_piece_vocab_op_test.cc @@ -75,6 +75,8 @@ TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceFromDatasetFuntions) { rc = spv_op->AddChild(file_op); ASSERT_TRUE(rc.IsOk()); + file_op->set_total_repeats(1); + file_op->set_num_repeats_per_epoch(1); rc = tree->AssignRoot(spv_op); ASSERT_TRUE(rc.IsOk()); rc = tree->Prepare(); @@ -147,6 +149,8 @@ TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceTokenizerFuntions) { rc = spv_op->AddChild(file_op); ASSERT_TRUE(rc.IsOk()); + file_op->set_total_repeats(1); + file_op->set_num_repeats_per_epoch(1); rc = tree->AssignRoot(spv_op); ASSERT_TRUE(rc.IsOk()); rc = tree->Prepare(); diff --git a/tests/ut/cpp/dataset/shuffle_op_test.cc b/tests/ut/cpp/dataset/shuffle_op_test.cc index 45d2d7f608..6bde46a90e 100644 --- a/tests/ut/cpp/dataset/shuffle_op_test.cc +++ b/tests/ut/cpp/dataset/shuffle_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -300,8 +300,12 @@ TEST_F(MindDataTestShuffleOp, TestRepeatShuffle) { EXPECT_TRUE(rc.IsOk()); // Set children/root layout. + my_shuffle_op->set_total_repeats(numRepeats); + my_shuffle_op->set_num_repeats_per_epoch(numRepeats); rc = my_repeat_op->AddChild(my_shuffle_op); EXPECT_TRUE(rc.IsOk()); + my_tfreader_op->set_total_repeats(numRepeats); + my_tfreader_op->set_num_repeats_per_epoch(numRepeats); rc = my_shuffle_op->AddChild(my_tfreader_op); EXPECT_TRUE(rc.IsOk()); rc = my_tree->AssignRoot(my_repeat_op); diff --git a/tests/ut/cpp/dataset/tfReader_op_test.cc b/tests/ut/cpp/dataset/tfReader_op_test.cc index 9f0919cd96..5577a51e05 100644 --- a/tests/ut/cpp/dataset/tfReader_op_test.cc +++ b/tests/ut/cpp/dataset/tfReader_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ #include "minddata/dataset/core/client.h" #include "minddata/dataset/engine/data_schema.h" #include "common/common.h" -#include "utils/ms_utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" @@ -330,11 +329,14 @@ TEST_F(MindDataTestTFReaderOp, TestTFReaderRepeat) { ASSERT_TRUE(rc.IsOk()); // RepeatOp - std::shared_ptr my_repeat_op = std::make_shared(3); + uint32_t num_repeats = 3; + std::shared_ptr my_repeat_op = std::make_shared(num_repeats); rc = my_tree->AssociateNode(my_repeat_op); ASSERT_TRUE(rc.IsOk()); // Set children/root layout. + my_tfreader_op->set_total_repeats(num_repeats); + my_tfreader_op->set_num_repeats_per_epoch(num_repeats); rc = my_repeat_op->AddChild(my_tfreader_op); ASSERT_TRUE(rc.IsOk()); rc = my_tree->AssignRoot(my_repeat_op); @@ -705,7 +707,7 @@ TEST_F(MindDataTestTFReaderOp, TestTFReaderInvalidFiles) { std::string valid_file = datasets_root_path_ + "/testTFTestAllTypes/test.data"; std::string schema_file = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json"; std::string invalid_file = datasets_root_path_ + "/testTFTestAllTypes/invalidFile.txt"; - std::string nonexistent_file = "this/file/doesnt/exist"; + std::string nonexistent_file = "this/file/not/exist"; std::shared_ptr my_tfreader_op; TFReaderOp::Builder builder; diff --git a/tests/ut/cpp/dataset/voc_op_test.cc b/tests/ut/cpp/dataset/voc_op_test.cc index e58b07b35c..2bafbddf6d 100644 --- a/tests/ut/cpp/dataset/voc_op_test.cc +++ b/tests/ut/cpp/dataset/voc_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,8 +45,6 @@ using mindspore::LogStream; std::shared_ptr Batch(int batch_size = 1, bool drop = false, int rows_per_buf = 2); -std::shared_ptr Repeat(int repeat_cnt); - std::shared_ptr Build(std::vector> ops); class MindDataTestVOCOp : public UT::DatasetOpTesting { diff --git a/tests/ut/cpp/dataset/zip_op_test.cc b/tests/ut/cpp/dataset/zip_op_test.cc index b55578f672..eee1457048 100644 --- a/tests/ut/cpp/dataset/zip_op_test.cc +++ b/tests/ut/cpp/dataset/zip_op_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -141,6 +141,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) { MS_LOG(INFO) << "UT test TestZipRepeat."; auto my_tree = std::make_shared(); + uint32_t num_repeats = 3; std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data"; std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data"; std::shared_ptr my_tfreader_op; @@ -169,17 +170,23 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) { EXPECT_TRUE(rc.IsOk()); rc = my_tree->AssociateNode(zip_op); EXPECT_TRUE(rc.IsOk()); + my_tfreader_op->set_total_repeats(num_repeats); + my_tfreader_op->set_num_repeats_per_epoch(num_repeats); rc = zip_op->AddChild(std::move(my_tfreader_op)); EXPECT_TRUE(rc.IsOk()); + my_tfreader_op2->set_total_repeats(num_repeats); + my_tfreader_op2->set_num_repeats_per_epoch(num_repeats); rc = zip_op->AddChild(std::move(my_tfreader_op2)); EXPECT_TRUE(rc.IsOk()); // Builder(num_of_repeats) std::shared_ptr my_repeat_op; - rc = RepeatOp::Builder(3).Build(&my_repeat_op); + rc = RepeatOp::Builder(num_repeats).Build(&my_repeat_op); EXPECT_TRUE(rc.IsOk()); rc = my_tree->AssociateNode(my_repeat_op); EXPECT_TRUE(rc.IsOk()); + zip_op->set_total_repeats(num_repeats); + zip_op->set_num_repeats_per_epoch(num_repeats); rc = my_repeat_op->AddChild(zip_op); EXPECT_TRUE(rc.IsOk()); rc = my_tree->AssignRoot(my_repeat_op); diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index 3f03d6a065..dedef3a349 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -138,7 +138,7 @@ def test_cache_map_basic3(): @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_basic4(): """ - Test Map with non-deterministic TensorOps above cache + Test Map containing random operation above cache repeat | @@ -374,7 +374,7 @@ def test_cache_map_failure4(): @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_failure5(): """ - Test Map with non-deterministic TensorOps under cache (failure) + Test Map containing random operation under cache (failure) repeat | @@ -406,7 +406,7 @@ def test_cache_map_failure5(): num_iter = 0 for _ in data.create_dict_iterator(): num_iter += 1 - assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value) + assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) assert num_iter == 0 logger.info('test_cache_failure5 Ended.\n') diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py index 5316b37aff..915e14bdd8 100644 --- a/tests/ut/python/dataset/test_cache_nomap.py +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -2087,7 +2087,7 @@ def test_cache_nomap_failure4(): @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_failure5(): """ - Test Map with non-deterministic TensorOps under cache (failure) + Test Map containing random operation under cache (failure) repeat | @@ -2118,7 +2118,7 @@ def test_cache_nomap_failure5(): num_iter = 0 for _ in data.create_dict_iterator(): num_iter += 1 - assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value) + assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) assert num_iter == 0 logger.info('test_cache_nomap_failure5 Ended.\n')