diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 29487b6c7d..4d3fe4f7e6 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -568,8 +568,8 @@ std::shared_ptr Dataset::BuildSentencePieceVocab( const std::vector &col_names, uint32_t vocab_size, float character_coverage, SentencePieceModel model_type, const std::unordered_map ¶ms) { auto vocab = std::make_shared(); - auto ds = std::make_shared(IRNode(), vocab, col_names, vocab_size, character_coverage, - model_type, params); + auto ds = std::make_shared(IRNode()->DeepCopy(), vocab, col_names, vocab_size, + character_coverage, model_type, params); std::unique_ptr runtime_context = std::make_unique(); Status rc = runtime_context->Init(); @@ -600,8 +600,8 @@ std::shared_ptr Dataset::BuildVocab(const std::vector &colum const std::pair &freq_range, int64_t top_k, const std::vector &special_tokens, bool special_first) { auto vocab = std::make_shared(); - auto ds = - std::make_shared(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first); + auto ds = std::make_shared(IRNode()->DeepCopy(), vocab, columns, freq_range, top_k, special_tokens, + special_first); std::unique_ptr runtime_context = std::make_unique(); Status rc = runtime_context->Init(); diff --git a/mindspore/ccsrc/minddata/dataset/api/samplers.cc b/mindspore/ccsrc/minddata/dataset/api/samplers.cc index 44e6f6869d..98a1dc219e 100644 --- a/mindspore/ccsrc/minddata/dataset/api/samplers.cc +++ b/mindspore/ccsrc/minddata/dataset/api/samplers.cc @@ -190,13 +190,12 @@ std::shared_ptr PKSamplerObj::Build() { return sampler; } -#ifndef ENABLE_ANDROID // PreBuiltOperation -PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr sampler) - : sp_(std::move(sampler)), sp_minddataset_(nullptr) {} +PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr sampler) : sp_(std::move(sampler)) {} +#ifndef ENABLE_ANDROID PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr sampler) - : sp_(nullptr), sp_minddataset_(std::move(sampler)) {} + : sp_minddataset_(std::move(sampler)) {} #endif bool PreBuiltSamplerObj::ValidateParams() { return true; } @@ -207,6 +206,13 @@ std::shared_ptr PreBuiltSamplerObj::Build() { return sp_; } std::shared_ptr PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } #endif +std::shared_ptr PreBuiltSamplerObj::Copy() { +#ifndef ENABLE_ANDROID + if (sp_minddataset_ != nullptr) return std::make_shared(sp_minddataset_); +#endif + return std::make_shared(sp_); +} + #ifndef ENABLE_ANDROID std::shared_ptr PKSamplerObj::BuildForMindDataset() { // runtime mindrecord sampler object diff --git a/mindspore/ccsrc/minddata/dataset/api/transforms.cc b/mindspore/ccsrc/minddata/dataset/api/transforms.cc index 5c163bd74a..4b41e85b9d 100644 --- a/mindspore/ccsrc/minddata/dataset/api/transforms.cc +++ b/mindspore/ccsrc/minddata/dataset/api/transforms.cc @@ -30,8 +30,6 @@ namespace mindspore { namespace dataset { -TensorOperation::TensorOperation() {} - /* ####################################### Validator Functions ############################################ */ Status ValidateVectorFillvalue(const std::string &transform_name, const std::vector &fill_value) { if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) { @@ -231,7 +229,7 @@ std::shared_ptr PreBuiltOperation::Build() { return op_; } // RandomApplyOperation RandomApplyOperation::RandomApplyOperation(const std::vector> &transforms, double prob) - : transforms_(transforms), prob_(prob) {} + : TensorOperation(true), transforms_(transforms), prob_(prob) {} Status RandomApplyOperation::ValidateParams() { RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomApply", transforms_)); @@ -248,7 +246,7 @@ std::shared_ptr RandomApplyOperation::Build() { // RandomChoiceOperation RandomChoiceOperation::RandomChoiceOperation(const std::vector> &transforms) - : transforms_(transforms) {} + : TensorOperation(true), transforms_(transforms) {} Status RandomChoiceOperation::ValidateParams() { RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomChoice", transforms_)); diff --git a/mindspore/ccsrc/minddata/dataset/api/vision.cc b/mindspore/ccsrc/minddata/dataset/api/vision.cc index fc99b80c01..3083672f59 100644 --- a/mindspore/ccsrc/minddata/dataset/api/vision.cc +++ b/mindspore/ccsrc/minddata/dataset/api/vision.cc @@ -734,7 +734,9 @@ RandomAffineOperation::RandomAffineOperation(const std::vector °rees scale_range_(scale_range), shear_ranges_(shear_ranges), interpolation_(interpolation), - fill_value_(fill_value) {} + fill_value_(fill_value) { + random_op_ = true; +} Status RandomAffineOperation::ValidateParams() { // Degrees @@ -867,7 +869,7 @@ std::shared_ptr RandomAffineOperation::Build() { } // RandomColorOperation. -RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) {} +RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) { random_op_ = true; } Status RandomColorOperation::ValidateParams() { // Do some input validation. @@ -891,7 +893,9 @@ Status RandomColorOperation::ValidateParams() { // RandomColorAdjustOperation. RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector brightness, std::vector contrast, std::vector saturation, std::vector hue) - : brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {} + : brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) { + random_op_ = true; +} Status RandomColorAdjustOperation::ValidateParams() { // brightness @@ -1012,11 +1016,14 @@ std::shared_ptr RandomColorAdjustOperation::Build() { // RandomCropOperation RandomCropOperation::RandomCropOperation(std::vector size, std::vector padding, bool pad_if_needed, std::vector fill_value, BorderType padding_mode) - : size_(size), + : TensorOperation(true), + size_(size), padding_(padding), pad_if_needed_(pad_if_needed), fill_value_(fill_value), - padding_mode_(padding_mode) {} + padding_mode_(padding_mode) { + random_op_ = true; +} Status RandomCropOperation::ValidateParams() { // size @@ -1083,7 +1090,12 @@ std::shared_ptr RandomCropOperation::Build() { RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector size, std::vector scale, std::vector ratio, InterpolationMode interpolation, int32_t max_attempts) - : size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {} + : TensorOperation(true), + size_(size), + scale_(scale), + ratio_(ratio), + interpolation_(interpolation), + max_attempts_(max_attempts) {} Status RandomCropDecodeResizeOperation::ValidateParams() { // size @@ -1176,7 +1188,8 @@ std::shared_ptr RandomCropDecodeResizeOperation::Build() { RandomCropWithBBoxOperation::RandomCropWithBBoxOperation(std::vector size, std::vector padding, bool pad_if_needed, std::vector fill_value, BorderType padding_mode) - : size_(size), + : TensorOperation(true), + size_(size), padding_(padding), pad_if_needed_(pad_if_needed), fill_value_(fill_value), @@ -1245,7 +1258,8 @@ std::shared_ptr RandomCropWithBBoxOperation::Build() { } // RandomHorizontalFlipOperation -RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {} +RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) + : TensorOperation(true), probability_(probability) {} Status RandomHorizontalFlipOperation::ValidateParams() { RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlip", probability_)); @@ -1260,7 +1274,7 @@ std::shared_ptr RandomHorizontalFlipOperation::Build() { // RandomHorizontalFlipWithBBoxOperation RandomHorizontalFlipWithBBoxOperation::RandomHorizontalFlipWithBBoxOperation(float probability) - : probability_(probability) {} + : TensorOperation(true), probability_(probability) {} Status RandomHorizontalFlipWithBBoxOperation::ValidateParams() { RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlipWithBBox", probability_)); @@ -1275,7 +1289,8 @@ std::shared_ptr RandomHorizontalFlipWithBBoxOperation::Build() { } // RandomPosterizeOperation -RandomPosterizeOperation::RandomPosterizeOperation(const std::vector &bit_range) : bit_range_(bit_range) {} +RandomPosterizeOperation::RandomPosterizeOperation(const std::vector &bit_range) + : TensorOperation(true), bit_range_(bit_range) {} Status RandomPosterizeOperation::ValidateParams() { if (bit_range_.size() != 2) { @@ -1309,7 +1324,7 @@ std::shared_ptr RandomPosterizeOperation::Build() { } // RandomResizeOperation -RandomResizeOperation::RandomResizeOperation(std::vector size) : size_(size) {} +RandomResizeOperation::RandomResizeOperation(std::vector size) : TensorOperation(true), size_(size) {} Status RandomResizeOperation::ValidateParams() { // size @@ -1343,7 +1358,8 @@ std::shared_ptr RandomResizeOperation::Build() { } // RandomResizeWithBBoxOperation -RandomResizeWithBBoxOperation::RandomResizeWithBBoxOperation(std::vector size) : size_(size) {} +RandomResizeWithBBoxOperation::RandomResizeWithBBoxOperation(std::vector size) + : TensorOperation(true), size_(size) {} Status RandomResizeWithBBoxOperation::ValidateParams() { // size @@ -1380,7 +1396,12 @@ std::shared_ptr RandomResizeWithBBoxOperation::Build() { RandomResizedCropOperation::RandomResizedCropOperation(std::vector size, std::vector scale, std::vector ratio, InterpolationMode interpolation, int32_t max_attempts) - : size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {} + : TensorOperation(true), + size_(size), + scale_(scale), + ratio_(ratio), + interpolation_(interpolation), + max_attempts_(max_attempts) {} Status RandomResizedCropOperation::ValidateParams() { // size @@ -1536,7 +1557,8 @@ std::shared_ptr RandomResizedCropWithBBoxOperation::Build() { RandomRotationOperation::RandomRotationOperation(std::vector degrees, InterpolationMode interpolation_mode, bool expand, std::vector center, std::vector fill_value) - : degrees_(degrees), + : TensorOperation(true), + degrees_(degrees), interpolation_mode_(interpolation_mode), expand_(expand), center_(center), @@ -1603,7 +1625,7 @@ std::shared_ptr RandomRotationOperation::Build() { // RandomSelectSubpolicyOperation. RandomSelectSubpolicyOperation::RandomSelectSubpolicyOperation( std::vector, double>>> policy) - : policy_(policy) {} + : TensorOperation(true), policy_(policy) {} Status RandomSelectSubpolicyOperation::ValidateParams() { if (policy_.empty()) { @@ -1650,7 +1672,8 @@ std::shared_ptr RandomSelectSubpolicyOperation::Build() { } // Function to create RandomSharpness. -RandomSharpnessOperation::RandomSharpnessOperation(std::vector degrees) : degrees_(degrees) {} +RandomSharpnessOperation::RandomSharpnessOperation(std::vector degrees) + : TensorOperation(true), degrees_(degrees) {} Status RandomSharpnessOperation::ValidateParams() { if (degrees_.size() != 2 || degrees_[0] < 0 || degrees_[1] < 0) { @@ -1674,7 +1697,8 @@ std::shared_ptr RandomSharpnessOperation::Build() { } // RandomSolarizeOperation. -RandomSolarizeOperation::RandomSolarizeOperation(std::vector threshold) : threshold_(threshold) {} +RandomSolarizeOperation::RandomSolarizeOperation(std::vector threshold) + : TensorOperation(true), threshold_(threshold) {} Status RandomSolarizeOperation::ValidateParams() { if (threshold_.size() != 2) { @@ -1705,7 +1729,8 @@ std::shared_ptr RandomSolarizeOperation::Build() { } // RandomVerticalFlipOperation -RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {} +RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) + : TensorOperation(true), probability_(probability) {} Status RandomVerticalFlipOperation::ValidateParams() { RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlip", probability_)); @@ -1720,7 +1745,7 @@ std::shared_ptr RandomVerticalFlipOperation::Build() { // RandomVerticalFlipWithBBoxOperation RandomVerticalFlipWithBBoxOperation::RandomVerticalFlipWithBBoxOperation(float probability) - : probability_(probability) {} + : TensorOperation(true), probability_(probability) {} Status RandomVerticalFlipWithBBoxOperation::ValidateParams() { RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlipWithBBox", probability_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt index 18f3bcce12..3974ab2a1b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt @@ -9,11 +9,13 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES build_sentence_piece_vocab_node.cc build_vocab_node.cc concat_node.cc + epoch_ctrl_node.cc filter_node.cc map_node.cc project_node.cc rename_node.cc repeat_node.cc + root_node.cc shuffle_node.cc skip_node.cc sync_wait_node.cc diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc index b0e716e649..ee67b21ca7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc @@ -43,14 +43,29 @@ BatchNode::BatchNode(std::shared_ptr child, int32_t batch_size, boo batch_size_func_(batch_size_func), batch_map_func_(batch_map_func), pad_map_(pad_map) { - this->children.push_back(child); + this->AddChild(child); } #endif // constructor #2, called by C++ API BatchNode::BatchNode(std::shared_ptr child, int32_t batch_size, bool drop_remainder) : batch_size_(batch_size), drop_remainder_(drop_remainder), pad_(false) { - this->children.push_back(child); + this->AddChild(child); +} + +std::shared_ptr BatchNode::Copy() { +#ifdef ENABLE_PYTHON + auto node = std::make_shared(nullptr, batch_size_, drop_remainder_, pad_, in_col_names_, out_col_names_, + col_order_, batch_size_func_, batch_map_func_, pad_map_); +#else + auto node = std::make_shared(nullptr, batch_size_, drop_remainder_); +#endif + return node; +} + +void BatchNode::Print(std::ostream &out) const { + out << Name() + "(batch_size:" + std::to_string(batch_size_) + + " drop_remainder:" + (drop_remainder_ ? "true" : "false") + ")"; } Status BatchNode::ValidateParams() { diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h index 0e4d693b46..9bb1802a4e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h @@ -44,6 +44,18 @@ class BatchNode : public DatasetNode { /// \brief Destructor ~BatchNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kBatchNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc index c038a183cc..1cdb6cbd26 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc @@ -41,7 +41,17 @@ BucketBatchByLengthNode::BucketBatchByLengthNode( pad_info_(pad_info), pad_to_bucket_boundary_(pad_to_bucket_boundary), drop_remainder_(drop_remainder) { - this->children.push_back(child); + this->AddChild(child); +} + +std::shared_ptr BucketBatchByLengthNode::Copy() { + auto node = std::make_shared(nullptr, column_names_, bucket_boundaries_, bucket_batch_sizes_, + element_length_function_, pad_info_, pad_to_bucket_boundary_); + return node; +} + +void BucketBatchByLengthNode::Print(std::ostream &out) const { + out << Name() + "(columns:" + PrintColumns(column_names_) + ",...)"; } std::vector> BucketBatchByLengthNode::Build() { diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h index cffaabe923..1cdf46cd6e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h @@ -40,6 +40,18 @@ class BucketBatchByLengthNode : public DatasetNode { /// \brief Destructor ~BucketBatchByLengthNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kBucketBatchByLengthNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc index eaf8378639..fae47f00ba 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc @@ -22,6 +22,7 @@ #include #include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h" +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/status.h" namespace mindspore { @@ -38,7 +39,18 @@ BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr chil character_coverage_(character_coverage), model_type_(model_type), params_(params) { - this->children.push_back(child); + this->AddChild(child); +} + +std::shared_ptr BuildSentenceVocabNode::Copy() { + auto node = std::make_shared(nullptr, vocab_, col_names_, vocab_size_, character_coverage_, + model_type_, params_); + return node; +} + +void BuildSentenceVocabNode::Print(std::ostream &out) const { + out << Name() + "," + "columns:" + PrintColumns(col_names_) + ",vocab_size:" + std::to_string(vocab_size_) + + ",...)"; } // Function to build BuildSentenceVocabNode @@ -81,5 +93,16 @@ Status BuildSentenceVocabNode::ValidateParams() { return Status::OK(); } +// Visitor accepting method for NodePass +Status BuildSentenceVocabNode::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for NodePass +Status BuildSentenceVocabNode::AcceptAfter(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h index 01b36a8e6f..65954d7b3f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h @@ -38,6 +38,18 @@ class BuildSentenceVocabNode : public DatasetNode { /// \brief Destructor ~BuildSentenceVocabNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kBuildSentencePieceVocabNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; @@ -46,6 +58,18 @@ class BuildSentenceVocabNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(NodePass *p, bool *modified) override; + private: std::shared_ptr vocab_; std::vector col_names_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc index 623eccb86a..ff85f8e2a4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc @@ -22,7 +22,7 @@ #include #include "minddata/dataset/engine/datasetops/build_vocab_op.h" - +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { @@ -36,7 +36,17 @@ BuildVocabNode::BuildVocabNode(std::shared_ptr child, std::shared_p top_k_(top_k), special_tokens_(special_tokens), special_first_(special_first) { - this->children.push_back(child); + this->AddChild(child); +} + +std::shared_ptr BuildVocabNode::Copy() { + auto node = + std::make_shared(nullptr, vocab_, columns_, freq_range_, top_k_, special_tokens_, special_first_); + return node; +} + +void BuildVocabNode::Print(std::ostream &out) const { + out << Name() + "(," + "columns:" + PrintColumns(columns_) + ",...)"; } // Function to build BuildVocabNode @@ -78,5 +88,16 @@ Status BuildVocabNode::ValidateParams() { return Status::OK(); } +// Visitor accepting method for NodePass +Status BuildVocabNode::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for NodePass +Status BuildVocabNode::AcceptAfter(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h index 408115a4aa..27a86b0604 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h @@ -37,6 +37,18 @@ class BuildVocabNode : public DatasetNode { /// \brief Destructor ~BuildVocabNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kBuildVocabNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; @@ -45,6 +57,18 @@ class BuildVocabNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(NodePass *p, bool *modified) override; + private: std::shared_ptr vocab_; std::vector columns_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc index cbf098249d..ae76a27be4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc @@ -22,7 +22,7 @@ #include #include "minddata/dataset/engine/datasetops/concat_op.h" - +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { @@ -35,17 +35,25 @@ ConcatNode::ConcatNode(const std::vector> &datasets : sampler_(sampler), children_flag_and_nums_(children_flag_and_nums), children_start_end_index_(children_start_end_index) { - this->children = datasets; + for (auto const &child : datasets) AddChild(child); +} + +std::shared_ptr ConcatNode::Copy() { + // create an empty vector to copy a concat + auto node = std::make_shared(std::vector>()); + return node; } +void ConcatNode::Print(std::ostream &out) const { out << Name(); } + Status ConcatNode::ValidateParams() { - if (children.size() < 2) { + if (children_.size() < 2) { std::string err_msg = "ConcatNode: concatenated datasets are not specified."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } - if (find(children.begin(), children.end(), nullptr) != children.end()) { + if (find(children_.begin(), children_.end(), nullptr) != children_.end()) { std::string err_msg = "ConcatNode: concatenated datasets should not be null."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); @@ -73,5 +81,16 @@ std::vector> ConcatNode::Build() { return node_ops; } +// Visitor accepting method for NodePass +Status ConcatNode::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for NodePass +Status ConcatNode::AcceptAfter(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h index 53be272761..c542e46e2b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h @@ -38,6 +38,18 @@ class ConcatNode : public DatasetNode { /// \brief Destructor ~ConcatNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kConcatNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; @@ -50,6 +62,18 @@ class ConcatNode : public DatasetNode { std::shared_ptr sampler_; std::vector> children_flag_and_nums_; std::vector> children_start_end_index_; + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(NodePass *p, bool *modified) override; }; } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index 47e49d0976..e92b3b3642 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -233,14 +233,92 @@ std::shared_ptr DatasetNode::SetNumWorkers(int32_t num_workers) { return shared_from_this(); } -DatasetNode::DatasetNode() { +DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}) { // Fetch some default value from config manager std::shared_ptr cfg = GlobalContext::config_manager(); num_workers_ = cfg->num_parallel_workers(); rows_per_buffer_ = cfg->rows_per_buffer(); connector_que_size_ = cfg->op_connector_size(); worker_connector_size_ = cfg->worker_connector_size(); - build_status = Status::OK(); // remove me after changing return val of Build() +} + +// this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied +std::shared_ptr DatasetNode::DeepCopy() { + std::shared_ptr new_node = this->Copy(); + for (const auto &child : children_) { + new_node->AddChild(child->DeepCopy()); + } + return new_node; +} + +std::string DatasetNode::PrintColumns(const std::vector &columns) const { + std::string me; + if (columns.empty()) { + me = ""; + } else { + me = "["; + auto i = 0; + for (auto it = columns.begin(); it < columns.end(); ++it, ++i) { + me += *it; + if (i < columns.size() - 1) { + me += ", "; + } else { + me += "]"; + } + } + } + return me; +} + +void DatasetNode::PrintTree(std::ostream &out) const { + int level = 0; + PrintNode(out, &level); +} + +void DatasetNode::PrintNode(std::ostream &out, int *level) const { + const std::string prefix = "+-"; + const std::string indent = " "; + out << prefix; + Print(out); + for (const auto &c : this->Children()) { + out << '\n'; + ++(*level); + for (auto i = 0; i < *level; i++) { + out << indent; + } + c->PrintNode(out, level); + --(*level); + } +} + +// Add a node as a child, node's parent needs to be nullptr +// this function will allow child to be a nullptr, in which case it will simply skip +void DatasetNode::AddChild(std::shared_ptr child) { + if (child != nullptr && child->parent_ == nullptr) { + children_.push_back(child); + child->parent_ = this; + } else if (child != nullptr) { + MS_LOG(WARNING) << "DatasetNode::AddChild() Fail" + child->Name() + "'s parent isn't a nullptr."; + } +} + +// Remove this node from its parent. Add the child of this node to its parent. +// for now, this remove is limited to node with a single child or no child +Status DatasetNode::Remove() { + CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "Cannot remove root or a node without parent."); + CHECK_FAIL_RETURN_UNEXPECTED(children_.size() < 2, "Cannot remove node with more than 1 child."); + if (children_.empty()) { // I am a leaf node, remove me from my parent's children list + parent_->children_.erase(std::remove(parent_->children_.begin(), parent_->children_.end(), shared_from_this()), + parent_->children_.end()); // removal using "erase remove idiom" + } else { // replace my position in my parent's children list with my single child + auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this()); + CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list."); + children_[0]->parent_ = parent_; // set my single child's parent ptr to my parent + *itr = std::move(children_[0]); // replace me in my parent's children list with my single child + children_.clear(); // release my single child from my children list + } + parent_ = nullptr; + return Status::OK(); } // In DFS tree traversal, each node is visited twice. Accept is called on the first visit. @@ -255,13 +333,25 @@ Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) { // This method will only be called if its derived class does not implement one. return p->VisitAfter(shared_from_this(), modified); } + Status DatasetNode::GetShardId(int32_t *shard_id) { if (!Children().empty()) { // Get shard id from the child node return Children()[0]->GetShardId(shard_id); } else { - RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node"); + RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n"); } } +// Visitor accepting method for NodePass +Status SourceNode::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for NodePass +Status SourceNode::AcceptAfter(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index 46a9f0d4c4..57ed1de171 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -42,6 +42,45 @@ class NodePass; } \ } while (false) +// Names for non-leaf IR node +constexpr char kBatchNode[] = "Batch"; +constexpr char kBucketBatchByLengthNode[] = "BucketBatchByLength"; +constexpr char kBuildSentencePieceVocabNode[] = "BuildSentencePieceVocab"; +constexpr char kBuildVocabNode[] = "BuildVocab"; +constexpr char kConcatNode[] = "Concat"; +constexpr char kDatasetNode[] = "Dataset"; +constexpr char kEpochCtrlNode[] = "EpochCtrl"; +constexpr char kFilterNode[] = "Filter"; +constexpr char kMapNode[] = "Map"; +constexpr char kProjectNode[] = "Project"; +constexpr char kRenameNode[] = "Rename"; +constexpr char kRepeatNode[] = "Repeat"; +constexpr char kRootNode[] = "Top"; +constexpr char kShuffleNode[] = "Shuffle"; +constexpr char kSkipNode[] = "Skip"; +constexpr char kSyncWaitNode[] = "SyncWait"; +constexpr char kTakeNode[] = "Take"; +constexpr char kTransferNode[] = "Transfer"; +constexpr char kZipNode[] = "Zip"; + +// Names for leaf IR node +constexpr char kAlbumNode[] = "AlbumDataset"; +constexpr char kCelebANode[] = "CelebADataset"; +constexpr char kCifar100Node[] = "Cifar100Dataset"; +constexpr char kCifar10Node[] = "Cifar10Dataset"; +constexpr char kCLUENode[] = "CLUEDataset"; +constexpr char kCocoNode[] = "CocoDataset"; +constexpr char kCSVNode[] = "CSVDataset"; +constexpr char kGeneratorNode[] = "GeneratorDataset"; +constexpr char kImageFolderNode[] = "ImageFolderDataset"; +constexpr char kManifestNode[] = "ManifestDataset"; +constexpr char kMindDataNode[] = "MindDataDataset"; +constexpr char kMnistNode[] = "MnistDataset"; +constexpr char kRandomNode[] = "RandomDataset"; +constexpr char kTextFileNode[] = "TextFileDataset"; +constexpr char kTFRecordNode[] = "TFRecordDataset"; +constexpr char kVOCNode[] = "VOCDataset"; + Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr *shuffle_op); @@ -75,6 +114,7 @@ Status ValidateDatasetDirParam(const std::string &dataset_name, std::string data /// \return Shared pointer to the current Sampler. std::shared_ptr SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id); +// The base class of all IR nodes class DatasetNode : public std::enable_shared_from_this { public: /// \brief Constructor @@ -87,6 +127,36 @@ class DatasetNode : public std::enable_shared_from_this { /// \brief Destructor ~DatasetNode() = default; + /// \brief Node name getter + /// \return Name of the current node + virtual std::string Name() const = 0; + + /// \brief Pure virtual function to print the description + /// \param out - The output stream to write output to + virtual void Print(std::ostream &out) const = 0; + + /// \brief Pure virtual function to make a new copy of the node + /// \return The new copy of the node + virtual std::shared_ptr Copy() = 0; + + /// \brief Print the IR tree to output stream + /// \param out - The output stream to write output to + void PrintTree(std::ostream &out) const; + + /// \brief << Stream output operator overload + /// \notes This allows you to write the debug print info using stream operators + /// \param out - reference to the output stream being overloaded + /// \param dO - reference to the DatasetOp to display + /// \return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const DatasetNode &node) { + node.PrintTree(out); + return out; + } + + /// \brief Make a new copy of the tree from the current node + /// \return The new copy of the tree + std::shared_ptr DeepCopy(); + /// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object /// \return The list of shared pointers to the newly created DatasetOps virtual std::vector> Build() = 0; @@ -95,17 +165,38 @@ class DatasetNode : public std::enable_shared_from_this { /// \return Status Status::OK() if all the parameters are valid virtual Status ValidateParams() = 0; - const std::vector> Children() const { return children; } - /// \brief Pure virtual function for derived class to get the shard id of specific node /// \return Status Status::OK() if get shard id successfully virtual Status GetShardId(int32_t *shard_id); + /// \brief Getter function for child nodes + /// \return Child nodes + const std::vector> Children() const { return children_; } + + /// \brief Establish the parent-child relationship between this node and its child. + void AddChild(std::shared_ptr child); + + /// \brief detach this node from its parent, add its child (if any) to its parent + /// \return error code, return error if node has more than 1 children + Status Remove(); + + /// \brief Check if this node has cache + /// \return True if the data of this node will be cached + const bool IsCached() const { return (cache_ != nullptr); } + /// \brief Setter function for runtime number of workers /// \param[in] num_workers The number of threads in this operator /// \return Shared pointer to the original object std::shared_ptr SetNumWorkers(int32_t num_workers); + /// \brief A helper templated function for casting "this" pointer to shared_ptr + /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr + /// \return A shared_ptr casted to the derived class + template + std::shared_ptr shared_from_base() { + return std::static_pointer_cast(shared_from_this()); + } + /// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up /// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node /// visit on the way back up the tree after its descendants are visited. @@ -129,17 +220,123 @@ class DatasetNode : public std::enable_shared_from_this { Status BuildStatus() { return build_status; } protected: - std::vector> children; + std::vector> children_; + DatasetNode *parent_; std::shared_ptr cache_; - Status AddCacheOp(std::vector> *node_ops); - int32_t num_workers_; int32_t rows_per_buffer_; int32_t connector_que_size_; int32_t worker_connector_size_; Status build_status; // remove me after changing return val of Build() + std::string PrintColumns(const std::vector &columns) const; + Status AddCacheOp(std::vector> *node_ops); + void PrintNode(std::ostream &out, int *level) const; }; +// SourceNode represents the leaf nodes of a pipeline where the data is pulled into. +class SourceNode : public DatasetNode { + public: + /// \brief Constructor + SourceNode() : DatasetNode() {} + + /// \brief Constructor that initializes the cache + /// \param dataset_cache DatasetCache + explicit SourceNode(const std::shared_ptr &dataset_cache) : DatasetNode(dataset_cache) {} + + /// \brief Destructor + ~SourceNode() = default; + + /// \brief Node name getter + /// \return Name of the current node + virtual std::string Name() const = 0; + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(NodePass *p, bool *modified) override; + + /// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes + /// \return True if the dataset represented by this node is a mappable dataset + const bool IsMappable() const { return mappable_; } + + protected: + bool mappable_; +}; + +// MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes. +class MappableSourceNode : public SourceNode { + public: + /// \brief Constructor + MappableSourceNode() : SourceNode() { mappable_ = true; } + + /// \brief Constructor that initializes the cache + /// \param dataset_cache DatasetCache + explicit MappableSourceNode(const std::shared_ptr &dataset_cache) : SourceNode(dataset_cache) { + mappable_ = true; + } + + /// \brief Destructor + ~MappableSourceNode() = default; + + /// \brief Node name getter + /// \return Name of the current node + virtual std::string Name() const = 0; +}; + +// NonMappableSourceNode represents the leaf nodes that can not be randomly accessed. +class NonMappableSourceNode : public SourceNode { + public: + /// \brief Constructor + NonMappableSourceNode() : SourceNode() { mappable_ = false; } + + /// \brief Constructor that initializes the cache + /// \param dataset_cache DatasetCache + explicit NonMappableSourceNode(const std::shared_ptr &dataset_cache) : SourceNode(dataset_cache) { + mappable_ = false; + } + + /// \brief Destructor + ~NonMappableSourceNode() = default; + + /// \brief Node name getter + /// \return Name of the current node + virtual std::string Name() const = 0; +}; + +// NonLeafNode represents operations over data in a pipeline. +class NonLeafNode : public DatasetNode { + public: + /// \brief Constructor + NonLeafNode() = default; + + /// \brief Destructor + ~NonLeafNode() = default; + + /// \brief Node name getter + /// \return Name of the current node + virtual std::string Name() const = 0; +}; + +// SinkNode represents the end node of a pipeline where the data is pushed out +class SinkNode : public DatasetNode { + public: + /// \brief Constructor + SinkNode() = default; + + /// \brief Destructor + ~SinkNode() = default; + + /// \brief Node name getter + /// \return Name of the current node + virtual std::string Name() const = 0; +}; } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc new file mode 100644 index 0000000000..705c8b59cc --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h" + +#include +#include +#include + +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +// Constructor for EpochCtrlNode +EpochCtrlNode::EpochCtrlNode(std::shared_ptr child, int32_t num_epochs) : num_epochs_(num_epochs) { + // The root node's parent must set to null pointer. + this->AddChild(child); +} +std::shared_ptr EpochCtrlNode::Copy() { + auto node = std::make_shared(nullptr, this->num_epochs_); + return node; +} + +void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" + std::to_string(num_epochs_) + ")"; } + +// Function to build the EpochCtrlOp +std::vector> EpochCtrlNode::Build() { + // A dummy vector + std::vector> node_ops; + node_ops.push_back(std::make_shared(num_epochs_)); + return node_ops; +} + +// Function to validate the parameters for EpochCtrlNode +Status EpochCtrlNode::ValidateParams() { + if (num_epochs_ <= 0 && num_epochs_ != -1) { + std::string err_msg = + "EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_); + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + if (children_.size() != 1 || children_[0] == nullptr) { + std::string err_msg = "Internal error: epoch control node should have one child node"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h new file mode 100644 index 0000000000..8964a9f9fc --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { + +class EpochCtrlNode : public DatasetNode { + public: + /// \brief Constructor + explicit EpochCtrlNode(std::shared_ptr child, int32_t num_epochs); + + /// \brief Destructor + ~EpochCtrlNode() = default; + + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kEpochCtrlNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; + + private: + int32_t num_epochs_; +}; + +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc index 9bfa61863d..a357740dca 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc @@ -21,7 +21,7 @@ #include #include "minddata/dataset/engine/datasetops/filter_op.h" - +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/status.h" namespace mindspore { @@ -31,7 +31,16 @@ namespace dataset { FilterNode::FilterNode(std::shared_ptr child, std::shared_ptr predicate, std::vector input_columns) : predicate_(predicate), input_columns_(input_columns) { - this->children.push_back(child); + this->AddChild(child); +} + +std::shared_ptr FilterNode::Copy() { + auto node = std::make_shared(nullptr, predicate_, input_columns_); + return node; +} + +void FilterNode::Print(std::ostream &out) const { + out << Name() + "(," + "input_cols:" + PrintColumns(input_columns_) + ")"; } std::vector> FilterNode::Build() { @@ -54,5 +63,17 @@ Status FilterNode::ValidateParams() { return Status::OK(); } +// Visitor accepting method for NodePass +Status FilterNode::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for NodePass +Status FilterNode::AcceptAfter(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h index fcc8bf9885..7e66168c61 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h @@ -35,6 +35,18 @@ class FilterNode : public DatasetNode { /// \brief Destructor ~FilterNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kFilterNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; @@ -43,6 +55,18 @@ class FilterNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(NodePass *p, bool *modified) override; + private: std::shared_ptr predicate_; std::vector input_columns_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc index 56e03e2f33..31873d07a1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc @@ -22,6 +22,7 @@ #include #include "minddata/dataset/engine/datasetops/map_op/map_op.h" +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/include/transforms.h" #include "minddata/dataset/util/status.h" namespace mindspore { @@ -37,7 +38,18 @@ MapNode::MapNode(std::shared_ptr child, std::vectorchildren.push_back(child); + this->AddChild(child); +} + +std::shared_ptr MapNode::Copy() { + auto node = std::make_shared(nullptr, operations_, input_columns_, output_columns_, project_columns_, cache_, + callbacks_); + return node; +} + +void MapNode::Print(std::ostream &out) const { + out << Name() + "(" + ",input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) + + "," + ",...)"; } std::vector> MapNode::Build() { @@ -93,5 +105,16 @@ Status MapNode::ValidateParams() { return Status::OK(); } +// Visitor accepting method for NodePass +Status MapNode::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for NodePass +Status MapNode::AcceptAfter(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h index 101e73382a..20e63ae288 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h @@ -37,6 +37,18 @@ class MapNode : public DatasetNode { /// \brief Destructor ~MapNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kMapNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; @@ -45,6 +57,23 @@ class MapNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Getter of tensor operations + /// \return Vector of operations the Map node will process + const auto &TensorOperations() const { return operations_; } + auto &TensorOperations() { return operations_; } + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(NodePass *p, bool *modified) override; + private: std::vector> operations_; std::vector input_columns_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc index 3ea08f2b05..caafe9b1da 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc @@ -29,9 +29,16 @@ namespace dataset { // Function to build ProjectOp ProjectNode::ProjectNode(std::shared_ptr child, const std::vector &columns) : columns_(columns) { - this->children.push_back(child); + this->AddChild(child); } +std::shared_ptr ProjectNode::Copy() { + auto node = std::make_shared(nullptr, this->columns_); + return node; +} + +void ProjectNode::Print(std::ostream &out) const { out << Name() + "(column: " + PrintColumns(columns_) + ")"; } + Status ProjectNode::ValidateParams() { if (columns_.empty()) { std::string err_msg = "ProjectNode: No columns are specified."; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h index e90f6d68d9..1c7c1e69fc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h @@ -34,6 +34,18 @@ class ProjectNode : public DatasetNode { /// \brief Destructor ~ProjectNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kProjectNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc index 8761102356..03bea5a0e0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc @@ -30,7 +30,16 @@ namespace dataset { RenameNode::RenameNode(std::shared_ptr child, const std::vector &input_columns, const std::vector &output_columns) : input_columns_(input_columns), output_columns_(output_columns) { - this->children.push_back(child); + this->AddChild(child); +} + +std::shared_ptr RenameNode::Copy() { + auto node = std::make_shared(nullptr, input_columns_, output_columns_); + return node; +} + +void RenameNode::Print(std::ostream &out) const { + out << Name() + "(input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) + ")"; } Status RenameNode::ValidateParams() { diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h index 8a8faf2a4a..42bc2bcc45 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h @@ -35,6 +35,18 @@ class RenameNode : public DatasetNode { /// \brief Destructor ~RenameNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kRenameNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc index 7fe738a20d..65ef8bacea 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc @@ -21,15 +21,22 @@ #include #include "minddata/dataset/engine/datasetops/repeat_op.h" - +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { RepeatNode::RepeatNode(std::shared_ptr child, int32_t count) : repeat_count_(count) { - this->children.push_back(child); + this->AddChild(child); +} + +std::shared_ptr RepeatNode::Copy() { + auto node = std::make_shared(nullptr, this->repeat_count_); + return node; } +void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + std::to_string(repeat_count_) + ")"; } + std::vector> RepeatNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -49,5 +56,16 @@ Status RepeatNode::ValidateParams() { return Status::OK(); } +// Visitor accepting method for NodePass +Status RepeatNode::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for NodePass +Status RepeatNode::AcceptAfter(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h index b582dcb326..318a7dda3e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h @@ -36,6 +36,18 @@ class RepeatNode : public DatasetNode { /// \brief Destructor ~RepeatNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kRepeatNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; @@ -44,6 +56,18 @@ class RepeatNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(NodePass *p, bool *modified) override; + private: int32_t repeat_count_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc new file mode 100644 index 0000000000..6bc13c2022 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/root_node.h" + +#include +#include +#include + +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +// Constructor for RootNode +RootNode::RootNode(std::shared_ptr child, int32_t num_epochs) : DatasetNode(), num_epochs_(num_epochs) { + // The root node's parent must remain nullptr. (which is set in the constructor of DatasetNode) + AddChild(child); +} + +std::shared_ptr RootNode::Copy() { + auto node = std::make_shared(nullptr, num_epochs_); + return node; +} + +void RootNode::Print(std::ostream &out) const { out << Name(); } + +std::vector> RootNode::Build() { + // root node doesn't build a runtime Op. this function should return Status::Error when called. + return {}; +} + +// Function to validate the parameters for RootNode +Status RootNode::ValidateParams() { + if (num_epochs_ <= 0 && num_epochs_ != -1) { + std::string err_msg = + "RootNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_); + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + if (parent_ != nullptr) { + std::string err_msg = "Internal error: root node should not have a parent"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + if (children_.size() != 1) { + std::string err_msg = "Internal error: root node should have one child node"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + if (children_[0] == nullptr) { + std::string err_msg = "Internal error: root node's child is a null pointer"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + return Status::OK(); +} + +// Visitor accepting method for NodePass +Status RootNode::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for NodePass +Status RootNode::AcceptAfter(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h new file mode 100644 index 0000000000..83146db92e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { + +class RootNode : public DatasetNode { + public: + /// \brief Constructor + RootNode(std::shared_ptr child, int32_t num_epochs); + + /// \brief Destructor + ~RootNode() = default; + + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kRootNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::vector> Build() override; + + /// \brief Getter of number of epochs + int32_t num_epochs() { return num_epochs_; } + + /// \brief Parameters validation + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(NodePass *p, bool *modified) override; + + private: + int32_t num_epochs_; +}; + +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc index e722547aae..3e470586e3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc @@ -29,7 +29,17 @@ namespace dataset { // Constructor for ShuffleNode ShuffleNode::ShuffleNode(std::shared_ptr child, int32_t shuffle_size, bool reset_every_epoch) : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) { - this->children.push_back(child); + this->AddChild(child); +} + +std::shared_ptr ShuffleNode::Copy() { + auto node = std::make_shared(nullptr, shuffle_size_, reset_every_epoch_); + return node; +} + +void ShuffleNode::Print(std::ostream &out) const { + out << Name() + "(shuffle_size:" + std::to_string(shuffle_size_) + + ",reset_every_epoch:" + (reset_every_epoch_ ? "true" : "false") + ")"; } // Function to build the ShuffleOp diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h index 0b81684e61..bcc4f99956 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h @@ -34,6 +34,18 @@ class ShuffleNode : public DatasetNode { ~ShuffleNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kShuffleNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + std::vector> Build() override; Status ValidateParams() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc index 8590c47c42..b2e4eae252 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc @@ -27,10 +27,15 @@ namespace mindspore { namespace dataset { // Constructor for SkipNode -SkipNode::SkipNode(std::shared_ptr child, int32_t count) : skip_count_(count) { - this->children.push_back(child); +SkipNode::SkipNode(std::shared_ptr child, int32_t count) : skip_count_(count) { this->AddChild(child); } + +std::shared_ptr SkipNode::Copy() { + auto node = std::make_shared(nullptr, skip_count_); + return node; } +void SkipNode::Print(std::ostream &out) const { out << Name() + "(skip_count:" + std::to_string(skip_count_) + ")"; } + // Function to build the SkipOp std::vector> SkipNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h index 19e7cc9031..376ea96869 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h @@ -34,6 +34,18 @@ class SkipNode : public DatasetNode { /// \brief Destructor ~SkipNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kSkipNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc index 48193fed80..a6127f4470 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc @@ -32,13 +32,23 @@ namespace dataset { AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema, const std::vector &column_names, bool decode, const std::shared_ptr &sampler, const std::shared_ptr &cache) - : DatasetNode(std::move(cache)), + : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), schema_path_(data_schema), column_names_(column_names), decode_(decode), sampler_(sampler) {} +std::shared_ptr AlbumNode::Copy() { + std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + auto node = std::make_shared(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_); + return node; +} + +void AlbumNode::Print(std::ostream &out) const { + out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")"; +} + Status AlbumNode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h index 21fce849c0..1cacbfcfc2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h @@ -26,7 +26,7 @@ namespace mindspore { namespace dataset { -class AlbumNode : public DatasetNode { +class AlbumNode : public MappableSourceNode { public: /// \brief Constructor AlbumNode(const std::string &dataset_dir, const std::string &data_schema, @@ -36,6 +36,18 @@ class AlbumNode : public DatasetNode { /// \brief Destructor ~AlbumNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kAlbumNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create a runtime dataset op object from this class /// \return shared pointer to the newly created DatasetOp std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc index b86fdcfd8a..9b888eb544 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc @@ -31,13 +31,23 @@ namespace dataset { CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr &sampler, const bool &decode, const std::set &extensions, const std::shared_ptr &cache) - : DatasetNode(std::move(cache)), + : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler), decode_(decode), extensions_(extensions) {} +std::shared_ptr CelebANode::Copy() { + std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + auto node = std::make_shared(dataset_dir_, usage_, sampler, decode_, extensions_, cache_); + return node; +} + +void CelebANode::Print(std::ostream &out) const { + out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")"; +} + Status CelebANode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h index ed30adfd9c..a2518d5220 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h @@ -28,7 +28,7 @@ namespace mindspore { namespace dataset { -class CelebANode : public DatasetNode { +class CelebANode : public MappableSourceNode { public: /// \brief Constructor CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr &sampler, @@ -37,6 +37,18 @@ class CelebANode : public DatasetNode { /// \brief Destructor ~CelebANode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kCelebANode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc index d19e831437..f24cc21bfa 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc @@ -30,7 +30,17 @@ namespace dataset { // Constructor for Cifar100Node Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler, std::shared_ptr cache) - : DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} + : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} + +std::shared_ptr Cifar100Node::Copy() { + std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + auto node = std::make_shared(dataset_dir_, usage_, sampler, cache_); + return node; +} + +void Cifar100Node::Print(std::ostream &out) const { + out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")"; +} Status Cifar100Node::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h index fe24f8f3a0..25e63e2b91 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h @@ -26,7 +26,7 @@ namespace mindspore { namespace dataset { -class Cifar100Node : public DatasetNode { +class Cifar100Node : public MappableSourceNode { public: /// \brief Constructor Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler, @@ -35,6 +35,18 @@ class Cifar100Node : public DatasetNode { /// \brief Destructor ~Cifar100Node() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kCifar100Node; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc index 5fcfa36c99..0b3eec3cdf 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc @@ -30,7 +30,17 @@ namespace dataset { // Constructor for Cifar10Node Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler, std::shared_ptr cache) - : DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} + : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} + +std::shared_ptr Cifar10Node::Copy() { + std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + auto node = std::make_shared(dataset_dir_, usage_, sampler, cache_); + return node; +} + +void Cifar10Node::Print(std::ostream &out) const { + out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")"; +} Status Cifar10Node::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h index 716474ae2e..63140f4b67 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h @@ -26,7 +26,7 @@ namespace mindspore { namespace dataset { -class Cifar10Node : public DatasetNode { +class Cifar10Node : public MappableSourceNode { public: /// \brief Constructor Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler, @@ -35,6 +35,18 @@ class Cifar10Node : public DatasetNode { /// \brief Destructor ~Cifar10Node() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kCifar10Node; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc index cb0a2116ab..f1e8ad82a9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc @@ -32,7 +32,7 @@ namespace dataset { // Constructor for CLUENode CLUENode::CLUENode(const std::vector clue_files, std::string task, std::string usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr cache) - : DatasetNode(std::move(cache)), + : NonMappableSourceNode(std::move(cache)), dataset_files_(clue_files), task_(task), usage_(usage), @@ -41,6 +41,17 @@ CLUENode::CLUENode(const std::vector clue_files, std::string task, num_shards_(num_shards), shard_id_(shard_id) {} +std::shared_ptr CLUENode::Copy() { + auto node = + std::make_shared(dataset_files_, task_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); + return node; +} + +void CLUENode::Print(std::ostream &out) const { + out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ",..." + + ",num_shards:" + std::to_string(num_shards_) + ",shard_id:" + std::to_string(shard_id_) + ")"; +} + Status CLUENode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h index 76da0501b8..c315ab8386 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h @@ -28,7 +28,7 @@ namespace dataset { /// \class CLUENode /// \brief A Dataset derived class to represent CLUE dataset -class CLUENode : public DatasetNode { +class CLUENode : public NonMappableSourceNode { public: /// \brief Constructor CLUENode(const std::vector dataset_files, std::string task, std::string usage, int64_t num_samples, @@ -37,6 +37,18 @@ class CLUENode : public DatasetNode { /// \brief Destructor ~CLUENode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kCLUENode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc index beaf3c9d45..40ce5e808d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc @@ -30,13 +30,21 @@ namespace dataset { // Constructor for CocoNode CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, const bool &decode, const std::shared_ptr &sampler, std::shared_ptr cache) - : DatasetNode(std::move(cache)), + : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {} +std::shared_ptr CocoNode::Copy() { + std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + auto node = std::make_shared(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_); + return node; +} + +void CocoNode::Print(std::ostream &out) const { out << Name(); } + Status CocoNode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h index d3b4275d7f..d4c3db57ab 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h @@ -26,7 +26,7 @@ namespace mindspore { namespace dataset { -class CocoNode : public DatasetNode { +class CocoNode : public MappableSourceNode { public: /// \brief Constructor CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, @@ -35,6 +35,18 @@ class CocoNode : public DatasetNode { /// \brief Destructor ~CocoNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kCocoNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc index fb11973da2..8aacf0de42 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc @@ -33,7 +33,7 @@ CSVNode::CSVNode(const std::vector &csv_files, char field_delim, const std::vector> &column_defaults, const std::vector &column_names, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr cache) - : DatasetNode(std::move(cache)), + : NonMappableSourceNode(std::move(cache)), dataset_files_(csv_files), field_delim_(field_delim), column_defaults_(column_defaults), @@ -43,6 +43,17 @@ CSVNode::CSVNode(const std::vector &csv_files, char field_delim, num_shards_(num_shards), shard_id_(shard_id) {} +std::shared_ptr CSVNode::Copy() { + auto node = std::make_shared(dataset_files_, field_delim_, column_defaults_, column_names_, num_samples_, + shuffle_, num_shards_, shard_id_, cache_); + return node; +} + +void CSVNode::Print(std::ostream &out) const { + out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ",..." + + ",num_shards:" + std::to_string(num_shards_) + ",shard_id:" + std::to_string(shard_id_) + ")"; +} + Status CSVNode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h index bb7dae493f..88cb488d08 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h @@ -47,7 +47,7 @@ class CsvRecord : public CsvBase { T value; }; -class CSVNode : public DatasetNode { +class CSVNode : public NonMappableSourceNode { public: /// \brief Constructor CSVNode(const std::vector &dataset_files, char field_delim, @@ -58,6 +58,18 @@ class CSVNode : public DatasetNode { /// \brief Destructor ~CSVNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kCSVNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc index 9b48c27b70..b991818333 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc @@ -28,7 +28,19 @@ namespace dataset { GeneratorNode::GeneratorNode(py::function generator_function, const std::vector &column_names, const std::vector &column_types) - : generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {} + : MappableSourceNode(), + generator_function_(generator_function), + column_names_(column_names), + column_types_(column_types) {} + +std::shared_ptr GeneratorNode::Copy() { + auto node = std::make_shared(generator_function_, column_names_, column_types_); + return node; +} + +void GeneratorNode::Print(std::ostream &out) const { + out << Name() + "(:" + ",columns:" + PrintColumns(column_names_) + ",)"; +} GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr &schema) : generator_function_(generator_function), schema_(schema) {} diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h index fc7fab1076..d06dedc25c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h @@ -26,10 +26,9 @@ namespace mindspore { namespace dataset { - /// \class GeneratorNode /// \brief A Dataset derived class to represent GeneratorNode dataset -class GeneratorNode : public DatasetNode { +class GeneratorNode : public MappableSourceNode { public: /// \brief Constructor GeneratorNode(py::function generator_function, const std::vector &column_names, @@ -41,6 +40,18 @@ class GeneratorNode : public DatasetNode { /// \brief Destructor ~GeneratorNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kGeneratorNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc index 62c094a4a6..2dc8fa57b7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc @@ -33,13 +33,24 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar bool recursive, std::set extensions, std::map class_indexing, std::shared_ptr cache = nullptr) - : dataset_dir_(dataset_dir), + : MappableSourceNode(std::move(cache)), + dataset_dir_(dataset_dir), decode_(decode), sampler_(sampler), recursive_(recursive), class_indexing_(class_indexing), - exts_(extensions), - DatasetNode(std::move(cache)) {} + exts_(extensions) {} + +std::shared_ptr ImageFolderNode::Copy() { + std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + auto node = + std::make_shared(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_); + return node; +} + +void ImageFolderNode::Print(std::ostream &out) const { + out << Name() + "(path:" + dataset_dir_ + ",decode:" + (decode_ ? "true" : "false") + ",...)"; +} Status ImageFolderNode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h index 22045ed791..a112407d3b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h @@ -31,7 +31,7 @@ namespace dataset { /// \class ImageFolderNode /// \brief A Dataset derived class to represent ImageFolder dataset -class ImageFolderNode : public DatasetNode { +class ImageFolderNode : public MappableSourceNode { public: /// \brief Constructor ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr sampler, bool recursive, @@ -41,6 +41,18 @@ class ImageFolderNode : public DatasetNode { /// \brief Destructor ~ImageFolderNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kImageFolderNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc index 2740fb45ac..b9fd2ff9f6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc @@ -32,13 +32,30 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u const std::shared_ptr &sampler, const std::map &class_indexing, bool decode, std::shared_ptr cache) - : DatasetNode(std::move(cache)), + : MappableSourceNode(std::move(cache)), dataset_file_(dataset_file), usage_(usage), decode_(decode), class_index_(class_indexing), sampler_(sampler) {} +std::shared_ptr ManifestNode::Copy() { + std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + auto node = std::make_shared(dataset_file_, usage_, sampler, class_index_, decode_, cache_); + return node; +} + +void ManifestNode::Print(std::ostream &out) const { + out << Name() + "(file:" + dataset_file_; + if (sampler_ != nullptr) { + out << ",sampler"; + } + if (cache_ != nullptr) { + out << ",cache"; + } + out << ")"; +} + Status ManifestNode::ValidateParams() { std::vector forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; for (char c : dataset_file_) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h index b623868bcd..eb11bb1003 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h @@ -27,7 +27,7 @@ namespace mindspore { namespace dataset { -class ManifestNode : public DatasetNode { +class ManifestNode : public MappableSourceNode { public: /// \brief Constructor ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr &sampler, @@ -36,6 +36,18 @@ class ManifestNode : public DatasetNode { /// \brief Destructor ~ManifestNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kManifestNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc index d4b4afb7df..5662b80d0c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc @@ -30,7 +30,8 @@ namespace dataset { MindDataNode::MindDataNode(const std::vector &dataset_files, const std::vector &columns_list, const std::shared_ptr &sampler, nlohmann::json padded_sample, int64_t num_padded) - : dataset_file_(std::string()), + : MappableSourceNode(), + dataset_file_(std::string()), dataset_files_(dataset_files), search_for_pattern_(false), columns_list_(columns_list), @@ -41,7 +42,8 @@ MindDataNode::MindDataNode(const std::vector &dataset_files, const MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector &columns_list, const std::shared_ptr &sampler, nlohmann::json padded_sample, int64_t num_padded) - : dataset_file_(dataset_file), + : MappableSourceNode(), + dataset_file_(dataset_file), dataset_files_({}), search_for_pattern_(true), columns_list_(columns_list), @@ -50,6 +52,19 @@ MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector MindDataNode::Copy() { + std::shared_ptr node; + std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + if (dataset_files_.empty()) { + node = std::make_shared(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_); + } else { + node = std::make_shared(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_); + } + return node; +} + +void MindDataNode::Print(std::ostream &out) const { out << Name() + "(file:" + dataset_file_ + ",...)"; } + Status MindDataNode::ValidateParams() { if (!search_for_pattern_ && dataset_files_.size() > 4096) { std::string err_msg = diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h index 850137fcb2..4078125a07 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h @@ -27,7 +27,7 @@ namespace mindspore { namespace dataset { -class MindDataNode : public DatasetNode { +class MindDataNode : public MappableSourceNode { public: /// \brief Constructor MindDataNode(const std::vector &dataset_files, const std::vector &columns_list, @@ -40,6 +40,18 @@ class MindDataNode : public DatasetNode { /// \brief Destructor ~MindDataNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kMindDataNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc index 41baf8086a..57371b17ca 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc @@ -29,7 +29,15 @@ namespace dataset { MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr sampler, std::shared_ptr cache) - : DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} + : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} + +std::shared_ptr MnistNode::Copy() { + std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + auto node = std::make_shared(dataset_dir_, usage_, sampler, cache_); + return node; +} + +void MnistNode::Print(std::ostream &out) const { out << Name(); } Status MnistNode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h index 5e614ad335..9386849631 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h @@ -26,7 +26,7 @@ namespace mindspore { namespace dataset { -class MnistNode : public DatasetNode { +class MnistNode : public MappableSourceNode { public: /// \brief Constructor MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr sampler, @@ -35,6 +35,18 @@ class MnistNode : public DatasetNode { /// \brief Destructor ~MnistNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kMnistNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc index 76f42408cf..5e1f75ee7c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc @@ -27,6 +27,18 @@ namespace mindspore { namespace dataset { +std::shared_ptr RandomNode::Copy() { + std::shared_ptr node; + if (schema_ != nullptr) { + node = std::make_shared(total_rows_, schema_, columns_list_, cache_); + } else { + node = std::make_shared(total_rows_, schema_path_, columns_list_, cache_); + } + return node; +} + +void RandomNode::Print(std::ostream &out) const { out << Name() + "(num_row:" + std::to_string(total_rows_) + ",...)"; } + // ValidateParams for RandomNode Status RandomNode::ValidateParams() { if (total_rows_ < 0) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h index 0ec2eb34fe..ea1e1f6346 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h @@ -27,7 +27,7 @@ namespace mindspore { namespace dataset { -class RandomNode : public DatasetNode { +class RandomNode : public NonMappableSourceNode { public: // Some constants to provide limits to random generation. static constexpr int32_t kMaxNumColumns = 4; @@ -37,7 +37,7 @@ class RandomNode : public DatasetNode { /// \brief Constructor RandomNode(const int32_t &total_rows, std::shared_ptr schema, const std::vector &columns_list, std::shared_ptr cache) - : DatasetNode(std::move(cache)), + : NonMappableSourceNode(std::move(cache)), total_rows_(total_rows), schema_path_(""), schema_(std::move(schema)), @@ -46,14 +46,27 @@ class RandomNode : public DatasetNode { /// \brief Constructor RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector &columns_list, std::shared_ptr cache) - : DatasetNode(std::move(cache)), + : NonMappableSourceNode(std::move(cache)), total_rows_(total_rows), schema_path_(schema_path), + schema_(nullptr), columns_list_(columns_list) {} /// \brief Destructor ~RandomNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kRandomNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc index d7c64a9d95..f5745fc29c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc @@ -31,13 +31,23 @@ namespace dataset { // Constructor for TextFileNode TextFileNode::TextFileNode(std::vector dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr cache) - : DatasetNode(std::move(cache)), + : NonMappableSourceNode(std::move(cache)), dataset_files_(dataset_files), num_samples_(num_samples), shuffle_(shuffle), num_shards_(num_shards), shard_id_(shard_id) {} +std::shared_ptr TextFileNode::Copy() { + auto node = std::make_shared(dataset_files_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); + return node; +} + +void TextFileNode::Print(std::ostream &out) const { + out << Name() + "(file:..." + ",num_shards:" + std::to_string(num_shards_) + + ",shard_id:" + std::to_string(shard_id_) + ",cache:" + ((cache_ != nullptr) ? "true" : "false") + ",...)"; +} + Status TextFileNode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h index 96a76cef28..cdad8eaadc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h @@ -28,7 +28,7 @@ namespace dataset { /// \class TextFileNode /// \brief A Dataset derived class to represent TextFile dataset -class TextFileNode : public DatasetNode { +class TextFileNode : public NonMappableSourceNode { public: /// \brief Constructor TextFileNode(std::vector dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, @@ -37,6 +37,18 @@ class TextFileNode : public DatasetNode { /// \brief Destructor ~TextFileNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kTextFileNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc index 7f4cd73cf1..a7a8a17658 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc @@ -30,6 +30,23 @@ namespace mindspore { namespace dataset { +std::shared_ptr TFRecordNode::Copy() { + std::shared_ptr node; + if (schema_obj_ != nullptr) { + node = std::make_shared(dataset_files_, schema_obj_, columns_list_, num_samples_, shuffle_, + num_shards_, shard_id_, shard_equal_rows_, cache_); + } else { + node = std::make_shared(dataset_files_, schema_path_, columns_list_, num_samples_, shuffle_, + num_shards_, shard_id_, shard_equal_rows_, cache_); + } + return node; +} + +void TFRecordNode::Print(std::ostream &out) const { + out << Name() + "(num_samples:" + std::to_string(num_samples_) + ",num_shards:" + std::to_string(num_shards_) + + ",shard_id:" + std::to_string(shard_id_) + ",...)"; +} + // Validator for TFRecordNode Status TFRecordNode::ValidateParams() { if (dataset_files_.empty()) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h index 6f12b0a64e..2c2a14be70 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h @@ -29,14 +29,14 @@ namespace dataset { /// \class TFRecordNode /// \brief A Dataset derived class to represent TFRecord dataset -class TFRecordNode : public DatasetNode { +class TFRecordNode : public NonMappableSourceNode { public: /// \brief Constructor /// \note Parameter 'schema' is the path to the schema file TFRecordNode(const std::vector &dataset_files, std::string schema, const std::vector &columns_list, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr cache) - : DatasetNode(std::move(cache)), + : NonMappableSourceNode(std::move(cache)), dataset_files_(dataset_files), schema_path_(schema), columns_list_(columns_list), @@ -51,7 +51,7 @@ class TFRecordNode : public DatasetNode { TFRecordNode(const std::vector &dataset_files, std::shared_ptr schema, const std::vector &columns_list, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr cache) - : DatasetNode(std::move(cache)), + : NonMappableSourceNode(std::move(cache)), dataset_files_(dataset_files), schema_obj_(schema), columns_list_(columns_list), @@ -64,6 +64,18 @@ class TFRecordNode : public DatasetNode { /// \brief Destructor ~TFRecordNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kTFRecordNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc index f0d49aa4b6..9b36424d7c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc @@ -32,7 +32,7 @@ namespace dataset { VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, const std::map &class_indexing, bool decode, std::shared_ptr sampler, std::shared_ptr cache) - : DatasetNode(std::move(cache)), + : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), task_(task), usage_(usage), @@ -40,6 +40,14 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const decode_(decode), sampler_(sampler) {} +std::shared_ptr VOCNode::Copy() { + std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + auto node = std::make_shared(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_); + return node; +} + +void VOCNode::Print(std::ostream &out) const { out << Name(); } + Status VOCNode::ValidateParams() { Path dir(dataset_dir_); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h index 4102e3189c..55750fa40f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h @@ -27,7 +27,7 @@ namespace mindspore { namespace dataset { -class VOCNode : public DatasetNode { +class VOCNode : public MappableSourceNode { public: /// \brief Constructor VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, @@ -37,6 +37,18 @@ class VOCNode : public DatasetNode { /// \brief Destructor ~VOCNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kVOCNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc index ec03303f31..5a485abe1c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc @@ -29,7 +29,16 @@ namespace dataset { // Constructor for SyncWaitNode SyncWaitNode::SyncWaitNode(std::shared_ptr child, const std::string &condition_name, py::function callback) : condition_name_(condition_name), callback_(callback) { - this->children.push_back(child); + this->AddChild(child); +} + +std::shared_ptr SyncWaitNode::Copy() { + auto node = std::make_shared(nullptr, condition_name_, callback_); + return node; +} + +void SyncWaitNode::Print(std::ostream &out) const { + out << Name() + "(cond_name:" + condition_name_ + "" + ")"; } // Function to build the BarrierOp diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h index 36320f848e..ecd3995cbd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h @@ -36,6 +36,18 @@ class SyncWaitNode : public DatasetNode { /// \brief Destructor ~SyncWaitNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kSyncWaitNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc index 917df6a781..509b4a7bfd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc @@ -27,10 +27,15 @@ namespace mindspore { namespace dataset { // Constructor for TakeNode -TakeNode::TakeNode(std::shared_ptr child, int32_t count) : take_count_(count) { - this->children.push_back(child); +TakeNode::TakeNode(std::shared_ptr child, int32_t count) : take_count_(count) { this->AddChild(child); } + +std::shared_ptr TakeNode::Copy() { + auto node = std::make_shared(nullptr, take_count_); + return node; } +void TakeNode::Print(std::ostream &out) const { out << Name() + "(num_rows:" + std::to_string(take_count_) + ")"; } + // Function to build the TakeOp std::vector> TakeNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h index 93d735d15a..6d3b1e3d8a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h @@ -34,6 +34,18 @@ class TakeNode : public DatasetNode { /// \brief Destructor ~TakeNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kTakeNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps std::vector> Build() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc index 790a04e2e9..de243b951c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc @@ -22,6 +22,7 @@ #include #include "minddata/dataset/engine/datasetops/device_queue_op.h" +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/status.h" #include "utils/ms_context.h" @@ -39,7 +40,19 @@ TransferNode::TransferNode(std::shared_ptr child, std::string queue total_batch_(total_batch), create_data_info_queue_(create_data_info_queue), device_id_(0) { - this->children.push_back(child); + this->AddChild(child); +} + +std::shared_ptr TransferNode::Copy() { + auto node = std::make_shared(nullptr, queue_name_, device_type_, send_epoch_end_, total_batch_, + create_data_info_queue_); + return node; +} + +void TransferNode::Print(std::ostream &out) const { + out << Name() + "(prefetch_size:" + std::to_string(prefetch_size_) + + ",send_epoch_end:" + (send_epoch_end_ ? "true" : "false") + ",total_batch:" + std::to_string(total_batch_) + + ")"; } // Validator for TransferNode @@ -94,5 +107,16 @@ std::vector> TransferNode::Build() { return node_ops; } +// Visitor accepting method for NodePass +Status TransferNode::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for NodePass +Status TransferNode::AcceptAfter(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h index 65cad58077..b6e03113e8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h @@ -35,6 +35,18 @@ class TransferNode : public DatasetNode { /// \brief Destructor ~TransferNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kTransferNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps std::vector> Build() override; @@ -43,6 +55,20 @@ class TransferNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + static Status get_distribution(std::shared_ptr ds, int32_t *device_id); + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(NodePass *p, bool *modified) override; + private: std::string queue_name_; int32_t device_id_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc index 0d2c068635..9cf11421e1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc @@ -21,30 +21,36 @@ #include #include "minddata/dataset/engine/datasetops/zip_op.h" - +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -ZipNode::ZipNode(const std::vector> &datasets) : datasets_(datasets) { - for (auto dataset : datasets_) { - this->children.push_back(dataset); - } +ZipNode::ZipNode(const std::vector> &datasets) { + for (auto const &child : datasets) AddChild(child); } +std::shared_ptr ZipNode::Copy() { + std::vector> empty_vector; + empty_vector.clear(); + auto node = std::make_shared(empty_vector); + return node; +} + +void ZipNode::Print(std::ostream &out) const { out << Name(); } + Status ZipNode::ValidateParams() { - if (datasets_.empty()) { - std::string err_msg = "ZipNode: datasets to zip are not specified."; + if (children_.size() < 2) { + std::string err_msg = "ZipNode: input datasets are not specified."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } - if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { - std::string err_msg = "ZipNode: zip datasets should not be null."; + if (find(children_.begin(), children_.end(), nullptr) != children_.end()) { + std::string err_msg = "ZipNode: input datasets should not be null."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return Status::OK(); } @@ -56,5 +62,17 @@ std::vector> ZipNode::Build() { return node_ops; } +// Visitor accepting method for NodePass +Status ZipNode::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for NodePass +Status ZipNode::AcceptAfter(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h index 27f92e0da5..86ec73e65c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h @@ -34,6 +34,18 @@ class ZipNode : public DatasetNode { /// \brief Destructor ~ZipNode() = default; + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kZipNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; @@ -42,8 +54,17 @@ class ZipNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; - private: - std::vector> datasets_; + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for accepting NodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(NodePass *p, bool *modified) override; }; } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc index d05fd04636..05b69ad97f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc @@ -22,10 +22,12 @@ #endif #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" #include "minddata/dataset/engine/ir/datasetops/concat_node.h" +#include "minddata/dataset/engine/ir/datasetops/filter_node.h" #include "minddata/dataset/engine/ir/datasetops/map_node.h" #include "minddata/dataset/engine/ir/datasetops/project_node.h" #include "minddata/dataset/engine/ir/datasetops/rename_node.h" #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" +#include "minddata/dataset/engine/ir/datasetops/root_node.h" #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" #include "minddata/dataset/engine/ir/datasetops/skip_node.h" #ifdef ENABLE_PYTHON @@ -34,34 +36,6 @@ #include "minddata/dataset/engine/ir/datasetops/take_node.h" #include "minddata/dataset/engine/ir/datasetops/transfer_node.h" #include "minddata/dataset/engine/ir/datasetops/zip_node.h" -#include "minddata/dataset/engine/ir/datasetops/source/album_node.h" -#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" -#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h" -#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" -#ifndef ENABLE_ANDROID -#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" -#endif -#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h" -#ifndef ENABLE_ANDROID -#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" -#endif -#ifdef ENABLE_PYTHON -#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" -#endif -#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" -#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" -#ifndef ENABLE_ANDROID -#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" -#endif -#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" -#include "minddata/dataset/engine/ir/datasetops/source/random_node.h" -#ifndef ENABLE_ANDROID -#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" -#endif -#ifndef ENABLE_ANDROID -#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" -#endif -#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" ////////////////////////////////// // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. @@ -113,7 +87,12 @@ namespace mindspore { namespace dataset { // Driver method for TreePass -Status TreePass::Run(std::shared_ptr root_ir, bool *modified) { return Status::OK(); } +Status TreePass::Run(std::shared_ptr root_ir, bool *modified) { + if (root_ir == nullptr || modified == nullptr) { + return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); + } + return this->RunOnTree(root_ir, modified); +} // Driver method for NodePass Status NodePass::Run(std::shared_ptr root_ir, bool *modified) { @@ -132,15 +111,23 @@ Status NodePass::Run(std::shared_ptr root_ir, bool *modified) { // Helper function to perform DFS visit Status NodePass::DFSNodeVisit(std::shared_ptr node_ir, bool *modified) { - RETURN_IF_NOT_OK(node_ir->Accept(this, modified)); + bool m = false; + + RETURN_IF_NOT_OK(node_ir->Accept(this, &m)); + *modified |= m; for (const auto &c : node_ir->Children()) { - RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); + RETURN_IF_NOT_OK(this->DFSNodeVisit(c, &m)); + *modified |= m; } - return node_ir->AcceptAfter(this, modified); + RETURN_IF_NOT_OK(node_ir->AcceptAfter(this, &m)); + *modified |= m; + return Status::OK(); } // Helper function to perform BFS visit Status NodePass::BFSNodeVisit(std::shared_ptr node_ir, bool *modified) { + bool m = false; + // Initialize bfs queue with root std::queue> bfsQueue; bfsQueue.push(node_ir); @@ -152,7 +139,8 @@ Status NodePass::BFSNodeVisit(std::shared_ptr node_ir, bool *modifi bfsQueue.pop(); // Run node pass - RETURN_IF_NOT_OK(curNode->Accept(this, modified)); + RETURN_IF_NOT_OK(curNode->Accept(this, &m)); + *modified |= m; // Push children into bfs queue for (const auto &c : curNode->Children()) { @@ -162,331 +150,119 @@ Status NodePass::BFSNodeVisit(std::shared_ptr node_ir, bool *modifi return Status::OK(); } -// For datasetops IR +// For non-leaf IR node Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return Visit(std::static_pointer_cast(node), modified); } - Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return VisitAfter(std::static_pointer_cast(node), modified); } - Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return Visit(std::static_pointer_cast(node), modified); } - Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return VisitAfter(std::static_pointer_cast(node), modified); -} - -#ifndef ENABLE_ANDROID -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return Visit(std::static_pointer_cast(node), modified); -} - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return VisitAfter(std::static_pointer_cast(node), modified); } -#endif - Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return Visit(std::static_pointer_cast(node), modified); } - Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return VisitAfter(std::static_pointer_cast(node), modified); } - Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return Visit(std::static_pointer_cast(node), modified); } - Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return VisitAfter(std::static_pointer_cast(node), modified); } - +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + return Visit(std::static_pointer_cast(node), modified); +} +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + return VisitAfter(std::static_pointer_cast(node), modified); +} Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return Visit(std::static_pointer_cast(node), modified); } - Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return VisitAfter(std::static_pointer_cast(node), modified); } - Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return Visit(std::static_pointer_cast(node), modified); } - Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return VisitAfter(std::static_pointer_cast(node), modified); } - Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return Visit(std::static_pointer_cast(node), modified); } - Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return VisitAfter(std::static_pointer_cast(node), modified); } - Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return Visit(std::static_pointer_cast(node), modified); } - Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return VisitAfter(std::static_pointer_cast(node), modified); } - -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default +Status NodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } - -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default +Status NodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } - -#ifdef ENABLE_PYTHON -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default +Status NodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -#endif - Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return Visit(std::static_pointer_cast(node), modified); } - Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return VisitAfter(std::static_pointer_cast(node), modified); } - Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return Visit(std::static_pointer_cast(node), modified); } - Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return VisitAfter(std::static_pointer_cast(node), modified); } - Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return Visit(std::static_pointer_cast(node), modified); } - Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return VisitAfter(std::static_pointer_cast(node), modified); -} - -// For datasetops/source IR -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return Visit(std::static_pointer_cast(node), modified); -} - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default return VisitAfter(std::static_pointer_cast(node), modified); } - -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return Visit(std::static_pointer_cast(node), modified); -} - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return VisitAfter(std::static_pointer_cast(node), modified); -} - -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return Visit(std::static_pointer_cast(node), modified); -} - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return VisitAfter(std::static_pointer_cast(node), modified); -} - -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return Visit(std::static_pointer_cast(node), modified); -} - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return VisitAfter(std::static_pointer_cast(node), modified); -} - -#ifndef ENABLE_ANDROID -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return Visit(std::static_pointer_cast(node), modified); -} - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return VisitAfter(std::static_pointer_cast(node), modified); -} -#endif - -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return Visit(std::static_pointer_cast(node), modified); -} - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return VisitAfter(std::static_pointer_cast(node), modified); -} - -#ifndef ENABLE_ANDROID -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return Visit(std::static_pointer_cast(node), modified); -} - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return VisitAfter(std::static_pointer_cast(node), modified); -} -#endif - #ifdef ENABLE_PYTHON -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return Visit(std::static_pointer_cast(node), modified); -} - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return VisitAfter(std::static_pointer_cast(node), modified); -} -#endif - -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return Visit(std::static_pointer_cast(node), modified); -} - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return VisitAfter(std::static_pointer_cast(node), modified); -} - -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return Visit(std::static_pointer_cast(node), modified); -} - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return VisitAfter(std::static_pointer_cast(node), modified); -} - -#ifndef ENABLE_ANDROID -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return Visit(std::static_pointer_cast(node), modified); -} - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return VisitAfter(std::static_pointer_cast(node), modified); -} -#endif - -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return Visit(std::static_pointer_cast(node), modified); -} - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return VisitAfter(std::static_pointer_cast(node), modified); -} - -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return Visit(std::static_pointer_cast(node), modified); -} - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return VisitAfter(std::static_pointer_cast(node), modified); -} - -#ifndef ENABLE_ANDROID -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default +Status NodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } #endif - #ifndef ENABLE_ANDROID -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default +Status NodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } #endif -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default +// For leaf IR Node +Status NodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } - -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h index bebee15519..9b3ef0c849 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h @@ -26,123 +26,87 @@ namespace mindspore { namespace dataset { +// Non-leaf IR node class BatchNode; class BucketBatchByLengthNode; -#ifndef ENABLE_ANDROID -class BuildSentenceVocabNode; -#endif class BuildVocabNode; class ConcatNode; +class FilterNode; class MapNode; class ProjectNode; class RenameNode; class RepeatNode; +class RootNode; class ShuffleNode; class SkipNode; -#ifdef ENABLE_PYTHON -class SyncWaitNode; -#endif class TakeNode; class TransferNode; class ZipNode; +#ifdef ENABLE_PYTHON +class SyncWaitNode; +#endif +#ifndef ENABLE_ANDROID +class BuildSentenceVocabNode; +#endif +// Leaf IR node class AlbumNode; class CelebANode; class Cifar100Node; class Cifar10Node; -#ifndef ENABLE_ANDROID -class CLUENode; -#endif class CocoNode; -#ifndef ENABLE_ANDROID -class CSVNode; -#endif -#ifdef ENABLE_PYTHON -class GeneratorNode; -#endif class ImageFolderNode; class ManifestNode; -#ifndef ENABLE_ANDROID -class MindDataNode; -#endif class MnistNode; class RandomNode; -#ifndef ENABLE_ANDROID -class TextFileNode; +class VOCNode; +#ifdef ENABLE_PYTHON +class GeneratorNode; #endif #ifndef ENABLE_ANDROID +class CLUENode; +class CSVNode; +class MindDataNode; +class TextFileNode; class TFRecordNode; #endif -class VOCNode; ////////////////////////////////// // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. class BatchOp; - class MapOp; - class ProjectOp; - class RenameOp; - class SkipOp; - class ShuffleOp; - class AlbumOp; - class RandomDataOp; - class RepeatOp; - class TakeOp; - class ZipOp; - class DeviceQueueOp; - class ImageFolderOp; - class MnistOp; - class ManifestOp; - class CifarOp; - class VOCOp; - class CocoOp; - class CelebAOp; - class EpochCtrlOp; - class BuildVocabOp; - class ConcatOp; - #ifndef ENABLE_ANDROID class MindRecordOp; - class TFReaderOp; - class CacheOp; - class CacheMergeOp; - class CacheLookupOp; - class BuildSentencePieceVocabOp; - class ClueOp; - class CsvOp; - class TextFileOp; #endif - #ifdef ENABLE_PYTHON class FilterOp; - class GeneratorOp; #endif ////////////////////////////////// @@ -175,6 +139,13 @@ class TreePass : public Pass { /// \param[inout] modified Indicate if the tree was modified Status Run(std::shared_ptr root_ir, bool *modified) final; + /// \brief Derived classes may implement the runOnTree function to implement tree transformation. + /// "modified" flag needs to be set to true if tree is modified during the pass execution. + /// \param[inout] tree The tree to operate on. + /// \param[inout] Indicate of the tree was modified. + /// \return Status The error code return + virtual Status RunOnTree(std::shared_ptr root_ir, bool *modified) { return Status::OK(); } + ////////////////////////////////// // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. /// \brief Run the transformation pass against the execution tree. @@ -191,8 +162,17 @@ class TreePass : public Pass { ////////////////////////////////// }; -// NodePass is a basic Pass class which performs transformation on Node visiting. +// NodePass is a base Pass class which performs transformation on node visiting. // NodePass implements Visitor design pattern. +// The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal, +// and the other when all the descending nodes are visited. +// Actual transformation is done by implementing a new derived class of NodePass. +// The derived class will implement the method Visit()/VisitAfter() passing specified node types +// it wants to action on them, overriding the ones defined in NodePass. +// If the derived class wants to perform the same action on all node types, +// it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode. +// This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back +// to call the Visit()/VisitAfter() in this parent NodePass class. class NodePass : public Pass { public: // Tree traversal order @@ -223,153 +203,57 @@ class NodePass : public Pass { /// \return Status The error code return virtual Status VisitAfter(std::shared_ptr node, bool *modified) { return Status::OK(); } - // For datasetops IR - // Visit method to be overridden. - // Note that member template can not be virtual, any node which wants to work with NodePass - // should declare Visit of its own type and override "Accept" from DatasetNode. + // Visit()/VisitAfter() method to be overridden. + // These pairs of Visit()/VisitAfter() for each derived class of DatasetNode are defined here. + // Their implementation are in .cc file to avoid adding the include files of those derived classes. + // The implementation simply falls back to call Visit()/VisitAfter of class DatasetNode, the parent of + // the derived classes. With this technique, the transformation classes derived from NodePass needs only to + // implement Visit()/VisitAfter() passing DatasetNode if it wants to action on any derived classes + // of DatasetNode in the same way. + // Note that virtual template functions are not permitted in C++. + // + // Non-leaf IR node virtual Status Visit(std::shared_ptr node, bool *modified); - - // VisitAfter method to be overridden. - // Note that member template can not be virtual, any node which wants to work with NodePass - // should declare VisitAfter of its own type and override "AcceptAfter" from DatasetNode. virtual Status VisitAfter(std::shared_ptr node, bool *modified); - virtual Status Visit(std::shared_ptr node, bool *modified); - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - -#ifndef ENABLE_ANDROID - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); -#endif - virtual Status Visit(std::shared_ptr node, bool *modified); - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - virtual Status Visit(std::shared_ptr node, bool *modified); - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - + virtual Status Visit(std::shared_ptr node, bool *modified); + virtual Status VisitAfter(std::shared_ptr node, bool *modified); virtual Status Visit(std::shared_ptr node, bool *modified); - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - virtual Status Visit(std::shared_ptr node, bool *modified); - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - virtual Status Visit(std::shared_ptr node, bool *modified); - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - virtual Status Visit(std::shared_ptr node, bool *modified); - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - + virtual Status Visit(std::shared_ptr node, bool *modified); + virtual Status VisitAfter(std::shared_ptr node, bool *modified); virtual Status Visit(std::shared_ptr node, bool *modified); - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - virtual Status Visit(std::shared_ptr node, bool *modified); - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - -#ifdef ENABLE_PYTHON - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); -#endif - virtual Status Visit(std::shared_ptr node, bool *modified); - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - virtual Status Visit(std::shared_ptr node, bool *modified); - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - virtual Status Visit(std::shared_ptr node, bool *modified); - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - - // For datasetops/source IR - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - -#ifndef ENABLE_ANDROID - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); -#endif - - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - -#ifndef ENABLE_ANDROID - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); -#endif - #ifdef ENABLE_PYTHON - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); -#endif - - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - -#ifndef ENABLE_ANDROID - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); -#endif - - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - -#ifndef ENABLE_ANDROID - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); + virtual Status Visit(std::shared_ptr node, bool *modified); + virtual Status VisitAfter(std::shared_ptr node, bool *modified); #endif - #ifndef ENABLE_ANDROID - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); + virtual Status Visit(std::shared_ptr node, bool *modified); + virtual Status VisitAfter(std::shared_ptr node, bool *modified); #endif - - virtual Status Visit(std::shared_ptr node, bool *modified); - - virtual Status VisitAfter(std::shared_ptr node, bool *modified); + // Leaf IR node + virtual Status Visit(std::shared_ptr node, bool *modified); + virtual Status VisitAfter(std::shared_ptr node, bool *modified); ////////////////////////////////// // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. @@ -396,86 +280,47 @@ class NodePass : public Pass { // Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode // of its own type and override "Accept" from DatasetOp. virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); - #ifndef ENABLE_ANDROID virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); #endif - #ifdef ENABLE_PYTHON virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); #endif ////////////////////////////////// diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc index b3dbee4910..6e08a44dc4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc @@ -18,6 +18,7 @@ #include "minddata/dataset/core/client.h" #include "minddata/dataset/include/datasets.h" +#include "minddata/dataset/engine/ir/datasetops/root_node.h" #include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/opt/pre/input_validation_pass.h" @@ -119,11 +120,16 @@ Status TreeAdapter::BuildExecutionTree(std::shared_ptr ir, std::sha return Status::OK(); } -Status TreeAdapter::Compile(std::shared_ptr root_ir, int32_t num_epochs) { - num_epochs_ = num_epochs; +Status TreeAdapter::Compile(std::shared_ptr input_ir, int32_t num_epochs) { optimize_ = true; // Always ON (temporary) - RETURN_UNEXPECTED_IF_NULL(root_ir); + RETURN_UNEXPECTED_IF_NULL(input_ir); + MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n'; + + // Copy the input IR tree and insert under the root node + // Create a root node to host the input IR tree + auto root_ir = std::make_shared(input_ir->DeepCopy(), num_epochs); + MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n'; // Pre-pass of the IR tree RETURN_IF_NOT_OK(PrePass(root_ir)); @@ -136,11 +142,15 @@ Status TreeAdapter::Compile(std::shared_ptr root_ir, int32_t num_ep // Post-pass of the IR tree RETURN_IF_NOT_OK(PostPass(root_ir)); + MS_LOG(INFO) << "Plan after PostPass:" << '\n' << *root_ir << '\n'; + // This will evolve in the long run tree_ = std::make_unique(); + // Build the Execution tree from the child of the root node std::shared_ptr root_op; - RETURN_IF_NOT_OK(BuildExecutionTree(root_ir, &root_op)); + // We will replace input_ir with root_ir->Children()[0] once IR optimizer is in + RETURN_IF_NOT_OK(BuildExecutionTree(input_ir, &root_op)); RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_); diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h index d4ef3e85e3..6e95ae0a60 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h @@ -67,10 +67,6 @@ class TreeAdapter { // Optional optimizations status bool OptimizationEnabled() const { return optimize_; } - // Getter function to get the total number of epochs to be run on this tree. - // @return total number of epochs - int32_t num_epochs() { return num_epochs_; } - private: // This function runs a mandatory pass checking the syntax and semantics of the IR tree. Status PrePass(std::shared_ptr ir); diff --git a/mindspore/ccsrc/minddata/dataset/include/samplers.h b/mindspore/ccsrc/minddata/dataset/include/samplers.h index 6485c9534b..e3150fa751 100644 --- a/mindspore/ccsrc/minddata/dataset/include/samplers.h +++ b/mindspore/ccsrc/minddata/dataset/include/samplers.h @@ -47,6 +47,10 @@ class SamplerObj : public std::enable_shared_from_this { /// \return Shared pointers to the newly created Sampler virtual std::shared_ptr Build() = 0; + /// \brief Pure virtual function to copy a SamplerObj class + /// \return Shared pointers to the newly copied SamplerObj + virtual std::shared_ptr Copy() = 0; + /// \brief Function for derived class to get the shard id of sampler /// \return The shard id of the derived sampler virtual int64_t ShardId() { return 0; } @@ -132,6 +136,11 @@ class DistributedSamplerObj : public SamplerObj { std::shared_ptr Build() override; + std::shared_ptr Copy() override { + return std::make_shared(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_, + even_dist_); + } + #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; #endif @@ -160,6 +169,10 @@ class PKSamplerObj : public SamplerObj { std::shared_ptr Build() override; + std::shared_ptr Copy() override { + return std::make_shared(num_val_, shuffle_, num_samples_); + } + #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; #endif @@ -174,9 +187,8 @@ class PKSamplerObj : public SamplerObj { class PreBuiltSamplerObj : public SamplerObj { public: -#ifndef ENABLE_ANDROID explicit PreBuiltSamplerObj(std::shared_ptr sampler); - +#ifndef ENABLE_ANDROID explicit PreBuiltSamplerObj(std::shared_ptr sampler); #endif @@ -188,6 +200,8 @@ class PreBuiltSamplerObj : public SamplerObj { std::shared_ptr BuildForMindDataset() override; #endif + std::shared_ptr Copy() override; + bool ValidateParams() override; private: @@ -205,6 +219,8 @@ class RandomSamplerObj : public SamplerObj { std::shared_ptr Build() override; + std::shared_ptr Copy() override { return std::make_shared(replacement_, num_samples_); } + #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; #endif @@ -224,6 +240,10 @@ class SequentialSamplerObj : public SamplerObj { std::shared_ptr Build() override; + std::shared_ptr Copy() override { + return std::make_shared(start_index_, num_samples_); + } + #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; #endif @@ -243,6 +263,10 @@ class SubsetRandomSamplerObj : public SamplerObj { std::shared_ptr Build() override; + std::shared_ptr Copy() override { + return std::make_shared(indices_, num_samples_); + } + #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; #endif @@ -262,6 +286,10 @@ class WeightedRandomSamplerObj : public SamplerObj { std::shared_ptr Build() override; + std::shared_ptr Copy() override { + return std::make_shared(weights_, num_samples_, replacement_); + } + bool ValidateParams() override; private: diff --git a/mindspore/ccsrc/minddata/dataset/include/transforms.h b/mindspore/ccsrc/minddata/dataset/include/transforms.h index 1a38b7d869..3c9f624bb6 100644 --- a/mindspore/ccsrc/minddata/dataset/include/transforms.h +++ b/mindspore/ccsrc/minddata/dataset/include/transforms.h @@ -32,7 +32,10 @@ class TensorOp; class TensorOperation : public std::enable_shared_from_this { public: /// \brief Constructor - TensorOperation(); + TensorOperation() : random_op_(false) {} + + /// \brief Constructor + explicit TensorOperation(bool random) : random_op_(random) {} /// \brief Destructor ~TensorOperation() = default; @@ -42,6 +45,13 @@ class TensorOperation : public std::enable_shared_from_this { virtual std::shared_ptr Build() = 0; virtual Status ValidateParams() = 0; + + /// \brief Check whether the operation is deterministic. + /// \return true if this op is a random op (returns non-deterministic result e.g. RandomCrop) + bool IsRandomOp() const { return random_op_; } + + protected: + bool random_op_; }; // Helper function to validate fill value diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc index adc9d10d52..e9250ea219 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc @@ -427,7 +427,7 @@ Status PadEndNumeric(const std::shared_ptr &src, std::shared_ptr Status PadEndNumericHelper(const std::shared_ptr &src, std::shared_ptr dst, std::vector cur_ind, size_t cur_dim) { if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data - dst->CopyLastDimAt(src, cur_ind); + RETURN_IF_NOT_OK(dst->CopyLastDimAt(src, cur_ind)); } else { // not the last dimension, keep doing recursion dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]); for (dsize_t i = 0; i < min_ind; i++) { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h index 3dfb3f713d..20f3395697 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h @@ -57,7 +57,7 @@ class RandomCropOp : public TensorOp { Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; // Function breaks out the compute function's image padding functionality and makes available to other Ops - // Using this class as a base - restructrued to allow for RandomCropWithBBox Augmentation Op + // Using this class as a base - re-structured to allow for RandomCropWithBBox Augmentation Op // @param input: Input is the original Image // @param pad_image: Pointer to new Padded image // @param t_pad_top: Total Top Padding - Based on input and value calculated in function if required diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index bebc632203..9107e95296 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -570,7 +570,7 @@ class WeightedRandomSampler(BuiltinSampler): Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). Args: - weights (list[float]): A sequence of weights, not necessarily summing up to 1. + weights (list[float, int]): A sequence of weights, not necessarily summing up to 1. num_samples (int, optional): Number of elements to sample (default=None, all elements). replacement (bool): If True, put the sample ID back for the next draw (default=True). diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 896c284896..4a6ff81abc 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -17,6 +17,7 @@ SET(DE_UT_SRCS c_api_dataset_coco_test.cc c_api_dataset_config_test.cc c_api_dataset_csv_test.cc + c_api_dataset_ir_node_test.cc c_api_dataset_iterator_test.cc c_api_dataset_manifest_test.cc c_api_dataset_minddata_test.cc diff --git a/tests/ut/cpp/dataset/c_api_dataset_ir_node_test.cc b/tests/ut/cpp/dataset/c_api_dataset_ir_node_test.cc new file mode 100644 index 0000000000..fba70d7cee --- /dev/null +++ b/tests/ut/cpp/dataset/c_api_dataset_ir_node_test.cc @@ -0,0 +1,142 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/core/client.h" +#include "common/common.h" +#include "gtest/gtest.h" + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" +#include "minddata/dataset/engine/opt/pre/getter_pass.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::MsLogLevel::INFO; + +class MindDataTestIRNodes : public UT::DatasetOpTesting { + public: + MindDataTestIRNodes() = default; + void SetUp() override { GlobalInit(); } + + // compare the ptr of the nodes in two trees, used to test the deep copy of nodes, will return error code + // if (ptr1 == ptr2) does not equal to flag or the two tree has different structures (or node names are not the same) + Status CompareTwoTrees(std::shared_ptr root1, std::shared_ptr root2, bool flag) { + CHECK_FAIL_RETURN_UNEXPECTED(root1 != nullptr && root2 != nullptr, "Error in Compare, nullptr."); + if (((root1.get() == root2.get()) != flag) || (root1->Name() != root2->Name())) { + std::string err_msg = + "Expect node ptr " + root1->Name() + (flag ? "==" : "!=") + root2->Name() + " but they aren't!"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + size_t num_child = root1->Children().size(); + + CHECK_FAIL_RETURN_UNEXPECTED(num_child == root2->Children().size(), + root1->Name() + " has " + std::to_string(num_child) + "child, node #2 has " + + std::to_string(root2->Children().size()) + " child."); + + for (size_t ind = 0; ind < num_child; ind++) { + RETURN_IF_NOT_OK(CompareTwoTrees(root1->Children()[ind], root2->Children()[ind], flag)); + } + return Status::OK(); + } + + // print the node's name in post order + Status PostOrderPrintTree(std::shared_ptr ir, std::string &names) { + RETURN_UNEXPECTED_IF_NULL(ir); + for (auto child : ir->Children()) { + RETURN_IF_NOT_OK(PostOrderPrintTree(child, names)); + } + names += (ir->Name() + "->"); + return Status::OK(); + } +}; + +TEST_F(MindDataTestIRNodes, MindDataTestSimpleDeepCopy) { + MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestSimpleDeepCopy."; + + auto tree1 = RandomData(44)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2)->IRNode(); + + auto tree2 = tree1->DeepCopy(); + std::string tree_1_names, tree_2_names; + + ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names)); + ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names)); + + // expected output for the 2 names: + // RandomDataset->Repeat->Project->Shuffle->Batch-> + EXPECT_EQ(tree_1_names, tree_2_names); + + ASSERT_OK(CompareTwoTrees(tree1, tree1, true)); + ASSERT_OK(CompareTwoTrees(tree1, tree2, false)); + + // verify compare function is correct + EXPECT_TRUE(CompareTwoTrees(tree2, tree2, false).IsError()); +} + +TEST_F(MindDataTestIRNodes, MindDataTestZipDeepCopy) { + MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestZipDeepCopy."; + + auto branch1 = RandomData(44)->Project({"label"}); + auto branch2 = RandomData(44)->Shuffle(10); + + auto tree1 = Zip({branch1, branch2})->Batch(2)->IRNode(); + + auto tree2 = tree1->DeepCopy(); + std::string tree_1_names, tree_2_names; + + ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names)); + ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names)); + + // expected output for the 2 names: + // RandomDataset->Project->RandomDataset->Shuffle->Zip->Batch-> + EXPECT_EQ(tree_1_names, tree_2_names); + + // verify the pointer within the same tree are the same + ASSERT_OK(CompareTwoTrees(tree1, tree1, true)); + // verify two trees + ASSERT_OK(CompareTwoTrees(tree1, tree2, false)); +} + +TEST_F(MindDataTestIRNodes, MindDataTestNodeRemove) { + MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestNodeRemove."; + + auto branch1 = RandomData(44)->Project({"label"}); + auto branch2 = ImageFolder("path"); + auto tree = Zip({branch1, branch2})->IRNode(); + /*** + tree looks like this, we will remove node and test its functionalities + Zip + / \ + Project ImageFolder + / + RandomData + ***/ + auto tree_copy_1 = tree->DeepCopy(); + ASSERT_EQ(tree_copy_1->Children().size(), 2); + // remove the project in the tree and test + ASSERT_OK(tree_copy_1->Children()[0]->Remove()); // remove Project from tree + ASSERT_OK(CompareTwoTrees(tree_copy_1, Zip({RandomData(44), ImageFolder("path")})->IRNode(), false)); + // remove the ImageFolder, a leaf node from the tree + std::string tree_1_names, tree_2_names; + ASSERT_OK(PostOrderPrintTree(tree_copy_1, tree_1_names)); + EXPECT_EQ(tree_1_names, "RandomDataset->ImageFolderDataset->Zip->"); + auto tree_copy_2 = tree->DeepCopy(); + ASSERT_EQ(tree_copy_2->Children().size(), 2); + tree_copy_2->Children()[1]->Remove(); + ASSERT_OK(PostOrderPrintTree(tree_copy_2, tree_2_names)); + EXPECT_EQ(tree_2_names, "RandomDataset->Project->Zip->"); +}