add 2 test cases to IRNode Deepcopy() address review cmts fix ut samplerObj copy ci fix ci fix ci round III address further review cmts add a missing macro fix merge conflict fix complie err fix lite compile err fix compile err fix lite compile round III address an issue fix minor commentstags/v1.1.0
| @@ -568,8 +568,8 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab( | |||||
| const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage, | const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage, | ||||
| SentencePieceModel model_type, const std::unordered_map<std::string, std::string> ¶ms) { | SentencePieceModel model_type, const std::unordered_map<std::string, std::string> ¶ms) { | ||||
| auto vocab = std::make_shared<SentencePieceVocab>(); | auto vocab = std::make_shared<SentencePieceVocab>(); | ||||
| auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage, | |||||
| model_type, params); | |||||
| auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode()->DeepCopy(), vocab, col_names, vocab_size, | |||||
| character_coverage, model_type, params); | |||||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | ||||
| Status rc = runtime_context->Init(); | Status rc = runtime_context->Init(); | ||||
| @@ -600,8 +600,8 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum | |||||
| const std::pair<int64_t, int64_t> &freq_range, int64_t top_k, | const std::pair<int64_t, int64_t> &freq_range, int64_t top_k, | ||||
| const std::vector<std::string> &special_tokens, bool special_first) { | const std::vector<std::string> &special_tokens, bool special_first) { | ||||
| auto vocab = std::make_shared<Vocab>(); | auto vocab = std::make_shared<Vocab>(); | ||||
| auto ds = | |||||
| std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first); | |||||
| auto ds = std::make_shared<BuildVocabNode>(IRNode()->DeepCopy(), vocab, columns, freq_range, top_k, special_tokens, | |||||
| special_first); | |||||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | ||||
| Status rc = runtime_context->Init(); | Status rc = runtime_context->Init(); | ||||
| @@ -190,13 +190,12 @@ std::shared_ptr<SamplerRT> PKSamplerObj::Build() { | |||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | |||||
| // PreBuiltOperation | // PreBuiltOperation | ||||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) | |||||
| : sp_(std::move(sampler)), sp_minddataset_(nullptr) {} | |||||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {} | |||||
| #ifndef ENABLE_ANDROID | |||||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler) | PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler) | ||||
| : sp_(nullptr), sp_minddataset_(std::move(sampler)) {} | |||||
| : sp_minddataset_(std::move(sampler)) {} | |||||
| #endif | #endif | ||||
| bool PreBuiltSamplerObj::ValidateParams() { return true; } | bool PreBuiltSamplerObj::ValidateParams() { return true; } | ||||
| @@ -207,6 +206,13 @@ std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; } | |||||
| std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } | std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } | ||||
| #endif | #endif | ||||
| std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() { | |||||
| #ifndef ENABLE_ANDROID | |||||
| if (sp_minddataset_ != nullptr) return std::make_shared<PreBuiltSamplerObj>(sp_minddataset_); | |||||
| #endif | |||||
| return std::make_shared<PreBuiltSamplerObj>(sp_); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() { | std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() { | ||||
| // runtime mindrecord sampler object | // runtime mindrecord sampler object | ||||
| @@ -30,8 +30,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| TensorOperation::TensorOperation() {} | |||||
| /* ####################################### Validator Functions ############################################ */ | /* ####################################### Validator Functions ############################################ */ | ||||
| Status ValidateVectorFillvalue(const std::string &transform_name, const std::vector<uint8_t> &fill_value) { | Status ValidateVectorFillvalue(const std::string &transform_name, const std::vector<uint8_t> &fill_value) { | ||||
| if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) { | if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) { | ||||
| @@ -231,7 +229,7 @@ std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; } | |||||
| // RandomApplyOperation | // RandomApplyOperation | ||||
| RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob) | RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob) | ||||
| : transforms_(transforms), prob_(prob) {} | |||||
| : TensorOperation(true), transforms_(transforms), prob_(prob) {} | |||||
| Status RandomApplyOperation::ValidateParams() { | Status RandomApplyOperation::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomApply", transforms_)); | RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomApply", transforms_)); | ||||
| @@ -248,7 +246,7 @@ std::shared_ptr<TensorOp> RandomApplyOperation::Build() { | |||||
| // RandomChoiceOperation | // RandomChoiceOperation | ||||
| RandomChoiceOperation::RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms) | RandomChoiceOperation::RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms) | ||||
| : transforms_(transforms) {} | |||||
| : TensorOperation(true), transforms_(transforms) {} | |||||
| Status RandomChoiceOperation::ValidateParams() { | Status RandomChoiceOperation::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomChoice", transforms_)); | RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomChoice", transforms_)); | ||||
| @@ -734,7 +734,9 @@ RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> °rees | |||||
| scale_range_(scale_range), | scale_range_(scale_range), | ||||
| shear_ranges_(shear_ranges), | shear_ranges_(shear_ranges), | ||||
| interpolation_(interpolation), | interpolation_(interpolation), | ||||
| fill_value_(fill_value) {} | |||||
| fill_value_(fill_value) { | |||||
| random_op_ = true; | |||||
| } | |||||
| Status RandomAffineOperation::ValidateParams() { | Status RandomAffineOperation::ValidateParams() { | ||||
| // Degrees | // Degrees | ||||
| @@ -867,7 +869,7 @@ std::shared_ptr<TensorOp> RandomAffineOperation::Build() { | |||||
| } | } | ||||
| // RandomColorOperation. | // RandomColorOperation. | ||||
| RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) {} | |||||
| RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) { random_op_ = true; } | |||||
| Status RandomColorOperation::ValidateParams() { | Status RandomColorOperation::ValidateParams() { | ||||
| // Do some input validation. | // Do some input validation. | ||||
| @@ -891,7 +893,9 @@ Status RandomColorOperation::ValidateParams() { | |||||
| // RandomColorAdjustOperation. | // RandomColorAdjustOperation. | ||||
| RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast, | RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast, | ||||
| std::vector<float> saturation, std::vector<float> hue) | std::vector<float> saturation, std::vector<float> hue) | ||||
| : brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {} | |||||
| : brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) { | |||||
| random_op_ = true; | |||||
| } | |||||
| Status RandomColorAdjustOperation::ValidateParams() { | Status RandomColorAdjustOperation::ValidateParams() { | ||||
| // brightness | // brightness | ||||
| @@ -1012,11 +1016,14 @@ std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() { | |||||
| // RandomCropOperation | // RandomCropOperation | ||||
| RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed, | RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed, | ||||
| std::vector<uint8_t> fill_value, BorderType padding_mode) | std::vector<uint8_t> fill_value, BorderType padding_mode) | ||||
| : size_(size), | |||||
| : TensorOperation(true), | |||||
| size_(size), | |||||
| padding_(padding), | padding_(padding), | ||||
| pad_if_needed_(pad_if_needed), | pad_if_needed_(pad_if_needed), | ||||
| fill_value_(fill_value), | fill_value_(fill_value), | ||||
| padding_mode_(padding_mode) {} | |||||
| padding_mode_(padding_mode) { | |||||
| random_op_ = true; | |||||
| } | |||||
| Status RandomCropOperation::ValidateParams() { | Status RandomCropOperation::ValidateParams() { | ||||
| // size | // size | ||||
| @@ -1083,7 +1090,12 @@ std::shared_ptr<TensorOp> RandomCropOperation::Build() { | |||||
| RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale, | RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale, | ||||
| std::vector<float> ratio, | std::vector<float> ratio, | ||||
| InterpolationMode interpolation, int32_t max_attempts) | InterpolationMode interpolation, int32_t max_attempts) | ||||
| : size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {} | |||||
| : TensorOperation(true), | |||||
| size_(size), | |||||
| scale_(scale), | |||||
| ratio_(ratio), | |||||
| interpolation_(interpolation), | |||||
| max_attempts_(max_attempts) {} | |||||
| Status RandomCropDecodeResizeOperation::ValidateParams() { | Status RandomCropDecodeResizeOperation::ValidateParams() { | ||||
| // size | // size | ||||
| @@ -1176,7 +1188,8 @@ std::shared_ptr<TensorOp> RandomCropDecodeResizeOperation::Build() { | |||||
| RandomCropWithBBoxOperation::RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding, | RandomCropWithBBoxOperation::RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding, | ||||
| bool pad_if_needed, std::vector<uint8_t> fill_value, | bool pad_if_needed, std::vector<uint8_t> fill_value, | ||||
| BorderType padding_mode) | BorderType padding_mode) | ||||
| : size_(size), | |||||
| : TensorOperation(true), | |||||
| size_(size), | |||||
| padding_(padding), | padding_(padding), | ||||
| pad_if_needed_(pad_if_needed), | pad_if_needed_(pad_if_needed), | ||||
| fill_value_(fill_value), | fill_value_(fill_value), | ||||
| @@ -1245,7 +1258,8 @@ std::shared_ptr<TensorOp> RandomCropWithBBoxOperation::Build() { | |||||
| } | } | ||||
| // RandomHorizontalFlipOperation | // RandomHorizontalFlipOperation | ||||
| RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {} | |||||
| RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) | |||||
| : TensorOperation(true), probability_(probability) {} | |||||
| Status RandomHorizontalFlipOperation::ValidateParams() { | Status RandomHorizontalFlipOperation::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlip", probability_)); | RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlip", probability_)); | ||||
| @@ -1260,7 +1274,7 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() { | |||||
| // RandomHorizontalFlipWithBBoxOperation | // RandomHorizontalFlipWithBBoxOperation | ||||
| RandomHorizontalFlipWithBBoxOperation::RandomHorizontalFlipWithBBoxOperation(float probability) | RandomHorizontalFlipWithBBoxOperation::RandomHorizontalFlipWithBBoxOperation(float probability) | ||||
| : probability_(probability) {} | |||||
| : TensorOperation(true), probability_(probability) {} | |||||
| Status RandomHorizontalFlipWithBBoxOperation::ValidateParams() { | Status RandomHorizontalFlipWithBBoxOperation::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlipWithBBox", probability_)); | RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlipWithBBox", probability_)); | ||||
| @@ -1275,7 +1289,8 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipWithBBoxOperation::Build() { | |||||
| } | } | ||||
| // RandomPosterizeOperation | // RandomPosterizeOperation | ||||
| RandomPosterizeOperation::RandomPosterizeOperation(const std::vector<uint8_t> &bit_range) : bit_range_(bit_range) {} | |||||
| RandomPosterizeOperation::RandomPosterizeOperation(const std::vector<uint8_t> &bit_range) | |||||
| : TensorOperation(true), bit_range_(bit_range) {} | |||||
| Status RandomPosterizeOperation::ValidateParams() { | Status RandomPosterizeOperation::ValidateParams() { | ||||
| if (bit_range_.size() != 2) { | if (bit_range_.size() != 2) { | ||||
| @@ -1309,7 +1324,7 @@ std::shared_ptr<TensorOp> RandomPosterizeOperation::Build() { | |||||
| } | } | ||||
| // RandomResizeOperation | // RandomResizeOperation | ||||
| RandomResizeOperation::RandomResizeOperation(std::vector<int32_t> size) : size_(size) {} | |||||
| RandomResizeOperation::RandomResizeOperation(std::vector<int32_t> size) : TensorOperation(true), size_(size) {} | |||||
| Status RandomResizeOperation::ValidateParams() { | Status RandomResizeOperation::ValidateParams() { | ||||
| // size | // size | ||||
| @@ -1343,7 +1358,8 @@ std::shared_ptr<TensorOp> RandomResizeOperation::Build() { | |||||
| } | } | ||||
| // RandomResizeWithBBoxOperation | // RandomResizeWithBBoxOperation | ||||
| RandomResizeWithBBoxOperation::RandomResizeWithBBoxOperation(std::vector<int32_t> size) : size_(size) {} | |||||
| RandomResizeWithBBoxOperation::RandomResizeWithBBoxOperation(std::vector<int32_t> size) | |||||
| : TensorOperation(true), size_(size) {} | |||||
| Status RandomResizeWithBBoxOperation::ValidateParams() { | Status RandomResizeWithBBoxOperation::ValidateParams() { | ||||
| // size | // size | ||||
| @@ -1380,7 +1396,12 @@ std::shared_ptr<TensorOp> RandomResizeWithBBoxOperation::Build() { | |||||
| RandomResizedCropOperation::RandomResizedCropOperation(std::vector<int32_t> size, std::vector<float> scale, | RandomResizedCropOperation::RandomResizedCropOperation(std::vector<int32_t> size, std::vector<float> scale, | ||||
| std::vector<float> ratio, InterpolationMode interpolation, | std::vector<float> ratio, InterpolationMode interpolation, | ||||
| int32_t max_attempts) | int32_t max_attempts) | ||||
| : size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {} | |||||
| : TensorOperation(true), | |||||
| size_(size), | |||||
| scale_(scale), | |||||
| ratio_(ratio), | |||||
| interpolation_(interpolation), | |||||
| max_attempts_(max_attempts) {} | |||||
| Status RandomResizedCropOperation::ValidateParams() { | Status RandomResizedCropOperation::ValidateParams() { | ||||
| // size | // size | ||||
| @@ -1536,7 +1557,8 @@ std::shared_ptr<TensorOp> RandomResizedCropWithBBoxOperation::Build() { | |||||
| RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode, | RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode, | ||||
| bool expand, std::vector<float> center, | bool expand, std::vector<float> center, | ||||
| std::vector<uint8_t> fill_value) | std::vector<uint8_t> fill_value) | ||||
| : degrees_(degrees), | |||||
| : TensorOperation(true), | |||||
| degrees_(degrees), | |||||
| interpolation_mode_(interpolation_mode), | interpolation_mode_(interpolation_mode), | ||||
| expand_(expand), | expand_(expand), | ||||
| center_(center), | center_(center), | ||||
| @@ -1603,7 +1625,7 @@ std::shared_ptr<TensorOp> RandomRotationOperation::Build() { | |||||
| // RandomSelectSubpolicyOperation. | // RandomSelectSubpolicyOperation. | ||||
| RandomSelectSubpolicyOperation::RandomSelectSubpolicyOperation( | RandomSelectSubpolicyOperation::RandomSelectSubpolicyOperation( | ||||
| std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy) | std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy) | ||||
| : policy_(policy) {} | |||||
| : TensorOperation(true), policy_(policy) {} | |||||
| Status RandomSelectSubpolicyOperation::ValidateParams() { | Status RandomSelectSubpolicyOperation::ValidateParams() { | ||||
| if (policy_.empty()) { | if (policy_.empty()) { | ||||
| @@ -1650,7 +1672,8 @@ std::shared_ptr<TensorOp> RandomSelectSubpolicyOperation::Build() { | |||||
| } | } | ||||
| // Function to create RandomSharpness. | // Function to create RandomSharpness. | ||||
| RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees) : degrees_(degrees) {} | |||||
| RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees) | |||||
| : TensorOperation(true), degrees_(degrees) {} | |||||
| Status RandomSharpnessOperation::ValidateParams() { | Status RandomSharpnessOperation::ValidateParams() { | ||||
| if (degrees_.size() != 2 || degrees_[0] < 0 || degrees_[1] < 0) { | if (degrees_.size() != 2 || degrees_[0] < 0 || degrees_[1] < 0) { | ||||
| @@ -1674,7 +1697,8 @@ std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() { | |||||
| } | } | ||||
| // RandomSolarizeOperation. | // RandomSolarizeOperation. | ||||
| RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold) : threshold_(threshold) {} | |||||
| RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold) | |||||
| : TensorOperation(true), threshold_(threshold) {} | |||||
| Status RandomSolarizeOperation::ValidateParams() { | Status RandomSolarizeOperation::ValidateParams() { | ||||
| if (threshold_.size() != 2) { | if (threshold_.size() != 2) { | ||||
| @@ -1705,7 +1729,8 @@ std::shared_ptr<TensorOp> RandomSolarizeOperation::Build() { | |||||
| } | } | ||||
| // RandomVerticalFlipOperation | // RandomVerticalFlipOperation | ||||
| RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {} | |||||
| RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) | |||||
| : TensorOperation(true), probability_(probability) {} | |||||
| Status RandomVerticalFlipOperation::ValidateParams() { | Status RandomVerticalFlipOperation::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlip", probability_)); | RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlip", probability_)); | ||||
| @@ -1720,7 +1745,7 @@ std::shared_ptr<TensorOp> RandomVerticalFlipOperation::Build() { | |||||
| // RandomVerticalFlipWithBBoxOperation | // RandomVerticalFlipWithBBoxOperation | ||||
| RandomVerticalFlipWithBBoxOperation::RandomVerticalFlipWithBBoxOperation(float probability) | RandomVerticalFlipWithBBoxOperation::RandomVerticalFlipWithBBoxOperation(float probability) | ||||
| : probability_(probability) {} | |||||
| : TensorOperation(true), probability_(probability) {} | |||||
| Status RandomVerticalFlipWithBBoxOperation::ValidateParams() { | Status RandomVerticalFlipWithBBoxOperation::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlipWithBBox", probability_)); | RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlipWithBBox", probability_)); | ||||
| @@ -9,11 +9,13 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES | |||||
| build_sentence_piece_vocab_node.cc | build_sentence_piece_vocab_node.cc | ||||
| build_vocab_node.cc | build_vocab_node.cc | ||||
| concat_node.cc | concat_node.cc | ||||
| epoch_ctrl_node.cc | |||||
| filter_node.cc | filter_node.cc | ||||
| map_node.cc | map_node.cc | ||||
| project_node.cc | project_node.cc | ||||
| rename_node.cc | rename_node.cc | ||||
| repeat_node.cc | repeat_node.cc | ||||
| root_node.cc | |||||
| shuffle_node.cc | shuffle_node.cc | ||||
| skip_node.cc | skip_node.cc | ||||
| sync_wait_node.cc | sync_wait_node.cc | ||||
| @@ -43,14 +43,29 @@ BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, boo | |||||
| batch_size_func_(batch_size_func), | batch_size_func_(batch_size_func), | ||||
| batch_map_func_(batch_map_func), | batch_map_func_(batch_map_func), | ||||
| pad_map_(pad_map) { | pad_map_(pad_map) { | ||||
| this->children.push_back(child); | |||||
| this->AddChild(child); | |||||
| } | } | ||||
| #endif | #endif | ||||
| // constructor #2, called by C++ API | // constructor #2, called by C++ API | ||||
| BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder) | BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder) | ||||
| : batch_size_(batch_size), drop_remainder_(drop_remainder), pad_(false) { | : batch_size_(batch_size), drop_remainder_(drop_remainder), pad_(false) { | ||||
| this->children.push_back(child); | |||||
| this->AddChild(child); | |||||
| } | |||||
| std::shared_ptr<DatasetNode> BatchNode::Copy() { | |||||
| #ifdef ENABLE_PYTHON | |||||
| auto node = std::make_shared<BatchNode>(nullptr, batch_size_, drop_remainder_, pad_, in_col_names_, out_col_names_, | |||||
| col_order_, batch_size_func_, batch_map_func_, pad_map_); | |||||
| #else | |||||
| auto node = std::make_shared<BatchNode>(nullptr, batch_size_, drop_remainder_); | |||||
| #endif | |||||
| return node; | |||||
| } | |||||
| void BatchNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(batch_size:" + std::to_string(batch_size_) + | |||||
| " drop_remainder:" + (drop_remainder_ ? "true" : "false") + ")"; | |||||
| } | } | ||||
| Status BatchNode::ValidateParams() { | Status BatchNode::ValidateParams() { | ||||
| @@ -44,6 +44,18 @@ class BatchNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~BatchNode() = default; | ~BatchNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kBatchNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -41,7 +41,17 @@ BucketBatchByLengthNode::BucketBatchByLengthNode( | |||||
| pad_info_(pad_info), | pad_info_(pad_info), | ||||
| pad_to_bucket_boundary_(pad_to_bucket_boundary), | pad_to_bucket_boundary_(pad_to_bucket_boundary), | ||||
| drop_remainder_(drop_remainder) { | drop_remainder_(drop_remainder) { | ||||
| this->children.push_back(child); | |||||
| this->AddChild(child); | |||||
| } | |||||
| std::shared_ptr<DatasetNode> BucketBatchByLengthNode::Copy() { | |||||
| auto node = std::make_shared<BucketBatchByLengthNode>(nullptr, column_names_, bucket_boundaries_, bucket_batch_sizes_, | |||||
| element_length_function_, pad_info_, pad_to_bucket_boundary_); | |||||
| return node; | |||||
| } | |||||
| void BucketBatchByLengthNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(columns:" + PrintColumns(column_names_) + ",...)"; | |||||
| } | } | ||||
| std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() { | std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() { | ||||
| @@ -40,6 +40,18 @@ class BucketBatchByLengthNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~BucketBatchByLengthNode() = default; | ~BucketBatchByLengthNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kBucketBatchByLengthNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h" | #include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -38,7 +39,18 @@ BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> chil | |||||
| character_coverage_(character_coverage), | character_coverage_(character_coverage), | ||||
| model_type_(model_type), | model_type_(model_type), | ||||
| params_(params) { | params_(params) { | ||||
| this->children.push_back(child); | |||||
| this->AddChild(child); | |||||
| } | |||||
| std::shared_ptr<DatasetNode> BuildSentenceVocabNode::Copy() { | |||||
| auto node = std::make_shared<BuildSentenceVocabNode>(nullptr, vocab_, col_names_, vocab_size_, character_coverage_, | |||||
| model_type_, params_); | |||||
| return node; | |||||
| } | |||||
| void BuildSentenceVocabNode::Print(std::ostream &out) const { | |||||
| out << Name() + "<vocab>," + "columns:" + PrintColumns(col_names_) + ",vocab_size:" + std::to_string(vocab_size_) + | |||||
| ",...)"; | |||||
| } | } | ||||
| // Function to build BuildSentenceVocabNode | // Function to build BuildSentenceVocabNode | ||||
| @@ -81,5 +93,16 @@ Status BuildSentenceVocabNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accepting method for NodePass | |||||
| Status BuildSentenceVocabNode::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->Visit(shared_from_base<BuildSentenceVocabNode>(), modified); | |||||
| } | |||||
| // Visitor accepting method for NodePass | |||||
| Status BuildSentenceVocabNode::AcceptAfter(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->VisitAfter(shared_from_base<BuildSentenceVocabNode>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -38,6 +38,18 @@ class BuildSentenceVocabNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~BuildSentenceVocabNode() = default; | ~BuildSentenceVocabNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kBuildSentencePieceVocabNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -46,6 +58,18 @@ class BuildSentenceVocabNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| std::shared_ptr<SentencePieceVocab> vocab_; | std::shared_ptr<SentencePieceVocab> vocab_; | ||||
| std::vector<std::string> col_names_; | std::vector<std::string> col_names_; | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/datasetops/build_vocab_op.h" | #include "minddata/dataset/engine/datasetops/build_vocab_op.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -36,7 +36,17 @@ BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_p | |||||
| top_k_(top_k), | top_k_(top_k), | ||||
| special_tokens_(special_tokens), | special_tokens_(special_tokens), | ||||
| special_first_(special_first) { | special_first_(special_first) { | ||||
| this->children.push_back(child); | |||||
| this->AddChild(child); | |||||
| } | |||||
| std::shared_ptr<DatasetNode> BuildVocabNode::Copy() { | |||||
| auto node = | |||||
| std::make_shared<BuildVocabNode>(nullptr, vocab_, columns_, freq_range_, top_k_, special_tokens_, special_first_); | |||||
| return node; | |||||
| } | |||||
| void BuildVocabNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(<vocab>," + "columns:" + PrintColumns(columns_) + ",...)"; | |||||
| } | } | ||||
| // Function to build BuildVocabNode | // Function to build BuildVocabNode | ||||
| @@ -78,5 +88,16 @@ Status BuildVocabNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accepting method for NodePass | |||||
| Status BuildVocabNode::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->Visit(shared_from_base<BuildVocabNode>(), modified); | |||||
| } | |||||
| // Visitor accepting method for NodePass | |||||
| Status BuildVocabNode::AcceptAfter(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->VisitAfter(shared_from_base<BuildVocabNode>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -37,6 +37,18 @@ class BuildVocabNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~BuildVocabNode() = default; | ~BuildVocabNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kBuildVocabNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -45,6 +57,18 @@ class BuildVocabNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| std::shared_ptr<Vocab> vocab_; | std::shared_ptr<Vocab> vocab_; | ||||
| std::vector<std::string> columns_; | std::vector<std::string> columns_; | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/datasetops/concat_op.h" | #include "minddata/dataset/engine/datasetops/concat_op.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -35,17 +35,25 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets | |||||
| : sampler_(sampler), | : sampler_(sampler), | ||||
| children_flag_and_nums_(children_flag_and_nums), | children_flag_and_nums_(children_flag_and_nums), | ||||
| children_start_end_index_(children_start_end_index) { | children_start_end_index_(children_start_end_index) { | ||||
| this->children = datasets; | |||||
| for (auto const &child : datasets) AddChild(child); | |||||
| } | |||||
| std::shared_ptr<DatasetNode> ConcatNode::Copy() { | |||||
| // create an empty vector to copy a concat | |||||
| auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>()); | |||||
| return node; | |||||
| } | } | ||||
| void ConcatNode::Print(std::ostream &out) const { out << Name(); } | |||||
| Status ConcatNode::ValidateParams() { | Status ConcatNode::ValidateParams() { | ||||
| if (children.size() < 2) { | |||||
| if (children_.size() < 2) { | |||||
| std::string err_msg = "ConcatNode: concatenated datasets are not specified."; | std::string err_msg = "ConcatNode: concatenated datasets are not specified."; | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| if (find(children.begin(), children.end(), nullptr) != children.end()) { | |||||
| if (find(children_.begin(), children_.end(), nullptr) != children_.end()) { | |||||
| std::string err_msg = "ConcatNode: concatenated datasets should not be null."; | std::string err_msg = "ConcatNode: concatenated datasets should not be null."; | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| @@ -73,5 +81,16 @@ std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Visitor accepting method for NodePass | |||||
| Status ConcatNode::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->Visit(shared_from_base<ConcatNode>(), modified); | |||||
| } | |||||
| // Visitor accepting method for NodePass | |||||
| Status ConcatNode::AcceptAfter(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->VisitAfter(shared_from_base<ConcatNode>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -38,6 +38,18 @@ class ConcatNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~ConcatNode() = default; | ~ConcatNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kConcatNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -50,6 +62,18 @@ class ConcatNode : public DatasetNode { | |||||
| std::shared_ptr<SamplerObj> sampler_; | std::shared_ptr<SamplerObj> sampler_; | ||||
| std::vector<std::pair<int, int>> children_flag_and_nums_; | std::vector<std::pair<int, int>> children_flag_and_nums_; | ||||
| std::vector<std::pair<int, int>> children_start_end_index_; | std::vector<std::pair<int, int>> children_start_end_index_; | ||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -233,14 +233,92 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) { | |||||
| return shared_from_this(); | return shared_from_this(); | ||||
| } | } | ||||
| DatasetNode::DatasetNode() { | |||||
| DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}) { | |||||
| // Fetch some default value from config manager | // Fetch some default value from config manager | ||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | ||||
| num_workers_ = cfg->num_parallel_workers(); | num_workers_ = cfg->num_parallel_workers(); | ||||
| rows_per_buffer_ = cfg->rows_per_buffer(); | rows_per_buffer_ = cfg->rows_per_buffer(); | ||||
| connector_que_size_ = cfg->op_connector_size(); | connector_que_size_ = cfg->op_connector_size(); | ||||
| worker_connector_size_ = cfg->worker_connector_size(); | worker_connector_size_ = cfg->worker_connector_size(); | ||||
| build_status = Status::OK(); // remove me after changing return val of Build() | |||||
| } | |||||
| // this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied | |||||
| std::shared_ptr<DatasetNode> DatasetNode::DeepCopy() { | |||||
| std::shared_ptr<DatasetNode> new_node = this->Copy(); | |||||
| for (const auto &child : children_) { | |||||
| new_node->AddChild(child->DeepCopy()); | |||||
| } | |||||
| return new_node; | |||||
| } | |||||
| std::string DatasetNode::PrintColumns(const std::vector<std::string> &columns) const { | |||||
| std::string me; | |||||
| if (columns.empty()) { | |||||
| me = "<nil>"; | |||||
| } else { | |||||
| me = "["; | |||||
| auto i = 0; | |||||
| for (auto it = columns.begin(); it < columns.end(); ++it, ++i) { | |||||
| me += *it; | |||||
| if (i < columns.size() - 1) { | |||||
| me += ", "; | |||||
| } else { | |||||
| me += "]"; | |||||
| } | |||||
| } | |||||
| } | |||||
| return me; | |||||
| } | |||||
| void DatasetNode::PrintTree(std::ostream &out) const { | |||||
| int level = 0; | |||||
| PrintNode(out, &level); | |||||
| } | |||||
| void DatasetNode::PrintNode(std::ostream &out, int *level) const { | |||||
| const std::string prefix = "+-"; | |||||
| const std::string indent = " "; | |||||
| out << prefix; | |||||
| Print(out); | |||||
| for (const auto &c : this->Children()) { | |||||
| out << '\n'; | |||||
| ++(*level); | |||||
| for (auto i = 0; i < *level; i++) { | |||||
| out << indent; | |||||
| } | |||||
| c->PrintNode(out, level); | |||||
| --(*level); | |||||
| } | |||||
| } | |||||
| // Add a node as a child, node's parent needs to be nullptr | |||||
| // this function will allow child to be a nullptr, in which case it will simply skip | |||||
| void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) { | |||||
| if (child != nullptr && child->parent_ == nullptr) { | |||||
| children_.push_back(child); | |||||
| child->parent_ = this; | |||||
| } else if (child != nullptr) { | |||||
| MS_LOG(WARNING) << "DatasetNode::AddChild() Fail" + child->Name() + "'s parent isn't a nullptr."; | |||||
| } | |||||
| } | |||||
| // Remove this node from its parent. Add the child of this node to its parent. | |||||
| // for now, this remove is limited to node with a single child or no child | |||||
| Status DatasetNode::Remove() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "Cannot remove root or a node without parent."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(children_.size() < 2, "Cannot remove node with more than 1 child."); | |||||
| if (children_.empty()) { // I am a leaf node, remove me from my parent's children list | |||||
| parent_->children_.erase(std::remove(parent_->children_.begin(), parent_->children_.end(), shared_from_this()), | |||||
| parent_->children_.end()); // removal using "erase remove idiom" | |||||
| } else { // replace my position in my parent's children list with my single child | |||||
| auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this()); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list."); | |||||
| children_[0]->parent_ = parent_; // set my single child's parent ptr to my parent | |||||
| *itr = std::move(children_[0]); // replace me in my parent's children list with my single child | |||||
| children_.clear(); // release my single child from my children list | |||||
| } | |||||
| parent_ = nullptr; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| // In DFS tree traversal, each node is visited twice. Accept is called on the first visit. | // In DFS tree traversal, each node is visited twice. Accept is called on the first visit. | ||||
| @@ -255,13 +333,25 @@ Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) { | |||||
| // This method will only be called if its derived class does not implement one. | // This method will only be called if its derived class does not implement one. | ||||
| return p->VisitAfter(shared_from_this(), modified); | return p->VisitAfter(shared_from_this(), modified); | ||||
| } | } | ||||
| Status DatasetNode::GetShardId(int32_t *shard_id) { | Status DatasetNode::GetShardId(int32_t *shard_id) { | ||||
| if (!Children().empty()) { | if (!Children().empty()) { | ||||
| // Get shard id from the child node | // Get shard id from the child node | ||||
| return Children()[0]->GetShardId(shard_id); | return Children()[0]->GetShardId(shard_id); | ||||
| } else { | } else { | ||||
| RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node"); | |||||
| RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n"); | |||||
| } | } | ||||
| } | } | ||||
| // Visitor accepting method for NodePass | |||||
| Status SourceNode::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->Visit(shared_from_base<SourceNode>(), modified); | |||||
| } | |||||
| // Visitor accepting method for NodePass | |||||
| Status SourceNode::AcceptAfter(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->VisitAfter(shared_from_base<SourceNode>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,6 +42,45 @@ class NodePass; | |||||
| } \ | } \ | ||||
| } while (false) | } while (false) | ||||
| // Names for non-leaf IR node | |||||
| constexpr char kBatchNode[] = "Batch"; | |||||
| constexpr char kBucketBatchByLengthNode[] = "BucketBatchByLength"; | |||||
| constexpr char kBuildSentencePieceVocabNode[] = "BuildSentencePieceVocab"; | |||||
| constexpr char kBuildVocabNode[] = "BuildVocab"; | |||||
| constexpr char kConcatNode[] = "Concat"; | |||||
| constexpr char kDatasetNode[] = "Dataset"; | |||||
| constexpr char kEpochCtrlNode[] = "EpochCtrl"; | |||||
| constexpr char kFilterNode[] = "Filter"; | |||||
| constexpr char kMapNode[] = "Map"; | |||||
| constexpr char kProjectNode[] = "Project"; | |||||
| constexpr char kRenameNode[] = "Rename"; | |||||
| constexpr char kRepeatNode[] = "Repeat"; | |||||
| constexpr char kRootNode[] = "Top"; | |||||
| constexpr char kShuffleNode[] = "Shuffle"; | |||||
| constexpr char kSkipNode[] = "Skip"; | |||||
| constexpr char kSyncWaitNode[] = "SyncWait"; | |||||
| constexpr char kTakeNode[] = "Take"; | |||||
| constexpr char kTransferNode[] = "Transfer"; | |||||
| constexpr char kZipNode[] = "Zip"; | |||||
| // Names for leaf IR node | |||||
| constexpr char kAlbumNode[] = "AlbumDataset"; | |||||
| constexpr char kCelebANode[] = "CelebADataset"; | |||||
| constexpr char kCifar100Node[] = "Cifar100Dataset"; | |||||
| constexpr char kCifar10Node[] = "Cifar10Dataset"; | |||||
| constexpr char kCLUENode[] = "CLUEDataset"; | |||||
| constexpr char kCocoNode[] = "CocoDataset"; | |||||
| constexpr char kCSVNode[] = "CSVDataset"; | |||||
| constexpr char kGeneratorNode[] = "GeneratorDataset"; | |||||
| constexpr char kImageFolderNode[] = "ImageFolderDataset"; | |||||
| constexpr char kManifestNode[] = "ManifestDataset"; | |||||
| constexpr char kMindDataNode[] = "MindDataDataset"; | |||||
| constexpr char kMnistNode[] = "MnistDataset"; | |||||
| constexpr char kRandomNode[] = "RandomDataset"; | |||||
| constexpr char kTextFileNode[] = "TextFileDataset"; | |||||
| constexpr char kTFRecordNode[] = "TFRecordDataset"; | |||||
| constexpr char kVOCNode[] = "VOCDataset"; | |||||
| Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | ||||
| int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op); | int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op); | ||||
| @@ -75,6 +114,7 @@ Status ValidateDatasetDirParam(const std::string &dataset_name, std::string data | |||||
| /// \return Shared pointer to the current Sampler. | /// \return Shared pointer to the current Sampler. | ||||
| std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id); | std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id); | ||||
| // The base class of all IR nodes | |||||
| class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -87,6 +127,36 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~DatasetNode() = default; | ~DatasetNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| virtual std::string Name() const = 0; | |||||
| /// \brief Pure virtual function to print the description | |||||
| /// \param out - The output stream to write output to | |||||
| virtual void Print(std::ostream &out) const = 0; | |||||
| /// \brief Pure virtual function to make a new copy of the node | |||||
| /// \return The new copy of the node | |||||
| virtual std::shared_ptr<DatasetNode> Copy() = 0; | |||||
| /// \brief Print the IR tree to output stream | |||||
| /// \param out - The output stream to write output to | |||||
| void PrintTree(std::ostream &out) const; | |||||
| /// \brief << Stream output operator overload | |||||
| /// \notes This allows you to write the debug print info using stream operators | |||||
| /// \param out - reference to the output stream being overloaded | |||||
| /// \param dO - reference to the DatasetOp to display | |||||
| /// \return - the output stream must be returned | |||||
| friend std::ostream &operator<<(std::ostream &out, const DatasetNode &node) { | |||||
| node.PrintTree(out); | |||||
| return out; | |||||
| } | |||||
| /// \brief Make a new copy of the tree from the current node | |||||
| /// \return The new copy of the tree | |||||
| std::shared_ptr<DatasetNode> DeepCopy(); | |||||
| /// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object | /// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0; | virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0; | ||||
| @@ -95,17 +165,38 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| virtual Status ValidateParams() = 0; | virtual Status ValidateParams() = 0; | ||||
| const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children; } | |||||
| /// \brief Pure virtual function for derived class to get the shard id of specific node | /// \brief Pure virtual function for derived class to get the shard id of specific node | ||||
| /// \return Status Status::OK() if get shard id successfully | /// \return Status Status::OK() if get shard id successfully | ||||
| virtual Status GetShardId(int32_t *shard_id); | virtual Status GetShardId(int32_t *shard_id); | ||||
| /// \brief Getter function for child nodes | |||||
| /// \return Child nodes | |||||
| const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; } | |||||
| /// \brief Establish the parent-child relationship between this node and its child. | |||||
| void AddChild(std::shared_ptr<DatasetNode> child); | |||||
| /// \brief detach this node from its parent, add its child (if any) to its parent | |||||
| /// \return error code, return error if node has more than 1 children | |||||
| Status Remove(); | |||||
| /// \brief Check if this node has cache | |||||
| /// \return True if the data of this node will be cached | |||||
| const bool IsCached() const { return (cache_ != nullptr); } | |||||
| /// \brief Setter function for runtime number of workers | /// \brief Setter function for runtime number of workers | ||||
| /// \param[in] num_workers The number of threads in this operator | /// \param[in] num_workers The number of threads in this operator | ||||
| /// \return Shared pointer to the original object | /// \return Shared pointer to the original object | ||||
| std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers); | std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers); | ||||
| /// \brief A helper templated function for casting "this" pointer to shared_ptr<derived> | |||||
| /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr | |||||
| /// \return A shared_ptr casted to the derived class | |||||
| template <typename Derived> | |||||
| std::shared_ptr<Derived> shared_from_base() { | |||||
| return std::static_pointer_cast<Derived>(shared_from_this()); | |||||
| } | |||||
| /// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up | /// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up | ||||
| /// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node | /// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node | ||||
| /// visit on the way back up the tree after its descendants are visited. | /// visit on the way back up the tree after its descendants are visited. | ||||
| @@ -129,17 +220,123 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||||
| Status BuildStatus() { return build_status; } | Status BuildStatus() { return build_status; } | ||||
| protected: | protected: | ||||
| std::vector<std::shared_ptr<DatasetNode>> children; | |||||
| std::vector<std::shared_ptr<DatasetNode>> children_; | |||||
| DatasetNode *parent_; | |||||
| std::shared_ptr<DatasetCache> cache_; | std::shared_ptr<DatasetCache> cache_; | ||||
| Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops); | |||||
| int32_t num_workers_; | int32_t num_workers_; | ||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| int32_t connector_que_size_; | int32_t connector_que_size_; | ||||
| int32_t worker_connector_size_; | int32_t worker_connector_size_; | ||||
| Status build_status; // remove me after changing return val of Build() | Status build_status; // remove me after changing return val of Build() | ||||
| std::string PrintColumns(const std::vector<std::string> &columns) const; | |||||
| Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops); | |||||
| void PrintNode(std::ostream &out, int *level) const; | |||||
| }; | }; | ||||
| // SourceNode represents the leaf nodes of a pipeline where the data is pulled into. | |||||
| class SourceNode : public DatasetNode { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| SourceNode() : DatasetNode() {} | |||||
| /// \brief Constructor that initializes the cache | |||||
| /// \param dataset_cache DatasetCache | |||||
| explicit SourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {} | |||||
| /// \brief Destructor | |||||
| ~SourceNode() = default; | |||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| virtual std::string Name() const = 0; | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||||
| /// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes | |||||
| /// \return True if the dataset represented by this node is a mappable dataset | |||||
| const bool IsMappable() const { return mappable_; } | |||||
| protected: | |||||
| bool mappable_; | |||||
| }; | |||||
| // MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes. | |||||
| class MappableSourceNode : public SourceNode { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| MappableSourceNode() : SourceNode() { mappable_ = true; } | |||||
| /// \brief Constructor that initializes the cache | |||||
| /// \param dataset_cache DatasetCache | |||||
| explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) { | |||||
| mappable_ = true; | |||||
| } | |||||
| /// \brief Destructor | |||||
| ~MappableSourceNode() = default; | |||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| virtual std::string Name() const = 0; | |||||
| }; | |||||
| // NonMappableSourceNode represents the leaf nodes that can not be randomly accessed. | |||||
| class NonMappableSourceNode : public SourceNode { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| NonMappableSourceNode() : SourceNode() { mappable_ = false; } | |||||
| /// \brief Constructor that initializes the cache | |||||
| /// \param dataset_cache DatasetCache | |||||
| explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) { | |||||
| mappable_ = false; | |||||
| } | |||||
| /// \brief Destructor | |||||
| ~NonMappableSourceNode() = default; | |||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| virtual std::string Name() const = 0; | |||||
| }; | |||||
| // NonLeafNode represents operations over data in a pipeline. | |||||
| class NonLeafNode : public DatasetNode { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| NonLeafNode() = default; | |||||
| /// \brief Destructor | |||||
| ~NonLeafNode() = default; | |||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| virtual std::string Name() const = 0; | |||||
| }; | |||||
| // SinkNode represents the end node of a pipeline where the data is pushed out | |||||
| class SinkNode : public DatasetNode { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| SinkNode() = default; | |||||
| /// \brief Destructor | |||||
| ~SinkNode() = default; | |||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| virtual std::string Name() const = 0; | |||||
| }; | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ | ||||
| @@ -0,0 +1,67 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // Constructor for EpochCtrlNode | |||||
| EpochCtrlNode::EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : num_epochs_(num_epochs) { | |||||
| // The root node's parent must set to null pointer. | |||||
| this->AddChild(child); | |||||
| } | |||||
| std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() { | |||||
| auto node = std::make_shared<EpochCtrlNode>(nullptr, this->num_epochs_); | |||||
| return node; | |||||
| } | |||||
| void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" + std::to_string(num_epochs_) + ")"; } | |||||
| // Function to build the EpochCtrlOp | |||||
| std::vector<std::shared_ptr<DatasetOp>> EpochCtrlNode::Build() { | |||||
| // A dummy vector | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<EpochCtrlOp>(num_epochs_)); | |||||
| return node_ops; | |||||
| } | |||||
| // Function to validate the parameters for EpochCtrlNode | |||||
| Status EpochCtrlNode::ValidateParams() { | |||||
| if (num_epochs_ <= 0 && num_epochs_ != -1) { | |||||
| std::string err_msg = | |||||
| "EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (children_.size() != 1 || children_[0] == nullptr) { | |||||
| std::string err_msg = "Internal error: epoch control node should have one child node"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,63 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class EpochCtrlNode : public DatasetNode { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs); | |||||
| /// \brief Destructor | |||||
| ~EpochCtrlNode() = default; | |||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kEpochCtrlNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return shared pointer to the list of newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| int32_t num_epochs_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_ | |||||
| @@ -21,7 +21,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/datasetops/filter_op.h" | #include "minddata/dataset/engine/datasetops/filter_op.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -31,7 +31,16 @@ namespace dataset { | |||||
| FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate, | FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate, | ||||
| std::vector<std::string> input_columns) | std::vector<std::string> input_columns) | ||||
| : predicate_(predicate), input_columns_(input_columns) { | : predicate_(predicate), input_columns_(input_columns) { | ||||
| this->children.push_back(child); | |||||
| this->AddChild(child); | |||||
| } | |||||
| std::shared_ptr<DatasetNode> FilterNode::Copy() { | |||||
| auto node = std::make_shared<FilterNode>(nullptr, predicate_, input_columns_); | |||||
| return node; | |||||
| } | |||||
| void FilterNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(<predicate>," + "input_cols:" + PrintColumns(input_columns_) + ")"; | |||||
| } | } | ||||
| std::vector<std::shared_ptr<DatasetOp>> FilterNode::Build() { | std::vector<std::shared_ptr<DatasetOp>> FilterNode::Build() { | ||||
| @@ -54,5 +63,17 @@ Status FilterNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accepting method for NodePass | |||||
| Status FilterNode::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->Visit(shared_from_base<FilterNode>(), modified); | |||||
| } | |||||
| // Visitor accepting method for NodePass | |||||
| Status FilterNode::AcceptAfter(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->VisitAfter(shared_from_base<FilterNode>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,6 +35,18 @@ class FilterNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~FilterNode() = default; | ~FilterNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kFilterNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -43,6 +55,18 @@ class FilterNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| std::shared_ptr<TensorOp> predicate_; | std::shared_ptr<TensorOp> predicate_; | ||||
| std::vector<std::string> input_columns_; | std::vector<std::string> input_columns_; | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/include/transforms.h" | #include "minddata/dataset/include/transforms.h" | ||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -37,7 +38,18 @@ MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr | |||||
| project_columns_(project_columns), | project_columns_(project_columns), | ||||
| DatasetNode(std::move(cache)), | DatasetNode(std::move(cache)), | ||||
| callbacks_(callbacks) { | callbacks_(callbacks) { | ||||
| this->children.push_back(child); | |||||
| this->AddChild(child); | |||||
| } | |||||
| std::shared_ptr<DatasetNode> MapNode::Copy() { | |||||
| auto node = std::make_shared<MapNode>(nullptr, operations_, input_columns_, output_columns_, project_columns_, cache_, | |||||
| callbacks_); | |||||
| return node; | |||||
| } | |||||
| void MapNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(<ops>" + ",input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) + | |||||
| ",<project_cols>" + ",...)"; | |||||
| } | } | ||||
| std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() { | std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() { | ||||
| @@ -93,5 +105,16 @@ Status MapNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accepting method for NodePass | |||||
| Status MapNode::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->Visit(shared_from_base<MapNode>(), modified); | |||||
| } | |||||
| // Visitor accepting method for NodePass | |||||
| Status MapNode::AcceptAfter(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->VisitAfter(shared_from_base<MapNode>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -37,6 +37,18 @@ class MapNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~MapNode() = default; | ~MapNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kMapNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -45,6 +57,23 @@ class MapNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Getter of tensor operations | |||||
| /// \return Vector of operations the Map node will process | |||||
| const auto &TensorOperations() const { return operations_; } | |||||
| auto &TensorOperations() { return operations_; } | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| std::vector<std::shared_ptr<TensorOperation>> operations_; | std::vector<std::shared_ptr<TensorOperation>> operations_; | ||||
| std::vector<std::string> input_columns_; | std::vector<std::string> input_columns_; | ||||
| @@ -29,9 +29,16 @@ namespace dataset { | |||||
| // Function to build ProjectOp | // Function to build ProjectOp | ||||
| ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns) | ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns) | ||||
| : columns_(columns) { | : columns_(columns) { | ||||
| this->children.push_back(child); | |||||
| this->AddChild(child); | |||||
| } | } | ||||
| std::shared_ptr<DatasetNode> ProjectNode::Copy() { | |||||
| auto node = std::make_shared<ProjectNode>(nullptr, this->columns_); | |||||
| return node; | |||||
| } | |||||
| void ProjectNode::Print(std::ostream &out) const { out << Name() + "(column: " + PrintColumns(columns_) + ")"; } | |||||
| Status ProjectNode::ValidateParams() { | Status ProjectNode::ValidateParams() { | ||||
| if (columns_.empty()) { | if (columns_.empty()) { | ||||
| std::string err_msg = "ProjectNode: No columns are specified."; | std::string err_msg = "ProjectNode: No columns are specified."; | ||||
| @@ -34,6 +34,18 @@ class ProjectNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~ProjectNode() = default; | ~ProjectNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kProjectNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -30,7 +30,16 @@ namespace dataset { | |||||
| RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns, | RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns, | ||||
| const std::vector<std::string> &output_columns) | const std::vector<std::string> &output_columns) | ||||
| : input_columns_(input_columns), output_columns_(output_columns) { | : input_columns_(input_columns), output_columns_(output_columns) { | ||||
| this->children.push_back(child); | |||||
| this->AddChild(child); | |||||
| } | |||||
| std::shared_ptr<DatasetNode> RenameNode::Copy() { | |||||
| auto node = std::make_shared<RenameNode>(nullptr, input_columns_, output_columns_); | |||||
| return node; | |||||
| } | |||||
| void RenameNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) + ")"; | |||||
| } | } | ||||
| Status RenameNode::ValidateParams() { | Status RenameNode::ValidateParams() { | ||||
| @@ -35,6 +35,18 @@ class RenameNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~RenameNode() = default; | ~RenameNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kRenameNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -21,15 +21,22 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | #include "minddata/dataset/engine/datasetops/repeat_op.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) { | RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) { | ||||
| this->children.push_back(child); | |||||
| this->AddChild(child); | |||||
| } | |||||
| std::shared_ptr<DatasetNode> RepeatNode::Copy() { | |||||
| auto node = std::make_shared<RepeatNode>(nullptr, this->repeat_count_); | |||||
| return node; | |||||
| } | } | ||||
| void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + std::to_string(repeat_count_) + ")"; } | |||||
| std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() { | std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() { | ||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | // A vector containing shared pointer to the Dataset Ops that this object will create | ||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | std::vector<std::shared_ptr<DatasetOp>> node_ops; | ||||
| @@ -49,5 +56,16 @@ Status RepeatNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accepting method for NodePass | |||||
| Status RepeatNode::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->Visit(shared_from_base<RepeatNode>(), modified); | |||||
| } | |||||
| // Visitor accepting method for NodePass | |||||
| Status RepeatNode::AcceptAfter(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->VisitAfter(shared_from_base<RepeatNode>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,6 +36,18 @@ class RepeatNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~RepeatNode() = default; | ~RepeatNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kRepeatNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -44,6 +56,18 @@ class RepeatNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| int32_t repeat_count_; | int32_t repeat_count_; | ||||
| }; | }; | ||||
| @@ -0,0 +1,85 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/root_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // Constructor for RootNode | |||||
| RootNode::RootNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : DatasetNode(), num_epochs_(num_epochs) { | |||||
| // The root node's parent must remain nullptr. (which is set in the constructor of DatasetNode) | |||||
| AddChild(child); | |||||
| } | |||||
| std::shared_ptr<DatasetNode> RootNode::Copy() { | |||||
| auto node = std::make_shared<RootNode>(nullptr, num_epochs_); | |||||
| return node; | |||||
| } | |||||
| void RootNode::Print(std::ostream &out) const { out << Name(); } | |||||
| std::vector<std::shared_ptr<DatasetOp>> RootNode::Build() { | |||||
| // root node doesn't build a runtime Op. this function should return Status::Error when called. | |||||
| return {}; | |||||
| } | |||||
| // Function to validate the parameters for RootNode | |||||
| Status RootNode::ValidateParams() { | |||||
| if (num_epochs_ <= 0 && num_epochs_ != -1) { | |||||
| std::string err_msg = | |||||
| "RootNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (parent_ != nullptr) { | |||||
| std::string err_msg = "Internal error: root node should not have a parent"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (children_.size() != 1) { | |||||
| std::string err_msg = "Internal error: root node should have one child node"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (children_[0] == nullptr) { | |||||
| std::string err_msg = "Internal error: root node's child is a null pointer"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Visitor accepting method for NodePass | |||||
| Status RootNode::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->Visit(shared_from_base<RootNode>(), modified); | |||||
| } | |||||
| // Visitor accepting method for NodePass | |||||
| Status RootNode::AcceptAfter(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->VisitAfter(shared_from_base<RootNode>(), modified); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,78 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class RootNode : public DatasetNode { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| RootNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs); | |||||
| /// \brief Destructor | |||||
| ~RootNode() = default; | |||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kRootNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return shared pointer to the list of newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Getter of number of epochs | |||||
| int32_t num_epochs() { return num_epochs_; } | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||||
| private: | |||||
| int32_t num_epochs_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_ | |||||
| @@ -29,7 +29,17 @@ namespace dataset { | |||||
| // Constructor for ShuffleNode | // Constructor for ShuffleNode | ||||
| ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch) | ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch) | ||||
| : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) { | : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) { | ||||
| this->children.push_back(child); | |||||
| this->AddChild(child); | |||||
| } | |||||
| std::shared_ptr<DatasetNode> ShuffleNode::Copy() { | |||||
| auto node = std::make_shared<ShuffleNode>(nullptr, shuffle_size_, reset_every_epoch_); | |||||
| return node; | |||||
| } | |||||
| void ShuffleNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(shuffle_size:" + std::to_string(shuffle_size_) + | |||||
| ",reset_every_epoch:" + (reset_every_epoch_ ? "true" : "false") + ")"; | |||||
| } | } | ||||
| // Function to build the ShuffleOp | // Function to build the ShuffleOp | ||||
| @@ -34,6 +34,18 @@ class ShuffleNode : public DatasetNode { | |||||
| ~ShuffleNode() = default; | ~ShuffleNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kShuffleNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| @@ -27,10 +27,15 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor for SkipNode | // Constructor for SkipNode | ||||
| SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) { | |||||
| this->children.push_back(child); | |||||
| SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) { this->AddChild(child); } | |||||
| std::shared_ptr<DatasetNode> SkipNode::Copy() { | |||||
| auto node = std::make_shared<SkipNode>(nullptr, skip_count_); | |||||
| return node; | |||||
| } | } | ||||
| void SkipNode::Print(std::ostream &out) const { out << Name() + "(skip_count:" + std::to_string(skip_count_) + ")"; } | |||||
| // Function to build the SkipOp | // Function to build the SkipOp | ||||
| std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() { | std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() { | ||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | // A vector containing shared pointer to the Dataset Ops that this object will create | ||||
| @@ -34,6 +34,18 @@ class SkipNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~SkipNode() = default; | ~SkipNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kSkipNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -32,13 +32,23 @@ namespace dataset { | |||||
| AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | ||||
| const std::vector<std::string> &column_names, bool decode, | const std::vector<std::string> &column_names, bool decode, | ||||
| const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) | const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) | ||||
| : DatasetNode(std::move(cache)), | |||||
| : MappableSourceNode(std::move(cache)), | |||||
| dataset_dir_(dataset_dir), | dataset_dir_(dataset_dir), | ||||
| schema_path_(data_schema), | schema_path_(data_schema), | ||||
| column_names_(column_names), | column_names_(column_names), | ||||
| decode_(decode), | decode_(decode), | ||||
| sampler_(sampler) {} | sampler_(sampler) {} | ||||
| std::shared_ptr<DatasetNode> AlbumNode::Copy() { | |||||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||||
| auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_); | |||||
| return node; | |||||
| } | |||||
| void AlbumNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")"; | |||||
| } | |||||
| Status AlbumNode::ValidateParams() { | Status AlbumNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_)); | ||||
| @@ -26,7 +26,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class AlbumNode : public DatasetNode { | |||||
| class AlbumNode : public MappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | ||||
| @@ -36,6 +36,18 @@ class AlbumNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~AlbumNode() = default; | ~AlbumNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kAlbumNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create a runtime dataset op object from this class | /// \brief a base class override function to create a runtime dataset op object from this class | ||||
| /// \return shared pointer to the newly created DatasetOp | /// \return shared pointer to the newly created DatasetOp | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -31,13 +31,23 @@ namespace dataset { | |||||
| CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, | CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, | ||||
| const std::shared_ptr<SamplerObj> &sampler, const bool &decode, | const std::shared_ptr<SamplerObj> &sampler, const bool &decode, | ||||
| const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache) | const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache) | ||||
| : DatasetNode(std::move(cache)), | |||||
| : MappableSourceNode(std::move(cache)), | |||||
| dataset_dir_(dataset_dir), | dataset_dir_(dataset_dir), | ||||
| usage_(usage), | usage_(usage), | ||||
| sampler_(sampler), | sampler_(sampler), | ||||
| decode_(decode), | decode_(decode), | ||||
| extensions_(extensions) {} | extensions_(extensions) {} | ||||
| std::shared_ptr<DatasetNode> CelebANode::Copy() { | |||||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||||
| auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_); | |||||
| return node; | |||||
| } | |||||
| void CelebANode::Print(std::ostream &out) const { | |||||
| out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")"; | |||||
| } | |||||
| Status CelebANode::ValidateParams() { | Status CelebANode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_)); | ||||
| @@ -28,7 +28,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class CelebANode : public DatasetNode { | |||||
| class CelebANode : public MappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | ||||
| @@ -37,6 +37,18 @@ class CelebANode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~CelebANode() = default; | ~CelebANode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kCelebANode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return shared pointer to the list of newly created DatasetOps | /// \return shared pointer to the list of newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -30,7 +30,17 @@ namespace dataset { | |||||
| // Constructor for Cifar100Node | // Constructor for Cifar100Node | ||||
| Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage, | Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage, | ||||
| std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache) | std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache) | ||||
| : DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| std::shared_ptr<DatasetNode> Cifar100Node::Copy() { | |||||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||||
| auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_); | |||||
| return node; | |||||
| } | |||||
| void Cifar100Node::Print(std::ostream &out) const { | |||||
| out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")"; | |||||
| } | |||||
| Status Cifar100Node::ValidateParams() { | Status Cifar100Node::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); | ||||
| @@ -26,7 +26,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class Cifar100Node : public DatasetNode { | |||||
| class Cifar100Node : public MappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | ||||
| @@ -35,6 +35,18 @@ class Cifar100Node : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~Cifar100Node() = default; | ~Cifar100Node() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kCifar100Node; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -30,7 +30,17 @@ namespace dataset { | |||||
| // Constructor for Cifar10Node | // Constructor for Cifar10Node | ||||
| Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | ||||
| std::shared_ptr<DatasetCache> cache) | std::shared_ptr<DatasetCache> cache) | ||||
| : DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| std::shared_ptr<DatasetNode> Cifar10Node::Copy() { | |||||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||||
| auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_); | |||||
| return node; | |||||
| } | |||||
| void Cifar10Node::Print(std::ostream &out) const { | |||||
| out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")"; | |||||
| } | |||||
| Status Cifar10Node::ValidateParams() { | Status Cifar10Node::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); | ||||
| @@ -26,7 +26,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class Cifar10Node : public DatasetNode { | |||||
| class Cifar10Node : public MappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | ||||
| @@ -35,6 +35,18 @@ class Cifar10Node : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~Cifar10Node() = default; | ~Cifar10Node() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kCifar10Node; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -32,7 +32,7 @@ namespace dataset { | |||||
| // Constructor for CLUENode | // Constructor for CLUENode | ||||
| CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples, | CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples, | ||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | ||||
| : DatasetNode(std::move(cache)), | |||||
| : NonMappableSourceNode(std::move(cache)), | |||||
| dataset_files_(clue_files), | dataset_files_(clue_files), | ||||
| task_(task), | task_(task), | ||||
| usage_(usage), | usage_(usage), | ||||
| @@ -41,6 +41,17 @@ CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, | |||||
| num_shards_(num_shards), | num_shards_(num_shards), | ||||
| shard_id_(shard_id) {} | shard_id_(shard_id) {} | ||||
| std::shared_ptr<DatasetNode> CLUENode::Copy() { | |||||
| auto node = | |||||
| std::make_shared<CLUENode>(dataset_files_, task_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); | |||||
| return node; | |||||
| } | |||||
| void CLUENode::Print(std::ostream &out) const { | |||||
| out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ",..." + | |||||
| ",num_shards:" + std::to_string(num_shards_) + ",shard_id:" + std::to_string(shard_id_) + ")"; | |||||
| } | |||||
| Status CLUENode::ValidateParams() { | Status CLUENode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_)); | RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_)); | ||||
| @@ -28,7 +28,7 @@ namespace dataset { | |||||
| /// \class CLUENode | /// \class CLUENode | ||||
| /// \brief A Dataset derived class to represent CLUE dataset | /// \brief A Dataset derived class to represent CLUE dataset | ||||
| class CLUENode : public DatasetNode { | |||||
| class CLUENode : public NonMappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples, | CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples, | ||||
| @@ -37,6 +37,18 @@ class CLUENode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~CLUENode() = default; | ~CLUENode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kCLUENode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -30,13 +30,21 @@ namespace dataset { | |||||
| // Constructor for CocoNode | // Constructor for CocoNode | ||||
| CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | ||||
| const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | ||||
| : DatasetNode(std::move(cache)), | |||||
| : MappableSourceNode(std::move(cache)), | |||||
| dataset_dir_(dataset_dir), | dataset_dir_(dataset_dir), | ||||
| annotation_file_(annotation_file), | annotation_file_(annotation_file), | ||||
| task_(task), | task_(task), | ||||
| decode_(decode), | decode_(decode), | ||||
| sampler_(sampler) {} | sampler_(sampler) {} | ||||
| std::shared_ptr<DatasetNode> CocoNode::Copy() { | |||||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||||
| auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_); | |||||
| return node; | |||||
| } | |||||
| void CocoNode::Print(std::ostream &out) const { out << Name(); } | |||||
| Status CocoNode::ValidateParams() { | Status CocoNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_)); | ||||
| @@ -26,7 +26,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class CocoNode : public DatasetNode { | |||||
| class CocoNode : public MappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | ||||
| @@ -35,6 +35,18 @@ class CocoNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~CocoNode() = default; | ~CocoNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kCocoNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return shared pointer to the list of newly created DatasetOps | /// \return shared pointer to the list of newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -33,7 +33,7 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim, | |||||
| const std::vector<std::shared_ptr<CsvBase>> &column_defaults, | const std::vector<std::shared_ptr<CsvBase>> &column_defaults, | ||||
| const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, | const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, | ||||
| int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | ||||
| : DatasetNode(std::move(cache)), | |||||
| : NonMappableSourceNode(std::move(cache)), | |||||
| dataset_files_(csv_files), | dataset_files_(csv_files), | ||||
| field_delim_(field_delim), | field_delim_(field_delim), | ||||
| column_defaults_(column_defaults), | column_defaults_(column_defaults), | ||||
| @@ -43,6 +43,17 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim, | |||||
| num_shards_(num_shards), | num_shards_(num_shards), | ||||
| shard_id_(shard_id) {} | shard_id_(shard_id) {} | ||||
| std::shared_ptr<DatasetNode> CSVNode::Copy() { | |||||
| auto node = std::make_shared<CSVNode>(dataset_files_, field_delim_, column_defaults_, column_names_, num_samples_, | |||||
| shuffle_, num_shards_, shard_id_, cache_); | |||||
| return node; | |||||
| } | |||||
| void CSVNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ",..." + | |||||
| ",num_shards:" + std::to_string(num_shards_) + ",shard_id:" + std::to_string(shard_id_) + ")"; | |||||
| } | |||||
| Status CSVNode::ValidateParams() { | Status CSVNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_)); | RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_)); | ||||
| @@ -47,7 +47,7 @@ class CsvRecord : public CsvBase { | |||||
| T value; | T value; | ||||
| }; | }; | ||||
| class CSVNode : public DatasetNode { | |||||
| class CSVNode : public NonMappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| CSVNode(const std::vector<std::string> &dataset_files, char field_delim, | CSVNode(const std::vector<std::string> &dataset_files, char field_delim, | ||||
| @@ -58,6 +58,18 @@ class CSVNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~CSVNode() = default; | ~CSVNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kCSVNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return shared pointer to the list of newly created DatasetOps | /// \return shared pointer to the list of newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -28,7 +28,19 @@ namespace dataset { | |||||
| GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | ||||
| const std::vector<DataType> &column_types) | const std::vector<DataType> &column_types) | ||||
| : generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {} | |||||
| : MappableSourceNode(), | |||||
| generator_function_(generator_function), | |||||
| column_names_(column_names), | |||||
| column_types_(column_types) {} | |||||
| std::shared_ptr<DatasetNode> GeneratorNode::Copy() { | |||||
| auto node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_); | |||||
| return node; | |||||
| } | |||||
| void GeneratorNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(<func>:" + ",columns:" + PrintColumns(column_names_) + ",<col_types>)"; | |||||
| } | |||||
| GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema) | GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema) | ||||
| : generator_function_(generator_function), schema_(schema) {} | : generator_function_(generator_function), schema_(schema) {} | ||||
| @@ -26,10 +26,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| /// \class GeneratorNode | /// \class GeneratorNode | ||||
| /// \brief A Dataset derived class to represent GeneratorNode dataset | /// \brief A Dataset derived class to represent GeneratorNode dataset | ||||
| class GeneratorNode : public DatasetNode { | |||||
| class GeneratorNode : public MappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | ||||
| @@ -41,6 +40,18 @@ class GeneratorNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~GeneratorNode() = default; | ~GeneratorNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kGeneratorNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -33,13 +33,24 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar | |||||
| bool recursive, std::set<std::string> extensions, | bool recursive, std::set<std::string> extensions, | ||||
| std::map<std::string, int32_t> class_indexing, | std::map<std::string, int32_t> class_indexing, | ||||
| std::shared_ptr<DatasetCache> cache = nullptr) | std::shared_ptr<DatasetCache> cache = nullptr) | ||||
| : dataset_dir_(dataset_dir), | |||||
| : MappableSourceNode(std::move(cache)), | |||||
| dataset_dir_(dataset_dir), | |||||
| decode_(decode), | decode_(decode), | ||||
| sampler_(sampler), | sampler_(sampler), | ||||
| recursive_(recursive), | recursive_(recursive), | ||||
| class_indexing_(class_indexing), | class_indexing_(class_indexing), | ||||
| exts_(extensions), | |||||
| DatasetNode(std::move(cache)) {} | |||||
| exts_(extensions) {} | |||||
| std::shared_ptr<DatasetNode> ImageFolderNode::Copy() { | |||||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||||
| auto node = | |||||
| std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_); | |||||
| return node; | |||||
| } | |||||
| void ImageFolderNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(path:" + dataset_dir_ + ",decode:" + (decode_ ? "true" : "false") + ",...)"; | |||||
| } | |||||
| Status ImageFolderNode::ValidateParams() { | Status ImageFolderNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); | ||||
| @@ -31,7 +31,7 @@ namespace dataset { | |||||
| /// \class ImageFolderNode | /// \class ImageFolderNode | ||||
| /// \brief A Dataset derived class to represent ImageFolder dataset | /// \brief A Dataset derived class to represent ImageFolder dataset | ||||
| class ImageFolderNode : public DatasetNode { | |||||
| class ImageFolderNode : public MappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive, | ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive, | ||||
| @@ -41,6 +41,18 @@ class ImageFolderNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~ImageFolderNode() = default; | ~ImageFolderNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kImageFolderNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -32,13 +32,30 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u | |||||
| const std::shared_ptr<SamplerObj> &sampler, | const std::shared_ptr<SamplerObj> &sampler, | ||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | const std::map<std::string, int32_t> &class_indexing, bool decode, | ||||
| std::shared_ptr<DatasetCache> cache) | std::shared_ptr<DatasetCache> cache) | ||||
| : DatasetNode(std::move(cache)), | |||||
| : MappableSourceNode(std::move(cache)), | |||||
| dataset_file_(dataset_file), | dataset_file_(dataset_file), | ||||
| usage_(usage), | usage_(usage), | ||||
| decode_(decode), | decode_(decode), | ||||
| class_index_(class_indexing), | class_index_(class_indexing), | ||||
| sampler_(sampler) {} | sampler_(sampler) {} | ||||
| std::shared_ptr<DatasetNode> ManifestNode::Copy() { | |||||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||||
| auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_); | |||||
| return node; | |||||
| } | |||||
| void ManifestNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(file:" + dataset_file_; | |||||
| if (sampler_ != nullptr) { | |||||
| out << ",sampler"; | |||||
| } | |||||
| if (cache_ != nullptr) { | |||||
| out << ",cache"; | |||||
| } | |||||
| out << ")"; | |||||
| } | |||||
| Status ManifestNode::ValidateParams() { | Status ManifestNode::ValidateParams() { | ||||
| std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; | std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; | ||||
| for (char c : dataset_file_) { | for (char c : dataset_file_) { | ||||
| @@ -27,7 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class ManifestNode : public DatasetNode { | |||||
| class ManifestNode : public MappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | ||||
| @@ -36,6 +36,18 @@ class ManifestNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~ManifestNode() = default; | ~ManifestNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kManifestNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -30,7 +30,8 @@ namespace dataset { | |||||
| MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list, | MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list, | ||||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded) | const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded) | ||||
| : dataset_file_(std::string()), | |||||
| : MappableSourceNode(), | |||||
| dataset_file_(std::string()), | |||||
| dataset_files_(dataset_files), | dataset_files_(dataset_files), | ||||
| search_for_pattern_(false), | search_for_pattern_(false), | ||||
| columns_list_(columns_list), | columns_list_(columns_list), | ||||
| @@ -41,7 +42,8 @@ MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const | |||||
| MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list, | MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list, | ||||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded) | const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded) | ||||
| : dataset_file_(dataset_file), | |||||
| : MappableSourceNode(), | |||||
| dataset_file_(dataset_file), | |||||
| dataset_files_({}), | dataset_files_({}), | ||||
| search_for_pattern_(true), | search_for_pattern_(true), | ||||
| columns_list_(columns_list), | columns_list_(columns_list), | ||||
| @@ -50,6 +52,19 @@ MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<st | |||||
| sample_bytes_({}), | sample_bytes_({}), | ||||
| num_padded_(num_padded) {} | num_padded_(num_padded) {} | ||||
| std::shared_ptr<DatasetNode> MindDataNode::Copy() { | |||||
| std::shared_ptr<MindDataNode> node; | |||||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||||
| if (dataset_files_.empty()) { | |||||
| node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_); | |||||
| } else { | |||||
| node = std::make_shared<MindDataNode>(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_); | |||||
| } | |||||
| return node; | |||||
| } | |||||
| void MindDataNode::Print(std::ostream &out) const { out << Name() + "(file:" + dataset_file_ + ",...)"; } | |||||
| Status MindDataNode::ValidateParams() { | Status MindDataNode::ValidateParams() { | ||||
| if (!search_for_pattern_ && dataset_files_.size() > 4096) { | if (!search_for_pattern_ && dataset_files_.size() > 4096) { | ||||
| std::string err_msg = | std::string err_msg = | ||||
| @@ -27,7 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class MindDataNode : public DatasetNode { | |||||
| class MindDataNode : public MappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list, | MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list, | ||||
| @@ -40,6 +40,18 @@ class MindDataNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~MindDataNode() = default; | ~MindDataNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kMindDataNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -29,7 +29,15 @@ namespace dataset { | |||||
| MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | ||||
| std::shared_ptr<DatasetCache> cache) | std::shared_ptr<DatasetCache> cache) | ||||
| : DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| std::shared_ptr<DatasetNode> MnistNode::Copy() { | |||||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||||
| auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_); | |||||
| return node; | |||||
| } | |||||
| void MnistNode::Print(std::ostream &out) const { out << Name(); } | |||||
| Status MnistNode::ValidateParams() { | Status MnistNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); | ||||
| @@ -26,7 +26,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class MnistNode : public DatasetNode { | |||||
| class MnistNode : public MappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | ||||
| @@ -35,6 +35,18 @@ class MnistNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~MnistNode() = default; | ~MnistNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kMnistNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -27,6 +27,18 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| std::shared_ptr<DatasetNode> RandomNode::Copy() { | |||||
| std::shared_ptr<RandomNode> node; | |||||
| if (schema_ != nullptr) { | |||||
| node = std::make_shared<RandomNode>(total_rows_, schema_, columns_list_, cache_); | |||||
| } else { | |||||
| node = std::make_shared<RandomNode>(total_rows_, schema_path_, columns_list_, cache_); | |||||
| } | |||||
| return node; | |||||
| } | |||||
| void RandomNode::Print(std::ostream &out) const { out << Name() + "(num_row:" + std::to_string(total_rows_) + ",...)"; } | |||||
| // ValidateParams for RandomNode | // ValidateParams for RandomNode | ||||
| Status RandomNode::ValidateParams() { | Status RandomNode::ValidateParams() { | ||||
| if (total_rows_ < 0) { | if (total_rows_ < 0) { | ||||
| @@ -27,7 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class RandomNode : public DatasetNode { | |||||
| class RandomNode : public NonMappableSourceNode { | |||||
| public: | public: | ||||
| // Some constants to provide limits to random generation. | // Some constants to provide limits to random generation. | ||||
| static constexpr int32_t kMaxNumColumns = 4; | static constexpr int32_t kMaxNumColumns = 4; | ||||
| @@ -37,7 +37,7 @@ class RandomNode : public DatasetNode { | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list, | RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list, | ||||
| std::shared_ptr<DatasetCache> cache) | std::shared_ptr<DatasetCache> cache) | ||||
| : DatasetNode(std::move(cache)), | |||||
| : NonMappableSourceNode(std::move(cache)), | |||||
| total_rows_(total_rows), | total_rows_(total_rows), | ||||
| schema_path_(""), | schema_path_(""), | ||||
| schema_(std::move(schema)), | schema_(std::move(schema)), | ||||
| @@ -46,14 +46,27 @@ class RandomNode : public DatasetNode { | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list, | RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list, | ||||
| std::shared_ptr<DatasetCache> cache) | std::shared_ptr<DatasetCache> cache) | ||||
| : DatasetNode(std::move(cache)), | |||||
| : NonMappableSourceNode(std::move(cache)), | |||||
| total_rows_(total_rows), | total_rows_(total_rows), | ||||
| schema_path_(schema_path), | schema_path_(schema_path), | ||||
| schema_(nullptr), | |||||
| columns_list_(columns_list) {} | columns_list_(columns_list) {} | ||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~RandomNode() = default; | ~RandomNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kRandomNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -31,13 +31,23 @@ namespace dataset { | |||||
| // Constructor for TextFileNode | // Constructor for TextFileNode | ||||
| TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, | TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, | ||||
| int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | ||||
| : DatasetNode(std::move(cache)), | |||||
| : NonMappableSourceNode(std::move(cache)), | |||||
| dataset_files_(dataset_files), | dataset_files_(dataset_files), | ||||
| num_samples_(num_samples), | num_samples_(num_samples), | ||||
| shuffle_(shuffle), | shuffle_(shuffle), | ||||
| num_shards_(num_shards), | num_shards_(num_shards), | ||||
| shard_id_(shard_id) {} | shard_id_(shard_id) {} | ||||
| std::shared_ptr<DatasetNode> TextFileNode::Copy() { | |||||
| auto node = std::make_shared<TextFileNode>(dataset_files_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); | |||||
| return node; | |||||
| } | |||||
| void TextFileNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(file:..." + ",num_shards:" + std::to_string(num_shards_) + | |||||
| ",shard_id:" + std::to_string(shard_id_) + ",cache:" + ((cache_ != nullptr) ? "true" : "false") + ",...)"; | |||||
| } | |||||
| Status TextFileNode::ValidateParams() { | Status TextFileNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_)); | RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_)); | ||||
| @@ -28,7 +28,7 @@ namespace dataset { | |||||
| /// \class TextFileNode | /// \class TextFileNode | ||||
| /// \brief A Dataset derived class to represent TextFile dataset | /// \brief A Dataset derived class to represent TextFile dataset | ||||
| class TextFileNode : public DatasetNode { | |||||
| class TextFileNode : public NonMappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, | TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, | ||||
| @@ -37,6 +37,18 @@ class TextFileNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~TextFileNode() = default; | ~TextFileNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kTextFileNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -30,6 +30,23 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| std::shared_ptr<DatasetNode> TFRecordNode::Copy() { | |||||
| std::shared_ptr<TFRecordNode> node; | |||||
| if (schema_obj_ != nullptr) { | |||||
| node = std::make_shared<TFRecordNode>(dataset_files_, schema_obj_, columns_list_, num_samples_, shuffle_, | |||||
| num_shards_, shard_id_, shard_equal_rows_, cache_); | |||||
| } else { | |||||
| node = std::make_shared<TFRecordNode>(dataset_files_, schema_path_, columns_list_, num_samples_, shuffle_, | |||||
| num_shards_, shard_id_, shard_equal_rows_, cache_); | |||||
| } | |||||
| return node; | |||||
| } | |||||
| void TFRecordNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(num_samples:" + std::to_string(num_samples_) + ",num_shards:" + std::to_string(num_shards_) + | |||||
| ",shard_id:" + std::to_string(shard_id_) + ",...)"; | |||||
| } | |||||
| // Validator for TFRecordNode | // Validator for TFRecordNode | ||||
| Status TFRecordNode::ValidateParams() { | Status TFRecordNode::ValidateParams() { | ||||
| if (dataset_files_.empty()) { | if (dataset_files_.empty()) { | ||||
| @@ -29,14 +29,14 @@ namespace dataset { | |||||
| /// \class TFRecordNode | /// \class TFRecordNode | ||||
| /// \brief A Dataset derived class to represent TFRecord dataset | /// \brief A Dataset derived class to represent TFRecord dataset | ||||
| class TFRecordNode : public DatasetNode { | |||||
| class TFRecordNode : public NonMappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| /// \note Parameter 'schema' is the path to the schema file | /// \note Parameter 'schema' is the path to the schema file | ||||
| TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema, | TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema, | ||||
| const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | ||||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | ||||
| : DatasetNode(std::move(cache)), | |||||
| : NonMappableSourceNode(std::move(cache)), | |||||
| dataset_files_(dataset_files), | dataset_files_(dataset_files), | ||||
| schema_path_(schema), | schema_path_(schema), | ||||
| columns_list_(columns_list), | columns_list_(columns_list), | ||||
| @@ -51,7 +51,7 @@ class TFRecordNode : public DatasetNode { | |||||
| TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema, | TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema, | ||||
| const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | ||||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | ||||
| : DatasetNode(std::move(cache)), | |||||
| : NonMappableSourceNode(std::move(cache)), | |||||
| dataset_files_(dataset_files), | dataset_files_(dataset_files), | ||||
| schema_obj_(schema), | schema_obj_(schema), | ||||
| columns_list_(columns_list), | columns_list_(columns_list), | ||||
| @@ -64,6 +64,18 @@ class TFRecordNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~TFRecordNode() = default; | ~TFRecordNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kTFRecordNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -32,7 +32,7 @@ namespace dataset { | |||||
| VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | ||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler, | const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler, | ||||
| std::shared_ptr<DatasetCache> cache) | std::shared_ptr<DatasetCache> cache) | ||||
| : DatasetNode(std::move(cache)), | |||||
| : MappableSourceNode(std::move(cache)), | |||||
| dataset_dir_(dataset_dir), | dataset_dir_(dataset_dir), | ||||
| task_(task), | task_(task), | ||||
| usage_(usage), | usage_(usage), | ||||
| @@ -40,6 +40,14 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const | |||||
| decode_(decode), | decode_(decode), | ||||
| sampler_(sampler) {} | sampler_(sampler) {} | ||||
| std::shared_ptr<DatasetNode> VOCNode::Copy() { | |||||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||||
| auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_); | |||||
| return node; | |||||
| } | |||||
| void VOCNode::Print(std::ostream &out) const { out << Name(); } | |||||
| Status VOCNode::ValidateParams() { | Status VOCNode::ValidateParams() { | ||||
| Path dir(dataset_dir_); | Path dir(dataset_dir_); | ||||
| @@ -27,7 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class VOCNode : public DatasetNode { | |||||
| class VOCNode : public MappableSourceNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | ||||
| @@ -37,6 +37,18 @@ class VOCNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~VOCNode() = default; | ~VOCNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kVOCNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return shared pointer to the list of newly created DatasetOps | /// \return shared pointer to the list of newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -29,7 +29,16 @@ namespace dataset { | |||||
| // Constructor for SyncWaitNode | // Constructor for SyncWaitNode | ||||
| SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback) | SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback) | ||||
| : condition_name_(condition_name), callback_(callback) { | : condition_name_(condition_name), callback_(callback) { | ||||
| this->children.push_back(child); | |||||
| this->AddChild(child); | |||||
| } | |||||
| std::shared_ptr<DatasetNode> SyncWaitNode::Copy() { | |||||
| auto node = std::make_shared<SyncWaitNode>(nullptr, condition_name_, callback_); | |||||
| return node; | |||||
| } | |||||
| void SyncWaitNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(cond_name:" + condition_name_ + "<pyfunc>" + ")"; | |||||
| } | } | ||||
| // Function to build the BarrierOp | // Function to build the BarrierOp | ||||
| @@ -36,6 +36,18 @@ class SyncWaitNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~SyncWaitNode() = default; | ~SyncWaitNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kSyncWaitNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -27,10 +27,15 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor for TakeNode | // Constructor for TakeNode | ||||
| TakeNode::TakeNode(std::shared_ptr<DatasetNode> child, int32_t count) : take_count_(count) { | |||||
| this->children.push_back(child); | |||||
| TakeNode::TakeNode(std::shared_ptr<DatasetNode> child, int32_t count) : take_count_(count) { this->AddChild(child); } | |||||
| std::shared_ptr<DatasetNode> TakeNode::Copy() { | |||||
| auto node = std::make_shared<TakeNode>(nullptr, take_count_); | |||||
| return node; | |||||
| } | } | ||||
| void TakeNode::Print(std::ostream &out) const { out << Name() + "(num_rows:" + std::to_string(take_count_) + ")"; } | |||||
| // Function to build the TakeOp | // Function to build the TakeOp | ||||
| std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() { | std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() { | ||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | // A vector containing shared pointer to the Dataset Ops that this object will create | ||||
| @@ -34,6 +34,18 @@ class TakeNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~TakeNode() = default; | ~TakeNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kTakeNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return shared pointer to the list of newly created DatasetOps | /// \return shared pointer to the list of newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | #include "minddata/dataset/engine/datasetops/device_queue_op.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| @@ -39,7 +40,19 @@ TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue | |||||
| total_batch_(total_batch), | total_batch_(total_batch), | ||||
| create_data_info_queue_(create_data_info_queue), | create_data_info_queue_(create_data_info_queue), | ||||
| device_id_(0) { | device_id_(0) { | ||||
| this->children.push_back(child); | |||||
| this->AddChild(child); | |||||
| } | |||||
| std::shared_ptr<DatasetNode> TransferNode::Copy() { | |||||
| auto node = std::make_shared<TransferNode>(nullptr, queue_name_, device_type_, send_epoch_end_, total_batch_, | |||||
| create_data_info_queue_); | |||||
| return node; | |||||
| } | |||||
| void TransferNode::Print(std::ostream &out) const { | |||||
| out << Name() + "(prefetch_size:" + std::to_string(prefetch_size_) + | |||||
| ",send_epoch_end:" + (send_epoch_end_ ? "true" : "false") + ",total_batch:" + std::to_string(total_batch_) + | |||||
| ")"; | |||||
| } | } | ||||
| // Validator for TransferNode | // Validator for TransferNode | ||||
| @@ -94,5 +107,16 @@ std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Visitor accepting method for NodePass | |||||
| Status TransferNode::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->Visit(shared_from_base<TransferNode>(), modified); | |||||
| } | |||||
| // Visitor accepting method for NodePass | |||||
| Status TransferNode::AcceptAfter(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->VisitAfter(shared_from_base<TransferNode>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,6 +35,18 @@ class TransferNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~TransferNode() = default; | ~TransferNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kTransferNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return shared pointer to the list of newly created DatasetOps | /// \return shared pointer to the list of newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -43,6 +55,20 @@ class TransferNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id); | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| std::string queue_name_; | std::string queue_name_; | ||||
| int32_t device_id_; | int32_t device_id_; | ||||
| @@ -21,30 +21,36 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/datasetops/zip_op.h" | #include "minddata/dataset/engine/datasetops/zip_op.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| ZipNode::ZipNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) : datasets_(datasets) { | |||||
| for (auto dataset : datasets_) { | |||||
| this->children.push_back(dataset); | |||||
| } | |||||
| ZipNode::ZipNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { | |||||
| for (auto const &child : datasets) AddChild(child); | |||||
| } | } | ||||
| std::shared_ptr<DatasetNode> ZipNode::Copy() { | |||||
| std::vector<std::shared_ptr<DatasetNode>> empty_vector; | |||||
| empty_vector.clear(); | |||||
| auto node = std::make_shared<ZipNode>(empty_vector); | |||||
| return node; | |||||
| } | |||||
| void ZipNode::Print(std::ostream &out) const { out << Name(); } | |||||
| Status ZipNode::ValidateParams() { | Status ZipNode::ValidateParams() { | ||||
| if (datasets_.empty()) { | |||||
| std::string err_msg = "ZipNode: datasets to zip are not specified."; | |||||
| if (children_.size() < 2) { | |||||
| std::string err_msg = "ZipNode: input datasets are not specified."; | |||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { | |||||
| std::string err_msg = "ZipNode: zip datasets should not be null."; | |||||
| if (find(children_.begin(), children_.end(), nullptr) != children_.end()) { | |||||
| std::string err_msg = "ZipNode: input datasets should not be null."; | |||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -56,5 +62,17 @@ std::vector<std::shared_ptr<DatasetOp>> ZipNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Visitor accepting method for NodePass | |||||
| Status ZipNode::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->Visit(shared_from_base<ZipNode>(), modified); | |||||
| } | |||||
| // Visitor accepting method for NodePass | |||||
| Status ZipNode::AcceptAfter(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->VisitAfter(shared_from_base<ZipNode>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,6 +34,18 @@ class ZipNode : public DatasetNode { | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~ZipNode() = default; | ~ZipNode() = default; | ||||
| /// \brief Node name getter | |||||
| /// \return Name of the current node | |||||
| std::string Name() const override { return kZipNode; } | |||||
| /// \brief Print the description | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Copy the node to a new object | |||||
| /// \return A shared pointer to the new copy | |||||
| std::shared_ptr<DatasetNode> Copy() override; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | /// \brief a base class override function to create the required runtime dataset op objects for this class | ||||
| /// \return The list of shared pointers to the newly created DatasetOps | /// \return The list of shared pointers to the newly created DatasetOps | ||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | std::vector<std::shared_ptr<DatasetOp>> Build() override; | ||||
| @@ -42,8 +54,17 @@ class ZipNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| private: | |||||
| std::vector<std::shared_ptr<DatasetNode>> datasets_; | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| /// \brief Base-class override for accepting NodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -22,10 +22,12 @@ | |||||
| #endif | #endif | ||||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/filter_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | #include "minddata/dataset/engine/ir/datasetops/map_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | #include "minddata/dataset/engine/ir/datasetops/project_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/root_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| @@ -34,34 +36,6 @@ | |||||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | #include "minddata/dataset/engine/ir/datasetops/take_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/transfer_node.h" | #include "minddata/dataset/engine/ir/datasetops/transfer_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||||
| #ifndef ENABLE_ANDROID | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" | |||||
| #endif | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h" | |||||
| #ifndef ENABLE_ANDROID | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | |||||
| #endif | |||||
| #ifdef ENABLE_PYTHON | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" | |||||
| #endif | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" | |||||
| #ifndef ENABLE_ANDROID | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" | |||||
| #endif | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" | |||||
| #ifndef ENABLE_ANDROID | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | |||||
| #endif | |||||
| #ifndef ENABLE_ANDROID | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" | |||||
| #endif | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" | |||||
| ////////////////////////////////// | ////////////////////////////////// | ||||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | ||||
| @@ -113,7 +87,12 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| // Driver method for TreePass | // Driver method for TreePass | ||||
| Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); } | |||||
| Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | |||||
| if (root_ir == nullptr || modified == nullptr) { | |||||
| return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); | |||||
| } | |||||
| return this->RunOnTree(root_ir, modified); | |||||
| } | |||||
| // Driver method for NodePass | // Driver method for NodePass | ||||
| Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | ||||
| @@ -132,15 +111,23 @@ Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | |||||
| // Helper function to perform DFS visit | // Helper function to perform DFS visit | ||||
| Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) { | Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) { | ||||
| RETURN_IF_NOT_OK(node_ir->Accept(this, modified)); | |||||
| bool m = false; | |||||
| RETURN_IF_NOT_OK(node_ir->Accept(this, &m)); | |||||
| *modified |= m; | |||||
| for (const auto &c : node_ir->Children()) { | for (const auto &c : node_ir->Children()) { | ||||
| RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); | |||||
| RETURN_IF_NOT_OK(this->DFSNodeVisit(c, &m)); | |||||
| *modified |= m; | |||||
| } | } | ||||
| return node_ir->AcceptAfter(this, modified); | |||||
| RETURN_IF_NOT_OK(node_ir->AcceptAfter(this, &m)); | |||||
| *modified |= m; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| // Helper function to perform BFS visit | // Helper function to perform BFS visit | ||||
| Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) { | Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) { | ||||
| bool m = false; | |||||
| // Initialize bfs queue with root | // Initialize bfs queue with root | ||||
| std::queue<std::shared_ptr<DatasetNode>> bfsQueue; | std::queue<std::shared_ptr<DatasetNode>> bfsQueue; | ||||
| bfsQueue.push(node_ir); | bfsQueue.push(node_ir); | ||||
| @@ -152,7 +139,8 @@ Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi | |||||
| bfsQueue.pop(); | bfsQueue.pop(); | ||||
| // Run node pass | // Run node pass | ||||
| RETURN_IF_NOT_OK(curNode->Accept(this, modified)); | |||||
| RETURN_IF_NOT_OK(curNode->Accept(this, &m)); | |||||
| *modified |= m; | |||||
| // Push children into bfs queue | // Push children into bfs queue | ||||
| for (const auto &c : curNode->Children()) { | for (const auto &c : curNode->Children()) { | ||||
| @@ -162,331 +150,119 @@ Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // For datasetops IR | |||||
| // For non-leaf IR node | |||||
| Status NodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) { | Status NodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) { | Status NodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) { | Status NodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) { | Status NodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | |||||
| Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| #endif | |||||
| Status NodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) { | Status NodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) { | Status NodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) { | Status NodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) { | Status NodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::Visit(std::shared_ptr<FilterNode> node, bool *modified) { | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *modified) { | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) { | Status NodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) { | Status NodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) { | Status NodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) { | Status NodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) { | Status NodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) { | Status NodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) { | Status NodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) { | Status NodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| Status NodePass::Visit(std::shared_ptr<RootNode> node, bool *modified) { | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| Status NodePass::VisitAfter(std::shared_ptr<RootNode> node, bool *modified) { | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| Status NodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) { | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| Status NodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) { | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| #ifdef ENABLE_PYTHON | |||||
| Status NodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| Status NodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) { | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| Status NodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) { | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| #endif | |||||
| Status NodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) { | Status NodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) { | Status NodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) { | Status NodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) { | Status NodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) { | Status NodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) { | Status NodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| // For datasetops/source IR | |||||
| Status NodePass::Visit(std::shared_ptr<AlbumNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<AlbumNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::Visit(std::shared_ptr<CelebANode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<CelebANode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::Visit(std::shared_ptr<Cifar100Node> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<Cifar100Node> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::Visit(std::shared_ptr<Cifar10Node> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<Cifar10Node> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | |||||
| Status NodePass::Visit(std::shared_ptr<CLUENode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<CLUENode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| #endif | |||||
| Status NodePass::Visit(std::shared_ptr<CocoNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<CocoNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | |||||
| Status NodePass::Visit(std::shared_ptr<CSVNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<CSVNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| #endif | |||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| Status NodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| #endif | |||||
| Status NodePass::Visit(std::shared_ptr<ImageFolderNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<ImageFolderNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::Visit(std::shared_ptr<ManifestNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<ManifestNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | |||||
| Status NodePass::Visit(std::shared_ptr<MindDataNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<MindDataNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| #endif | |||||
| Status NodePass::Visit(std::shared_ptr<MnistNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<MnistNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::Visit(std::shared_ptr<RandomNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status NodePass::VisitAfter(std::shared_ptr<RandomNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | |||||
| Status NodePass::Visit(std::shared_ptr<TextFileNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| Status NodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) { | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<TextFileNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| Status NodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) { | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| #endif | #endif | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| Status NodePass::Visit(std::shared_ptr<TFRecordNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<TFRecordNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| #endif | #endif | ||||
| Status NodePass::Visit(std::shared_ptr<VOCNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| // For leaf IR Node | |||||
| Status NodePass::Visit(std::shared_ptr<SourceNode> node, bool *modified) { | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status NodePass::VisitAfter(std::shared_ptr<VOCNode> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| Status NodePass::VisitAfter(std::shared_ptr<SourceNode> node, bool *modified) { | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| @@ -26,123 +26,87 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Non-leaf IR node | |||||
| class BatchNode; | class BatchNode; | ||||
| class BucketBatchByLengthNode; | class BucketBatchByLengthNode; | ||||
| #ifndef ENABLE_ANDROID | |||||
| class BuildSentenceVocabNode; | |||||
| #endif | |||||
| class BuildVocabNode; | class BuildVocabNode; | ||||
| class ConcatNode; | class ConcatNode; | ||||
| class FilterNode; | |||||
| class MapNode; | class MapNode; | ||||
| class ProjectNode; | class ProjectNode; | ||||
| class RenameNode; | class RenameNode; | ||||
| class RepeatNode; | class RepeatNode; | ||||
| class RootNode; | |||||
| class ShuffleNode; | class ShuffleNode; | ||||
| class SkipNode; | class SkipNode; | ||||
| #ifdef ENABLE_PYTHON | |||||
| class SyncWaitNode; | |||||
| #endif | |||||
| class TakeNode; | class TakeNode; | ||||
| class TransferNode; | class TransferNode; | ||||
| class ZipNode; | class ZipNode; | ||||
| #ifdef ENABLE_PYTHON | |||||
| class SyncWaitNode; | |||||
| #endif | |||||
| #ifndef ENABLE_ANDROID | |||||
| class BuildSentenceVocabNode; | |||||
| #endif | |||||
| // Leaf IR node | |||||
| class AlbumNode; | class AlbumNode; | ||||
| class CelebANode; | class CelebANode; | ||||
| class Cifar100Node; | class Cifar100Node; | ||||
| class Cifar10Node; | class Cifar10Node; | ||||
| #ifndef ENABLE_ANDROID | |||||
| class CLUENode; | |||||
| #endif | |||||
| class CocoNode; | class CocoNode; | ||||
| #ifndef ENABLE_ANDROID | |||||
| class CSVNode; | |||||
| #endif | |||||
| #ifdef ENABLE_PYTHON | |||||
| class GeneratorNode; | |||||
| #endif | |||||
| class ImageFolderNode; | class ImageFolderNode; | ||||
| class ManifestNode; | class ManifestNode; | ||||
| #ifndef ENABLE_ANDROID | |||||
| class MindDataNode; | |||||
| #endif | |||||
| class MnistNode; | class MnistNode; | ||||
| class RandomNode; | class RandomNode; | ||||
| #ifndef ENABLE_ANDROID | |||||
| class TextFileNode; | |||||
| class VOCNode; | |||||
| #ifdef ENABLE_PYTHON | |||||
| class GeneratorNode; | |||||
| #endif | #endif | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| class CLUENode; | |||||
| class CSVNode; | |||||
| class MindDataNode; | |||||
| class TextFileNode; | |||||
| class TFRecordNode; | class TFRecordNode; | ||||
| #endif | #endif | ||||
| class VOCNode; | |||||
| ////////////////////////////////// | ////////////////////////////////// | ||||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | ||||
| class BatchOp; | class BatchOp; | ||||
| class MapOp; | class MapOp; | ||||
| class ProjectOp; | class ProjectOp; | ||||
| class RenameOp; | class RenameOp; | ||||
| class SkipOp; | class SkipOp; | ||||
| class ShuffleOp; | class ShuffleOp; | ||||
| class AlbumOp; | class AlbumOp; | ||||
| class RandomDataOp; | class RandomDataOp; | ||||
| class RepeatOp; | class RepeatOp; | ||||
| class TakeOp; | class TakeOp; | ||||
| class ZipOp; | class ZipOp; | ||||
| class DeviceQueueOp; | class DeviceQueueOp; | ||||
| class ImageFolderOp; | class ImageFolderOp; | ||||
| class MnistOp; | class MnistOp; | ||||
| class ManifestOp; | class ManifestOp; | ||||
| class CifarOp; | class CifarOp; | ||||
| class VOCOp; | class VOCOp; | ||||
| class CocoOp; | class CocoOp; | ||||
| class CelebAOp; | class CelebAOp; | ||||
| class EpochCtrlOp; | class EpochCtrlOp; | ||||
| class BuildVocabOp; | class BuildVocabOp; | ||||
| class ConcatOp; | class ConcatOp; | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| class MindRecordOp; | class MindRecordOp; | ||||
| class TFReaderOp; | class TFReaderOp; | ||||
| class CacheOp; | class CacheOp; | ||||
| class CacheMergeOp; | class CacheMergeOp; | ||||
| class CacheLookupOp; | class CacheLookupOp; | ||||
| class BuildSentencePieceVocabOp; | class BuildSentencePieceVocabOp; | ||||
| class ClueOp; | class ClueOp; | ||||
| class CsvOp; | class CsvOp; | ||||
| class TextFileOp; | class TextFileOp; | ||||
| #endif | #endif | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| class FilterOp; | class FilterOp; | ||||
| class GeneratorOp; | class GeneratorOp; | ||||
| #endif | #endif | ||||
| ////////////////////////////////// | ////////////////////////////////// | ||||
| @@ -175,6 +139,13 @@ class TreePass : public Pass { | |||||
| /// \param[inout] modified Indicate if the tree was modified | /// \param[inout] modified Indicate if the tree was modified | ||||
| Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) final; | Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) final; | ||||
| /// \brief Derived classes may implement the runOnTree function to implement tree transformation. | |||||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution. | |||||
| /// \param[inout] tree The tree to operate on. | |||||
| /// \param[inout] Indicate of the tree was modified. | |||||
| /// \return Status The error code return | |||||
| virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); } | |||||
| ////////////////////////////////// | ////////////////////////////////// | ||||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | ||||
| /// \brief Run the transformation pass against the execution tree. | /// \brief Run the transformation pass against the execution tree. | ||||
| @@ -191,8 +162,17 @@ class TreePass : public Pass { | |||||
| ////////////////////////////////// | ////////////////////////////////// | ||||
| }; | }; | ||||
| // NodePass is a basic Pass class which performs transformation on Node visiting. | |||||
| // NodePass is a base Pass class which performs transformation on node visiting. | |||||
| // NodePass implements Visitor design pattern. | // NodePass implements Visitor design pattern. | ||||
| // The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal, | |||||
| // and the other when all the descending nodes are visited. | |||||
| // Actual transformation is done by implementing a new derived class of NodePass. | |||||
| // The derived class will implement the method Visit()/VisitAfter() passing specified node types | |||||
| // it wants to action on them, overriding the ones defined in NodePass. | |||||
| // If the derived class wants to perform the same action on all node types, | |||||
| // it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode. | |||||
| // This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back | |||||
| // to call the Visit()/VisitAfter() in this parent NodePass class. | |||||
| class NodePass : public Pass { | class NodePass : public Pass { | ||||
| public: | public: | ||||
| // Tree traversal order | // Tree traversal order | ||||
| @@ -223,153 +203,57 @@ class NodePass : public Pass { | |||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); } | virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); } | ||||
| // For datasetops IR | |||||
| // Visit method to be overridden. | |||||
| // Note that member template can not be virtual, any node which wants to work with NodePass | |||||
| // should declare Visit of its own type and override "Accept" from DatasetNode. | |||||
| // Visit()/VisitAfter() method to be overridden. | |||||
| // These pairs of Visit()/VisitAfter() for each derived class of DatasetNode are defined here. | |||||
| // Their implementation are in .cc file to avoid adding the include files of those derived classes. | |||||
| // The implementation simply falls back to call Visit()/VisitAfter of class DatasetNode, the parent of | |||||
| // the derived classes. With this technique, the transformation classes derived from NodePass needs only to | |||||
| // implement Visit()/VisitAfter() passing DatasetNode if it wants to action on any derived classes | |||||
| // of DatasetNode in the same way. | |||||
| // Note that virtual template functions are not permitted in C++. | |||||
| // | |||||
| // Non-leaf IR node | |||||
| virtual Status Visit(std::shared_ptr<BatchNode> node, bool *modified); | virtual Status Visit(std::shared_ptr<BatchNode> node, bool *modified); | ||||
| // VisitAfter method to be overridden. | |||||
| // Note that member template can not be virtual, any node which wants to work with NodePass | |||||
| // should declare VisitAfter of its own type and override "AcceptAfter" from DatasetNode. | |||||
| virtual Status VisitAfter(std::shared_ptr<BatchNode> node, bool *modified); | virtual Status VisitAfter(std::shared_ptr<BatchNode> node, bool *modified); | ||||
| virtual Status Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified); | virtual Status Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified); | virtual Status VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified); | ||||
| #ifndef ENABLE_ANDROID | |||||
| virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified); | |||||
| #endif | |||||
| virtual Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified); | virtual Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified); | virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified); | ||||
| virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *modified); | virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified); | virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified); | ||||
| virtual Status Visit(std::shared_ptr<FilterNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<FilterNode> node, bool *modified); | |||||
| virtual Status Visit(std::shared_ptr<MapNode> node, bool *modified); | virtual Status Visit(std::shared_ptr<MapNode> node, bool *modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *modified); | virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *modified); | ||||
| virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *modified); | virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified); | virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified); | ||||
| virtual Status Visit(std::shared_ptr<RenameNode> node, bool *modified); | virtual Status Visit(std::shared_ptr<RenameNode> node, bool *modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *modified); | virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *modified); | ||||
| virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *modified); | virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified); | virtual Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified); | ||||
| virtual Status Visit(std::shared_ptr<RootNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<RootNode> node, bool *modified); | |||||
| virtual Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified); | virtual Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified); | virtual Status VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified); | ||||
| virtual Status Visit(std::shared_ptr<SkipNode> node, bool *modified); | virtual Status Visit(std::shared_ptr<SkipNode> node, bool *modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<SkipNode> node, bool *modified); | virtual Status VisitAfter(std::shared_ptr<SkipNode> node, bool *modified); | ||||
| #ifdef ENABLE_PYTHON | |||||
| virtual Status Visit(std::shared_ptr<SyncWaitNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified); | |||||
| #endif | |||||
| virtual Status Visit(std::shared_ptr<TakeNode> node, bool *modified); | virtual Status Visit(std::shared_ptr<TakeNode> node, bool *modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<TakeNode> node, bool *modified); | virtual Status VisitAfter(std::shared_ptr<TakeNode> node, bool *modified); | ||||
| virtual Status Visit(std::shared_ptr<TransferNode> node, bool *modified); | virtual Status Visit(std::shared_ptr<TransferNode> node, bool *modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified); | virtual Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified); | ||||
| virtual Status Visit(std::shared_ptr<ZipNode> node, bool *modified); | virtual Status Visit(std::shared_ptr<ZipNode> node, bool *modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<ZipNode> node, bool *modified); | virtual Status VisitAfter(std::shared_ptr<ZipNode> node, bool *modified); | ||||
| // For datasetops/source IR | |||||
| virtual Status Visit(std::shared_ptr<AlbumNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<AlbumNode> node, bool *modified); | |||||
| virtual Status Visit(std::shared_ptr<CelebANode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<CelebANode> node, bool *modified); | |||||
| virtual Status Visit(std::shared_ptr<Cifar100Node> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<Cifar100Node> node, bool *modified); | |||||
| virtual Status Visit(std::shared_ptr<Cifar10Node> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<Cifar10Node> node, bool *modified); | |||||
| #ifndef ENABLE_ANDROID | |||||
| virtual Status Visit(std::shared_ptr<CLUENode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<CLUENode> node, bool *modified); | |||||
| #endif | |||||
| virtual Status Visit(std::shared_ptr<CocoNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<CocoNode> node, bool *modified); | |||||
| #ifndef ENABLE_ANDROID | |||||
| virtual Status Visit(std::shared_ptr<CSVNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<CSVNode> node, bool *modified); | |||||
| #endif | |||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| virtual Status Visit(std::shared_ptr<GeneratorNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<GeneratorNode> node, bool *modified); | |||||
| #endif | |||||
| virtual Status Visit(std::shared_ptr<ImageFolderNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<ImageFolderNode> node, bool *modified); | |||||
| virtual Status Visit(std::shared_ptr<ManifestNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<ManifestNode> node, bool *modified); | |||||
| #ifndef ENABLE_ANDROID | |||||
| virtual Status Visit(std::shared_ptr<MindDataNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<MindDataNode> node, bool *modified); | |||||
| #endif | |||||
| virtual Status Visit(std::shared_ptr<MnistNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<MnistNode> node, bool *modified); | |||||
| virtual Status Visit(std::shared_ptr<RandomNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<RandomNode> node, bool *modified); | |||||
| #ifndef ENABLE_ANDROID | |||||
| virtual Status Visit(std::shared_ptr<TextFileNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<TextFileNode> node, bool *modified); | |||||
| virtual Status Visit(std::shared_ptr<SyncWaitNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified); | |||||
| #endif | #endif | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| virtual Status Visit(std::shared_ptr<TFRecordNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<TFRecordNode> node, bool *modified); | |||||
| virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified); | |||||
| #endif | #endif | ||||
| virtual Status Visit(std::shared_ptr<VOCNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<VOCNode> node, bool *modified); | |||||
| // Leaf IR node | |||||
| virtual Status Visit(std::shared_ptr<SourceNode> node, bool *modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<SourceNode> node, bool *modified); | |||||
| ////////////////////////////////// | ////////////////////////////////// | ||||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | ||||
| @@ -396,86 +280,47 @@ class NodePass : public Pass { | |||||
| // Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode | // Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode | ||||
| // of its own type and override "Accept" from DatasetOp. | // of its own type and override "Accept" from DatasetOp. | ||||
| virtual Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified); | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified); | ||||
| #endif | #endif | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | ||||
| #endif | #endif | ||||
| ////////////////////////////////// | ////////////////////////////////// | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include "minddata/dataset/core/client.h" | #include "minddata/dataset/core/client.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/root_node.h" | |||||
| #include "minddata/dataset/engine/opt/pass.h" | #include "minddata/dataset/engine/opt/pass.h" | ||||
| #include "minddata/dataset/engine/opt/pre/input_validation_pass.h" | #include "minddata/dataset/engine/opt/pre/input_validation_pass.h" | ||||
| @@ -119,11 +120,16 @@ Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::sha | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs) { | |||||
| num_epochs_ = num_epochs; | |||||
| Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_epochs) { | |||||
| optimize_ = true; // Always ON (temporary) | optimize_ = true; // Always ON (temporary) | ||||
| RETURN_UNEXPECTED_IF_NULL(root_ir); | |||||
| RETURN_UNEXPECTED_IF_NULL(input_ir); | |||||
| MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n'; | |||||
| // Copy the input IR tree and insert under the root node | |||||
| // Create a root node to host the input IR tree | |||||
| auto root_ir = std::make_shared<RootNode>(input_ir->DeepCopy(), num_epochs); | |||||
| MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n'; | |||||
| // Pre-pass of the IR tree | // Pre-pass of the IR tree | ||||
| RETURN_IF_NOT_OK(PrePass(root_ir)); | RETURN_IF_NOT_OK(PrePass(root_ir)); | ||||
| @@ -136,11 +142,15 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep | |||||
| // Post-pass of the IR tree | // Post-pass of the IR tree | ||||
| RETURN_IF_NOT_OK(PostPass(root_ir)); | RETURN_IF_NOT_OK(PostPass(root_ir)); | ||||
| MS_LOG(INFO) << "Plan after PostPass:" << '\n' << *root_ir << '\n'; | |||||
| // This will evolve in the long run | // This will evolve in the long run | ||||
| tree_ = std::make_unique<ExecutionTree>(); | tree_ = std::make_unique<ExecutionTree>(); | ||||
| // Build the Execution tree from the child of the root node | |||||
| std::shared_ptr<DatasetOp> root_op; | std::shared_ptr<DatasetOp> root_op; | ||||
| RETURN_IF_NOT_OK(BuildExecutionTree(root_ir, &root_op)); | |||||
| // We will replace input_ir with root_ir->Children()[0] once IR optimizer is in | |||||
| RETURN_IF_NOT_OK(BuildExecutionTree(input_ir, &root_op)); | |||||
| RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); | RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); | ||||
| if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_); | if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_); | ||||
| @@ -67,10 +67,6 @@ class TreeAdapter { | |||||
| // Optional optimizations status | // Optional optimizations status | ||||
| bool OptimizationEnabled() const { return optimize_; } | bool OptimizationEnabled() const { return optimize_; } | ||||
| // Getter function to get the total number of epochs to be run on this tree. | |||||
| // @return total number of epochs | |||||
| int32_t num_epochs() { return num_epochs_; } | |||||
| private: | private: | ||||
| // This function runs a mandatory pass checking the syntax and semantics of the IR tree. | // This function runs a mandatory pass checking the syntax and semantics of the IR tree. | ||||
| Status PrePass(std::shared_ptr<DatasetNode> ir); | Status PrePass(std::shared_ptr<DatasetNode> ir); | ||||
| @@ -47,6 +47,10 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> { | |||||
| /// \return Shared pointers to the newly created Sampler | /// \return Shared pointers to the newly created Sampler | ||||
| virtual std::shared_ptr<SamplerRT> Build() = 0; | virtual std::shared_ptr<SamplerRT> Build() = 0; | ||||
| /// \brief Pure virtual function to copy a SamplerObj class | |||||
| /// \return Shared pointers to the newly copied SamplerObj | |||||
| virtual std::shared_ptr<SamplerObj> Copy() = 0; | |||||
| /// \brief Function for derived class to get the shard id of sampler | /// \brief Function for derived class to get the shard id of sampler | ||||
| /// \return The shard id of the derived sampler | /// \return The shard id of the derived sampler | ||||
| virtual int64_t ShardId() { return 0; } | virtual int64_t ShardId() { return 0; } | ||||
| @@ -132,6 +136,11 @@ class DistributedSamplerObj : public SamplerObj { | |||||
| std::shared_ptr<SamplerRT> Build() override; | std::shared_ptr<SamplerRT> Build() override; | ||||
| std::shared_ptr<SamplerObj> Copy() override { | |||||
| return std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_, | |||||
| even_dist_); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | ||||
| #endif | #endif | ||||
| @@ -160,6 +169,10 @@ class PKSamplerObj : public SamplerObj { | |||||
| std::shared_ptr<SamplerRT> Build() override; | std::shared_ptr<SamplerRT> Build() override; | ||||
| std::shared_ptr<SamplerObj> Copy() override { | |||||
| return std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | ||||
| #endif | #endif | ||||
| @@ -174,9 +187,8 @@ class PKSamplerObj : public SamplerObj { | |||||
| class PreBuiltSamplerObj : public SamplerObj { | class PreBuiltSamplerObj : public SamplerObj { | ||||
| public: | public: | ||||
| #ifndef ENABLE_ANDROID | |||||
| explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler); | explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler); | ||||
| #ifndef ENABLE_ANDROID | |||||
| explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler); | explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler); | ||||
| #endif | #endif | ||||
| @@ -188,6 +200,8 @@ class PreBuiltSamplerObj : public SamplerObj { | |||||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | ||||
| #endif | #endif | ||||
| std::shared_ptr<SamplerObj> Copy() override; | |||||
| bool ValidateParams() override; | bool ValidateParams() override; | ||||
| private: | private: | ||||
| @@ -205,6 +219,8 @@ class RandomSamplerObj : public SamplerObj { | |||||
| std::shared_ptr<SamplerRT> Build() override; | std::shared_ptr<SamplerRT> Build() override; | ||||
| std::shared_ptr<SamplerObj> Copy() override { return std::make_shared<RandomSamplerObj>(replacement_, num_samples_); } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | ||||
| #endif | #endif | ||||
| @@ -224,6 +240,10 @@ class SequentialSamplerObj : public SamplerObj { | |||||
| std::shared_ptr<SamplerRT> Build() override; | std::shared_ptr<SamplerRT> Build() override; | ||||
| std::shared_ptr<SamplerObj> Copy() override { | |||||
| return std::make_shared<SequentialSamplerObj>(start_index_, num_samples_); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | ||||
| #endif | #endif | ||||
| @@ -243,6 +263,10 @@ class SubsetRandomSamplerObj : public SamplerObj { | |||||
| std::shared_ptr<SamplerRT> Build() override; | std::shared_ptr<SamplerRT> Build() override; | ||||
| std::shared_ptr<SamplerObj> Copy() override { | |||||
| return std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | ||||
| #endif | #endif | ||||
| @@ -262,6 +286,10 @@ class WeightedRandomSamplerObj : public SamplerObj { | |||||
| std::shared_ptr<SamplerRT> Build() override; | std::shared_ptr<SamplerRT> Build() override; | ||||
| std::shared_ptr<SamplerObj> Copy() override { | |||||
| return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); | |||||
| } | |||||
| bool ValidateParams() override; | bool ValidateParams() override; | ||||
| private: | private: | ||||
| @@ -32,7 +32,10 @@ class TensorOp; | |||||
| class TensorOperation : public std::enable_shared_from_this<TensorOperation> { | class TensorOperation : public std::enable_shared_from_this<TensorOperation> { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| TensorOperation(); | |||||
| TensorOperation() : random_op_(false) {} | |||||
| /// \brief Constructor | |||||
| explicit TensorOperation(bool random) : random_op_(random) {} | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~TensorOperation() = default; | ~TensorOperation() = default; | ||||
| @@ -42,6 +45,13 @@ class TensorOperation : public std::enable_shared_from_this<TensorOperation> { | |||||
| virtual std::shared_ptr<TensorOp> Build() = 0; | virtual std::shared_ptr<TensorOp> Build() = 0; | ||||
| virtual Status ValidateParams() = 0; | virtual Status ValidateParams() = 0; | ||||
| /// \brief Check whether the operation is deterministic. | |||||
| /// \return true if this op is a random op (returns non-deterministic result e.g. RandomCrop) | |||||
| bool IsRandomOp() const { return random_op_; } | |||||
| protected: | |||||
| bool random_op_; | |||||
| }; | }; | ||||
| // Helper function to validate fill value | // Helper function to validate fill value | ||||
| @@ -427,7 +427,7 @@ Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> | |||||
| Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst, | Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst, | ||||
| std::vector<dsize_t> cur_ind, size_t cur_dim) { | std::vector<dsize_t> cur_ind, size_t cur_dim) { | ||||
| if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data | if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data | ||||
| dst->CopyLastDimAt(src, cur_ind); | |||||
| RETURN_IF_NOT_OK(dst->CopyLastDimAt(src, cur_ind)); | |||||
| } else { // not the last dimension, keep doing recursion | } else { // not the last dimension, keep doing recursion | ||||
| dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]); | dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]); | ||||
| for (dsize_t i = 0; i < min_ind; i++) { | for (dsize_t i = 0; i < min_ind; i++) { | ||||
| @@ -57,7 +57,7 @@ class RandomCropOp : public TensorOp { | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | ||||
| // Function breaks out the compute function's image padding functionality and makes available to other Ops | // Function breaks out the compute function's image padding functionality and makes available to other Ops | ||||
| // Using this class as a base - restructrued to allow for RandomCropWithBBox Augmentation Op | |||||
| // Using this class as a base - re-structured to allow for RandomCropWithBBox Augmentation Op | |||||
| // @param input: Input is the original Image | // @param input: Input is the original Image | ||||
| // @param pad_image: Pointer to new Padded image | // @param pad_image: Pointer to new Padded image | ||||
| // @param t_pad_top: Total Top Padding - Based on input and value calculated in function if required | // @param t_pad_top: Total Top Padding - Based on input and value calculated in function if required | ||||
| @@ -570,7 +570,7 @@ class WeightedRandomSampler(BuiltinSampler): | |||||
| Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). | Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). | ||||
| Args: | Args: | ||||
| weights (list[float]): A sequence of weights, not necessarily summing up to 1. | |||||
| weights (list[float, int]): A sequence of weights, not necessarily summing up to 1. | |||||
| num_samples (int, optional): Number of elements to sample (default=None, all elements). | num_samples (int, optional): Number of elements to sample (default=None, all elements). | ||||
| replacement (bool): If True, put the sample ID back for the next draw (default=True). | replacement (bool): If True, put the sample ID back for the next draw (default=True). | ||||
| @@ -17,6 +17,7 @@ SET(DE_UT_SRCS | |||||
| c_api_dataset_coco_test.cc | c_api_dataset_coco_test.cc | ||||
| c_api_dataset_config_test.cc | c_api_dataset_config_test.cc | ||||
| c_api_dataset_csv_test.cc | c_api_dataset_csv_test.cc | ||||
| c_api_dataset_ir_node_test.cc | |||||
| c_api_dataset_iterator_test.cc | c_api_dataset_iterator_test.cc | ||||
| c_api_dataset_manifest_test.cc | c_api_dataset_manifest_test.cc | ||||
| c_api_dataset_minddata_test.cc | c_api_dataset_minddata_test.cc | ||||
| @@ -0,0 +1,142 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "minddata/dataset/core/client.h" | |||||
| #include "common/common.h" | |||||
| #include "gtest/gtest.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| #include "minddata/dataset/engine/opt/pre/getter_pass.h" | |||||
| using namespace mindspore::dataset; | |||||
| using mindspore::LogStream; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| class MindDataTestIRNodes : public UT::DatasetOpTesting { | |||||
| public: | |||||
| MindDataTestIRNodes() = default; | |||||
| void SetUp() override { GlobalInit(); } | |||||
| // compare the ptr of the nodes in two trees, used to test the deep copy of nodes, will return error code | |||||
| // if (ptr1 == ptr2) does not equal to flag or the two tree has different structures (or node names are not the same) | |||||
| Status CompareTwoTrees(std::shared_ptr<DatasetNode> root1, std::shared_ptr<DatasetNode> root2, bool flag) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(root1 != nullptr && root2 != nullptr, "Error in Compare, nullptr."); | |||||
| if (((root1.get() == root2.get()) != flag) || (root1->Name() != root2->Name())) { | |||||
| std::string err_msg = | |||||
| "Expect node ptr " + root1->Name() + (flag ? "==" : "!=") + root2->Name() + " but they aren't!"; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| size_t num_child = root1->Children().size(); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_child == root2->Children().size(), | |||||
| root1->Name() + " has " + std::to_string(num_child) + "child, node #2 has " + | |||||
| std::to_string(root2->Children().size()) + " child."); | |||||
| for (size_t ind = 0; ind < num_child; ind++) { | |||||
| RETURN_IF_NOT_OK(CompareTwoTrees(root1->Children()[ind], root2->Children()[ind], flag)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // print the node's name in post order | |||||
| Status PostOrderPrintTree(std::shared_ptr<DatasetNode> ir, std::string &names) { | |||||
| RETURN_UNEXPECTED_IF_NULL(ir); | |||||
| for (auto child : ir->Children()) { | |||||
| RETURN_IF_NOT_OK(PostOrderPrintTree(child, names)); | |||||
| } | |||||
| names += (ir->Name() + "->"); | |||||
| return Status::OK(); | |||||
| } | |||||
| }; | |||||
| TEST_F(MindDataTestIRNodes, MindDataTestSimpleDeepCopy) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestSimpleDeepCopy."; | |||||
| auto tree1 = RandomData(44)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2)->IRNode(); | |||||
| auto tree2 = tree1->DeepCopy(); | |||||
| std::string tree_1_names, tree_2_names; | |||||
| ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names)); | |||||
| ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names)); | |||||
| // expected output for the 2 names: | |||||
| // RandomDataset->Repeat->Project->Shuffle->Batch-> | |||||
| EXPECT_EQ(tree_1_names, tree_2_names); | |||||
| ASSERT_OK(CompareTwoTrees(tree1, tree1, true)); | |||||
| ASSERT_OK(CompareTwoTrees(tree1, tree2, false)); | |||||
| // verify compare function is correct | |||||
| EXPECT_TRUE(CompareTwoTrees(tree2, tree2, false).IsError()); | |||||
| } | |||||
| TEST_F(MindDataTestIRNodes, MindDataTestZipDeepCopy) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestZipDeepCopy."; | |||||
| auto branch1 = RandomData(44)->Project({"label"}); | |||||
| auto branch2 = RandomData(44)->Shuffle(10); | |||||
| auto tree1 = Zip({branch1, branch2})->Batch(2)->IRNode(); | |||||
| auto tree2 = tree1->DeepCopy(); | |||||
| std::string tree_1_names, tree_2_names; | |||||
| ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names)); | |||||
| ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names)); | |||||
| // expected output for the 2 names: | |||||
| // RandomDataset->Project->RandomDataset->Shuffle->Zip->Batch-> | |||||
| EXPECT_EQ(tree_1_names, tree_2_names); | |||||
| // verify the pointer within the same tree are the same | |||||
| ASSERT_OK(CompareTwoTrees(tree1, tree1, true)); | |||||
| // verify two trees | |||||
| ASSERT_OK(CompareTwoTrees(tree1, tree2, false)); | |||||
| } | |||||
| TEST_F(MindDataTestIRNodes, MindDataTestNodeRemove) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestNodeRemove."; | |||||
| auto branch1 = RandomData(44)->Project({"label"}); | |||||
| auto branch2 = ImageFolder("path"); | |||||
| auto tree = Zip({branch1, branch2})->IRNode(); | |||||
| /*** | |||||
| tree looks like this, we will remove node and test its functionalities | |||||
| Zip | |||||
| / \ | |||||
| Project ImageFolder | |||||
| / | |||||
| RandomData | |||||
| ***/ | |||||
| auto tree_copy_1 = tree->DeepCopy(); | |||||
| ASSERT_EQ(tree_copy_1->Children().size(), 2); | |||||
| // remove the project in the tree and test | |||||
| ASSERT_OK(tree_copy_1->Children()[0]->Remove()); // remove Project from tree | |||||
| ASSERT_OK(CompareTwoTrees(tree_copy_1, Zip({RandomData(44), ImageFolder("path")})->IRNode(), false)); | |||||
| // remove the ImageFolder, a leaf node from the tree | |||||
| std::string tree_1_names, tree_2_names; | |||||
| ASSERT_OK(PostOrderPrintTree(tree_copy_1, tree_1_names)); | |||||
| EXPECT_EQ(tree_1_names, "RandomDataset->ImageFolderDataset->Zip->"); | |||||
| auto tree_copy_2 = tree->DeepCopy(); | |||||
| ASSERT_EQ(tree_copy_2->Children().size(), 2); | |||||
| tree_copy_2->Children()[1]->Remove(); | |||||
| ASSERT_OK(PostOrderPrintTree(tree_copy_2, tree_2_names)); | |||||
| EXPECT_EQ(tree_2_names, "RandomDataset->Project->Zip->"); | |||||
| } | |||||