From: @nsyca Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -258,10 +258,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// \return The number of required repeats for the operator | |||
| int32_t op_total_repeats() { return op_total_repeats_; } | |||
| /// \brief Getter function | |||
| /// \return The number of required epochs for the operator | |||
| int32_t op_total_epochs() { return op_total_repeats_ / op_num_repeats_per_epoch_; } | |||
| /// \brief Getter function | |||
| /// \return The number of repeats per epoch for the operator | |||
| int32_t op_num_repeats_per_epoch() const { return op_num_repeats_per_epoch_; } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -17,10 +17,8 @@ | |||
| #include <iostream> | |||
| #include <utility> | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/db_connector.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/log_adapter.h" | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -304,6 +304,10 @@ class CsvOp : public ParallelOp { | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *const modified) override; | |||
| // Op name getter | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "CsvOp"; } | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -16,20 +16,9 @@ | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <limits> | |||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | |||
| #include "minddata/dataset/engine/datasetops/shuffle_op.h" | |||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/removal_pass.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | |||
| #include "minddata/dataset/engine/opt/post/repeat_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/cache_error_pass.h" | |||
| #include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" | |||
| #include "minddata/dataset/engine/perf/profiling.h" | |||
| #include "minddata/dataset/engine/perf/monitor.h" | |||
| #if defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE) | |||
| @@ -255,97 +244,13 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui | |||
| return Status::OK(); | |||
| } | |||
| // The driver of the prepare phase of the execution tree. | |||
| // Prepare phase consists of three sub phases | |||
| // | |||
| // 1. PreAction() | |||
| // Compulsory transformation/action pre optimization. | |||
| // For example, CacheOp Insertion | |||
| // | |||
| // 2. Optimize() | |||
| // Optimization transformation/action, optional | |||
| // For example, MapOp Fusion | |||
| // | |||
| // 3. PostAction() | |||
| // Compulsory transformation/action post optimization. | |||
| // For example, repeatOp inlining | |||
| // | |||
| // @return Status The status code returned | |||
| Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) { | |||
| num_epochs_ = num_epochs; | |||
| partially_prepare_ = partial; | |||
| // Pre optimization compulsory transformation | |||
| RETURN_IF_NOT_OK(this->PreAction()); | |||
| // Post optimization compulsory transformation | |||
| RETURN_IF_NOT_OK(this->PostAction()); | |||
| // The tree is ready to be prepared. | |||
| tree_state_ = kDeTStatePrepare; | |||
| // Existing transformation implementation, will be removed later | |||
| RETURN_IF_NOT_OK(this->PrepareDeprecated()); | |||
| return Status::OK(); | |||
| } | |||
| Status ExecutionTree::PreAction() { | |||
| bool modified = false; | |||
| std::vector<std::unique_ptr<Pass>> pre_actions; | |||
| // Construct pre actions | |||
| if (!partially_prepare_) { | |||
| #ifndef ENABLE_ANDROID | |||
| pre_actions.push_back(std::make_unique<CacheErrorPass>()); | |||
| #endif | |||
| pre_actions.push_back(std::make_unique<EpochInjectionPass>()); | |||
| pre_actions.push_back(std::make_unique<RemovalPass>()); | |||
| } | |||
| MS_LOG(INFO) << "Running " << pre_actions.size() << " pre pass loops."; | |||
| // Apply pre action passes | |||
| for (auto &pass : pre_actions) { | |||
| RETURN_IF_NOT_OK(pass->Run(this, &modified)); | |||
| } | |||
| MS_LOG(INFO) << "Pre passes complete."; | |||
| return Status::OK(); | |||
| } | |||
| Status ExecutionTree::PostAction() { | |||
| bool modified = false; | |||
| OptPass post_actions; | |||
| // Construct pre actions | |||
| MS_LOG(INFO) << "Running post pass loops."; | |||
| #ifndef ENABLE_ANDROID | |||
| // Calling CacheErrorPass again. This is a temporary fix until the TensorOperation is properly done in Pybind. | |||
| // The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API. | |||
| // This is because Python API binding to TensorOperation is still in progress. | |||
| post_actions.push_back(std::make_unique<CacheErrorPass>()); | |||
| post_actions.push_back(std::make_unique<RepeatPass>()); | |||
| #endif | |||
| // Apply post action passes | |||
| for (auto &pass : post_actions) { | |||
| RETURN_IF_NOT_OK(pass->Run(this, &modified)); | |||
| } | |||
| MS_LOG(INFO) << "Post passes complete."; | |||
| return Status::OK(); | |||
| } | |||
| // The driver of the prepare phase of the execution tree. The prepare phase will recursively | |||
| // walk the tree to perform modifications to the tree or specific nodes within the tree to get | |||
| // it ready for execution. | |||
| // | |||
| // This driver is deprecated. | |||
| Status ExecutionTree::PrepareDeprecated() { | |||
| // Tree must be in pending prepare state before we can assign root to it | |||
| if (tree_state_ != kDeTStatePrepare) { | |||
| std::string err_msg = | |||
| "Invalid tree state for preparing the tree. Current state: " + std::to_string(static_cast<int>(tree_state_)) + | |||
| " Expected state: " + std::to_string(static_cast<int>(kDeTStatePrepare)); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| Status ExecutionTree::Prepare() { | |||
| // The tree is ready to be prepared. | |||
| tree_state_ = kDeTStatePrepare; | |||
| if (root_ == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("Please assign one operator as the root of this tree."); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -169,24 +169,6 @@ class ExecutionTree { | |||
| // @return the prepare flags | |||
| uint32_t PrepareFlags() const { return prepare_flags_; } | |||
| // The driver of the prepare phase of the execution tree. | |||
| // Prepare phase consists of three sub phases | |||
| // | |||
| // 1. PreAction() | |||
| // Compulsory transformation/action pre optimization. | |||
| // For example, CacheOp Insertion | |||
| // | |||
| // 2. Optimize() | |||
| // Optimization transformation/action, optional | |||
| // For example, MapOp Fusion | |||
| // | |||
| // 3. PostAction() | |||
| // Compulsory transformation/action post optimization. | |||
| // For example, repeatOp inlining | |||
| // | |||
| // @return Status The status code returned | |||
| Status Prepare(int num_epochs = -1, bool partial = false); | |||
| // Compulsory transformation/action pre optimization. | |||
| // @return Status The status code returned | |||
| Status PreAction(); | |||
| @@ -200,7 +182,7 @@ class ExecutionTree { | |||
| // it ready for execution. | |||
| // @param Total number of epochs that will be run on this tree | |||
| // @return Status The status code returned | |||
| Status PrepareDeprecated(); | |||
| Status Prepare(); | |||
| // Recursive function used during prepare phase to visit a node and drive any pre- and post- | |||
| // node actions during a tree walk. | |||
| @@ -239,10 +221,6 @@ class ExecutionTree { | |||
| // Getter for profiling manager, no ownership | |||
| ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); } | |||
| // Getter function to get the total number of epochs to be run on this tree. | |||
| // @return total number of epochs | |||
| int32_t num_epochs() { return num_epochs_; } | |||
| private: | |||
| // A helper functions for doing the recursive printing | |||
| // @param dataset_op - The dataset op to print | |||
| @@ -257,9 +235,7 @@ class ExecutionTree { | |||
| int32_t id_count_; // Counter for generating operator id's | |||
| uint32_t prepare_flags_; // Flags used during tree prepare | |||
| TreeState tree_state_; // Tracking the current tree state | |||
| int32_t num_epochs_; // Total number of epochs to run for this tree | |||
| std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager | |||
| bool partially_prepare_; // Temp: during migration to IR, if true, run remaining passes. | |||
| #if defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE) | |||
| // This rank_id is for numa and device_queue, one process work with only one rank_id, | |||
| // for standalone scenario, this rank_id may come from env 'CUDA_VISIBLE_DEVICES', | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -102,9 +102,11 @@ Status BatchNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||
| node_ops->push_back(project_op); | |||
| } | |||
| node_ops->push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, | |||
| in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, | |||
| pad_map_)); | |||
| auto op = std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, | |||
| in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, pad_map_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| #else | |||
| node_ops->push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, | |||
| in_col_names_, pad_map_)); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -84,9 +84,12 @@ void BucketBatchByLengthNode::Print(std::ostream &out) const { | |||
| Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| bucket_boundaries_.insert(bucket_boundaries_.begin(), 0); | |||
| node_ops->push_back(std::make_shared<BucketBatchByLengthOp>( | |||
| column_names_, bucket_boundaries_, bucket_batch_sizes_, element_length_function_, pad_info_, | |||
| pad_to_bucket_boundary_, drop_remainder_, connector_que_size_)); | |||
| auto op = std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_, | |||
| element_length_function_, pad_info_, pad_to_bucket_boundary_, | |||
| drop_remainder_, connector_que_size_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| if (bucket_boundaries_[0] == 0) { | |||
| bucket_boundaries_.erase(bucket_boundaries_.begin()); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -55,10 +55,11 @@ void BuildSentenceVocabNode::Print(std::ostream &out) const { | |||
| // Function to build BuildSentenceVocabNode | |||
| Status BuildSentenceVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| std::shared_ptr<BuildSentencePieceVocabOp> build_sentence_piece_vocab_op; | |||
| build_sentence_piece_vocab_op = std::make_shared<BuildSentencePieceVocabOp>( | |||
| vocab_, col_names_, vocab_size_, character_coverage_, model_type_, params_, connector_que_size_); | |||
| node_ops->push_back(build_sentence_piece_vocab_op); | |||
| auto op = std::make_shared<BuildSentencePieceVocabOp>(vocab_, col_names_, vocab_size_, character_coverage_, | |||
| model_type_, params_, connector_que_size_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -54,6 +54,8 @@ Status BuildVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node | |||
| std::shared_ptr<BuildVocabOp> build_vocab_op; | |||
| build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_, | |||
| special_first_, num_workers_, connector_que_size_); | |||
| build_vocab_op->set_total_repeats(GetTotalRepeats()); | |||
| build_vocab_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(build_vocab_op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -51,10 +51,24 @@ Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) | |||
| "Internal error. Attempt to create a cache lookup node without cache client."); | |||
| RETURN_IF_NOT_OK(cache_->Build()); | |||
| RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_)); | |||
| lookup_op_->set_total_repeats(GetTotalRepeats()); | |||
| lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(lookup_op_); | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status CacheLookupNode::Accept(IRNodePass *const p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<CacheLookupNode>(), modified); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status CacheLookupNode::AcceptAfter(IRNodePass *const p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<CacheLookupNode>(), modified); | |||
| } | |||
| std::shared_ptr<SamplerObj> CacheLookupNode::SamplerCopy() { | |||
| // CacheLookupNode should already been copied, so we just return it here | |||
| return std::static_pointer_cast<SamplerObj>(lookup_node_copy_); | |||
| @@ -64,6 +64,18 @@ class CacheLookupNode : public DatasetNode, public SamplerObj { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(IRNodePass *const p, bool *const modified) override; | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | |||
| private: | |||
| std::shared_ptr<SamplerObj> sampler_; | |||
| std::shared_ptr<DatasetOp> lookup_op_; | |||
| @@ -48,9 +48,23 @@ Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) | |||
| RETURN_IF_NOT_OK(cache_->Build()); | |||
| std::shared_ptr<DatasetOp> merge_op = nullptr; | |||
| RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op)); | |||
| merge_op->set_total_repeats(GetTotalRepeats()); | |||
| merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(merge_op); | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status CacheMergeNode::Accept(IRNodePass *const p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<CacheMergeNode>(), modified); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status CacheMergeNode::AcceptAfter(IRNodePass *const p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<CacheMergeNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -53,6 +53,18 @@ class CacheMergeNode : public DatasetNode { | |||
| /// \brief Parameters validation | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(IRNodePass *const p, bool *const modified) override; | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -53,9 +53,23 @@ Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op)); | |||
| cache_op->SetSampler(sampler_->SamplerBuild()); | |||
| cache_op->set_total_repeats(GetTotalRepeats()); | |||
| cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(cache_op); | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status CacheNode::Accept(IRNodePass *const p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<CacheNode>(), modified); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status CacheNode::AcceptAfter(IRNodePass *const p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<CacheNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -55,6 +55,18 @@ class CacheNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(IRNodePass *const p, bool *const modified) override; | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | |||
| private: | |||
| std::shared_ptr<SamplerObj> sampler_; | |||
| }; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -119,12 +119,16 @@ Status ConcatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size | |||
| } | |||
| Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| std::shared_ptr<ConcatOp> op; | |||
| if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) { | |||
| node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_)); | |||
| op = std::make_shared<ConcatOp>(connector_que_size_); | |||
| } else { | |||
| node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->SamplerBuild(), | |||
| children_flag_and_nums_, children_start_end_index_)); | |||
| op = std::make_shared<ConcatOp>(connector_que_size_, sampler_->SamplerBuild(), children_flag_and_nums_, | |||
| children_start_end_index_); | |||
| } | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -219,7 +219,9 @@ DatasetNode::DatasetNode() | |||
| dataset_size_(-1), | |||
| mappable_(kNotADataSource), | |||
| nary_op_(false), | |||
| descendant_of_cache_(false) { | |||
| descendant_of_cache_(false), | |||
| total_repeats_(-1), | |||
| num_epochs_(1) { | |||
| // Fetch some default value from config manager | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| num_workers_ = cfg->num_parallel_workers(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -27,6 +27,7 @@ | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | |||
| #include "minddata/dataset/engine/datasetops/filter_op.h" | |||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | |||
| #include "minddata/dataset/engine/datasetops/project_op.h" | |||
| @@ -292,6 +293,24 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| /// \return Status of the function | |||
| virtual Status to_json(nlohmann::json *out_json); | |||
| /// \brief Setter function, set the number of total repeats for the operator | |||
| void SetTotalRepeats(int32_t total_repeats) { total_repeats_ = total_repeats; } | |||
| /// \brief Setter function, set the number of epochs for the operator | |||
| void SetNumEpochs(int32_t num_epochs) { num_epochs_ = num_epochs; } | |||
| /// \brief Getter function | |||
| /// \return The number of required repeats for the operator | |||
| int32_t GetTotalRepeats() const { return total_repeats_; } | |||
| /// \brief Getter function | |||
| /// \return The number of epochs for the operator | |||
| int32_t GetNumEpochs() const { return num_epochs_; } | |||
| /// \brief Getter function | |||
| /// \return The number of repeats per epoch for the operator | |||
| int32_t GetNumRepeatsPerEpoch() const { return total_repeats_ / num_epochs_; } | |||
| protected: | |||
| std::vector<std::shared_ptr<DatasetNode>> children_; | |||
| DatasetNode *parent_; // used to record the only one parent of an IR node after parsing phase | |||
| @@ -301,6 +320,8 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| int32_t rows_per_buffer_; | |||
| int32_t connector_que_size_; | |||
| int32_t worker_connector_size_; | |||
| int32_t total_repeats_; // Number of times required to run this operator | |||
| int32_t num_epochs_; // Number of epochs | |||
| // Establish a parent-child relationship between this node and the input node. | |||
| // Used only in the constructor of the class and its derived classes. | |||
| void AddChild(std::shared_ptr<DatasetNode> child); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -44,6 +44,8 @@ void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" + | |||
| // Function to build the EpochCtrlOp | |||
| Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| auto new_op_ = std::make_shared<EpochCtrlOp>(repeat_count_); | |||
| new_op_->set_total_repeats(GetTotalRepeats()); | |||
| new_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(new_op_); | |||
| op_ = new_op_; | |||
| return Status::OK(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -44,7 +44,10 @@ void FilterNode::Print(std::ostream &out) const { | |||
| } | |||
| Status FilterNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| node_ops->push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_)); | |||
| auto op = std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -38,7 +38,8 @@ MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr | |||
| output_columns_(output_columns), | |||
| project_columns_(project_columns), | |||
| DatasetNode(std::move(cache)), | |||
| callbacks_(callbacks) { | |||
| callbacks_(callbacks), | |||
| under_a_cache_(false) { | |||
| this->AddChild(child); | |||
| } | |||
| @@ -64,6 +65,17 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| operations_.begin(), operations_.end(), std::back_inserter(tensor_ops), | |||
| [](std::shared_ptr<TensorOperation> operation) -> std::shared_ptr<TensorOp> { return operation->Build(); }); | |||
| // This is temporary code. | |||
| // Because the randomness of its tensor operations is not known in TensorOperation form until we convert them | |||
| // to TensorOp, we need to check the randomness here. | |||
| // When TensorOperation captures the randomness behaviour, remove this code and the member "under_a_cache_" | |||
| // and the temporary code in CacheValidation pre pass in IR optimizer. | |||
| if (under_a_cache_) { | |||
| auto itr = std::find_if(tensor_ops.begin(), tensor_ops.end(), [](const auto &it) { return !it->Deterministic(); }); | |||
| if (itr != tensor_ops.end()) { | |||
| RETURN_STATUS_UNEXPECTED("MapNode containing random operation is not supported as a descendant of cache."); | |||
| } | |||
| } | |||
| // This parameter will be removed with next rebase | |||
| std::vector<std::string> col_orders; | |||
| auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_); | |||
| @@ -74,9 +86,12 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| if (!project_columns_.empty()) { | |||
| auto project_op = std::make_shared<ProjectOp>(project_columns_); | |||
| project_op->set_total_repeats(GetTotalRepeats()); | |||
| project_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(project_op); | |||
| } | |||
| map_op->set_total_repeats(GetTotalRepeats()); | |||
| map_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(map_op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -79,6 +79,9 @@ class MapNode : public DatasetNode { | |||
| /// \brief setter to set all tensor operations | |||
| void setOperations(const std::vector<std::shared_ptr<TensorOperation>> &operations); | |||
| /// \brief indicate this Map will be cached | |||
| void Cached() { under_a_cache_ = true; } | |||
| /// \brief Getter functions | |||
| /// \brief Getter of tensor operations | |||
| /// \return Vector of operations the Map node will process | |||
| @@ -95,12 +98,11 @@ class MapNode : public DatasetNode { | |||
| private: | |||
| std::vector<std::shared_ptr<TensorOperation>> operations_; | |||
| private: | |||
| std::vector<std::string> input_columns_; | |||
| std::vector<std::string> output_columns_; | |||
| std::vector<std::string> project_columns_; | |||
| std::vector<std::shared_ptr<DSCallback>> callbacks_; | |||
| bool under_a_cache_; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -53,7 +53,10 @@ Status ProjectNode::ValidateParams() { | |||
| } | |||
| Status ProjectNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| node_ops->push_back(std::make_shared<ProjectOp>(columns_)); | |||
| auto op = std::make_shared<ProjectOp>(columns_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -58,7 +58,10 @@ Status RenameNode::ValidateParams() { | |||
| } | |||
| Status RenameNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| node_ops->push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_)); | |||
| auto op = std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -40,6 +40,8 @@ void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + st | |||
| Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| auto new_op = std::make_shared<RepeatOp>(repeat_count_); | |||
| new_op->set_total_repeats(GetTotalRepeats()); | |||
| new_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(new_op); | |||
| op_ = new_op; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -44,8 +44,11 @@ void ShuffleNode::Print(std::ostream &out) const { | |||
| // Function to build the ShuffleOp | |||
| Status ShuffleNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| node_ops->push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_, | |||
| rows_per_buffer_)); | |||
| auto op = std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_, | |||
| rows_per_buffer_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -39,7 +39,10 @@ void SkipNode::Print(std::ostream &out) const { out << Name() + "(skip_count:" + | |||
| // Function to build the SkipOp | |||
| Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| node_ops->push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_)); | |||
| auto op = std::make_shared<SkipOp>(skip_count_, connector_que_size_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -72,9 +72,11 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||
| // Argument that is not exposed to user in the API. | |||
| std::set<std::string> extensions = {}; | |||
| node_ops->push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||
| decode_, extensions, std::move(schema), | |||
| std::move(sampler_->SamplerBuild()))); | |||
| auto album_op = std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_, | |||
| extensions, std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| album_op->set_total_repeats(GetTotalRepeats()); | |||
| album_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(album_op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -67,9 +67,12 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops | |||
| // label is like this:0 1 0 0 1...... | |||
| RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||
| node_ops->push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||
| decode_, usage_, extensions_, std::move(schema), | |||
| std::move(sampler_->SamplerBuild()))); | |||
| auto celeba_op = | |||
| std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_, usage_, | |||
| extensions_, std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| celeba_op->set_total_repeats(GetTotalRepeats()); | |||
| celeba_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(celeba_op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -64,9 +64,12 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, | |||
| dataset_dir_, connector_que_size_, std::move(schema), | |||
| std::move(sampler_->SamplerBuild()))); | |||
| auto cifar_op = | |||
| std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, dataset_dir_, | |||
| connector_que_size_, std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| cifar_op->set_total_repeats(GetTotalRepeats()); | |||
| cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(cifar_op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -62,9 +62,12 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, | |||
| dataset_dir_, connector_que_size_, std::move(schema), | |||
| std::move(sampler_->SamplerBuild()))); | |||
| auto cifar_op = | |||
| std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, dataset_dir_, | |||
| connector_que_size_, std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| cifar_op->set_total_repeats(GetTotalRepeats()); | |||
| cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(cifar_op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -193,9 +193,12 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||
| // Add the shuffle op after this op | |||
| RETURN_IF_NOT_OK(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op)); | |||
| shuffle_op->set_total_repeats(GetTotalRepeats()); | |||
| shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(shuffle_op); | |||
| } | |||
| clue_op->set_total_repeats(GetTotalRepeats()); | |||
| clue_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(clue_op); | |||
| return Status::OK(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -122,7 +122,8 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||
| std::shared_ptr<CocoOp> op = | |||
| std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, | |||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -130,10 +130,12 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| // Add the shuffle op after this op | |||
| RETURN_IF_NOT_OK(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op)); | |||
| shuffle_op->set_total_repeats(GetTotalRepeats()); | |||
| shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(shuffle_op); | |||
| } | |||
| csv_op->set_total_repeats(GetTotalRepeats()); | |||
| csv_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(csv_op); | |||
| return Status::OK(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -91,7 +91,8 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ | |||
| if (reset_ancestor_ != nullptr) { | |||
| reset_ancestor_->op_->AddToEoeList(op); | |||
| } | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -70,9 +70,12 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| node_ops->push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||
| recursive_, decode_, exts_, class_indexing_, std::move(schema), | |||
| std::move(sampler_->SamplerBuild()))); | |||
| auto op = std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||
| recursive_, decode_, exts_, class_indexing_, std::move(schema), | |||
| std::move(sampler_->SamplerBuild())); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -94,7 +94,8 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| manifest_op = | |||
| std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, | |||
| class_index_, std::move(schema), std::move(sampler_->SamplerBuild()), usage_); | |||
| manifest_op->set_total_repeats(GetTotalRepeats()); | |||
| manifest_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(manifest_op); | |||
| return Status::OK(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -169,6 +169,8 @@ Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| } | |||
| RETURN_IF_NOT_OK(mindrecord_op->Init()); | |||
| mindrecord_op->set_total_repeats(GetTotalRepeats()); | |||
| mindrecord_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(mindrecord_op); | |||
| return Status::OK(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -58,9 +58,11 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| node_ops->push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, | |||
| connector_que_size_, std::move(schema), | |||
| std::move(sampler_->SamplerBuild()))); | |||
| auto op = std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||
| std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -109,7 +109,8 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops | |||
| std::shared_ptr<RandomDataOp> op; | |||
| op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, | |||
| std::move(data_schema_)); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -98,9 +98,12 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| // Add the shuffle op after this op | |||
| RETURN_IF_NOT_OK(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op)); | |||
| shuffle_op->set_total_repeats(GetTotalRepeats()); | |||
| shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(shuffle_op); | |||
| } | |||
| text_file_op->set_total_repeats(GetTotalRepeats()); | |||
| text_file_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| // Add TextFileOp | |||
| node_ops->push_back(text_file_op); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -140,9 +140,12 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| // Add the shuffle op after this op | |||
| RETURN_IF_NOT_OK(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op)); | |||
| shuffle_op->set_total_repeats(GetTotalRepeats()); | |||
| shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(shuffle_op); | |||
| } | |||
| tf_reader_op->set_total_repeats(GetTotalRepeats()); | |||
| tf_reader_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| // Add TFReaderOp | |||
| node_ops->push_back(tf_reader_op); | |||
| return Status::OK(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -113,7 +113,8 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| voc_op = | |||
| std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, | |||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| voc_op->set_total_repeats(GetTotalRepeats()); | |||
| voc_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(voc_op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -47,7 +47,10 @@ Status SyncWaitNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| // The reason for this is because having it otherwise can lead to blocking issues | |||
| // See barrier_op.h for more details | |||
| int32_t rows_per_buffer = 1; | |||
| node_ops->push_back(std::make_shared<BarrierOp>(rows_per_buffer, connector_que_size_, condition_name_, callback_)); | |||
| auto op = std::make_shared<BarrierOp>(rows_per_buffer, connector_que_size_, condition_name_, callback_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 20202-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -40,7 +40,10 @@ void TakeNode::Print(std::ostream &out) const { out << Name() + "(num_rows:" + s | |||
| // Function to build the TakeOp | |||
| Status TakeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| node_ops->push_back(std::make_shared<TakeOp>(take_count_, connector_que_size_)); | |||
| auto op = std::make_shared<TakeOp>(take_count_, connector_que_size_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -100,8 +100,11 @@ Status TransferNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| device_id_ = 0; | |||
| RETURN_IF_NOT_OK(this->GetShardId(&device_id_)); | |||
| node_ops->push_back(std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, | |||
| total_batch_, create_data_info_queue_)); | |||
| auto op = std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, | |||
| total_batch_, create_data_info_queue_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -58,7 +58,10 @@ Status ZipNode::ValidateParams() { | |||
| } | |||
| Status ZipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| node_ops->push_back(std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_)); | |||
| auto op = std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| @@ -2,29 +2,19 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| set(DATASET_ENGINE_OPT_SRC_FILES | |||
| optional/tensor_op_fusion_pass.cc | |||
| pass.cc | |||
| post/auto_worker_pass.cc | |||
| post/repeat_pass.cc | |||
| pre/cache_transform_pass.cc | |||
| pre/cache_validation_pass.cc | |||
| pre/deep_copy_pass.cc | |||
| pre/epoch_ctrl_pass.cc | |||
| pre/getter_pass.cc | |||
| pre/input_validation_pass.cc | |||
| pre/epoch_ctrl_pass.cc | |||
| pre/node_removal_pass.cc | |||
| ) | |||
| # This set of files is for ExecTree's optimizer. It is being migrated to IR's optimizer. | |||
| # When the migration is complete, we will remove these files. | |||
| set(DATASET_ENGINE_OPT_SRC_FILES | |||
| ${DATASET_ENGINE_OPT_SRC_FILES} | |||
| optional/tensor_op_fusion_pass.cc | |||
| pre/cache_error_pass.cc | |||
| post/repeat_pass.cc | |||
| pre/cache_transform_pass.cc | |||
| pre/epoch_injection_pass.cc | |||
| util/printer_pass.cc | |||
| pre/removal_pass.cc | |||
| ) | |||
| if(ENABLE_PYTHON) | |||
| set(DATASET_ENGINE_OPT_SRC_FILES | |||
| ${DATASET_ENGINE_OPT_SRC_FILES} | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -21,6 +21,11 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/ir/datasetops/cache_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/filter_node.h" | |||
| @@ -187,6 +192,26 @@ Status IRNodePass::Visit(std::shared_ptr<ConcatNode> node, bool *const modified) | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status IRNodePass::Visit(std::shared_ptr<CacheLookupNode> node, bool *const modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status IRNodePass::Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status IRNodePass::Visit(std::shared_ptr<CacheNode> node, bool *const modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #endif | |||
| Status IRNodePass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -30,6 +30,11 @@ namespace dataset { | |||
| class BatchNode; | |||
| class BucketBatchByLengthNode; | |||
| class BuildVocabNode; | |||
| #ifndef ENABLE_ANDROID | |||
| class CacheLookupNode; | |||
| class CacheMergeNode; | |||
| class CacheNode; | |||
| #endif | |||
| class ConcatNode; | |||
| class EpochCtrlNode; | |||
| class FilterNode; | |||
| @@ -199,6 +204,14 @@ class IRNodePass : public IRPass { | |||
| virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *const modified); | |||
| virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *const modified); | |||
| virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified); | |||
| virtual Status VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified); | |||
| virtual Status Visit(std::shared_ptr<CacheLookupNode> node, bool *const modified); | |||
| virtual Status VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified); | |||
| virtual Status Visit(std::shared_ptr<CacheNode> node, bool *const modified); | |||
| virtual Status VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified); | |||
| #endif | |||
| virtual Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified); | |||
| virtual Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified); | |||
| virtual Status Visit(std::shared_ptr<FilterNode> node, bool *const modified); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,15 +14,16 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "minddata/dataset/engine/opt/post/repeat_pass.h" | |||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | |||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | |||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/generator_op.h" | |||
| #include <memory> | |||
| #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/cache_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/transfer_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -31,10 +32,10 @@ RepeatPass::RepeatPass() | |||
| : num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {} | |||
| // Identifies the subtree below this node as being in a repeated path of the tree. | |||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) { | |||
| Status RepeatPass::Visit(std::shared_ptr<RepeatNode> node, bool *const modified) { | |||
| // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_. | |||
| // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely. | |||
| if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { | |||
| if (node->Count() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { | |||
| num_repeats_ = -num_repeats_; | |||
| } | |||
| // This RepeatOp and its descendent nodes should be repeated for another num_repeats() times. | |||
| @@ -49,14 +50,14 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modi | |||
| // num_repeats_ is originally 4, after repeat2(2), num_repeats_ becomes 8 (2*4), | |||
| // meaning repeat2 and map op should be set to read 8 times (2*4). | |||
| // Then, after repeat1(3), num_repeats_ becomes 24 (3*2*4), meaning repeat1 and tfreader op should repeat 24 times. | |||
| num_repeats_ *= node->num_repeats(); | |||
| num_repeats_ *= node->Count(); | |||
| return Status::OK(); | |||
| } | |||
| // Identifies the subtree below this node as being in a repeated path of the tree. | |||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) { | |||
| Status RepeatPass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) { | |||
| // Get the total number of epochs from the EpochCtrlOp parameter | |||
| num_epochs_ = node->num_repeats(); | |||
| num_epochs_ = node->Count(); | |||
| // Every node below this EpochCtrlOp should be repeated for num_epochs_ times. | |||
| // For example: tfreader --> epoch ctrl(3) | |||
| // num_repeats_ is originally 1 (default initialization), after this epoch ctrl(3), num_repeats_ becomes 3 (1*3), | |||
| @@ -65,115 +66,108 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const m | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| // Identifies the subtree below this node as being in a cache merge path | |||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *const modified) { | |||
| Status RepeatPass::Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) { | |||
| // Turn on the flag that we're under a merge op | |||
| is_merge_ = true; | |||
| return Status::OK(); | |||
| } | |||
| // Identifies the subtree below this node as being cached | |||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) { | |||
| Status RepeatPass::Visit(std::shared_ptr<CacheNode> node, bool *const modified) { | |||
| // Turn on the flag that we're under a merge op | |||
| is_cached_ = true; | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| // Hooks up any identified eoe nodes under this repeat. | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) { | |||
| Status RepeatPass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) { | |||
| // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up | |||
| // and set its total repeats. It is important that the op is removed from the save area, | |||
| // because the merge op above us may also take action on it later for a different case when | |||
| // there is no repeat in the merge leg. | |||
| if (is_merge_ && cache_lookup_) { | |||
| cache_lookup_->set_total_repeats(num_repeats_); | |||
| cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| cache_lookup_->SetTotalRepeats(num_repeats_); | |||
| cache_lookup_->SetNumEpochs(num_epochs_); | |||
| cache_lookup_.reset(); | |||
| } | |||
| if (is_cached_) { | |||
| AddToCachedOpStack(node); | |||
| AddToCachedNodeStack(node); | |||
| } | |||
| node->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| node->SetTotalRepeats(num_repeats_); | |||
| node->SetNumEpochs(num_epochs_); | |||
| // We finish the walk of this RepeatOp's descendent nodes. | |||
| // The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n. | |||
| // But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode, | |||
| // so we divide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp. | |||
| num_repeats_ /= node->num_repeats(); | |||
| num_repeats_ /= node->Count(); | |||
| return Status::OK(); | |||
| } | |||
| // Hooks up any identified eoe nodes under this repeat. | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) { | |||
| node->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| Status RepeatPass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) { | |||
| node->SetTotalRepeats(num_repeats_); | |||
| node->SetNumEpochs(num_epochs_); | |||
| // We finish the walk of this EpochCtrl's descendent nodes. | |||
| num_repeats_ /= node->num_repeats(); | |||
| num_repeats_ /= node->Count(); | |||
| return Status::OK(); | |||
| } | |||
| // All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up | |||
| // for use with a controlling repeat above it. | |||
| Status RepeatPass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) { | |||
| // If we are under a cache op, then save ourselves to the cached op stack. | |||
| if (is_cached_) { | |||
| AddToCachedNodeStack(node); | |||
| } | |||
| // Set total repeats and total epochs for the node | |||
| node->SetTotalRepeats(num_repeats_); | |||
| node->SetNumEpochs(num_epochs_); | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| // CacheOp removes previous leaf ops and replaces them with itself | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) { | |||
| Status RepeatPass::VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) { | |||
| is_cached_ = false; | |||
| // if we are a cache within a repeat path of the tree, then adjust the total repeats and total epochs for cached ops. | |||
| // So that those cached nodes become 1-time use (up to eoe), never repeated. Instead | |||
| // the repeating behaviours shall be invoked against the cache op. | |||
| std::shared_ptr<DatasetOp> cached_op = PopFromCachedOpStack(); | |||
| while (cached_op != nullptr) { | |||
| int32_t cached_op_total_repeats = cached_op->op_total_repeats() / num_repeats_; | |||
| cached_op->set_total_repeats(cached_op_total_repeats); | |||
| std::shared_ptr<DatasetNode> cached_node = PopFromCachedNodeStack(); | |||
| while (cached_node != nullptr) { | |||
| int32_t cached_op_total_repeats = cached_node->GetTotalRepeats() / num_repeats_; | |||
| cached_node->SetTotalRepeats(cached_op_total_repeats); | |||
| // Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1 | |||
| cached_op->set_num_repeats_per_epoch(cached_op_total_repeats); | |||
| cached_op = PopFromCachedOpStack(); | |||
| cached_node->SetNumEpochs(1); | |||
| cached_node = PopFromCachedNodeStack(); | |||
| } | |||
| node->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| return Status::OK(); | |||
| } | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) { | |||
| // If we are under a cache op, then save ourselves to the cached op stack. | |||
| if (is_cached_) { | |||
| AddToCachedOpStack(node); | |||
| } | |||
| // Set total repeats and total epochs for the node | |||
| node->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| return Status::OK(); | |||
| } | |||
| // All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up | |||
| // for use with a controlling repeat above it. | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) { | |||
| // If we are under a cache op, then save ourselves to the cached op stack. | |||
| if (is_cached_) { | |||
| AddToCachedOpStack(node); | |||
| } | |||
| // Set total repeats and total epochs for the node | |||
| node->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| node->SetTotalRepeats(num_repeats_); | |||
| node->SetNumEpochs(num_epochs_); | |||
| return Status::OK(); | |||
| } | |||
| // Turns off the tracking for operations under merge op | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *const modified) { | |||
| Status RepeatPass::VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) { | |||
| // If there was not any repeat in the merge cache miss leg, then the cache_lookup | |||
| // would not have been consumed yet. In that case, we need to set its total repeats for it. | |||
| if (cache_lookup_) { | |||
| cache_lookup_->set_total_repeats(num_repeats_); | |||
| cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| cache_lookup_->SetTotalRepeats(num_repeats_); | |||
| cache_lookup_->SetNumEpochs(num_epochs_); | |||
| } | |||
| node->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| node->SetTotalRepeats(num_repeats_); | |||
| node->SetNumEpochs(num_epochs_); | |||
| cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used | |||
| is_merge_ = false; | |||
| return Status::OK(); | |||
| } | |||
| // Saves the lookup up in case it needs to be referenced by a repeat | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *const modified) { | |||
| Status RepeatPass::VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) { | |||
| if (!node->IsLeaf()) { | |||
| // By definition, the CacheLookup must be a leaf op. Make that clear here. | |||
| RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!"); | |||
| @@ -184,29 +178,30 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *const mo | |||
| // into the pass so that the decision can be made during the processing of the cache miss leg of the merge. | |||
| // Further, if there's a repeat above the merge but no repeat in the cache miss leg, then the merge op will | |||
| // add the lookup to the eoe stack | |||
| cache_lookup_ = std::static_pointer_cast<DatasetOp>(node); | |||
| cache_lookup_ = std::static_pointer_cast<DatasetNode>(node); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) { | |||
| // Set total repeats and total epochs for the DeviceQueueOp | |||
| node->set_total_repeats(num_epochs_); | |||
| node->set_num_repeats_per_epoch(1); | |||
| Status RepeatPass::VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) { | |||
| // Set total repeats and total epochs for the TransferNode | |||
| node->SetTotalRepeats(num_epochs_); | |||
| node->SetNumEpochs(num_epochs_); | |||
| return Status::OK(); | |||
| } | |||
| // Adds an operator to the cached operator stack save area | |||
| void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); } | |||
| void RepeatPass::AddToCachedNodeStack(std::shared_ptr<DatasetNode> node) { cached_node_stacks_.push(node); } | |||
| // Pops an operator from the cached operator stack save area | |||
| std::shared_ptr<DatasetOp> RepeatPass::PopFromCachedOpStack() { | |||
| std::shared_ptr<DatasetOp> top_op = nullptr; | |||
| if (!cached_op_stacks_.empty()) { | |||
| top_op = cached_op_stacks_.top(); | |||
| cached_op_stacks_.pop(); | |||
| std::shared_ptr<DatasetNode> RepeatPass::PopFromCachedNodeStack() { | |||
| std::shared_ptr<DatasetNode> top_node = nullptr; | |||
| if (!cached_node_stacks_.empty()) { | |||
| top_node = cached_node_stacks_.top(); | |||
| cached_node_stacks_.pop(); | |||
| } | |||
| return top_op; | |||
| return top_node; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -25,12 +25,11 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \class RepeatPass repeat_pass.h | |||
| /// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references | |||
| /// to the eoe-producing (typically leaf) nodes underneath it. | |||
| class RepeatPass : public NodePass { | |||
| /// \class RepeatPass | |||
| /// \brief This is a post pass that calculate the number of repeats the pipeline needs to fetch the data. | |||
| class RepeatPass : public IRNodePass { | |||
| public: | |||
| using op_stack = std::stack<std::shared_ptr<DatasetOp>>; | |||
| using op_stack = std::stack<std::shared_ptr<DatasetNode>>; | |||
| /// \brief Constructor | |||
| RepeatPass(); | |||
| @@ -40,93 +39,91 @@ class RepeatPass : public NodePass { | |||
| /// \brief Identifies the subtree below this node as being in a repeated path of the tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) override; | |||
| Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified) override; | |||
| /// \brief Identifies the subtree below this node as being in a repeated path of the tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) override; | |||
| Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Identifies the subtree below this node as being in a cache merge path | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *const modified) override; | |||
| Status Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) override; | |||
| /// \brief Identifies the subtree below this node as being cached | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override; | |||
| Status Visit(std::shared_ptr<CacheNode> node, bool *const modified) override; | |||
| #endif | |||
| /// \brief Hooks up any identified eoe nodes under this repeat. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) override; | |||
| Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) override; | |||
| /// \brief Hooks up any identified eoe nodes under this repeat. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) override; | |||
| Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override; | |||
| /// \brief CacheOp removes previous leaf ops and replaces them with itself | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief CacheNode removes previous leaf ops and replaces them with itself | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override; | |||
| Status VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) override; | |||
| /// \brief Turns of the tracking for operations under merge op | |||
| /// \brief Turns off the tracking for operations under merge op | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *const modified) override; | |||
| Status VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) override; | |||
| /// \brief Saves the lookup up in case it needs to be referenced by a repeat | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *const modified) override; | |||
| Status VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) override; | |||
| #endif | |||
| /// \brief Set the epoch count for DeviceQueue | |||
| /// \brief Sets the epoch count for TransferNode | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) override; | |||
| /// \brief Special case for GeneratorOp | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) override; | |||
| Status VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) override; | |||
| /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up | |||
| /// for use with a controlling repeat above it. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) override; | |||
| Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) override; | |||
| private: | |||
| /// \brief Adds an operator to the cached operator stack save area | |||
| /// \param op - The dataset op to work add to cached stack | |||
| /// \brief Adds an operator to the cached stack save area | |||
| /// \param node - The dataset node to add to cached stack | |||
| /// \return Status The status code returned | |||
| void AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op); | |||
| /// \brief Pops an operator from the cached operator stack save area | |||
| /// \return shared_ptr to the popped operator | |||
| std::shared_ptr<DatasetOp> PopFromCachedOpStack(); | |||
| bool is_merge_; // T/F if we are processing under a cache merge op | |||
| bool is_cached_; // T/F is we are processing under a cache op | |||
| int32_t num_repeats_; // A multiplier to the total number of repeats | |||
| int32_t num_epochs_; // To save the total number of epochs | |||
| op_stack cached_op_stacks_; // A save area for ops under a cache op | |||
| std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op | |||
| void AddToCachedNodeStack(std::shared_ptr<DatasetNode> node); | |||
| /// \brief Pops an operator from the cached stack save area | |||
| /// \return shared_ptr to the popped dataset node | |||
| std::shared_ptr<DatasetNode> PopFromCachedNodeStack(); | |||
| bool is_merge_; // T/F if we are processing under a cache merge node | |||
| bool is_cached_; // T/F is we are processing under a cache node | |||
| int32_t num_repeats_; // A multiplier to the total number of repeats | |||
| int32_t num_epochs_; // To save the total number of epochs | |||
| op_stack cached_node_stacks_; // A save area for operators under a cache node | |||
| std::shared_ptr<DatasetNode> cache_lookup_; // A save area for a cache lookup node | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -1,189 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| #include "minddata/dataset/engine/datasetops/zip_op.h" | |||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | |||
| #include "minddata/dataset/engine/opt/pre/cache_error_pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| CacheErrorPass::CacheErrorPass() : is_cached_(false), is_mappable_(false) {} | |||
| // Identifies the subtree below this node as being cached | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) { | |||
| // Turn on the flag that we're under a merge op | |||
| is_cached_ = true; | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if ZipOp exists under a cache | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ZipOp> node, bool *const modified) { | |||
| if (is_cached_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "ZipOp is currently not supported as a descendant operator under a cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if MapOp with non-deterministic TensorOps exists under a cache | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<MapOp> node, bool *const modified) { | |||
| if (is_cached_) { | |||
| auto tfuncs = node->TFuncs(); | |||
| for (size_t i = 0; i < tfuncs.size(); i++) { | |||
| if (!tfuncs[i]->Deterministic()) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache."); | |||
| } | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if ConcatOp exists under a cache | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *const modified) { | |||
| if (is_cached_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "ConcatOp is currently not supported as a descendant operator under a cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if TakeOp exists under a cache | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<TakeOp> node, bool *const modified) { | |||
| if (is_cached_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "TakeOp/SplitOp is currently not supported as a descendant operator under a cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if SkipOp exists under a cache | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<SkipOp> node, bool *const modified) { | |||
| if (is_cached_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "SkipOp is currently not supported as a descendant operator under a cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if SkipOp exists under a cache | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<BatchOp> node, bool *const modified) { | |||
| if (is_cached_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "BatchOp is currently not supported as a descendant operator under a cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| #ifdef ENABLE_PYTHON | |||
| // Returns an error if FilterOp exists under a cache | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *const modified) { | |||
| if (is_cached_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "FilterOp is currently not supported as a descendant operator under a cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<MnistOp> node, bool *const modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<CifarOp> node, bool *const modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<CocoOp> node, bool *const modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *const modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *const modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<VOCOp> node, bool *const modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) { | |||
| // Turn off the flag that we're under a merge op | |||
| is_cached_ = false; | |||
| return Status::OK(); | |||
| } | |||
| // Currently, returns an error if RepeatOp exists under a cache | |||
| // Because there is no operator in the cache hit stream to consume eoes, caching above repeat causes problem. | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) { | |||
| if (is_cached_ && is_mappable_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "Repeat is not supported as a descendant operator under a mappable cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -1,167 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_ | |||
| #include <memory> | |||
| #include <stack> | |||
| #include <utility> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \class CacheErrorPass cache_error_pass.h | |||
| /// \brief This is a NodePass who's job is to catch invalid tree configurations related to cache and generate failures. | |||
| class CacheErrorPass : public NodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| CacheErrorPass(); | |||
| /// \brief Destructor | |||
| ~CacheErrorPass() = default; | |||
| /// \brief Identifies the subtree below this node as being cached | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override; | |||
| /// \brief Returns an error if ZipOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *const modified) override; | |||
| /// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *const modified) override; | |||
| /// \brief Returns an error if ConcatOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *const modified) override; | |||
| /// \brief Returns an error if TakeOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *const modified) override; | |||
| /// \brief Returns an error if SkipOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *const modified) override; | |||
| /// \brief Returns an error if SkipOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<BatchOp> node, bool *const modified) override; | |||
| #ifdef ENABLE_PYTHON | |||
| /// \brief Returns an error if FilterOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *const modified) override; | |||
| #endif | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<MnistOp> node, bool *const modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CifarOp> node, bool *const modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CocoOp> node, bool *const modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *const modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *const modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<VOCOp> node, bool *const modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) override; | |||
| /// \brief Identifies the subtree above this node as not being cached | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override; | |||
| /// \brief Identifies and block repeat under cache scenarios | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) override; | |||
| private: | |||
| bool is_cached_; | |||
| bool is_mappable_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_POST_CACHE_ERROR_PASS_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -25,7 +25,6 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||
| #include "minddata/dataset/include/transforms.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -114,11 +113,18 @@ Status CacheValidationPass::Visit(std::shared_ptr<MapNode> node, bool *const mod | |||
| } | |||
| // If Map is created to be cached, set the flag indicating we found an operation with a cache. | |||
| is_cached_ = true; | |||
| // This is temporary code. | |||
| // Because the randomness of its tensor operations is not known in TensorOperation form until we convert them | |||
| // to TensorOp, we need to check the randomness in MapNode::Build(). | |||
| // By setting this MapNode is under a cache, we will check the randomness of its tensor operations without the need | |||
| // to walk the IR tree again. | |||
| node->Cached(); | |||
| auto tfuncs = node->TensorOperations(); | |||
| for (size_t i = 0; i < tfuncs.size(); i++) { | |||
| if (tfuncs[i]->IsRandomOp()) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "MapNode with non-deterministic operations is not supported as a descendant of cache."); | |||
| RETURN_STATUS_UNEXPECTED("MapNode containing random operation is not supported as a descendant of cache."); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,78 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // constructor | |||
| EpochInjectionPass::InjectionFinder::InjectionFinder(std::shared_ptr<DatasetOp> node) : injection_point_(node) {} | |||
| #ifndef ENABLE_ANDROID | |||
| // Performs finder work for BuildVocabOp that has special rules about epoch control injection | |||
| Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *const modified) { | |||
| injection_point_ = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| // Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection | |||
| Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, | |||
| bool *const modified) { | |||
| injection_point_ = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) { | |||
| // Assumption: There is only one DeviceQueueOp in a pipeline. This assumption is not validated here. | |||
| injection_point_ = node->child(0); | |||
| return Status::OK(); | |||
| } | |||
| // constructor | |||
| EpochInjectionPass::EpochInjectionPass() {} | |||
| // Runs an injection pass to inject in operators needed at the pre pass stage | |||
| Status EpochInjectionPass::RunOnTree(ExecutionTree *tree, bool *const modified) { | |||
| MS_LOG(INFO) << "Pre pass: Injection pass started."; | |||
| // First, run the finder to perform any injection info before we can go ahead to drive the op injection work. | |||
| // The finder can make updates to the EpochInjectionPass object. | |||
| EpochInjectionPass::InjectionFinder finder(tree->root()); | |||
| RETURN_IF_NOT_OK(finder.Run(tree, modified)); | |||
| // The first injection logic is to check if we should inject the epoch control op as the root node. | |||
| // Do not inject the op if the number of epochs is 1. | |||
| int32_t num_epochs = tree->num_epochs(); | |||
| std::shared_ptr<DatasetOp> epoch_inject_node = finder.injection_point(); | |||
| if (num_epochs != 1 && epoch_inject_node != nullptr) { | |||
| std::shared_ptr<EpochCtrlOp> epoch_ctrl_op; | |||
| RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op)); | |||
| RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op)); | |||
| RETURN_IF_NOT_OK(epoch_inject_node->InsertAsParent(epoch_ctrl_op)); | |||
| } | |||
| MS_LOG(INFO) << "Pre pass: Injection pass complete."; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -1,88 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ | |||
| #define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class DatasetOp; | |||
| /// \class EpochInjectionPass epoch_injection_pass.h | |||
| /// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api | |||
| /// parsing. | |||
| class EpochInjectionPass : public TreePass { | |||
| /// \class InjectionFinder | |||
| /// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for | |||
| /// operators that need to be injected. It is run first by the main injection pass to find out what operators | |||
| /// it may need to inject. | |||
| class InjectionFinder : public NodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| explicit InjectionFinder(std::shared_ptr<DatasetOp> node); | |||
| /// \brief Destructor | |||
| ~InjectionFinder() = default; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *const modified) override; | |||
| /// \brief Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *const modified) override; | |||
| #endif | |||
| /// \brief Register the DeviceQueueOp for further action. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) override; | |||
| /// \brief Getter | |||
| std::shared_ptr<DatasetOp> injection_point() { return injection_point_; } | |||
| private: | |||
| std::shared_ptr<DatasetOp> injection_point_; | |||
| }; | |||
| public: | |||
| /// \brief Constructor | |||
| EpochInjectionPass(); | |||
| /// \brief Destructor | |||
| ~EpochInjectionPass() = default; | |||
| /// \brief Runs an injection pass to inject in operators needed at the pre pass stage | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The status code returned | |||
| Status RunOnTree(ExecutionTree *tree, bool *const modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/include/datasets.h" | |||
| #include "minddata/dataset/engine/opt/pre/input_validation_pass.h" | |||
| @@ -1,75 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "minddata/dataset/engine/opt/pre/removal_pass.h" | |||
| #include "minddata/dataset/engine/datasetops/shuffle_op.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| RemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {} | |||
| #ifndef ENABLE_ANDROID | |||
| // Identifies the subtree below this node as a cached descendant tree. | |||
| Status RemovalPass::RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) { | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; | |||
| is_caching_ = true; | |||
| return Status::OK(); | |||
| } | |||
| // Resets the tracking of the cache within the tree | |||
| Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) { | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; | |||
| is_caching_ = false; | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| // Perform ShuffleOp removal check. | |||
| Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *const modified) { | |||
| *modified = false; | |||
| // If we are in a cache descendant tree, then this shuffle op needs to be removed | |||
| if (is_caching_) { | |||
| MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; | |||
| nodes_to_remove_.push_back(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // constructor | |||
| RemovalPass::RemovalPass() {} | |||
| // Walk the tree to collect the nodes to remove, then removes them. | |||
| Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *const modified) { | |||
| MS_LOG(INFO) << "Pre pass: removal pass started."; | |||
| // Create the removal node pass which can identify which nodes need to be removed. | |||
| std::unique_ptr<RemovalPass::RemovalNodes> removal_nodes = std::make_unique<RemovalPass::RemovalNodes>(); | |||
| RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); | |||
| // Then, execute the removal of any nodes that were set up for removal | |||
| for (auto node : removal_nodes->nodes_to_remove()) { | |||
| RETURN_IF_NOT_OK(node->Remove()); | |||
| } | |||
| MS_LOG(INFO) << "Pre pass: removal pass complete."; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -1,90 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class DatasetOp; | |||
| /// \class RemovalPass removal_pass.h | |||
| /// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which | |||
| /// nodes should be removed, and then removes them. | |||
| class RemovalPass : public TreePass { | |||
| /// \class RemovalNodes | |||
| /// \brief This is a NodePass who's job is to identify which nodes should be removed. | |||
| /// It works in conjunction with the removal_pass. | |||
| class RemovalNodes : public NodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param[in] removal_pass Raw pointer back to controlling tree pass | |||
| RemovalNodes(); | |||
| /// \brief Destructor | |||
| ~RemovalNodes() = default; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override; | |||
| /// \brief Resets the tracking of the cache within the tree | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override; | |||
| #endif | |||
| /// \brief Perform ShuffleOp removal check | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *const modified) override; | |||
| /// \brief Getter | |||
| /// \return All the nodes to be removed | |||
| std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove() { return nodes_to_remove_; } | |||
| private: | |||
| bool is_caching_; | |||
| std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove_; | |||
| }; | |||
| public: | |||
| /// \brief Constructor | |||
| RemovalPass(); | |||
| /// \brief Destructor | |||
| ~RemovalPass() = default; | |||
| /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The status code returned | |||
| Status RunOnTree(ExecutionTree *tree, bool *const modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ | |||
| @@ -1,121 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "minddata/dataset/engine/opt/util/printer_pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting DatasetOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<BatchOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting BatchOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<MapOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting MapOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting ProjectOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<RenameOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting RenameOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<SkipOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting SkipOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting ShuffleOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting MindRecordOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting TFReaderOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<FilterOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting FilterOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting GeneratorOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<TakeOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting TakeOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<ZipOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting ZipOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting DeviceQueueOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting ImageFolderOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| Status PrinterPass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) { | |||
| *modified = false; | |||
| std::cout << "Visiting ImageFolderOp" << '\n'; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -1,68 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H | |||
| #include <memory> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class PrinterPass : public NodePass { | |||
| public: | |||
| Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) override; | |||
| Status RunOnNode(std::shared_ptr<BatchOp> node, bool *const modified) override; | |||
| Status RunOnNode(std::shared_ptr<MapOp> node, bool *const modified) override; | |||
| Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *const modified) override; | |||
| Status RunOnNode(std::shared_ptr<RenameOp> node, bool *const modified) override; | |||
| Status RunOnNode(std::shared_ptr<SkipOp> node, bool *const modified) override; | |||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *const modified) override; | |||
| #ifndef ENABLE_ANDROID | |||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) override; | |||
| Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *const modified) override; | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| Status RunOnNode(std::shared_ptr<FilterOp> node, bool *const modified) override; | |||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) override; | |||
| #endif | |||
| Status RunOnNode(std::shared_ptr<TakeOp> node, bool *const modified) override; | |||
| Status RunOnNode(std::shared_ptr<ZipOp> node, bool *const modified) override; | |||
| Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) override; | |||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) override; | |||
| Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -21,6 +21,7 @@ | |||
| #include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/opt/post/auto_worker_pass.h" | |||
| #include "minddata/dataset/engine/opt/post/repeat_pass.h" | |||
| #ifdef ENABLE_PYTHON | |||
| #include "minddata/dataset/engine/opt/post/generator_node_pass.h" | |||
| #endif | |||
| @@ -94,6 +95,7 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) { | |||
| #ifdef ENABLE_PYTHON | |||
| actions.emplace_back(std::make_unique<GeneratorNodePass>()); | |||
| #endif | |||
| actions.emplace_back(std::make_unique<RepeatPass>()); | |||
| // We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here. | |||
| @@ -133,7 +135,7 @@ Status TreeAdapter::BuildExecutionTreeRecur(std::shared_ptr<DatasetNode> ir, std | |||
| return Status::OK(); | |||
| } | |||
| Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs) { | |||
| Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir) { | |||
| // This will evolve in the long run | |||
| tree_ = std::make_unique<ExecutionTree>(); | |||
| // disable profiling if this is only a getter pass | |||
| @@ -146,7 +148,7 @@ Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epoc | |||
| // Note: We will gradually move the pre pass, optimizer pass, and post pass | |||
| // on ExecutionTree to perform on IR tree. | |||
| // Prepare the tree | |||
| RETURN_IF_NOT_OK(tree_->Prepare(num_epochs, true)); | |||
| RETURN_IF_NOT_OK(tree_->Prepare()); | |||
| // After the tree is prepared, the col_name_id_map can safely be obtained | |||
| column_name_map_ = tree_->root()->column_name_id_map(); | |||
| @@ -192,7 +194,7 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e | |||
| // Remember the root node | |||
| root_ir_ = root_ir; | |||
| RETURN_IF_NOT_OK(Build(root_ir_, num_epochs)); | |||
| RETURN_IF_NOT_OK(Build(root_ir_)); | |||
| tree_state_ = kCompileStateReady; | |||
| return Status::OK(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -83,7 +83,7 @@ class TreeAdapter { | |||
| Status PostPass(std::shared_ptr<DatasetNode> ir); | |||
| // Build an Execution tree | |||
| Status Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs); | |||
| Status Build(std::shared_ptr<DatasetNode> root_ir); | |||
| // This RECURSIVE function walks the (optimized) IR tree in DFS to build its corresponding Execution tree. | |||
| Status BuildExecutionTreeRecur(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -13,22 +13,12 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <fstream> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "common/common.h" | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/engine/datasetops/source/album_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -89,7 +79,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchema) { | |||
| std::string folder_path = datasets_root_path_ + "/testAlbum/images"; | |||
| std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json"; | |||
| std::vector<std::string> column_names = {"image", "label", "id"}; | |||
| auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file, column_names, false), Repeat(2)}); | |||
| auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file, column_names, false); | |||
| auto op2 = Repeat(2); | |||
| op1->set_total_repeats(2); | |||
| op1->set_num_repeats_per_epoch(2); | |||
| auto tree = Build({op1, op2}); | |||
| ASSERT_OK(tree->Prepare()); | |||
| ASSERT_OK(tree->Launch()); | |||
| DatasetIterator di(tree); | |||
| @@ -111,7 +105,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchema) { | |||
| TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchemaNoOrder) { | |||
| std::string folder_path = datasets_root_path_ + "/testAlbum/images"; | |||
| std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json"; | |||
| auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)}); | |||
| auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file); | |||
| auto op2 = Repeat(2); | |||
| op1->set_total_repeats(2); | |||
| op1->set_num_repeats_per_epoch(2); | |||
| auto tree = Build({op1, op2}); | |||
| ASSERT_OK(tree->Prepare()); | |||
| ASSERT_OK(tree->Launch()); | |||
| DatasetIterator di(tree); | |||
| @@ -134,7 +132,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchemaFloat) { | |||
| std::string folder_path = datasets_root_path_ + "/testAlbum/images"; | |||
| // add the priority column | |||
| std::string schema_file = datasets_root_path_ + "/testAlbum/floatSchema.json"; | |||
| auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)}); | |||
| auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file); | |||
| auto op2 = Repeat(2); | |||
| op1->set_total_repeats(2); | |||
| op1->set_num_repeats_per_epoch(2); | |||
| auto tree = Build({op1, op2}); | |||
| tree->Prepare(); | |||
| ASSERT_OK(tree->Launch()); | |||
| DatasetIterator di(tree); | |||
| @@ -159,7 +161,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithFullSchema) { | |||
| std::string folder_path = datasets_root_path_ + "/testAlbum/images"; | |||
| // add the priority column | |||
| std::string schema_file = datasets_root_path_ + "/testAlbum/fullSchema.json"; | |||
| auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)}); | |||
| auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file); | |||
| auto op2 = Repeat(2); | |||
| op1->set_total_repeats(2); | |||
| op1->set_num_repeats_per_epoch(2); | |||
| auto tree = Build({op1, op2}); | |||
| ASSERT_OK(tree->Prepare()); | |||
| ASSERT_OK(tree->Launch()); | |||
| DatasetIterator di(tree); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -13,14 +13,11 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "common/common.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "gtest/gtest.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "securec.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| @@ -112,7 +109,12 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatch) { | |||
| TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) { | |||
| std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data"; | |||
| bool success = false; | |||
| auto tree = Build({TFReader(schema_file), Repeat(2), Batch(7, true, 99)}); | |||
| auto op1 = TFReader(schema_file); | |||
| auto op2 = Repeat(2); | |||
| auto op3 = Batch(7, true, 99); | |||
| op1->set_total_repeats(2); | |||
| op1->set_num_repeats_per_epoch(2); | |||
| auto tree = Build({op1, op2, op3}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| @@ -157,7 +159,12 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) { | |||
| TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) { | |||
| std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data"; | |||
| bool success = false; | |||
| auto tree = Build({TFReader(schema_file), Repeat(2), Batch(7, false, 99)}); | |||
| auto op1 = TFReader(schema_file); | |||
| auto op2 = Repeat(2); | |||
| auto op3 = Batch(7, false, 99); | |||
| op1->set_total_repeats(2); | |||
| op1->set_num_repeats_per_epoch(2); | |||
| auto tree = Build({op1, op2, op3}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| @@ -209,7 +216,14 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) { | |||
| TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) { | |||
| std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data"; | |||
| bool success = false; | |||
| auto tree = Build({TFReader(schema_file), Batch(7, false, 99), Repeat(2)}); | |||
| auto op1 = TFReader(schema_file); | |||
| auto op2 = Batch(7, false, 99); | |||
| auto op3 = Repeat(2); | |||
| op1->set_total_repeats(2); | |||
| op1->set_num_repeats_per_epoch(2); | |||
| op2->set_total_repeats(2); | |||
| op2->set_num_repeats_per_epoch(2); | |||
| auto tree = Build({op1, op2, op3}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| @@ -255,7 +269,14 @@ TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) { | |||
| TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) { | |||
| std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data"; | |||
| bool success = false; | |||
| auto tree = Build({TFReader(schema_file), Batch(5, true, 99), Repeat(2)}); | |||
| auto op1 = TFReader(schema_file); | |||
| auto op2 = Batch(5, true, 99); | |||
| auto op3 = Repeat(2); | |||
| op1->set_total_repeats(2); | |||
| op1->set_num_repeats_per_epoch(2); | |||
| op2->set_total_repeats(2); | |||
| op2->set_num_repeats_per_epoch(2); | |||
| auto tree = Build({op1, op2, op3}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -293,15 +293,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Assign tree relations and root | |||
| myCacheOp->set_total_repeats(numRepeats); | |||
| myCacheOp->set_num_repeats_per_epoch(numRepeats); | |||
| rc = myRepeatOp->AddChild(myCacheOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Always set to 1 under a CacheOp because we read from it only once. The CacheOp is the one that repeats. | |||
| myRandomDataOp->set_total_repeats(1); | |||
| myRandomDataOp->set_num_repeats_per_epoch(1); | |||
| rc = myCacheOp->AddChild(myRandomDataOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssignRoot(myRepeatOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | |||
| rc = myTree->Prepare(1); | |||
| rc = myTree->Prepare(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // quick check to see what tree looks like | |||
| @@ -412,15 +417,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Assign tree relations and root | |||
| myCacheOp->set_total_repeats(numRepeats); | |||
| myCacheOp->set_num_repeats_per_epoch(numRepeats); | |||
| rc = myRepeatOp->AddChild(myCacheOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Always set to 1 under a CacheOp because we read from it only once. The CacheOp is the one that repeats. | |||
| myRandomDataOp->set_total_repeats(1); | |||
| myRandomDataOp->set_num_repeats_per_epoch(1); | |||
| rc = myCacheOp->AddChild(myRandomDataOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssignRoot(myRepeatOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | |||
| rc = myTree->Prepare(1); | |||
| rc = myTree->Prepare(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| std::cout << *myClient << std::endl; | |||
| @@ -502,14 +512,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { | |||
| rc = myTree->AssignRoot(myRepeatOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| myMergeOp->set_total_repeats(numRepeats); | |||
| myMergeOp->set_num_repeats_per_epoch(numRepeats); | |||
| rc = myRepeatOp->AddChild(myMergeOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| myLookupOp->set_total_repeats(numRepeats); | |||
| myLookupOp->set_num_repeats_per_epoch(numRepeats); | |||
| rc = myMergeOp->AddChild(myLookupOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| so->set_total_repeats(numRepeats); | |||
| so->set_num_repeats_per_epoch(numRepeats); | |||
| rc = myMergeOp->AddChild(so); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->Prepare(1); | |||
| rc = myTree->Prepare(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->Launch(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -13,14 +13,11 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <fstream> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "common/common.h" | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| @@ -98,7 +95,11 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) { | |||
| {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, | |||
| {0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}}; | |||
| uint32_t count = 0; | |||
| auto tree = Build({Celeba(16, 2, 32, dir), Repeat(2)}); | |||
| auto op1 = Celeba(16, 2, 32, dir); | |||
| auto op2 = Repeat(2); | |||
| auto tree = Build({op1, op2}); | |||
| op1->set_total_repeats(2); | |||
| op1->set_num_repeats_per_epoch(2); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -39,8 +39,6 @@ using mindspore::MsLogLevel::ERROR; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| std::shared_ptr<RepeatOp> Repeat(int repeatCnt); | |||
| std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops); | |||
| std::shared_ptr<CifarOp> Cifarop(uint64_t num_works, uint64_t rows, uint64_t conns, std::string path, | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -45,8 +45,6 @@ using mindspore::LogStream; | |||
| std::shared_ptr<BatchOp> Batch(int batch_size = 1, bool drop = false, int rows_per_buf = 2); | |||
| std::shared_ptr<RepeatOp> Repeat(int repeat_cnt); | |||
| std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops); | |||
| class MindDataTestCocoOp : public UT::DatasetOpTesting { | |||
| @@ -261,4 +259,4 @@ TEST_F(MindDataTestCocoOp, TestCocoPanoptic) { | |||
| } | |||
| ASSERT_EQ(row_count, 2); | |||
| } | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -13,7 +13,6 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <fstream> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <string> | |||
| @@ -29,7 +28,6 @@ | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -82,7 +80,11 @@ class MindDataTestImageFolderSampler : public UT::DatasetOpTesting { | |||
| TEST_F(MindDataTestImageFolderSampler, TestSequentialImageFolderWithRepeat) { | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data"; | |||
| auto tree = Build({ImageFolder(16, 2, 32, folder_path, false), Repeat(2)}); | |||
| auto op1 = ImageFolder(16, 2, 32, folder_path, false); | |||
| auto op2 = Repeat(2); | |||
| op1->set_total_repeats(2); | |||
| op1->set_num_repeats_per_epoch(2); | |||
| auto tree = Build({op1, op2}); | |||
| tree->Prepare(); | |||
| int32_t res[] = {0, 1, 2, 3}; | |||
| Status rc = tree->Launch(); | |||
| @@ -166,7 +168,12 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) { | |||
| TEST_F(MindDataTestImageFolderSampler, TestSequentialImageFolderWithRepeatBatch) { | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data"; | |||
| auto tree = Build({ImageFolder(16, 2, 32, folder_path, false), Repeat(2), Batch(11)}); | |||
| auto op1 = ImageFolder(16, 2, 32, folder_path, false); | |||
| auto op2 = Repeat(2); | |||
| auto op3 = Batch(11); | |||
| op1->set_total_repeats(2); | |||
| op1->set_num_repeats_per_epoch(2); | |||
| auto tree = Build({op1, op2, op3}); | |||
| tree->Prepare(); | |||
| int32_t res[4][11] = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, | |||
| {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, | |||
| @@ -297,7 +304,11 @@ TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) { | |||
| int64_t num_samples = 0; | |||
| std::shared_ptr<SamplerRT> sampler = std::make_shared<DistributedSamplerRT>(num_samples, 11, 10, false); | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data"; | |||
| auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler)), Repeat(4)}); | |||
| auto op1 = ImageFolder(16, 2, 32, folder_path, false, std::move(sampler)); | |||
| auto op2 = Repeat(4); | |||
| op1->set_total_repeats(4); | |||
| op1->set_num_repeats_per_epoch(4); | |||
| auto tree = Build({op1, op2}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -20,6 +20,7 @@ | |||
| #include "common/common.h" | |||
| #include "minddata/dataset/callback/ds_callback.h" | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | |||
| #include "minddata/dataset/engine/tree_adapter.h" | |||
| #include "minddata/dataset/include/datasets.h" | |||
| @@ -166,6 +167,10 @@ TEST_F(MindDataTestCallback, TestBasicCallback) { | |||
| std::shared_ptr<RepeatOp> repeat_op; | |||
| rc = RepeatOp::Builder(2).Build(&repeat_op); | |||
| // start build then launch tree | |||
| leaf->set_total_repeats(2); | |||
| leaf->set_num_repeats_per_epoch(2); | |||
| map_op->set_total_repeats(2); | |||
| map_op->set_num_repeats_per_epoch(2); | |||
| std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op}); | |||
| rc = tree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| @@ -213,8 +218,15 @@ TEST_F(MindDataTestCallback, TestMultiEpochCallback) { | |||
| // config RepeatOp | |||
| std::shared_ptr<RepeatOp> repeat_op; | |||
| rc = RepeatOp::Builder(2).Build(&repeat_op); | |||
| // config EpochCtrlOp | |||
| std::shared_ptr<EpochCtrlOp> epoch_ctrl_op; | |||
| rc = EpochCtrlOp::Builder(-1).Build(&epoch_ctrl_op); | |||
| // start build then launch tree | |||
| std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op}); | |||
| leaf->set_total_repeats(-2); | |||
| leaf->set_num_repeats_per_epoch(2); | |||
| map_op->set_total_repeats(-2); | |||
| map_op->set_num_repeats_per_epoch(2); | |||
| std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op, epoch_ctrl_op}); | |||
| rc = tree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = tree->Launch(); | |||
| @@ -271,8 +283,15 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) { | |||
| // config RepeatOp | |||
| std::shared_ptr<RepeatOp> repeat_op; | |||
| rc = RepeatOp::Builder(2).Build(&repeat_op); | |||
| // config EpochCtrlOp | |||
| std::shared_ptr<EpochCtrlOp> epoch_ctrl_op; | |||
| rc = EpochCtrlOp::Builder(-1).Build(&epoch_ctrl_op); | |||
| // start build then launch tree | |||
| std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op}); | |||
| leaf->set_total_repeats(-2); | |||
| leaf->set_num_repeats_per_epoch(2); | |||
| map_op->set_total_repeats(-2); | |||
| map_op->set_num_repeats_per_epoch(2); | |||
| std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op, epoch_ctrl_op}); | |||
| rc = tree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = tree->Launch(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -58,7 +58,11 @@ class MindDataTestManifest : public UT::DatasetOpTesting { | |||
| TEST_F(MindDataTestManifest, TestSequentialManifestWithRepeat) { | |||
| std::string file = datasets_root_path_ + "/testManifestData/cpp.json"; | |||
| auto tree = Build({Manifest(16, 2, 32, file), Repeat(2)}); | |||
| auto op1 = Manifest(16, 2, 32, file); | |||
| auto op2 = Repeat(2); | |||
| op1->set_total_repeats(2); | |||
| op1->set_num_repeats_per_epoch(2); | |||
| auto tree = Build({op1, op2}); | |||
| tree->Prepare(); | |||
| uint32_t res[] = {0, 1, 0, 1}; | |||
| Status rc = tree->Launch(); | |||
| @@ -148,7 +152,11 @@ TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) { | |||
| int64_t num_samples = 1; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||
| auto tree = Build({Manifest(16, 2, 32, file, "train", std::move(seq_sampler), {}), Repeat(4)}); | |||
| auto op1 = Manifest(16, 2, 32, file, "train", std::move(seq_sampler), {}); | |||
| auto op2 = Repeat(4); | |||
| op1->set_total_repeats(4); | |||
| op1->set_num_repeats_per_epoch(4); | |||
| auto tree = Build({op1, op2}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -17,7 +17,6 @@ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "common/common.h" | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| @@ -416,6 +415,8 @@ TEST_F(MindDataTestMapOp, TestTFReaderRepeatMap) { | |||
| rc = my_map_op->AddChild(my_repeat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_tfreader_op->set_total_repeats(num_repeats); | |||
| my_tfreader_op->set_num_repeats_per_epoch(num_repeats); | |||
| rc = my_repeat_op->AddChild(my_tfreader_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| @@ -471,9 +472,13 @@ TEST_F(MindDataTestMapOp, TestTFReaderMapRepeat) { | |||
| rc = my_tree_->AssociateNode(my_map_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_map_op->set_total_repeats(num_repeats); | |||
| my_map_op->set_num_repeats_per_epoch(num_repeats); | |||
| rc = my_repeat_op->AddChild(my_map_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_tfreader_op->set_total_repeats(num_repeats); | |||
| my_tfreader_op->set_num_repeats_per_epoch(num_repeats); | |||
| rc = my_map_op->AddChild(my_tfreader_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| @@ -548,9 +553,13 @@ TEST_F(MindDataTestMapOp, TFReader_Decode_Repeat_Resize) { | |||
| rc = my_tree_->AssociateNode(my_map_resize_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_tfreader_op->set_total_repeats(num_repeats); | |||
| my_tfreader_op->set_num_repeats_per_epoch(num_repeats); | |||
| rc = my_map_decode_op->AddChild(my_tfreader_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_map_decode_op->set_total_repeats(num_repeats); | |||
| my_map_decode_op->set_num_repeats_per_epoch(num_repeats); | |||
| rc = my_repeat_op->AddChild(my_map_decode_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| @@ -611,7 +620,12 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) { | |||
| rc = map_resize_builder.Build(&map_resize_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_tree_ = Build({ImageFolder(16, 2, 32, folder_path, false), map_decode_map, repeat_op, map_resize_op}); | |||
| auto image_folder_op = ImageFolder(16, 2, 32, folder_path, false); | |||
| image_folder_op->set_total_repeats(num_repeats); | |||
| image_folder_op->set_num_repeats_per_epoch(num_repeats); | |||
| map_decode_map->set_total_repeats(num_repeats); | |||
| map_decode_map->set_num_repeats_per_epoch(num_repeats); | |||
| my_tree_ = Build({image_folder_op, map_decode_map, repeat_op, map_resize_op}); | |||
| rc = my_tree_->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree_->Launch(); | |||
| @@ -656,7 +670,12 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) { | |||
| rc = map_resize_builder.Build(&map_resize_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| auto my_tree_2 = Build({ImageFolder(16, 2, 32, folder_path, false), map_decode_map, repeat_op, map_resize_op}); | |||
| image_folder_op = ImageFolder(16, 2, 32, folder_path, false); | |||
| image_folder_op->set_total_repeats(num_repeats); | |||
| image_folder_op->set_num_repeats_per_epoch(num_repeats); | |||
| map_decode_map->set_total_repeats(num_repeats); | |||
| map_decode_map->set_num_repeats_per_epoch(num_repeats); | |||
| auto my_tree_2 = Build({image_folder_op, map_decode_map, repeat_op, map_resize_op}); | |||
| rc = my_tree_2->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| @@ -714,7 +733,12 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize_NoInputColumns) { | |||
| rc = map_resize_builder.Build(&map_resize_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_tree_ = Build({ImageFolder(16, 2, 32, folder_path, false), map_decode_map, repeat_op, map_resize_op}); | |||
| auto image_folder_op = ImageFolder(16, 2, 32, folder_path, false); | |||
| image_folder_op->set_total_repeats(num_repeats); | |||
| image_folder_op->set_num_repeats_per_epoch(num_repeats); | |||
| map_decode_map->set_total_repeats(num_repeats); | |||
| map_decode_map->set_num_repeats_per_epoch(num_repeats); | |||
| my_tree_ = Build({image_folder_op, map_decode_map, repeat_op, map_resize_op}); | |||
| rc = my_tree_->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree_->Launch(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -370,9 +370,12 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordRepeat) { | |||
| rc = my_tree->AssociateNode(my_repeat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_mindrecord_op->set_total_repeats(num_repeats); | |||
| my_mindrecord_op->set_num_repeats_per_epoch(num_repeats); | |||
| rc = my_repeat_op->AddChild(my_mindrecord_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Set children/root layout. | |||
| rc = my_tree->AssignRoot(my_repeat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| @@ -452,6 +455,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) { | |||
| rc = my_tree->AssociateNode(my_repeat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_mindrecord_op->set_total_repeats(num_repeats); | |||
| my_mindrecord_op->set_num_repeats_per_epoch(num_repeats); | |||
| rc = my_repeat_op->AddChild(my_mindrecord_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -78,7 +78,11 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) { | |||
| int64_t num_samples = 10; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||
| auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2)}); | |||
| auto op1 = CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)); | |||
| auto op2 = Repeat(2); | |||
| op1->set_total_repeats(2); | |||
| op1->set_num_repeats_per_epoch(2); | |||
| auto tree = Build({op1, op2}); | |||
| tree->Prepare(); | |||
| uint32_t res[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; | |||
| Status rc = tree->Launch(); | |||
| @@ -108,7 +112,12 @@ TEST_F(MindDataTestMnistSampler, TestSequentialImageFolderWithRepeatBatch) { | |||
| int64_t num_samples = 10; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||
| auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2), Batch(5)}); | |||
| auto op1 = CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)); | |||
| auto op2 = Repeat(2); | |||
| auto op3 = Batch(5); | |||
| op1->set_total_repeats(2); | |||
| op1->set_num_repeats_per_epoch(2); | |||
| auto tree = Build({op1, op2, op3}); | |||
| tree->Prepare(); | |||
| uint32_t res[4][5] = { {0, 0, 0, 0, 0 }, | |||
| {0, 0, 0, 0, 0 }, | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -35,7 +35,7 @@ class MindDataTestRandomDataOp : public UT::DatasetOpTesting { | |||
| // Test info: | |||
| // - Simple test with a user-provided schema generated purely from DataSchema C API | |||
| // - has an interation loop | |||
| // - has an interaction loop | |||
| // | |||
| // Tree: single node tree with RandomDataOp | |||
| // | |||
| @@ -213,7 +213,7 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic3) { | |||
| // Test info: | |||
| // - json schema input it's a fairly simple one | |||
| // - has an interation loop | |||
| // - has an interaction loop | |||
| // | |||
| // Tree: RepeatOp over RandomDataOp | |||
| // | |||
| @@ -253,6 +253,8 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic4) { | |||
| rc = myTree->AssociateNode(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| myRandomDataOp->set_total_repeats(numRepeats); | |||
| myRandomDataOp->set_num_repeats_per_epoch(numRepeats); | |||
| rc = myRepeatOp->AddChild(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| @@ -290,7 +292,7 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic4) { | |||
| // Test info: | |||
| // - json schema input it's a fairly simple one | |||
| // - has an interation loop | |||
| // - has an interaction loop | |||
| // - same as MindDataTestRandomDataOpBasic4 except that this one will have parallel workers | |||
| // | |||
| // Tree: RepeatOp over RandomDataOp | |||
| @@ -331,6 +333,8 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic5) { | |||
| rc = myTree->AssociateNode(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| myRandomDataOp->set_total_repeats(numRepeats); | |||
| myRandomDataOp->set_num_repeats_per_epoch(numRepeats); | |||
| rc = myRepeatOp->AddChild(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| @@ -418,9 +422,13 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpTree1) { | |||
| rc = myTree->AssociateNode(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| myShuffleOp->set_total_repeats(numRepeats); | |||
| myShuffleOp->set_num_repeats_per_epoch(numRepeats); | |||
| rc = myRepeatOp->AddChild(myShuffleOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| myRandomDataOp->set_total_repeats(numRepeats); | |||
| myRandomDataOp->set_num_repeats_per_epoch(numRepeats); | |||
| rc = myShuffleOp->AddChild(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| @@ -75,6 +75,8 @@ TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceFromDatasetFuntions) { | |||
| rc = spv_op->AddChild(file_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| file_op->set_total_repeats(1); | |||
| file_op->set_num_repeats_per_epoch(1); | |||
| rc = tree->AssignRoot(spv_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = tree->Prepare(); | |||
| @@ -147,6 +149,8 @@ TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceTokenizerFuntions) { | |||
| rc = spv_op->AddChild(file_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| file_op->set_total_repeats(1); | |||
| file_op->set_num_repeats_per_epoch(1); | |||
| rc = tree->AssignRoot(spv_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = tree->Prepare(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -300,8 +300,12 @@ TEST_F(MindDataTestShuffleOp, TestRepeatShuffle) { | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Set children/root layout. | |||
| my_shuffle_op->set_total_repeats(numRepeats); | |||
| my_shuffle_op->set_num_repeats_per_epoch(numRepeats); | |||
| rc = my_repeat_op->AddChild(my_shuffle_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_tfreader_op->set_total_repeats(numRepeats); | |||
| my_tfreader_op->set_num_repeats_per_epoch(numRepeats); | |||
| rc = my_shuffle_op->AddChild(my_tfreader_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssignRoot(my_repeat_op); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -20,7 +20,6 @@ | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "common/common.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -330,11 +329,14 @@ TEST_F(MindDataTestTFReaderOp, TestTFReaderRepeat) { | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // RepeatOp | |||
| std::shared_ptr<RepeatOp> my_repeat_op = std::make_shared<RepeatOp>(3); | |||
| uint32_t num_repeats = 3; | |||
| std::shared_ptr<RepeatOp> my_repeat_op = std::make_shared<RepeatOp>(num_repeats); | |||
| rc = my_tree->AssociateNode(my_repeat_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Set children/root layout. | |||
| my_tfreader_op->set_total_repeats(num_repeats); | |||
| my_tfreader_op->set_num_repeats_per_epoch(num_repeats); | |||
| rc = my_repeat_op->AddChild(my_tfreader_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssignRoot(my_repeat_op); | |||
| @@ -705,7 +707,7 @@ TEST_F(MindDataTestTFReaderOp, TestTFReaderInvalidFiles) { | |||
| std::string valid_file = datasets_root_path_ + "/testTFTestAllTypes/test.data"; | |||
| std::string schema_file = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json"; | |||
| std::string invalid_file = datasets_root_path_ + "/testTFTestAllTypes/invalidFile.txt"; | |||
| std::string nonexistent_file = "this/file/doesnt/exist"; | |||
| std::string nonexistent_file = "this/file/not/exist"; | |||
| std::shared_ptr<TFReaderOp> my_tfreader_op; | |||
| TFReaderOp::Builder builder; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -45,8 +45,6 @@ using mindspore::LogStream; | |||
| std::shared_ptr<BatchOp> Batch(int batch_size = 1, bool drop = false, int rows_per_buf = 2); | |||
| std::shared_ptr<RepeatOp> Repeat(int repeat_cnt); | |||
| std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops); | |||
| class MindDataTestVOCOp : public UT::DatasetOpTesting { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -141,6 +141,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) { | |||
| MS_LOG(INFO) << "UT test TestZipRepeat."; | |||
| auto my_tree = std::make_shared<ExecutionTree>(); | |||
| uint32_t num_repeats = 3; | |||
| std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data"; | |||
| std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data"; | |||
| std::shared_ptr<TFReaderOp> my_tfreader_op; | |||
| @@ -169,17 +170,23 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) { | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssociateNode(zip_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_tfreader_op->set_total_repeats(num_repeats); | |||
| my_tfreader_op->set_num_repeats_per_epoch(num_repeats); | |||
| rc = zip_op->AddChild(std::move(my_tfreader_op)); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| my_tfreader_op2->set_total_repeats(num_repeats); | |||
| my_tfreader_op2->set_num_repeats_per_epoch(num_repeats); | |||
| rc = zip_op->AddChild(std::move(my_tfreader_op2)); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Builder(num_of_repeats) | |||
| std::shared_ptr<RepeatOp> my_repeat_op; | |||
| rc = RepeatOp::Builder(3).Build(&my_repeat_op); | |||
| rc = RepeatOp::Builder(num_repeats).Build(&my_repeat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssociateNode(my_repeat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| zip_op->set_total_repeats(num_repeats); | |||
| zip_op->set_num_repeats_per_epoch(num_repeats); | |||
| rc = my_repeat_op->AddChild(zip_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssignRoot(my_repeat_op); | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -138,7 +138,7 @@ def test_cache_map_basic3(): | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_basic4(): | |||
| """ | |||
| Test Map with non-deterministic TensorOps above cache | |||
| Test Map containing random operation above cache | |||
| repeat | |||
| | | |||
| @@ -374,7 +374,7 @@ def test_cache_map_failure4(): | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_failure5(): | |||
| """ | |||
| Test Map with non-deterministic TensorOps under cache (failure) | |||
| Test Map containing random operation under cache (failure) | |||
| repeat | |||
| | | |||
| @@ -406,7 +406,7 @@ def test_cache_map_failure5(): | |||
| num_iter = 0 | |||
| for _ in data.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value) | |||
| assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure5 Ended.\n') | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -2087,7 +2087,7 @@ def test_cache_nomap_failure4(): | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_failure5(): | |||
| """ | |||
| Test Map with non-deterministic TensorOps under cache (failure) | |||
| Test Map containing random operation under cache (failure) | |||
| repeat | |||
| | | |||
| @@ -2118,7 +2118,7 @@ def test_cache_nomap_failure5(): | |||
| num_iter = 0 | |||
| for _ in data.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value) | |||
| assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_nomap_failure5 Ended.\n') | |||