From: @nsyca Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -568,8 +568,8 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab( | |||
| const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage, | |||
| SentencePieceModel model_type, const std::unordered_map<std::string, std::string> ¶ms) { | |||
| auto vocab = std::make_shared<SentencePieceVocab>(); | |||
| auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage, | |||
| model_type, params); | |||
| auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode()->DeepCopy(), vocab, col_names, vocab_size, | |||
| character_coverage, model_type, params); | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| Status rc = runtime_context->Init(); | |||
| @@ -600,8 +600,8 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum | |||
| const std::pair<int64_t, int64_t> &freq_range, int64_t top_k, | |||
| const std::vector<std::string> &special_tokens, bool special_first) { | |||
| auto vocab = std::make_shared<Vocab>(); | |||
| auto ds = | |||
| std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first); | |||
| auto ds = std::make_shared<BuildVocabNode>(IRNode()->DeepCopy(), vocab, columns, freq_range, top_k, special_tokens, | |||
| special_first); | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| Status rc = runtime_context->Init(); | |||
| @@ -190,13 +190,12 @@ std::shared_ptr<SamplerRT> PKSamplerObj::Build() { | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| // PreBuiltOperation | |||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) | |||
| : sp_(std::move(sampler)), sp_minddataset_(nullptr) {} | |||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {} | |||
| #ifndef ENABLE_ANDROID | |||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler) | |||
| : sp_(nullptr), sp_minddataset_(std::move(sampler)) {} | |||
| : sp_minddataset_(std::move(sampler)) {} | |||
| #endif | |||
| bool PreBuiltSamplerObj::ValidateParams() { return true; } | |||
| @@ -207,6 +206,13 @@ std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; } | |||
| std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } | |||
| #endif | |||
| std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() { | |||
| #ifndef ENABLE_ANDROID | |||
| if (sp_minddataset_ != nullptr) return std::make_shared<PreBuiltSamplerObj>(sp_minddataset_); | |||
| #endif | |||
| return std::make_shared<PreBuiltSamplerObj>(sp_); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| @@ -30,8 +30,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| TensorOperation::TensorOperation() {} | |||
| /* ####################################### Validator Functions ############################################ */ | |||
| Status ValidateVectorFillvalue(const std::string &transform_name, const std::vector<uint8_t> &fill_value) { | |||
| if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) { | |||
| @@ -231,7 +229,7 @@ std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; } | |||
| // RandomApplyOperation | |||
| RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob) | |||
| : transforms_(transforms), prob_(prob) {} | |||
| : TensorOperation(true), transforms_(transforms), prob_(prob) {} | |||
| Status RandomApplyOperation::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomApply", transforms_)); | |||
| @@ -248,7 +246,7 @@ std::shared_ptr<TensorOp> RandomApplyOperation::Build() { | |||
| // RandomChoiceOperation | |||
| RandomChoiceOperation::RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms) | |||
| : transforms_(transforms) {} | |||
| : TensorOperation(true), transforms_(transforms) {} | |||
| Status RandomChoiceOperation::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomChoice", transforms_)); | |||
| @@ -734,7 +734,9 @@ RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> °rees | |||
| scale_range_(scale_range), | |||
| shear_ranges_(shear_ranges), | |||
| interpolation_(interpolation), | |||
| fill_value_(fill_value) {} | |||
| fill_value_(fill_value) { | |||
| random_op_ = true; | |||
| } | |||
| Status RandomAffineOperation::ValidateParams() { | |||
| // Degrees | |||
| @@ -867,7 +869,7 @@ std::shared_ptr<TensorOp> RandomAffineOperation::Build() { | |||
| } | |||
| // RandomColorOperation. | |||
| RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) {} | |||
| RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) { random_op_ = true; } | |||
| Status RandomColorOperation::ValidateParams() { | |||
| // Do some input validation. | |||
| @@ -891,7 +893,9 @@ Status RandomColorOperation::ValidateParams() { | |||
| // RandomColorAdjustOperation. | |||
| RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast, | |||
| std::vector<float> saturation, std::vector<float> hue) | |||
| : brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {} | |||
| : brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) { | |||
| random_op_ = true; | |||
| } | |||
| Status RandomColorAdjustOperation::ValidateParams() { | |||
| // brightness | |||
| @@ -1012,11 +1016,14 @@ std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() { | |||
| // RandomCropOperation | |||
| RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed, | |||
| std::vector<uint8_t> fill_value, BorderType padding_mode) | |||
| : size_(size), | |||
| : TensorOperation(true), | |||
| size_(size), | |||
| padding_(padding), | |||
| pad_if_needed_(pad_if_needed), | |||
| fill_value_(fill_value), | |||
| padding_mode_(padding_mode) {} | |||
| padding_mode_(padding_mode) { | |||
| random_op_ = true; | |||
| } | |||
| Status RandomCropOperation::ValidateParams() { | |||
| // size | |||
| @@ -1083,7 +1090,12 @@ std::shared_ptr<TensorOp> RandomCropOperation::Build() { | |||
| RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale, | |||
| std::vector<float> ratio, | |||
| InterpolationMode interpolation, int32_t max_attempts) | |||
| : size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {} | |||
| : TensorOperation(true), | |||
| size_(size), | |||
| scale_(scale), | |||
| ratio_(ratio), | |||
| interpolation_(interpolation), | |||
| max_attempts_(max_attempts) {} | |||
| Status RandomCropDecodeResizeOperation::ValidateParams() { | |||
| // size | |||
| @@ -1176,7 +1188,8 @@ std::shared_ptr<TensorOp> RandomCropDecodeResizeOperation::Build() { | |||
| RandomCropWithBBoxOperation::RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding, | |||
| bool pad_if_needed, std::vector<uint8_t> fill_value, | |||
| BorderType padding_mode) | |||
| : size_(size), | |||
| : TensorOperation(true), | |||
| size_(size), | |||
| padding_(padding), | |||
| pad_if_needed_(pad_if_needed), | |||
| fill_value_(fill_value), | |||
| @@ -1245,7 +1258,8 @@ std::shared_ptr<TensorOp> RandomCropWithBBoxOperation::Build() { | |||
| } | |||
| // RandomHorizontalFlipOperation | |||
| RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {} | |||
| RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) | |||
| : TensorOperation(true), probability_(probability) {} | |||
| Status RandomHorizontalFlipOperation::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlip", probability_)); | |||
| @@ -1260,7 +1274,7 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() { | |||
| // RandomHorizontalFlipWithBBoxOperation | |||
| RandomHorizontalFlipWithBBoxOperation::RandomHorizontalFlipWithBBoxOperation(float probability) | |||
| : probability_(probability) {} | |||
| : TensorOperation(true), probability_(probability) {} | |||
| Status RandomHorizontalFlipWithBBoxOperation::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlipWithBBox", probability_)); | |||
| @@ -1275,7 +1289,8 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipWithBBoxOperation::Build() { | |||
| } | |||
| // RandomPosterizeOperation | |||
| RandomPosterizeOperation::RandomPosterizeOperation(const std::vector<uint8_t> &bit_range) : bit_range_(bit_range) {} | |||
| RandomPosterizeOperation::RandomPosterizeOperation(const std::vector<uint8_t> &bit_range) | |||
| : TensorOperation(true), bit_range_(bit_range) {} | |||
| Status RandomPosterizeOperation::ValidateParams() { | |||
| if (bit_range_.size() != 2) { | |||
| @@ -1309,7 +1324,7 @@ std::shared_ptr<TensorOp> RandomPosterizeOperation::Build() { | |||
| } | |||
| // RandomResizeOperation | |||
| RandomResizeOperation::RandomResizeOperation(std::vector<int32_t> size) : size_(size) {} | |||
| RandomResizeOperation::RandomResizeOperation(std::vector<int32_t> size) : TensorOperation(true), size_(size) {} | |||
| Status RandomResizeOperation::ValidateParams() { | |||
| // size | |||
| @@ -1343,7 +1358,8 @@ std::shared_ptr<TensorOp> RandomResizeOperation::Build() { | |||
| } | |||
| // RandomResizeWithBBoxOperation | |||
| RandomResizeWithBBoxOperation::RandomResizeWithBBoxOperation(std::vector<int32_t> size) : size_(size) {} | |||
| RandomResizeWithBBoxOperation::RandomResizeWithBBoxOperation(std::vector<int32_t> size) | |||
| : TensorOperation(true), size_(size) {} | |||
| Status RandomResizeWithBBoxOperation::ValidateParams() { | |||
| // size | |||
| @@ -1380,7 +1396,12 @@ std::shared_ptr<TensorOp> RandomResizeWithBBoxOperation::Build() { | |||
| RandomResizedCropOperation::RandomResizedCropOperation(std::vector<int32_t> size, std::vector<float> scale, | |||
| std::vector<float> ratio, InterpolationMode interpolation, | |||
| int32_t max_attempts) | |||
| : size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {} | |||
| : TensorOperation(true), | |||
| size_(size), | |||
| scale_(scale), | |||
| ratio_(ratio), | |||
| interpolation_(interpolation), | |||
| max_attempts_(max_attempts) {} | |||
| Status RandomResizedCropOperation::ValidateParams() { | |||
| // size | |||
| @@ -1536,7 +1557,8 @@ std::shared_ptr<TensorOp> RandomResizedCropWithBBoxOperation::Build() { | |||
| RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode, | |||
| bool expand, std::vector<float> center, | |||
| std::vector<uint8_t> fill_value) | |||
| : degrees_(degrees), | |||
| : TensorOperation(true), | |||
| degrees_(degrees), | |||
| interpolation_mode_(interpolation_mode), | |||
| expand_(expand), | |||
| center_(center), | |||
| @@ -1603,7 +1625,7 @@ std::shared_ptr<TensorOp> RandomRotationOperation::Build() { | |||
| // RandomSelectSubpolicyOperation. | |||
| RandomSelectSubpolicyOperation::RandomSelectSubpolicyOperation( | |||
| std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy) | |||
| : policy_(policy) {} | |||
| : TensorOperation(true), policy_(policy) {} | |||
| Status RandomSelectSubpolicyOperation::ValidateParams() { | |||
| if (policy_.empty()) { | |||
| @@ -1650,7 +1672,8 @@ std::shared_ptr<TensorOp> RandomSelectSubpolicyOperation::Build() { | |||
| } | |||
| // Function to create RandomSharpness. | |||
| RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees) : degrees_(degrees) {} | |||
| RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees) | |||
| : TensorOperation(true), degrees_(degrees) {} | |||
| Status RandomSharpnessOperation::ValidateParams() { | |||
| if (degrees_.size() != 2 || degrees_[0] < 0 || degrees_[1] < 0) { | |||
| @@ -1674,7 +1697,8 @@ std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() { | |||
| } | |||
| // RandomSolarizeOperation. | |||
| RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold) : threshold_(threshold) {} | |||
| RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold) | |||
| : TensorOperation(true), threshold_(threshold) {} | |||
| Status RandomSolarizeOperation::ValidateParams() { | |||
| if (threshold_.size() != 2) { | |||
| @@ -1705,7 +1729,8 @@ std::shared_ptr<TensorOp> RandomSolarizeOperation::Build() { | |||
| } | |||
| // RandomVerticalFlipOperation | |||
| RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {} | |||
| RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) | |||
| : TensorOperation(true), probability_(probability) {} | |||
| Status RandomVerticalFlipOperation::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlip", probability_)); | |||
| @@ -1720,7 +1745,7 @@ std::shared_ptr<TensorOp> RandomVerticalFlipOperation::Build() { | |||
| // RandomVerticalFlipWithBBoxOperation | |||
| RandomVerticalFlipWithBBoxOperation::RandomVerticalFlipWithBBoxOperation(float probability) | |||
| : probability_(probability) {} | |||
| : TensorOperation(true), probability_(probability) {} | |||
| Status RandomVerticalFlipWithBBoxOperation::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlipWithBBox", probability_)); | |||
| @@ -9,11 +9,13 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES | |||
| build_sentence_piece_vocab_node.cc | |||
| build_vocab_node.cc | |||
| concat_node.cc | |||
| epoch_ctrl_node.cc | |||
| filter_node.cc | |||
| map_node.cc | |||
| project_node.cc | |||
| rename_node.cc | |||
| repeat_node.cc | |||
| root_node.cc | |||
| shuffle_node.cc | |||
| skip_node.cc | |||
| sync_wait_node.cc | |||
| @@ -43,14 +43,29 @@ BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, boo | |||
| batch_size_func_(batch_size_func), | |||
| batch_map_func_(batch_map_func), | |||
| pad_map_(pad_map) { | |||
| this->children.push_back(child); | |||
| this->AddChild(child); | |||
| } | |||
| #endif | |||
| // constructor #2, called by C++ API | |||
| BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder) | |||
| : batch_size_(batch_size), drop_remainder_(drop_remainder), pad_(false) { | |||
| this->children.push_back(child); | |||
| this->AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> BatchNode::Copy() { | |||
| #ifdef ENABLE_PYTHON | |||
| auto node = std::make_shared<BatchNode>(nullptr, batch_size_, drop_remainder_, pad_, in_col_names_, out_col_names_, | |||
| col_order_, batch_size_func_, batch_map_func_, pad_map_); | |||
| #else | |||
| auto node = std::make_shared<BatchNode>(nullptr, batch_size_, drop_remainder_); | |||
| #endif | |||
| return node; | |||
| } | |||
| void BatchNode::Print(std::ostream &out) const { | |||
| out << Name() + "(batch_size:" + std::to_string(batch_size_) + | |||
| " drop_remainder:" + (drop_remainder_ ? "true" : "false") + ")"; | |||
| } | |||
| Status BatchNode::ValidateParams() { | |||
| @@ -44,6 +44,18 @@ class BatchNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~BatchNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kBatchNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -41,7 +41,17 @@ BucketBatchByLengthNode::BucketBatchByLengthNode( | |||
| pad_info_(pad_info), | |||
| pad_to_bucket_boundary_(pad_to_bucket_boundary), | |||
| drop_remainder_(drop_remainder) { | |||
| this->children.push_back(child); | |||
| this->AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> BucketBatchByLengthNode::Copy() { | |||
| auto node = std::make_shared<BucketBatchByLengthNode>(nullptr, column_names_, bucket_boundaries_, bucket_batch_sizes_, | |||
| element_length_function_, pad_info_, pad_to_bucket_boundary_); | |||
| return node; | |||
| } | |||
| void BucketBatchByLengthNode::Print(std::ostream &out) const { | |||
| out << Name() + "(columns:" + PrintColumns(column_names_) + ",...)"; | |||
| } | |||
| std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() { | |||
| @@ -40,6 +40,18 @@ class BucketBatchByLengthNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~BucketBatchByLengthNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kBucketBatchByLengthNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -22,6 +22,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| @@ -38,7 +39,18 @@ BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> chil | |||
| character_coverage_(character_coverage), | |||
| model_type_(model_type), | |||
| params_(params) { | |||
| this->children.push_back(child); | |||
| this->AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> BuildSentenceVocabNode::Copy() { | |||
| auto node = std::make_shared<BuildSentenceVocabNode>(nullptr, vocab_, col_names_, vocab_size_, character_coverage_, | |||
| model_type_, params_); | |||
| return node; | |||
| } | |||
| void BuildSentenceVocabNode::Print(std::ostream &out) const { | |||
| out << Name() + "<vocab>," + "columns:" + PrintColumns(col_names_) + ",vocab_size:" + std::to_string(vocab_size_) + | |||
| ",...)"; | |||
| } | |||
| // Function to build BuildSentenceVocabNode | |||
| @@ -81,5 +93,16 @@ Status BuildSentenceVocabNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status BuildSentenceVocabNode::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<BuildSentenceVocabNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status BuildSentenceVocabNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<BuildSentenceVocabNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -38,6 +38,18 @@ class BuildSentenceVocabNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~BuildSentenceVocabNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kBuildSentencePieceVocabNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -46,6 +58,18 @@ class BuildSentenceVocabNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| private: | |||
| std::shared_ptr<SentencePieceVocab> vocab_; | |||
| std::vector<std::string> col_names_; | |||
| @@ -22,7 +22,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/build_vocab_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -36,7 +36,17 @@ BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_p | |||
| top_k_(top_k), | |||
| special_tokens_(special_tokens), | |||
| special_first_(special_first) { | |||
| this->children.push_back(child); | |||
| this->AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> BuildVocabNode::Copy() { | |||
| auto node = | |||
| std::make_shared<BuildVocabNode>(nullptr, vocab_, columns_, freq_range_, top_k_, special_tokens_, special_first_); | |||
| return node; | |||
| } | |||
| void BuildVocabNode::Print(std::ostream &out) const { | |||
| out << Name() + "(<vocab>," + "columns:" + PrintColumns(columns_) + ",...)"; | |||
| } | |||
| // Function to build BuildVocabNode | |||
| @@ -78,5 +88,16 @@ Status BuildVocabNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status BuildVocabNode::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<BuildVocabNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status BuildVocabNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<BuildVocabNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -37,6 +37,18 @@ class BuildVocabNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~BuildVocabNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kBuildVocabNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -45,6 +57,18 @@ class BuildVocabNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| private: | |||
| std::shared_ptr<Vocab> vocab_; | |||
| std::vector<std::string> columns_; | |||
| @@ -22,7 +22,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/concat_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -35,17 +35,25 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets | |||
| : sampler_(sampler), | |||
| children_flag_and_nums_(children_flag_and_nums), | |||
| children_start_end_index_(children_start_end_index) { | |||
| this->children = datasets; | |||
| for (auto const &child : datasets) AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> ConcatNode::Copy() { | |||
| // create an empty vector to copy a concat | |||
| auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>()); | |||
| return node; | |||
| } | |||
| void ConcatNode::Print(std::ostream &out) const { out << Name(); } | |||
| Status ConcatNode::ValidateParams() { | |||
| if (children.size() < 2) { | |||
| if (children_.size() < 2) { | |||
| std::string err_msg = "ConcatNode: concatenated datasets are not specified."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (find(children.begin(), children.end(), nullptr) != children.end()) { | |||
| if (find(children_.begin(), children_.end(), nullptr) != children_.end()) { | |||
| std::string err_msg = "ConcatNode: concatenated datasets should not be null."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| @@ -73,5 +81,16 @@ std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() { | |||
| return node_ops; | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status ConcatNode::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<ConcatNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status ConcatNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<ConcatNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -38,6 +38,18 @@ class ConcatNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~ConcatNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kConcatNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -50,6 +62,18 @@ class ConcatNode : public DatasetNode { | |||
| std::shared_ptr<SamplerObj> sampler_; | |||
| std::vector<std::pair<int, int>> children_flag_and_nums_; | |||
| std::vector<std::pair<int, int>> children_start_end_index_; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -233,14 +233,92 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) { | |||
| return shared_from_this(); | |||
| } | |||
| DatasetNode::DatasetNode() { | |||
| DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}) { | |||
| // Fetch some default value from config manager | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| num_workers_ = cfg->num_parallel_workers(); | |||
| rows_per_buffer_ = cfg->rows_per_buffer(); | |||
| connector_que_size_ = cfg->op_connector_size(); | |||
| worker_connector_size_ = cfg->worker_connector_size(); | |||
| build_status = Status::OK(); // remove me after changing return val of Build() | |||
| } | |||
| // this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied | |||
| std::shared_ptr<DatasetNode> DatasetNode::DeepCopy() { | |||
| std::shared_ptr<DatasetNode> new_node = this->Copy(); | |||
| for (const auto &child : children_) { | |||
| new_node->AddChild(child->DeepCopy()); | |||
| } | |||
| return new_node; | |||
| } | |||
| std::string DatasetNode::PrintColumns(const std::vector<std::string> &columns) const { | |||
| std::string me; | |||
| if (columns.empty()) { | |||
| me = "<nil>"; | |||
| } else { | |||
| me = "["; | |||
| auto i = 0; | |||
| for (auto it = columns.begin(); it < columns.end(); ++it, ++i) { | |||
| me += *it; | |||
| if (i < columns.size() - 1) { | |||
| me += ", "; | |||
| } else { | |||
| me += "]"; | |||
| } | |||
| } | |||
| } | |||
| return me; | |||
| } | |||
| void DatasetNode::PrintTree(std::ostream &out) const { | |||
| int level = 0; | |||
| PrintNode(out, &level); | |||
| } | |||
| void DatasetNode::PrintNode(std::ostream &out, int *level) const { | |||
| const std::string prefix = "+-"; | |||
| const std::string indent = " "; | |||
| out << prefix; | |||
| Print(out); | |||
| for (const auto &c : this->Children()) { | |||
| out << '\n'; | |||
| ++(*level); | |||
| for (auto i = 0; i < *level; i++) { | |||
| out << indent; | |||
| } | |||
| c->PrintNode(out, level); | |||
| --(*level); | |||
| } | |||
| } | |||
| // Add a node as a child, node's parent needs to be nullptr | |||
| // this function will allow child to be a nullptr, in which case it will simply skip | |||
| void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) { | |||
| if (child != nullptr && child->parent_ == nullptr) { | |||
| children_.push_back(child); | |||
| child->parent_ = this; | |||
| } else if (child != nullptr) { | |||
| MS_LOG(WARNING) << "DatasetNode::AddChild() Fail" + child->Name() + "'s parent isn't a nullptr."; | |||
| } | |||
| } | |||
| // Remove this node from its parent. Add the child of this node to its parent. | |||
| // for now, this remove is limited to node with a single child or no child | |||
| Status DatasetNode::Remove() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "Cannot remove root or a node without parent."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(children_.size() < 2, "Cannot remove node with more than 1 child."); | |||
| if (children_.empty()) { // I am a leaf node, remove me from my parent's children list | |||
| parent_->children_.erase(std::remove(parent_->children_.begin(), parent_->children_.end(), shared_from_this()), | |||
| parent_->children_.end()); // removal using "erase remove idiom" | |||
| } else { // replace my position in my parent's children list with my single child | |||
| auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list."); | |||
| children_[0]->parent_ = parent_; // set my single child's parent ptr to my parent | |||
| *itr = std::move(children_[0]); // replace me in my parent's children list with my single child | |||
| children_.clear(); // release my single child from my children list | |||
| } | |||
| parent_ = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| // In DFS tree traversal, each node is visited twice. Accept is called on the first visit. | |||
| @@ -255,13 +333,25 @@ Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // This method will only be called if its derived class does not implement one. | |||
| return p->VisitAfter(shared_from_this(), modified); | |||
| } | |||
| Status DatasetNode::GetShardId(int32_t *shard_id) { | |||
| if (!Children().empty()) { | |||
| // Get shard id from the child node | |||
| return Children()[0]->GetShardId(shard_id); | |||
| } else { | |||
| RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node"); | |||
| RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n"); | |||
| } | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status SourceNode::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<SourceNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status SourceNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<SourceNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -42,6 +42,45 @@ class NodePass; | |||
| } \ | |||
| } while (false) | |||
| // Names for non-leaf IR node | |||
| constexpr char kBatchNode[] = "Batch"; | |||
| constexpr char kBucketBatchByLengthNode[] = "BucketBatchByLength"; | |||
| constexpr char kBuildSentencePieceVocabNode[] = "BuildSentencePieceVocab"; | |||
| constexpr char kBuildVocabNode[] = "BuildVocab"; | |||
| constexpr char kConcatNode[] = "Concat"; | |||
| constexpr char kDatasetNode[] = "Dataset"; | |||
| constexpr char kEpochCtrlNode[] = "EpochCtrl"; | |||
| constexpr char kFilterNode[] = "Filter"; | |||
| constexpr char kMapNode[] = "Map"; | |||
| constexpr char kProjectNode[] = "Project"; | |||
| constexpr char kRenameNode[] = "Rename"; | |||
| constexpr char kRepeatNode[] = "Repeat"; | |||
| constexpr char kRootNode[] = "Top"; | |||
| constexpr char kShuffleNode[] = "Shuffle"; | |||
| constexpr char kSkipNode[] = "Skip"; | |||
| constexpr char kSyncWaitNode[] = "SyncWait"; | |||
| constexpr char kTakeNode[] = "Take"; | |||
| constexpr char kTransferNode[] = "Transfer"; | |||
| constexpr char kZipNode[] = "Zip"; | |||
| // Names for leaf IR node | |||
| constexpr char kAlbumNode[] = "AlbumDataset"; | |||
| constexpr char kCelebANode[] = "CelebADataset"; | |||
| constexpr char kCifar100Node[] = "Cifar100Dataset"; | |||
| constexpr char kCifar10Node[] = "Cifar10Dataset"; | |||
| constexpr char kCLUENode[] = "CLUEDataset"; | |||
| constexpr char kCocoNode[] = "CocoDataset"; | |||
| constexpr char kCSVNode[] = "CSVDataset"; | |||
| constexpr char kGeneratorNode[] = "GeneratorDataset"; | |||
| constexpr char kImageFolderNode[] = "ImageFolderDataset"; | |||
| constexpr char kManifestNode[] = "ManifestDataset"; | |||
| constexpr char kMindDataNode[] = "MindDataDataset"; | |||
| constexpr char kMnistNode[] = "MnistDataset"; | |||
| constexpr char kRandomNode[] = "RandomDataset"; | |||
| constexpr char kTextFileNode[] = "TextFileDataset"; | |||
| constexpr char kTFRecordNode[] = "TFRecordDataset"; | |||
| constexpr char kVOCNode[] = "VOCDataset"; | |||
| Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | |||
| int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op); | |||
| @@ -75,6 +114,7 @@ Status ValidateDatasetDirParam(const std::string &dataset_name, std::string data | |||
| /// \return Shared pointer to the current Sampler. | |||
| std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id); | |||
| // The base class of all IR nodes | |||
| class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| public: | |||
| /// \brief Constructor | |||
| @@ -87,6 +127,36 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| /// \brief Destructor | |||
| ~DatasetNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| virtual std::string Name() const = 0; | |||
| /// \brief Pure virtual function to print the description | |||
| /// \param out - The output stream to write output to | |||
| virtual void Print(std::ostream &out) const = 0; | |||
| /// \brief Pure virtual function to make a new copy of the node | |||
| /// \return The new copy of the node | |||
| virtual std::shared_ptr<DatasetNode> Copy() = 0; | |||
| /// \brief Print the IR tree to output stream | |||
| /// \param out - The output stream to write output to | |||
| void PrintTree(std::ostream &out) const; | |||
| /// \brief << Stream output operator overload | |||
| /// \notes This allows you to write the debug print info using stream operators | |||
| /// \param out - reference to the output stream being overloaded | |||
| /// \param dO - reference to the DatasetOp to display | |||
| /// \return - the output stream must be returned | |||
| friend std::ostream &operator<<(std::ostream &out, const DatasetNode &node) { | |||
| node.PrintTree(out); | |||
| return out; | |||
| } | |||
| /// \brief Make a new copy of the tree from the current node | |||
| /// \return The new copy of the tree | |||
| std::shared_ptr<DatasetNode> DeepCopy(); | |||
| /// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0; | |||
| @@ -95,17 +165,38 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| virtual Status ValidateParams() = 0; | |||
| const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children; } | |||
| /// \brief Pure virtual function for derived class to get the shard id of specific node | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| virtual Status GetShardId(int32_t *shard_id); | |||
| /// \brief Getter function for child nodes | |||
| /// \return Child nodes | |||
| const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; } | |||
| /// \brief Establish the parent-child relationship between this node and its child. | |||
| void AddChild(std::shared_ptr<DatasetNode> child); | |||
| /// \brief detach this node from its parent, add its child (if any) to its parent | |||
| /// \return error code, return error if node has more than 1 children | |||
| Status Remove(); | |||
| /// \brief Check if this node has cache | |||
| /// \return True if the data of this node will be cached | |||
| const bool IsCached() const { return (cache_ != nullptr); } | |||
| /// \brief Setter function for runtime number of workers | |||
| /// \param[in] num_workers The number of threads in this operator | |||
| /// \return Shared pointer to the original object | |||
| std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers); | |||
| /// \brief A helper templated function for casting "this" pointer to shared_ptr<derived> | |||
| /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr | |||
| /// \return A shared_ptr casted to the derived class | |||
| template <typename Derived> | |||
| std::shared_ptr<Derived> shared_from_base() { | |||
| return std::static_pointer_cast<Derived>(shared_from_this()); | |||
| } | |||
| /// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up | |||
| /// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node | |||
| /// visit on the way back up the tree after its descendants are visited. | |||
| @@ -129,17 +220,123 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| Status BuildStatus() { return build_status; } | |||
| protected: | |||
| std::vector<std::shared_ptr<DatasetNode>> children; | |||
| std::vector<std::shared_ptr<DatasetNode>> children_; | |||
| DatasetNode *parent_; | |||
| std::shared_ptr<DatasetCache> cache_; | |||
| Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops); | |||
| int32_t num_workers_; | |||
| int32_t rows_per_buffer_; | |||
| int32_t connector_que_size_; | |||
| int32_t worker_connector_size_; | |||
| Status build_status; // remove me after changing return val of Build() | |||
| std::string PrintColumns(const std::vector<std::string> &columns) const; | |||
| Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops); | |||
| void PrintNode(std::ostream &out, int *level) const; | |||
| }; | |||
| // SourceNode represents the leaf nodes of a pipeline where the data is pulled into. | |||
| class SourceNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| SourceNode() : DatasetNode() {} | |||
| /// \brief Constructor that initializes the cache | |||
| /// \param dataset_cache DatasetCache | |||
| explicit SourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {} | |||
| /// \brief Destructor | |||
| ~SourceNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| virtual std::string Name() const = 0; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| /// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes | |||
| /// \return True if the dataset represented by this node is a mappable dataset | |||
| const bool IsMappable() const { return mappable_; } | |||
| protected: | |||
| bool mappable_; | |||
| }; | |||
| // MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes. | |||
| class MappableSourceNode : public SourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| MappableSourceNode() : SourceNode() { mappable_ = true; } | |||
| /// \brief Constructor that initializes the cache | |||
| /// \param dataset_cache DatasetCache | |||
| explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) { | |||
| mappable_ = true; | |||
| } | |||
| /// \brief Destructor | |||
| ~MappableSourceNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| virtual std::string Name() const = 0; | |||
| }; | |||
| // NonMappableSourceNode represents the leaf nodes that can not be randomly accessed. | |||
| class NonMappableSourceNode : public SourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| NonMappableSourceNode() : SourceNode() { mappable_ = false; } | |||
| /// \brief Constructor that initializes the cache | |||
| /// \param dataset_cache DatasetCache | |||
| explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) { | |||
| mappable_ = false; | |||
| } | |||
| /// \brief Destructor | |||
| ~NonMappableSourceNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| virtual std::string Name() const = 0; | |||
| }; | |||
| // NonLeafNode represents operations over data in a pipeline. | |||
| class NonLeafNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| NonLeafNode() = default; | |||
| /// \brief Destructor | |||
| ~NonLeafNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| virtual std::string Name() const = 0; | |||
| }; | |||
| // SinkNode represents the end node of a pipeline where the data is pushed out | |||
| class SinkNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| SinkNode() = default; | |||
| /// \brief Destructor | |||
| ~SinkNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| virtual std::string Name() const = 0; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor for EpochCtrlNode | |||
| EpochCtrlNode::EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : num_epochs_(num_epochs) { | |||
| // The root node's parent must set to null pointer. | |||
| this->AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() { | |||
| auto node = std::make_shared<EpochCtrlNode>(nullptr, this->num_epochs_); | |||
| return node; | |||
| } | |||
| void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" + std::to_string(num_epochs_) + ")"; } | |||
| // Function to build the EpochCtrlOp | |||
| std::vector<std::shared_ptr<DatasetOp>> EpochCtrlNode::Build() { | |||
| // A dummy vector | |||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||
| node_ops.push_back(std::make_shared<EpochCtrlOp>(num_epochs_)); | |||
| return node_ops; | |||
| } | |||
| // Function to validate the parameters for EpochCtrlNode | |||
| Status EpochCtrlNode::ValidateParams() { | |||
| if (num_epochs_ <= 0 && num_epochs_ != -1) { | |||
| std::string err_msg = | |||
| "EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (children_.size() != 1 || children_[0] == nullptr) { | |||
| std::string err_msg = "Internal error: epoch control node should have one child node"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class EpochCtrlNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| explicit EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs); | |||
| /// \brief Destructor | |||
| ~EpochCtrlNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kEpochCtrlNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return shared pointer to the list of newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| /// \brief Parameters validation | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| private: | |||
| int32_t num_epochs_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_ | |||
| @@ -21,7 +21,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/filter_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| @@ -31,7 +31,16 @@ namespace dataset { | |||
| FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate, | |||
| std::vector<std::string> input_columns) | |||
| : predicate_(predicate), input_columns_(input_columns) { | |||
| this->children.push_back(child); | |||
| this->AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> FilterNode::Copy() { | |||
| auto node = std::make_shared<FilterNode>(nullptr, predicate_, input_columns_); | |||
| return node; | |||
| } | |||
| void FilterNode::Print(std::ostream &out) const { | |||
| out << Name() + "(<predicate>," + "input_cols:" + PrintColumns(input_columns_) + ")"; | |||
| } | |||
| std::vector<std::shared_ptr<DatasetOp>> FilterNode::Build() { | |||
| @@ -54,5 +63,17 @@ Status FilterNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status FilterNode::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<FilterNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status FilterNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<FilterNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -35,6 +35,18 @@ class FilterNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~FilterNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kFilterNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -43,6 +55,18 @@ class FilterNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| private: | |||
| std::shared_ptr<TensorOp> predicate_; | |||
| std::vector<std::string> input_columns_; | |||
| @@ -22,6 +22,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/include/transforms.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| @@ -37,7 +38,18 @@ MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr | |||
| project_columns_(project_columns), | |||
| DatasetNode(std::move(cache)), | |||
| callbacks_(callbacks) { | |||
| this->children.push_back(child); | |||
| this->AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> MapNode::Copy() { | |||
| auto node = std::make_shared<MapNode>(nullptr, operations_, input_columns_, output_columns_, project_columns_, cache_, | |||
| callbacks_); | |||
| return node; | |||
| } | |||
| void MapNode::Print(std::ostream &out) const { | |||
| out << Name() + "(<ops>" + ",input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) + | |||
| ",<project_cols>" + ",...)"; | |||
| } | |||
| std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() { | |||
| @@ -93,5 +105,16 @@ Status MapNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status MapNode::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<MapNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status MapNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<MapNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -37,6 +37,18 @@ class MapNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~MapNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kMapNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -45,6 +57,23 @@ class MapNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Getter of tensor operations | |||
| /// \return Vector of operations the Map node will process | |||
| const auto &TensorOperations() const { return operations_; } | |||
| auto &TensorOperations() { return operations_; } | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| private: | |||
| std::vector<std::shared_ptr<TensorOperation>> operations_; | |||
| std::vector<std::string> input_columns_; | |||
| @@ -29,9 +29,16 @@ namespace dataset { | |||
| // Function to build ProjectOp | |||
| ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns) | |||
| : columns_(columns) { | |||
| this->children.push_back(child); | |||
| this->AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> ProjectNode::Copy() { | |||
| auto node = std::make_shared<ProjectNode>(nullptr, this->columns_); | |||
| return node; | |||
| } | |||
| void ProjectNode::Print(std::ostream &out) const { out << Name() + "(column: " + PrintColumns(columns_) + ")"; } | |||
| Status ProjectNode::ValidateParams() { | |||
| if (columns_.empty()) { | |||
| std::string err_msg = "ProjectNode: No columns are specified."; | |||
| @@ -34,6 +34,18 @@ class ProjectNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~ProjectNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kProjectNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -30,7 +30,16 @@ namespace dataset { | |||
| RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns, | |||
| const std::vector<std::string> &output_columns) | |||
| : input_columns_(input_columns), output_columns_(output_columns) { | |||
| this->children.push_back(child); | |||
| this->AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> RenameNode::Copy() { | |||
| auto node = std::make_shared<RenameNode>(nullptr, input_columns_, output_columns_); | |||
| return node; | |||
| } | |||
| void RenameNode::Print(std::ostream &out) const { | |||
| out << Name() + "(input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) + ")"; | |||
| } | |||
| Status RenameNode::ValidateParams() { | |||
| @@ -35,6 +35,18 @@ class RenameNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~RenameNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kRenameNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -21,15 +21,22 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) { | |||
| this->children.push_back(child); | |||
| this->AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> RepeatNode::Copy() { | |||
| auto node = std::make_shared<RepeatNode>(nullptr, this->repeat_count_); | |||
| return node; | |||
| } | |||
| void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + std::to_string(repeat_count_) + ")"; } | |||
| std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() { | |||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||
| @@ -49,5 +56,16 @@ Status RepeatNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status RepeatNode::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<RepeatNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status RepeatNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<RepeatNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -36,6 +36,18 @@ class RepeatNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~RepeatNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kRepeatNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -44,6 +56,18 @@ class RepeatNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| private: | |||
| int32_t repeat_count_; | |||
| }; | |||
| @@ -0,0 +1,85 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/ir/datasetops/root_node.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor for RootNode | |||
| RootNode::RootNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : DatasetNode(), num_epochs_(num_epochs) { | |||
| // The root node's parent must remain nullptr. (which is set in the constructor of DatasetNode) | |||
| AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> RootNode::Copy() { | |||
| auto node = std::make_shared<RootNode>(nullptr, num_epochs_); | |||
| return node; | |||
| } | |||
| void RootNode::Print(std::ostream &out) const { out << Name(); } | |||
| std::vector<std::shared_ptr<DatasetOp>> RootNode::Build() { | |||
| // root node doesn't build a runtime Op. this function should return Status::Error when called. | |||
| return {}; | |||
| } | |||
| // Function to validate the parameters for RootNode | |||
| Status RootNode::ValidateParams() { | |||
| if (num_epochs_ <= 0 && num_epochs_ != -1) { | |||
| std::string err_msg = | |||
| "RootNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (parent_ != nullptr) { | |||
| std::string err_msg = "Internal error: root node should not have a parent"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (children_.size() != 1) { | |||
| std::string err_msg = "Internal error: root node should have one child node"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (children_[0] == nullptr) { | |||
| std::string err_msg = "Internal error: root node's child is a null pointer"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status RootNode::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<RootNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status RootNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<RootNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,78 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class RootNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| RootNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs); | |||
| /// \brief Destructor | |||
| ~RootNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kRootNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return shared pointer to the list of newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| /// \brief Getter of number of epochs | |||
| int32_t num_epochs() { return num_epochs_; } | |||
| /// \brief Parameters validation | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| private: | |||
| int32_t num_epochs_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_ | |||
| @@ -29,7 +29,17 @@ namespace dataset { | |||
| // Constructor for ShuffleNode | |||
| ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch) | |||
| : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) { | |||
| this->children.push_back(child); | |||
| this->AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> ShuffleNode::Copy() { | |||
| auto node = std::make_shared<ShuffleNode>(nullptr, shuffle_size_, reset_every_epoch_); | |||
| return node; | |||
| } | |||
| void ShuffleNode::Print(std::ostream &out) const { | |||
| out << Name() + "(shuffle_size:" + std::to_string(shuffle_size_) + | |||
| ",reset_every_epoch:" + (reset_every_epoch_ ? "true" : "false") + ")"; | |||
| } | |||
| // Function to build the ShuffleOp | |||
| @@ -34,6 +34,18 @@ class ShuffleNode : public DatasetNode { | |||
| ~ShuffleNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kShuffleNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| Status ValidateParams() override; | |||
| @@ -27,10 +27,15 @@ namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor for SkipNode | |||
| SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) { | |||
| this->children.push_back(child); | |||
| SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) { this->AddChild(child); } | |||
| std::shared_ptr<DatasetNode> SkipNode::Copy() { | |||
| auto node = std::make_shared<SkipNode>(nullptr, skip_count_); | |||
| return node; | |||
| } | |||
| void SkipNode::Print(std::ostream &out) const { out << Name() + "(skip_count:" + std::to_string(skip_count_) + ")"; } | |||
| // Function to build the SkipOp | |||
| std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() { | |||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||
| @@ -34,6 +34,18 @@ class SkipNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~SkipNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kSkipNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -32,13 +32,23 @@ namespace dataset { | |||
| AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | |||
| const std::vector<std::string> &column_names, bool decode, | |||
| const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) | |||
| : DatasetNode(std::move(cache)), | |||
| : MappableSourceNode(std::move(cache)), | |||
| dataset_dir_(dataset_dir), | |||
| schema_path_(data_schema), | |||
| column_names_(column_names), | |||
| decode_(decode), | |||
| sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> AlbumNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_); | |||
| return node; | |||
| } | |||
| void AlbumNode::Print(std::ostream &out) const { | |||
| out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")"; | |||
| } | |||
| Status AlbumNode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_)); | |||
| @@ -26,7 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class AlbumNode : public DatasetNode { | |||
| class AlbumNode : public MappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | |||
| @@ -36,6 +36,18 @@ class AlbumNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~AlbumNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kAlbumNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create a runtime dataset op object from this class | |||
| /// \return shared pointer to the newly created DatasetOp | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -31,13 +31,23 @@ namespace dataset { | |||
| CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, | |||
| const std::shared_ptr<SamplerObj> &sampler, const bool &decode, | |||
| const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache) | |||
| : DatasetNode(std::move(cache)), | |||
| : MappableSourceNode(std::move(cache)), | |||
| dataset_dir_(dataset_dir), | |||
| usage_(usage), | |||
| sampler_(sampler), | |||
| decode_(decode), | |||
| extensions_(extensions) {} | |||
| std::shared_ptr<DatasetNode> CelebANode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_); | |||
| return node; | |||
| } | |||
| void CelebANode::Print(std::ostream &out) const { | |||
| out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")"; | |||
| } | |||
| Status CelebANode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_)); | |||
| @@ -28,7 +28,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class CelebANode : public DatasetNode { | |||
| class CelebANode : public MappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | |||
| @@ -37,6 +37,18 @@ class CelebANode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~CelebANode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kCelebANode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return shared pointer to the list of newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -30,7 +30,17 @@ namespace dataset { | |||
| // Constructor for Cifar100Node | |||
| Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage, | |||
| std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||
| : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> Cifar100Node::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_); | |||
| return node; | |||
| } | |||
| void Cifar100Node::Print(std::ostream &out) const { | |||
| out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")"; | |||
| } | |||
| Status Cifar100Node::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); | |||
| @@ -26,7 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class Cifar100Node : public DatasetNode { | |||
| class Cifar100Node : public MappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | |||
| @@ -35,6 +35,18 @@ class Cifar100Node : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~Cifar100Node() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kCifar100Node; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -30,7 +30,17 @@ namespace dataset { | |||
| // Constructor for Cifar10Node | |||
| Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | |||
| std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||
| : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> Cifar10Node::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_); | |||
| return node; | |||
| } | |||
| void Cifar10Node::Print(std::ostream &out) const { | |||
| out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")"; | |||
| } | |||
| Status Cifar10Node::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); | |||
| @@ -26,7 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class Cifar10Node : public DatasetNode { | |||
| class Cifar10Node : public MappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | |||
| @@ -35,6 +35,18 @@ class Cifar10Node : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~Cifar10Node() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kCifar10Node; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -32,7 +32,7 @@ namespace dataset { | |||
| // Constructor for CLUENode | |||
| CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples, | |||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), | |||
| : NonMappableSourceNode(std::move(cache)), | |||
| dataset_files_(clue_files), | |||
| task_(task), | |||
| usage_(usage), | |||
| @@ -41,6 +41,17 @@ CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, | |||
| num_shards_(num_shards), | |||
| shard_id_(shard_id) {} | |||
| std::shared_ptr<DatasetNode> CLUENode::Copy() { | |||
| auto node = | |||
| std::make_shared<CLUENode>(dataset_files_, task_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); | |||
| return node; | |||
| } | |||
| void CLUENode::Print(std::ostream &out) const { | |||
| out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ",..." + | |||
| ",num_shards:" + std::to_string(num_shards_) + ",shard_id:" + std::to_string(shard_id_) + ")"; | |||
| } | |||
| Status CLUENode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_)); | |||
| @@ -28,7 +28,7 @@ namespace dataset { | |||
| /// \class CLUENode | |||
| /// \brief A Dataset derived class to represent CLUE dataset | |||
| class CLUENode : public DatasetNode { | |||
| class CLUENode : public NonMappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples, | |||
| @@ -37,6 +37,18 @@ class CLUENode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~CLUENode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kCLUENode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -30,13 +30,21 @@ namespace dataset { | |||
| // Constructor for CocoNode | |||
| CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | |||
| const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), | |||
| : MappableSourceNode(std::move(cache)), | |||
| dataset_dir_(dataset_dir), | |||
| annotation_file_(annotation_file), | |||
| task_(task), | |||
| decode_(decode), | |||
| sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> CocoNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_); | |||
| return node; | |||
| } | |||
| void CocoNode::Print(std::ostream &out) const { out << Name(); } | |||
| Status CocoNode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_)); | |||
| @@ -26,7 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class CocoNode : public DatasetNode { | |||
| class CocoNode : public MappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | |||
| @@ -35,6 +35,18 @@ class CocoNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~CocoNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kCocoNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return shared pointer to the list of newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -33,7 +33,7 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim, | |||
| const std::vector<std::shared_ptr<CsvBase>> &column_defaults, | |||
| const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, | |||
| int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), | |||
| : NonMappableSourceNode(std::move(cache)), | |||
| dataset_files_(csv_files), | |||
| field_delim_(field_delim), | |||
| column_defaults_(column_defaults), | |||
| @@ -43,6 +43,17 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim, | |||
| num_shards_(num_shards), | |||
| shard_id_(shard_id) {} | |||
| std::shared_ptr<DatasetNode> CSVNode::Copy() { | |||
| auto node = std::make_shared<CSVNode>(dataset_files_, field_delim_, column_defaults_, column_names_, num_samples_, | |||
| shuffle_, num_shards_, shard_id_, cache_); | |||
| return node; | |||
| } | |||
| void CSVNode::Print(std::ostream &out) const { | |||
| out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ",..." + | |||
| ",num_shards:" + std::to_string(num_shards_) + ",shard_id:" + std::to_string(shard_id_) + ")"; | |||
| } | |||
| Status CSVNode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_)); | |||
| @@ -47,7 +47,7 @@ class CsvRecord : public CsvBase { | |||
| T value; | |||
| }; | |||
| class CSVNode : public DatasetNode { | |||
| class CSVNode : public NonMappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| CSVNode(const std::vector<std::string> &dataset_files, char field_delim, | |||
| @@ -58,6 +58,18 @@ class CSVNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~CSVNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kCSVNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return shared pointer to the list of newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -28,7 +28,19 @@ namespace dataset { | |||
| GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | |||
| const std::vector<DataType> &column_types) | |||
| : generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {} | |||
| : MappableSourceNode(), | |||
| generator_function_(generator_function), | |||
| column_names_(column_names), | |||
| column_types_(column_types) {} | |||
| std::shared_ptr<DatasetNode> GeneratorNode::Copy() { | |||
| auto node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_); | |||
| return node; | |||
| } | |||
| void GeneratorNode::Print(std::ostream &out) const { | |||
| out << Name() + "(<func>:" + ",columns:" + PrintColumns(column_names_) + ",<col_types>)"; | |||
| } | |||
| GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema) | |||
| : generator_function_(generator_function), schema_(schema) {} | |||
| @@ -26,10 +26,9 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \class GeneratorNode | |||
| /// \brief A Dataset derived class to represent GeneratorNode dataset | |||
| class GeneratorNode : public DatasetNode { | |||
| class GeneratorNode : public MappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | |||
| @@ -41,6 +40,18 @@ class GeneratorNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~GeneratorNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kGeneratorNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -33,13 +33,24 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar | |||
| bool recursive, std::set<std::string> extensions, | |||
| std::map<std::string, int32_t> class_indexing, | |||
| std::shared_ptr<DatasetCache> cache = nullptr) | |||
| : dataset_dir_(dataset_dir), | |||
| : MappableSourceNode(std::move(cache)), | |||
| dataset_dir_(dataset_dir), | |||
| decode_(decode), | |||
| sampler_(sampler), | |||
| recursive_(recursive), | |||
| class_indexing_(class_indexing), | |||
| exts_(extensions), | |||
| DatasetNode(std::move(cache)) {} | |||
| exts_(extensions) {} | |||
| std::shared_ptr<DatasetNode> ImageFolderNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| auto node = | |||
| std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_); | |||
| return node; | |||
| } | |||
| void ImageFolderNode::Print(std::ostream &out) const { | |||
| out << Name() + "(path:" + dataset_dir_ + ",decode:" + (decode_ ? "true" : "false") + ",...)"; | |||
| } | |||
| Status ImageFolderNode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); | |||
| @@ -31,7 +31,7 @@ namespace dataset { | |||
| /// \class ImageFolderNode | |||
| /// \brief A Dataset derived class to represent ImageFolder dataset | |||
| class ImageFolderNode : public DatasetNode { | |||
| class ImageFolderNode : public MappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive, | |||
| @@ -41,6 +41,18 @@ class ImageFolderNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~ImageFolderNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kImageFolderNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -32,13 +32,30 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u | |||
| const std::shared_ptr<SamplerObj> &sampler, | |||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | |||
| std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), | |||
| : MappableSourceNode(std::move(cache)), | |||
| dataset_file_(dataset_file), | |||
| usage_(usage), | |||
| decode_(decode), | |||
| class_index_(class_indexing), | |||
| sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> ManifestNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_); | |||
| return node; | |||
| } | |||
| void ManifestNode::Print(std::ostream &out) const { | |||
| out << Name() + "(file:" + dataset_file_; | |||
| if (sampler_ != nullptr) { | |||
| out << ",sampler"; | |||
| } | |||
| if (cache_ != nullptr) { | |||
| out << ",cache"; | |||
| } | |||
| out << ")"; | |||
| } | |||
| Status ManifestNode::ValidateParams() { | |||
| std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; | |||
| for (char c : dataset_file_) { | |||
| @@ -27,7 +27,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class ManifestNode : public DatasetNode { | |||
| class ManifestNode : public MappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | |||
| @@ -36,6 +36,18 @@ class ManifestNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~ManifestNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kManifestNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -30,7 +30,8 @@ namespace dataset { | |||
| MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list, | |||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded) | |||
| : dataset_file_(std::string()), | |||
| : MappableSourceNode(), | |||
| dataset_file_(std::string()), | |||
| dataset_files_(dataset_files), | |||
| search_for_pattern_(false), | |||
| columns_list_(columns_list), | |||
| @@ -41,7 +42,8 @@ MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const | |||
| MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list, | |||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded) | |||
| : dataset_file_(dataset_file), | |||
| : MappableSourceNode(), | |||
| dataset_file_(dataset_file), | |||
| dataset_files_({}), | |||
| search_for_pattern_(true), | |||
| columns_list_(columns_list), | |||
| @@ -50,6 +52,19 @@ MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<st | |||
| sample_bytes_({}), | |||
| num_padded_(num_padded) {} | |||
| std::shared_ptr<DatasetNode> MindDataNode::Copy() { | |||
| std::shared_ptr<MindDataNode> node; | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| if (dataset_files_.empty()) { | |||
| node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_); | |||
| } else { | |||
| node = std::make_shared<MindDataNode>(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_); | |||
| } | |||
| return node; | |||
| } | |||
| void MindDataNode::Print(std::ostream &out) const { out << Name() + "(file:" + dataset_file_ + ",...)"; } | |||
| Status MindDataNode::ValidateParams() { | |||
| if (!search_for_pattern_ && dataset_files_.size() > 4096) { | |||
| std::string err_msg = | |||
| @@ -27,7 +27,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class MindDataNode : public DatasetNode { | |||
| class MindDataNode : public MappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list, | |||
| @@ -40,6 +40,18 @@ class MindDataNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~MindDataNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kMindDataNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -29,7 +29,15 @@ namespace dataset { | |||
| MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | |||
| std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||
| : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> MnistNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_); | |||
| return node; | |||
| } | |||
| void MnistNode::Print(std::ostream &out) const { out << Name(); } | |||
| Status MnistNode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); | |||
| @@ -26,7 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class MnistNode : public DatasetNode { | |||
| class MnistNode : public MappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | |||
| @@ -35,6 +35,18 @@ class MnistNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~MnistNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kMnistNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -27,6 +27,18 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| std::shared_ptr<DatasetNode> RandomNode::Copy() { | |||
| std::shared_ptr<RandomNode> node; | |||
| if (schema_ != nullptr) { | |||
| node = std::make_shared<RandomNode>(total_rows_, schema_, columns_list_, cache_); | |||
| } else { | |||
| node = std::make_shared<RandomNode>(total_rows_, schema_path_, columns_list_, cache_); | |||
| } | |||
| return node; | |||
| } | |||
| void RandomNode::Print(std::ostream &out) const { out << Name() + "(num_row:" + std::to_string(total_rows_) + ",...)"; } | |||
| // ValidateParams for RandomNode | |||
| Status RandomNode::ValidateParams() { | |||
| if (total_rows_ < 0) { | |||
| @@ -27,7 +27,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class RandomNode : public DatasetNode { | |||
| class RandomNode : public NonMappableSourceNode { | |||
| public: | |||
| // Some constants to provide limits to random generation. | |||
| static constexpr int32_t kMaxNumColumns = 4; | |||
| @@ -37,7 +37,7 @@ class RandomNode : public DatasetNode { | |||
| /// \brief Constructor | |||
| RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list, | |||
| std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), | |||
| : NonMappableSourceNode(std::move(cache)), | |||
| total_rows_(total_rows), | |||
| schema_path_(""), | |||
| schema_(std::move(schema)), | |||
| @@ -46,14 +46,27 @@ class RandomNode : public DatasetNode { | |||
| /// \brief Constructor | |||
| RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list, | |||
| std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), | |||
| : NonMappableSourceNode(std::move(cache)), | |||
| total_rows_(total_rows), | |||
| schema_path_(schema_path), | |||
| schema_(nullptr), | |||
| columns_list_(columns_list) {} | |||
| /// \brief Destructor | |||
| ~RandomNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kRandomNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -31,13 +31,23 @@ namespace dataset { | |||
| // Constructor for TextFileNode | |||
| TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, | |||
| int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), | |||
| : NonMappableSourceNode(std::move(cache)), | |||
| dataset_files_(dataset_files), | |||
| num_samples_(num_samples), | |||
| shuffle_(shuffle), | |||
| num_shards_(num_shards), | |||
| shard_id_(shard_id) {} | |||
| std::shared_ptr<DatasetNode> TextFileNode::Copy() { | |||
| auto node = std::make_shared<TextFileNode>(dataset_files_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); | |||
| return node; | |||
| } | |||
| void TextFileNode::Print(std::ostream &out) const { | |||
| out << Name() + "(file:..." + ",num_shards:" + std::to_string(num_shards_) + | |||
| ",shard_id:" + std::to_string(shard_id_) + ",cache:" + ((cache_ != nullptr) ? "true" : "false") + ",...)"; | |||
| } | |||
| Status TextFileNode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_)); | |||
| @@ -28,7 +28,7 @@ namespace dataset { | |||
| /// \class TextFileNode | |||
| /// \brief A Dataset derived class to represent TextFile dataset | |||
| class TextFileNode : public DatasetNode { | |||
| class TextFileNode : public NonMappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, | |||
| @@ -37,6 +37,18 @@ class TextFileNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~TextFileNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kTextFileNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -30,6 +30,23 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| std::shared_ptr<DatasetNode> TFRecordNode::Copy() { | |||
| std::shared_ptr<TFRecordNode> node; | |||
| if (schema_obj_ != nullptr) { | |||
| node = std::make_shared<TFRecordNode>(dataset_files_, schema_obj_, columns_list_, num_samples_, shuffle_, | |||
| num_shards_, shard_id_, shard_equal_rows_, cache_); | |||
| } else { | |||
| node = std::make_shared<TFRecordNode>(dataset_files_, schema_path_, columns_list_, num_samples_, shuffle_, | |||
| num_shards_, shard_id_, shard_equal_rows_, cache_); | |||
| } | |||
| return node; | |||
| } | |||
| void TFRecordNode::Print(std::ostream &out) const { | |||
| out << Name() + "(num_samples:" + std::to_string(num_samples_) + ",num_shards:" + std::to_string(num_shards_) + | |||
| ",shard_id:" + std::to_string(shard_id_) + ",...)"; | |||
| } | |||
| // Validator for TFRecordNode | |||
| Status TFRecordNode::ValidateParams() { | |||
| if (dataset_files_.empty()) { | |||
| @@ -29,14 +29,14 @@ namespace dataset { | |||
| /// \class TFRecordNode | |||
| /// \brief A Dataset derived class to represent TFRecord dataset | |||
| class TFRecordNode : public DatasetNode { | |||
| class TFRecordNode : public NonMappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \note Parameter 'schema' is the path to the schema file | |||
| TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema, | |||
| const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | |||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), | |||
| : NonMappableSourceNode(std::move(cache)), | |||
| dataset_files_(dataset_files), | |||
| schema_path_(schema), | |||
| columns_list_(columns_list), | |||
| @@ -51,7 +51,7 @@ class TFRecordNode : public DatasetNode { | |||
| TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema, | |||
| const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | |||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), | |||
| : NonMappableSourceNode(std::move(cache)), | |||
| dataset_files_(dataset_files), | |||
| schema_obj_(schema), | |||
| columns_list_(columns_list), | |||
| @@ -64,6 +64,18 @@ class TFRecordNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~TFRecordNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kTFRecordNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -32,7 +32,7 @@ namespace dataset { | |||
| VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | |||
| const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler, | |||
| std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), | |||
| : MappableSourceNode(std::move(cache)), | |||
| dataset_dir_(dataset_dir), | |||
| task_(task), | |||
| usage_(usage), | |||
| @@ -40,6 +40,14 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const | |||
| decode_(decode), | |||
| sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> VOCNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_); | |||
| return node; | |||
| } | |||
| void VOCNode::Print(std::ostream &out) const { out << Name(); } | |||
| Status VOCNode::ValidateParams() { | |||
| Path dir(dataset_dir_); | |||
| @@ -27,7 +27,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class VOCNode : public DatasetNode { | |||
| class VOCNode : public MappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | |||
| @@ -37,6 +37,18 @@ class VOCNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~VOCNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kVOCNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return shared pointer to the list of newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -29,7 +29,16 @@ namespace dataset { | |||
| // Constructor for SyncWaitNode | |||
| SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback) | |||
| : condition_name_(condition_name), callback_(callback) { | |||
| this->children.push_back(child); | |||
| this->AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> SyncWaitNode::Copy() { | |||
| auto node = std::make_shared<SyncWaitNode>(nullptr, condition_name_, callback_); | |||
| return node; | |||
| } | |||
| void SyncWaitNode::Print(std::ostream &out) const { | |||
| out << Name() + "(cond_name:" + condition_name_ + "<pyfunc>" + ")"; | |||
| } | |||
| // Function to build the BarrierOp | |||
| @@ -36,6 +36,18 @@ class SyncWaitNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~SyncWaitNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kSyncWaitNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -27,10 +27,15 @@ namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor for TakeNode | |||
| TakeNode::TakeNode(std::shared_ptr<DatasetNode> child, int32_t count) : take_count_(count) { | |||
| this->children.push_back(child); | |||
| TakeNode::TakeNode(std::shared_ptr<DatasetNode> child, int32_t count) : take_count_(count) { this->AddChild(child); } | |||
| std::shared_ptr<DatasetNode> TakeNode::Copy() { | |||
| auto node = std::make_shared<TakeNode>(nullptr, take_count_); | |||
| return node; | |||
| } | |||
| void TakeNode::Print(std::ostream &out) const { out << Name() + "(num_rows:" + std::to_string(take_count_) + ")"; } | |||
| // Function to build the TakeOp | |||
| std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() { | |||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||
| @@ -34,6 +34,18 @@ class TakeNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~TakeNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kTakeNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return shared pointer to the list of newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -22,6 +22,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "utils/ms_context.h" | |||
| @@ -39,7 +40,19 @@ TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue | |||
| total_batch_(total_batch), | |||
| create_data_info_queue_(create_data_info_queue), | |||
| device_id_(0) { | |||
| this->children.push_back(child); | |||
| this->AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> TransferNode::Copy() { | |||
| auto node = std::make_shared<TransferNode>(nullptr, queue_name_, device_type_, send_epoch_end_, total_batch_, | |||
| create_data_info_queue_); | |||
| return node; | |||
| } | |||
| void TransferNode::Print(std::ostream &out) const { | |||
| out << Name() + "(prefetch_size:" + std::to_string(prefetch_size_) + | |||
| ",send_epoch_end:" + (send_epoch_end_ ? "true" : "false") + ",total_batch:" + std::to_string(total_batch_) + | |||
| ")"; | |||
| } | |||
| // Validator for TransferNode | |||
| @@ -94,5 +107,16 @@ std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() { | |||
| return node_ops; | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status TransferNode::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<TransferNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status TransferNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<TransferNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -35,6 +35,18 @@ class TransferNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~TransferNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kTransferNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return shared pointer to the list of newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -43,6 +55,20 @@ class TransferNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id); | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| private: | |||
| std::string queue_name_; | |||
| int32_t device_id_; | |||
| @@ -21,30 +21,36 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/zip_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| ZipNode::ZipNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) : datasets_(datasets) { | |||
| for (auto dataset : datasets_) { | |||
| this->children.push_back(dataset); | |||
| } | |||
| ZipNode::ZipNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { | |||
| for (auto const &child : datasets) AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> ZipNode::Copy() { | |||
| std::vector<std::shared_ptr<DatasetNode>> empty_vector; | |||
| empty_vector.clear(); | |||
| auto node = std::make_shared<ZipNode>(empty_vector); | |||
| return node; | |||
| } | |||
| void ZipNode::Print(std::ostream &out) const { out << Name(); } | |||
| Status ZipNode::ValidateParams() { | |||
| if (datasets_.empty()) { | |||
| std::string err_msg = "ZipNode: datasets to zip are not specified."; | |||
| if (children_.size() < 2) { | |||
| std::string err_msg = "ZipNode: input datasets are not specified."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { | |||
| std::string err_msg = "ZipNode: zip datasets should not be null."; | |||
| if (find(children_.begin(), children_.end(), nullptr) != children_.end()) { | |||
| std::string err_msg = "ZipNode: input datasets should not be null."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -56,5 +62,17 @@ std::vector<std::shared_ptr<DatasetOp>> ZipNode::Build() { | |||
| return node_ops; | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status ZipNode::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<ZipNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status ZipNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<ZipNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -34,6 +34,18 @@ class ZipNode : public DatasetNode { | |||
| /// \brief Destructor | |||
| ~ZipNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kZipNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \return The list of shared pointers to the newly created DatasetOps | |||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||
| @@ -42,8 +54,17 @@ class ZipNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| private: | |||
| std::vector<std::shared_ptr<DatasetNode>> datasets_; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -22,10 +22,12 @@ | |||
| #endif | |||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/filter_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/root_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||
| #ifdef ENABLE_PYTHON | |||
| @@ -34,34 +36,6 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/transfer_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | |||
| #endif | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" | |||
| ////////////////////////////////// | |||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | |||
| @@ -113,7 +87,12 @@ namespace mindspore { | |||
| namespace dataset { | |||
| // Driver method for TreePass | |||
| Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); } | |||
| Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | |||
| if (root_ir == nullptr || modified == nullptr) { | |||
| return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); | |||
| } | |||
| return this->RunOnTree(root_ir, modified); | |||
| } | |||
| // Driver method for NodePass | |||
| Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | |||
| @@ -132,15 +111,23 @@ Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | |||
| // Helper function to perform DFS visit | |||
| Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) { | |||
| RETURN_IF_NOT_OK(node_ir->Accept(this, modified)); | |||
| bool m = false; | |||
| RETURN_IF_NOT_OK(node_ir->Accept(this, &m)); | |||
| *modified |= m; | |||
| for (const auto &c : node_ir->Children()) { | |||
| RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); | |||
| RETURN_IF_NOT_OK(this->DFSNodeVisit(c, &m)); | |||
| *modified |= m; | |||
| } | |||
| return node_ir->AcceptAfter(this, modified); | |||
| RETURN_IF_NOT_OK(node_ir->AcceptAfter(this, &m)); | |||
| *modified |= m; | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to perform BFS visit | |||
| Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) { | |||
| bool m = false; | |||
| // Initialize bfs queue with root | |||
| std::queue<std::shared_ptr<DatasetNode>> bfsQueue; | |||
| bfsQueue.push(node_ir); | |||
| @@ -152,7 +139,8 @@ Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi | |||
| bfsQueue.pop(); | |||
| // Run node pass | |||
| RETURN_IF_NOT_OK(curNode->Accept(this, modified)); | |||
| RETURN_IF_NOT_OK(curNode->Accept(this, &m)); | |||
| *modified |= m; | |||
| // Push children into bfs queue | |||
| for (const auto &c : curNode->Children()) { | |||
| @@ -162,331 +150,119 @@ Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi | |||
| return Status::OK(); | |||
| } | |||
| // For datasetops IR | |||
| // For non-leaf IR node | |||
| Status NodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #endif | |||
| Status NodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<FilterNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| Status NodePass::Visit(std::shared_ptr<RootNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| Status NodePass::VisitAfter(std::shared_ptr<RootNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| Status NodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| Status NodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #ifdef ENABLE_PYTHON | |||
| Status NodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| Status NodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| Status NodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #endif | |||
| Status NodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| // For datasetops/source IR | |||
| Status NodePass::Visit(std::shared_ptr<AlbumNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<AlbumNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<CelebANode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<CelebANode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<Cifar100Node> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<Cifar100Node> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<Cifar10Node> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<Cifar10Node> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status NodePass::Visit(std::shared_ptr<CLUENode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<CLUENode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #endif | |||
| Status NodePass::Visit(std::shared_ptr<CocoNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<CocoNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status NodePass::Visit(std::shared_ptr<CSVNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<CSVNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| Status NodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #endif | |||
| Status NodePass::Visit(std::shared_ptr<ImageFolderNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<ImageFolderNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<ManifestNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<ManifestNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status NodePass::Visit(std::shared_ptr<MindDataNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<MindDataNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #endif | |||
| Status NodePass::Visit(std::shared_ptr<MnistNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<MnistNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<RandomNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<RandomNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status NodePass::Visit(std::shared_ptr<TextFileNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| Status NodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<TextFileNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| Status NodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #endif | |||
| #ifndef ENABLE_ANDROID | |||
| Status NodePass::Visit(std::shared_ptr<TFRecordNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<TFRecordNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #endif | |||
| Status NodePass::Visit(std::shared_ptr<VOCNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| // For leaf IR Node | |||
| Status NodePass::Visit(std::shared_ptr<SourceNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<VOCNode> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| Status NodePass::VisitAfter(std::shared_ptr<SourceNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| @@ -26,123 +26,87 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Non-leaf IR node | |||
| class BatchNode; | |||
| class BucketBatchByLengthNode; | |||
| #ifndef ENABLE_ANDROID | |||
| class BuildSentenceVocabNode; | |||
| #endif | |||
| class BuildVocabNode; | |||
| class ConcatNode; | |||
| class FilterNode; | |||
| class MapNode; | |||
| class ProjectNode; | |||
| class RenameNode; | |||
| class RepeatNode; | |||
| class RootNode; | |||
| class ShuffleNode; | |||
| class SkipNode; | |||
| #ifdef ENABLE_PYTHON | |||
| class SyncWaitNode; | |||
| #endif | |||
| class TakeNode; | |||
| class TransferNode; | |||
| class ZipNode; | |||
| #ifdef ENABLE_PYTHON | |||
| class SyncWaitNode; | |||
| #endif | |||
| #ifndef ENABLE_ANDROID | |||
| class BuildSentenceVocabNode; | |||
| #endif | |||
| // Leaf IR node | |||
| class AlbumNode; | |||
| class CelebANode; | |||
| class Cifar100Node; | |||
| class Cifar10Node; | |||
| #ifndef ENABLE_ANDROID | |||
| class CLUENode; | |||
| #endif | |||
| class CocoNode; | |||
| #ifndef ENABLE_ANDROID | |||
| class CSVNode; | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| class GeneratorNode; | |||
| #endif | |||
| class ImageFolderNode; | |||
| class ManifestNode; | |||
| #ifndef ENABLE_ANDROID | |||
| class MindDataNode; | |||
| #endif | |||
| class MnistNode; | |||
| class RandomNode; | |||
| #ifndef ENABLE_ANDROID | |||
| class TextFileNode; | |||
| class VOCNode; | |||
| #ifdef ENABLE_PYTHON | |||
| class GeneratorNode; | |||
| #endif | |||
| #ifndef ENABLE_ANDROID | |||
| class CLUENode; | |||
| class CSVNode; | |||
| class MindDataNode; | |||
| class TextFileNode; | |||
| class TFRecordNode; | |||
| #endif | |||
| class VOCNode; | |||
| ////////////////////////////////// | |||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | |||
| class BatchOp; | |||
| class MapOp; | |||
| class ProjectOp; | |||
| class RenameOp; | |||
| class SkipOp; | |||
| class ShuffleOp; | |||
| class AlbumOp; | |||
| class RandomDataOp; | |||
| class RepeatOp; | |||
| class TakeOp; | |||
| class ZipOp; | |||
| class DeviceQueueOp; | |||
| class ImageFolderOp; | |||
| class MnistOp; | |||
| class ManifestOp; | |||
| class CifarOp; | |||
| class VOCOp; | |||
| class CocoOp; | |||
| class CelebAOp; | |||
| class EpochCtrlOp; | |||
| class BuildVocabOp; | |||
| class ConcatOp; | |||
| #ifndef ENABLE_ANDROID | |||
| class MindRecordOp; | |||
| class TFReaderOp; | |||
| class CacheOp; | |||
| class CacheMergeOp; | |||
| class CacheLookupOp; | |||
| class BuildSentencePieceVocabOp; | |||
| class ClueOp; | |||
| class CsvOp; | |||
| class TextFileOp; | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| class FilterOp; | |||
| class GeneratorOp; | |||
| #endif | |||
| ////////////////////////////////// | |||
| @@ -175,6 +139,13 @@ class TreePass : public Pass { | |||
| /// \param[inout] modified Indicate if the tree was modified | |||
| Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) final; | |||
| /// \brief Derived classes may implement the runOnTree function to implement tree transformation. | |||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution. | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); } | |||
| ////////////////////////////////// | |||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | |||
| /// \brief Run the transformation pass against the execution tree. | |||
| @@ -191,8 +162,17 @@ class TreePass : public Pass { | |||
| ////////////////////////////////// | |||
| }; | |||
| // NodePass is a basic Pass class which performs transformation on Node visiting. | |||
| // NodePass is a base Pass class which performs transformation on node visiting. | |||
| // NodePass implements Visitor design pattern. | |||
| // The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal, | |||
| // and the other when all the descending nodes are visited. | |||
| // Actual transformation is done by implementing a new derived class of NodePass. | |||
| // The derived class will implement the method Visit()/VisitAfter() passing specified node types | |||
| // it wants to action on them, overriding the ones defined in NodePass. | |||
| // If the derived class wants to perform the same action on all node types, | |||
| // it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode. | |||
| // This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back | |||
| // to call the Visit()/VisitAfter() in this parent NodePass class. | |||
| class NodePass : public Pass { | |||
| public: | |||
| // Tree traversal order | |||
| @@ -223,153 +203,57 @@ class NodePass : public Pass { | |||
| /// \return Status The error code return | |||
| virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); } | |||
| // For datasetops IR | |||
| // Visit method to be overridden. | |||
| // Note that member template can not be virtual, any node which wants to work with NodePass | |||
| // should declare Visit of its own type and override "Accept" from DatasetNode. | |||
| // Visit()/VisitAfter() method to be overridden. | |||
| // These pairs of Visit()/VisitAfter() for each derived class of DatasetNode are defined here. | |||
| // Their implementation are in .cc file to avoid adding the include files of those derived classes. | |||
| // The implementation simply falls back to call Visit()/VisitAfter of class DatasetNode, the parent of | |||
| // the derived classes. With this technique, the transformation classes derived from NodePass needs only to | |||
| // implement Visit()/VisitAfter() passing DatasetNode if it wants to action on any derived classes | |||
| // of DatasetNode in the same way. | |||
| // Note that virtual template functions are not permitted in C++. | |||
| // | |||
| // Non-leaf IR node | |||
| virtual Status Visit(std::shared_ptr<BatchNode> node, bool *modified); | |||
| // VisitAfter method to be overridden. | |||
| // Note that member template can not be virtual, any node which wants to work with NodePass | |||
| // should declare VisitAfter of its own type and override "AcceptAfter" from DatasetNode. | |||
| virtual Status VisitAfter(std::shared_ptr<BatchNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified); | |||
| #endif | |||
| virtual Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<FilterNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<FilterNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<MapNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<RenameNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<RootNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<RootNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<SkipNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<SkipNode> node, bool *modified); | |||
| #ifdef ENABLE_PYTHON | |||
| virtual Status Visit(std::shared_ptr<SyncWaitNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified); | |||
| #endif | |||
| virtual Status Visit(std::shared_ptr<TakeNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<TakeNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<TransferNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<ZipNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<ZipNode> node, bool *modified); | |||
| // For datasetops/source IR | |||
| virtual Status Visit(std::shared_ptr<AlbumNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<AlbumNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<CelebANode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<CelebANode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<Cifar100Node> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<Cifar100Node> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<Cifar10Node> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<Cifar10Node> node, bool *modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status Visit(std::shared_ptr<CLUENode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<CLUENode> node, bool *modified); | |||
| #endif | |||
| virtual Status Visit(std::shared_ptr<CocoNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<CocoNode> node, bool *modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status Visit(std::shared_ptr<CSVNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<CSVNode> node, bool *modified); | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| virtual Status Visit(std::shared_ptr<GeneratorNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<GeneratorNode> node, bool *modified); | |||
| #endif | |||
| virtual Status Visit(std::shared_ptr<ImageFolderNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<ImageFolderNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<ManifestNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<ManifestNode> node, bool *modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status Visit(std::shared_ptr<MindDataNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<MindDataNode> node, bool *modified); | |||
| #endif | |||
| virtual Status Visit(std::shared_ptr<MnistNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<MnistNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<RandomNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<RandomNode> node, bool *modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status Visit(std::shared_ptr<TextFileNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<TextFileNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<SyncWaitNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified); | |||
| #endif | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status Visit(std::shared_ptr<TFRecordNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<TFRecordNode> node, bool *modified); | |||
| virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified); | |||
| #endif | |||
| virtual Status Visit(std::shared_ptr<VOCNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<VOCNode> node, bool *modified); | |||
| // Leaf IR node | |||
| virtual Status Visit(std::shared_ptr<SourceNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<SourceNode> node, bool *modified); | |||
| ////////////////////////////////// | |||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | |||
| @@ -396,86 +280,47 @@ class NodePass : public Pass { | |||
| // Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode | |||
| // of its own type and override "Accept" from DatasetOp. | |||
| virtual Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified); | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | |||
| #endif | |||
| ////////////////////////////////// | |||
| @@ -18,6 +18,7 @@ | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "minddata/dataset/include/datasets.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/root_node.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/input_validation_pass.h" | |||
| @@ -119,11 +120,16 @@ Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::sha | |||
| return Status::OK(); | |||
| } | |||
| Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs) { | |||
| num_epochs_ = num_epochs; | |||
| Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_epochs) { | |||
| optimize_ = true; // Always ON (temporary) | |||
| RETURN_UNEXPECTED_IF_NULL(root_ir); | |||
| RETURN_UNEXPECTED_IF_NULL(input_ir); | |||
| MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n'; | |||
| // Copy the input IR tree and insert under the root node | |||
| // Create a root node to host the input IR tree | |||
| auto root_ir = std::make_shared<RootNode>(input_ir->DeepCopy(), num_epochs); | |||
| MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n'; | |||
| // Pre-pass of the IR tree | |||
| RETURN_IF_NOT_OK(PrePass(root_ir)); | |||
| @@ -136,11 +142,15 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep | |||
| // Post-pass of the IR tree | |||
| RETURN_IF_NOT_OK(PostPass(root_ir)); | |||
| MS_LOG(INFO) << "Plan after PostPass:" << '\n' << *root_ir << '\n'; | |||
| // This will evolve in the long run | |||
| tree_ = std::make_unique<ExecutionTree>(); | |||
| // Build the Execution tree from the child of the root node | |||
| std::shared_ptr<DatasetOp> root_op; | |||
| RETURN_IF_NOT_OK(BuildExecutionTree(root_ir, &root_op)); | |||
| // We will replace input_ir with root_ir->Children()[0] once IR optimizer is in | |||
| RETURN_IF_NOT_OK(BuildExecutionTree(input_ir, &root_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); | |||
| if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_); | |||
| @@ -67,10 +67,6 @@ class TreeAdapter { | |||
| // Optional optimizations status | |||
| bool OptimizationEnabled() const { return optimize_; } | |||
| // Getter function to get the total number of epochs to be run on this tree. | |||
| // @return total number of epochs | |||
| int32_t num_epochs() { return num_epochs_; } | |||
| private: | |||
| // This function runs a mandatory pass checking the syntax and semantics of the IR tree. | |||
| Status PrePass(std::shared_ptr<DatasetNode> ir); | |||
| @@ -47,6 +47,10 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> { | |||
| /// \return Shared pointers to the newly created Sampler | |||
| virtual std::shared_ptr<SamplerRT> Build() = 0; | |||
| /// \brief Pure virtual function to copy a SamplerObj class | |||
| /// \return Shared pointers to the newly copied SamplerObj | |||
| virtual std::shared_ptr<SamplerObj> Copy() = 0; | |||
| /// \brief Function for derived class to get the shard id of sampler | |||
| /// \return The shard id of the derived sampler | |||
| virtual int64_t ShardId() { return 0; } | |||
| @@ -132,6 +136,11 @@ class DistributedSamplerObj : public SamplerObj { | |||
| std::shared_ptr<SamplerRT> Build() override; | |||
| std::shared_ptr<SamplerObj> Copy() override { | |||
| return std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_, | |||
| even_dist_); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| @@ -160,6 +169,10 @@ class PKSamplerObj : public SamplerObj { | |||
| std::shared_ptr<SamplerRT> Build() override; | |||
| std::shared_ptr<SamplerObj> Copy() override { | |||
| return std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| @@ -174,9 +187,8 @@ class PKSamplerObj : public SamplerObj { | |||
| class PreBuiltSamplerObj : public SamplerObj { | |||
| public: | |||
| #ifndef ENABLE_ANDROID | |||
| explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler); | |||
| #ifndef ENABLE_ANDROID | |||
| explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler); | |||
| #endif | |||
| @@ -188,6 +200,8 @@ class PreBuiltSamplerObj : public SamplerObj { | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| std::shared_ptr<SamplerObj> Copy() override; | |||
| bool ValidateParams() override; | |||
| private: | |||
| @@ -205,6 +219,8 @@ class RandomSamplerObj : public SamplerObj { | |||
| std::shared_ptr<SamplerRT> Build() override; | |||
| std::shared_ptr<SamplerObj> Copy() override { return std::make_shared<RandomSamplerObj>(replacement_, num_samples_); } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| @@ -224,6 +240,10 @@ class SequentialSamplerObj : public SamplerObj { | |||
| std::shared_ptr<SamplerRT> Build() override; | |||
| std::shared_ptr<SamplerObj> Copy() override { | |||
| return std::make_shared<SequentialSamplerObj>(start_index_, num_samples_); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| @@ -243,6 +263,10 @@ class SubsetRandomSamplerObj : public SamplerObj { | |||
| std::shared_ptr<SamplerRT> Build() override; | |||
| std::shared_ptr<SamplerObj> Copy() override { | |||
| return std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| @@ -262,6 +286,10 @@ class WeightedRandomSamplerObj : public SamplerObj { | |||
| std::shared_ptr<SamplerRT> Build() override; | |||
| std::shared_ptr<SamplerObj> Copy() override { | |||
| return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); | |||
| } | |||
| bool ValidateParams() override; | |||
| private: | |||
| @@ -32,7 +32,10 @@ class TensorOp; | |||
| class TensorOperation : public std::enable_shared_from_this<TensorOperation> { | |||
| public: | |||
| /// \brief Constructor | |||
| TensorOperation(); | |||
| TensorOperation() : random_op_(false) {} | |||
| /// \brief Constructor | |||
| explicit TensorOperation(bool random) : random_op_(random) {} | |||
| /// \brief Destructor | |||
| ~TensorOperation() = default; | |||
| @@ -42,6 +45,13 @@ class TensorOperation : public std::enable_shared_from_this<TensorOperation> { | |||
| virtual std::shared_ptr<TensorOp> Build() = 0; | |||
| virtual Status ValidateParams() = 0; | |||
| /// \brief Check whether the operation is deterministic. | |||
| /// \return true if this op is a random op (returns non-deterministic result e.g. RandomCrop) | |||
| bool IsRandomOp() const { return random_op_; } | |||
| protected: | |||
| bool random_op_; | |||
| }; | |||
| // Helper function to validate fill value | |||
| @@ -427,7 +427,7 @@ Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> | |||
| Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst, | |||
| std::vector<dsize_t> cur_ind, size_t cur_dim) { | |||
| if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data | |||
| dst->CopyLastDimAt(src, cur_ind); | |||
| RETURN_IF_NOT_OK(dst->CopyLastDimAt(src, cur_ind)); | |||
| } else { // not the last dimension, keep doing recursion | |||
| dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]); | |||
| for (dsize_t i = 0; i < min_ind; i++) { | |||
| @@ -57,7 +57,7 @@ class RandomCropOp : public TensorOp { | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| // Function breaks out the compute function's image padding functionality and makes available to other Ops | |||
| // Using this class as a base - restructrued to allow for RandomCropWithBBox Augmentation Op | |||
| // Using this class as a base - re-structured to allow for RandomCropWithBBox Augmentation Op | |||
| // @param input: Input is the original Image | |||
| // @param pad_image: Pointer to new Padded image | |||
| // @param t_pad_top: Total Top Padding - Based on input and value calculated in function if required | |||
| @@ -570,7 +570,7 @@ class WeightedRandomSampler(BuiltinSampler): | |||
| Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). | |||
| Args: | |||
| weights (list[float]): A sequence of weights, not necessarily summing up to 1. | |||
| weights (list[float, int]): A sequence of weights, not necessarily summing up to 1. | |||
| num_samples (int, optional): Number of elements to sample (default=None, all elements). | |||
| replacement (bool): If True, put the sample ID back for the next draw (default=True). | |||
| @@ -17,6 +17,7 @@ SET(DE_UT_SRCS | |||
| c_api_dataset_coco_test.cc | |||
| c_api_dataset_config_test.cc | |||
| c_api_dataset_csv_test.cc | |||
| c_api_dataset_ir_node_test.cc | |||
| c_api_dataset_iterator_test.cc | |||
| c_api_dataset_manifest_test.cc | |||
| c_api_dataset_minddata_test.cc | |||
| @@ -0,0 +1,142 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "common/common.h" | |||
| #include "gtest/gtest.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| #include "minddata/dataset/engine/opt/pre/getter_pass.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::LogStream; | |||
| using mindspore::MsLogLevel::INFO; | |||
| class MindDataTestIRNodes : public UT::DatasetOpTesting { | |||
| public: | |||
| MindDataTestIRNodes() = default; | |||
| void SetUp() override { GlobalInit(); } | |||
| // compare the ptr of the nodes in two trees, used to test the deep copy of nodes, will return error code | |||
| // if (ptr1 == ptr2) does not equal to flag or the two tree has different structures (or node names are not the same) | |||
| Status CompareTwoTrees(std::shared_ptr<DatasetNode> root1, std::shared_ptr<DatasetNode> root2, bool flag) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(root1 != nullptr && root2 != nullptr, "Error in Compare, nullptr."); | |||
| if (((root1.get() == root2.get()) != flag) || (root1->Name() != root2->Name())) { | |||
| std::string err_msg = | |||
| "Expect node ptr " + root1->Name() + (flag ? "==" : "!=") + root2->Name() + " but they aren't!"; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| size_t num_child = root1->Children().size(); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_child == root2->Children().size(), | |||
| root1->Name() + " has " + std::to_string(num_child) + "child, node #2 has " + | |||
| std::to_string(root2->Children().size()) + " child."); | |||
| for (size_t ind = 0; ind < num_child; ind++) { | |||
| RETURN_IF_NOT_OK(CompareTwoTrees(root1->Children()[ind], root2->Children()[ind], flag)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // print the node's name in post order | |||
| Status PostOrderPrintTree(std::shared_ptr<DatasetNode> ir, std::string &names) { | |||
| RETURN_UNEXPECTED_IF_NULL(ir); | |||
| for (auto child : ir->Children()) { | |||
| RETURN_IF_NOT_OK(PostOrderPrintTree(child, names)); | |||
| } | |||
| names += (ir->Name() + "->"); | |||
| return Status::OK(); | |||
| } | |||
| }; | |||
| TEST_F(MindDataTestIRNodes, MindDataTestSimpleDeepCopy) { | |||
| MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestSimpleDeepCopy."; | |||
| auto tree1 = RandomData(44)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2)->IRNode(); | |||
| auto tree2 = tree1->DeepCopy(); | |||
| std::string tree_1_names, tree_2_names; | |||
| ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names)); | |||
| ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names)); | |||
| // expected output for the 2 names: | |||
| // RandomDataset->Repeat->Project->Shuffle->Batch-> | |||
| EXPECT_EQ(tree_1_names, tree_2_names); | |||
| ASSERT_OK(CompareTwoTrees(tree1, tree1, true)); | |||
| ASSERT_OK(CompareTwoTrees(tree1, tree2, false)); | |||
| // verify compare function is correct | |||
| EXPECT_TRUE(CompareTwoTrees(tree2, tree2, false).IsError()); | |||
| } | |||
| TEST_F(MindDataTestIRNodes, MindDataTestZipDeepCopy) { | |||
| MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestZipDeepCopy."; | |||
| auto branch1 = RandomData(44)->Project({"label"}); | |||
| auto branch2 = RandomData(44)->Shuffle(10); | |||
| auto tree1 = Zip({branch1, branch2})->Batch(2)->IRNode(); | |||
| auto tree2 = tree1->DeepCopy(); | |||
| std::string tree_1_names, tree_2_names; | |||
| ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names)); | |||
| ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names)); | |||
| // expected output for the 2 names: | |||
| // RandomDataset->Project->RandomDataset->Shuffle->Zip->Batch-> | |||
| EXPECT_EQ(tree_1_names, tree_2_names); | |||
| // verify the pointer within the same tree are the same | |||
| ASSERT_OK(CompareTwoTrees(tree1, tree1, true)); | |||
| // verify two trees | |||
| ASSERT_OK(CompareTwoTrees(tree1, tree2, false)); | |||
| } | |||
| TEST_F(MindDataTestIRNodes, MindDataTestNodeRemove) { | |||
| MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestNodeRemove."; | |||
| auto branch1 = RandomData(44)->Project({"label"}); | |||
| auto branch2 = ImageFolder("path"); | |||
| auto tree = Zip({branch1, branch2})->IRNode(); | |||
| /*** | |||
| tree looks like this, we will remove node and test its functionalities | |||
| Zip | |||
| / \ | |||
| Project ImageFolder | |||
| / | |||
| RandomData | |||
| ***/ | |||
| auto tree_copy_1 = tree->DeepCopy(); | |||
| ASSERT_EQ(tree_copy_1->Children().size(), 2); | |||
| // remove the project in the tree and test | |||
| ASSERT_OK(tree_copy_1->Children()[0]->Remove()); // remove Project from tree | |||
| ASSERT_OK(CompareTwoTrees(tree_copy_1, Zip({RandomData(44), ImageFolder("path")})->IRNode(), false)); | |||
| // remove the ImageFolder, a leaf node from the tree | |||
| std::string tree_1_names, tree_2_names; | |||
| ASSERT_OK(PostOrderPrintTree(tree_copy_1, tree_1_names)); | |||
| EXPECT_EQ(tree_1_names, "RandomDataset->ImageFolderDataset->Zip->"); | |||
| auto tree_copy_2 = tree->DeepCopy(); | |||
| ASSERT_EQ(tree_copy_2->Children().size(), 2); | |||
| tree_copy_2->Children()[1]->Remove(); | |||
| ASSERT_OK(PostOrderPrintTree(tree_copy_2, tree_2_names)); | |||
| EXPECT_EQ(tree_2_names, "RandomDataset->Project->Zip->"); | |||
| } | |||