Merge pull request !3294 from nsyca/removal_passtags/v0.7.0-beta
| @@ -23,7 +23,7 @@ | |||||
| #include "minddata/dataset/engine/opt/pre/removal_pass.h" | #include "minddata/dataset/engine/opt/pre/removal_pass.h" | ||||
| #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | ||||
| #include "minddata/dataset/engine/opt/post/repeat_pass.h" | #include "minddata/dataset/engine/opt/post/repeat_pass.h" | ||||
| #include "minddata/dataset/engine/opt/pre/injection_pass.h" | |||||
| #include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" | |||||
| #include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | #include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | ||||
| #include "minddata/dataset/engine/perf/profiling.h" | #include "minddata/dataset/engine/perf/profiling.h" | ||||
| #include "minddata/dataset/engine/perf/monitor.h" | #include "minddata/dataset/engine/perf/monitor.h" | ||||
| @@ -225,7 +225,7 @@ Status ExecutionTree::PrepareTreePreAction() { | |||||
| std::vector<std::unique_ptr<Pass>> pre_actions; | std::vector<std::unique_ptr<Pass>> pre_actions; | ||||
| // Construct pre actions | // Construct pre actions | ||||
| MS_LOG(INFO) << "Running pre pass loops."; | MS_LOG(INFO) << "Running pre pass loops."; | ||||
| pre_actions.push_back(std::make_unique<InjectionPass>()); | |||||
| pre_actions.push_back(std::make_unique<EpochInjectionPass>()); | |||||
| pre_actions.push_back(std::make_unique<RemovalPass>()); | pre_actions.push_back(std::make_unique<RemovalPass>()); | ||||
| pre_actions.push_back(std::make_unique<CacheTransformPass>()); | pre_actions.push_back(std::make_unique<CacheTransformPass>()); | ||||
| // Apply pre action passes | // Apply pre action passes | ||||
| @@ -3,10 +3,8 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE | |||||
| add_library(engine-opt OBJECT | add_library(engine-opt OBJECT | ||||
| pass.cc | pass.cc | ||||
| post/repeat_pass.cc | post/repeat_pass.cc | ||||
| pre/cache_pass.cc | |||||
| pre/cache_transform_pass.cc | pre/cache_transform_pass.cc | ||||
| pre/injection_pass.cc | |||||
| pre/removal_nodes.cc | |||||
| pre/epoch_injection_pass.cc | |||||
| pre/removal_pass.cc | pre/removal_pass.cc | ||||
| optional/tensor_op_fusion_pass.cc | optional/tensor_op_fusion_pass.cc | ||||
| util/printer_pass.cc | util/printer_pass.cc | ||||
| @@ -1,181 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "minddata/dataset/engine/opt/pre/cache_pass.h" | |||||
| #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | |||||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/generator_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/mnist_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/voc_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // Constructor | |||||
| CachePass::CachePass(CacheTransformPass *transform_pass) | |||||
| : transform_pass_(transform_pass), is_caching_(false), leaf_op_(nullptr) {} | |||||
| // Identifies the subtree below this node as a cached descendant tree. | |||||
| Status CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; | |||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); | |||||
| } | |||||
| is_caching_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| // Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache | |||||
| // transformation | |||||
| Status CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| is_caching_ = false; // We a no longer in a cache subtree. clear the flag. | |||||
| if (leaf_op_) { | |||||
| MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; | |||||
| // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, | |||||
| // using base class pointers. | |||||
| transform_pass_->AddMappableCacheOperators(std::move(leaf_op_), node); | |||||
| } else { | |||||
| // If there was no leaf_op set, then this is a non-mappable scenario. | |||||
| if (sampler_) { | |||||
| // Grab the sampler that was saved from the leaf and plug it into the cache op | |||||
| node->SetSampler(std::move(sampler_)); | |||||
| MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; | |||||
| } else { | |||||
| // We're a cache op but no sampler was saved from leaf, so create a default sampler | |||||
| const int64_t num_samples = 0; | |||||
| const int64_t start_index = 0; | |||||
| sampler_ = std::make_shared<SequentialSampler>(num_samples, start_index); | |||||
| node->SetSampler(std::move(sampler_)); | |||||
| MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; | |||||
| } | |||||
| // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache | |||||
| uint32_t cache_crc = DatasetOp::GenerateCRC(node); | |||||
| RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Common code for mappable leaf setup. | |||||
| Status CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | |||||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||||
| if (is_caching_ && leaf_op_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | |||||
| } | |||||
| // If we are a leaf in the caching path, then save this leaf. | |||||
| if (is_caching_) { | |||||
| MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; | |||||
| leaf_op_ = std::move(leaf_op); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Common code for non mappable leaf setup. | |||||
| Status CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | |||||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||||
| if (is_caching_ && leaf_op_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | |||||
| } | |||||
| // Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf | |||||
| // as save it for use by cache op in ascendant tree. | |||||
| if (is_caching_) { | |||||
| RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); | |||||
| MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; | |||||
| } else { | |||||
| // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can | |||||
| // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) | |||||
| std::shared_ptr<Sampler> sampler_from_leaf; | |||||
| RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) { | |||||
| if (is_caching_) { | |||||
| // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic | |||||
| // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. | |||||
| node->MakeSimpleProducer(); | |||||
| } | |||||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) { | |||||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -1,141 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class CacheTransformPass; | |||||
| /// \class CachePass cache_pass.h | |||||
| /// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache | |||||
| /// transformation. It works in conjunction with the CacheTransformPass | |||||
| class CachePass : public NodePass { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| /// \param[in] transform_pass Raw pointer back to controlling tree pass | |||||
| explicit CachePass(CacheTransformPass *transform_pass); | |||||
| /// \brief Destructor | |||||
| ~CachePass() = default; | |||||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||||
| /// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache | |||||
| /// transformation | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | |||||
| private: | |||||
| /// \brief Common code for mappable leaf setup. | |||||
| /// \param[in] node The leaf node performing setup work. | |||||
| /// \return Status The error code return | |||||
| Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | |||||
| /// \brief Common code for non-mappable leaf setup. | |||||
| /// \param[in] node The leaf node performing setup work. | |||||
| /// \return Status The error code return | |||||
| Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | |||||
| bool is_caching_; | |||||
| std::shared_ptr<DatasetOp> leaf_op_; | |||||
| std::shared_ptr<Sampler> sampler_; | |||||
| CacheTransformPass *transform_pass_; // Back pointer to the owning transform pass | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_ | |||||
| @@ -15,17 +15,177 @@ | |||||
| */ | */ | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/opt/pre/cache_pass.h" | |||||
| #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | ||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/engine/cache/cache_client.h" | #include "minddata/dataset/engine/cache/cache_client.h" | ||||
| #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" | #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | #include "minddata/dataset/engine/datasetops/cache_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/generator_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/mnist_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/voc_op.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor | |||||
| CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_op_(nullptr) {} | |||||
| // Identifies the subtree below this node as a cached descendant tree. | |||||
| Status CacheTransformPass::CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; | |||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); | |||||
| } | |||||
| is_caching_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| // Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache | |||||
| // transformation | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| is_caching_ = false; // We a no longer in a cache subtree. clear the flag. | |||||
| if (leaf_op_) { | |||||
| MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; | |||||
| // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, | |||||
| // using base class pointers. | |||||
| AddMappableCacheOperators(std::move(leaf_op_), node); | |||||
| } else { | |||||
| // If there was no leaf_op set, then this is a non-mappable scenario. | |||||
| if (sampler_) { | |||||
| // Grab the sampler that was saved from the leaf and plug it into the cache op | |||||
| node->SetSampler(std::move(sampler_)); | |||||
| MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; | |||||
| } else { | |||||
| // We're a cache op but no sampler was saved from leaf, so create a default sampler | |||||
| int64_t num_samples = 0; | |||||
| int64_t start_index = 0; | |||||
| sampler_ = std::make_shared<SequentialSampler>(num_samples, start_index); | |||||
| node->SetSampler(std::move(sampler_)); | |||||
| MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; | |||||
| } | |||||
| // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache | |||||
| uint32_t cache_crc = DatasetOp::GenerateCRC(node); | |||||
| RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Common code for mappable leaf setup. | |||||
| Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | |||||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||||
| if (is_caching_ && leaf_op_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | |||||
| } | |||||
| // If we are a leaf in the caching path, then save this leaf. | |||||
| if (is_caching_) { | |||||
| MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; | |||||
| leaf_op_ = std::move(leaf_op); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Common code for non mappable leaf setup. | |||||
| Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | |||||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||||
| if (is_caching_ && leaf_op_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | |||||
| } | |||||
| // Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf | |||||
| // as save it for use by cache op in ascendant tree. | |||||
| if (is_caching_) { | |||||
| RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); | |||||
| MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; | |||||
| } else { | |||||
| // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can | |||||
| // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) | |||||
| std::shared_ptr<Sampler> sampler_from_leaf; | |||||
| RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) { | |||||
| if (is_caching_) { | |||||
| // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic | |||||
| // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. | |||||
| node->MakeSimpleProducer(); | |||||
| } | |||||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) { | |||||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache tranform identifications | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Assigns the leaf and cache operators that are involved in a cache transformation | |||||
| void CacheTransformPass::CachePass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, | |||||
| std::shared_ptr<CacheOp> cache_op) { | |||||
| cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); | |||||
| } | |||||
| // constructor | // constructor | ||||
| CacheTransformPass::CacheTransformPass() {} | CacheTransformPass::CacheTransformPass() {} | ||||
| @@ -34,11 +194,11 @@ Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||||
| MS_LOG(INFO) << "Pre pass: Cache transform pass started."; | MS_LOG(INFO) << "Pre pass: Cache transform pass started."; | ||||
| // Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will | // Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will | ||||
| // use to execute a transform. | // use to execute a transform. | ||||
| std::unique_ptr<Pass> cache_pass = std::make_unique<CachePass>(this); | |||||
| RETURN_IF_NOT_OK(cache_pass->Run(tree, modified)); | |||||
| CachePass cache_pass = CachePass(); | |||||
| RETURN_IF_NOT_OK(cache_pass.Run(tree, modified)); | |||||
| // Then, execute the transform for each pair | // Then, execute the transform for each pair | ||||
| for (auto cache_pair : cache_pairs_) { | |||||
| for (auto cache_pair : cache_pass.cache_pairs()) { | |||||
| MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; | MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; | ||||
| ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()); | ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()); | ||||
| } | } | ||||
| @@ -98,11 +258,5 @@ Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::share | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Assigns the leaf and cache operators that are involved in a cache transformation | |||||
| void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, | |||||
| std::shared_ptr<CacheOp> cache_op) { | |||||
| cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,6 +33,123 @@ class CacheClient; | |||||
| /// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching | /// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching | ||||
| /// operations | /// operations | ||||
| class CacheTransformPass : public TreePass { | class CacheTransformPass : public TreePass { | ||||
| /// \class CachePass | |||||
| /// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache | |||||
| /// transformation. It works in conjunction with the CacheTransformPass | |||||
| class CachePass : public NodePass { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| /// \param[in] transform_pass Raw pointer back to controlling tree pass | |||||
| CachePass(); | |||||
| /// \brief Destructor | |||||
| ~CachePass() = default; | |||||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||||
| /// \brief Resets the tracking of the cache within the tree and assigns the operators that | |||||
| /// will be involved in a cache transformation | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | |||||
| /// \brief Getter | |||||
| std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs() { return cache_pairs_; } | |||||
| private: | |||||
| /// \brief Common code for mappable leaf setup. | |||||
| /// \param[in] node The leaf node performing setup work. | |||||
| /// \return Status The error code return | |||||
| Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | |||||
| /// \brief Common code for non-mappable leaf setup. | |||||
| /// \param[in] node The leaf node performing setup work. | |||||
| /// \return Status The error code return | |||||
| Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | |||||
| /// \brief Assigns the leaf and cache operators that are involved in a cache transformation | |||||
| /// \param[in] leaf_op The leaf operator involved in the cache transform | |||||
| /// \param[in] cache_op The cache operator involved in the cache transform | |||||
| void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op); | |||||
| bool is_caching_; | |||||
| std::shared_ptr<DatasetOp> leaf_op_; | |||||
| std::shared_ptr<Sampler> sampler_; | |||||
| // The two operators that work together to establish the cache transform | |||||
| std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_; | |||||
| }; | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| CacheTransformPass(); | CacheTransformPass(); | ||||
| @@ -46,11 +163,6 @@ class CacheTransformPass : public TreePass { | |||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | Status RunOnTree(ExecutionTree *tree, bool *modified) override; | ||||
| /// \brief Assigns the leaf and cache operators that are involved in a cache transformation | |||||
| /// \param[in] leaf_op The leaf operator involved in the cache transform | |||||
| /// \param[in] cache_op The cache operator involved in the cache transform | |||||
| void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op); | |||||
| private: | private: | ||||
| /// \brief Helper function to execute the cache transformation. | /// \brief Helper function to execute the cache transformation. | ||||
| /// | /// | ||||
| @@ -72,9 +184,6 @@ class CacheTransformPass : public TreePass { | |||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op, | Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op, | ||||
| std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client); | std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client); | ||||
| // The two operators that work together to establish the cache transform | |||||
| std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,7 +16,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "minddata/dataset/engine/opt/pre/injection_pass.h" | |||||
| #include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" | |||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | #include "minddata/dataset/engine/datasetops/device_queue_op.h" | ||||
| @@ -25,64 +25,55 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| // constructor | // constructor | ||||
| InjectionPass::InjectionFinder::InjectionFinder(InjectionPass *injection_pass) : injection_pass_(injection_pass) {} | |||||
| EpochInjectionPass::InjectionFinder::InjectionFinder(std::shared_ptr<DatasetOp> node) : injection_point_(node) {} | |||||
| // Performs finder work for BuildVocabOp that has special rules about epoch control injection | // Performs finder work for BuildVocabOp that has special rules about epoch control injection | ||||
| Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) { | |||||
| if (injection_pass_) { | |||||
| injection_pass_->epoch_ctrl_bypass_ = true; | |||||
| return Status::OK(); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!"); | |||||
| } | |||||
| Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) { | |||||
| injection_point_ = nullptr; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| // Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection | // Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection | ||||
| Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) { | |||||
| if (injection_pass_) { | |||||
| injection_pass_->epoch_ctrl_bypass_ = true; | |||||
| return Status::OK(); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!"); | |||||
| } | |||||
| Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, | |||||
| bool *modified) { | |||||
| injection_point_ = nullptr; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| // Temporary code to prevent the injection of epoch control when cache op is present | // Temporary code to prevent the injection of epoch control when cache op is present | ||||
| // Remove this code in cache op phase 2 | // Remove this code in cache op phase 2 | ||||
| Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| if (injection_pass_) { | |||||
| injection_pass_->epoch_ctrl_bypass_ = true; | |||||
| return Status::OK(); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!"); | |||||
| } | |||||
| Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| injection_point_ = nullptr; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) { | |||||
| // Assumption: There is only one DeviceQueueOp in a pipeline. This assumption is not validated here. | |||||
| injection_point_ = node->child(0); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| // constructor | // constructor | ||||
| InjectionPass::InjectionPass() : epoch_ctrl_bypass_(false) {} | |||||
| EpochInjectionPass::EpochInjectionPass() {} | |||||
| // Runs an injection pass to inject in operators needed at the pre pass stage | // Runs an injection pass to inject in operators needed at the pre pass stage | ||||
| Status InjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||||
| Status EpochInjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||||
| MS_LOG(INFO) << "Pre pass: Injection pass started."; | MS_LOG(INFO) << "Pre pass: Injection pass started."; | ||||
| // First, run the finder to perform any injection info before we can go ahead to drive the op injection work. | // First, run the finder to perform any injection info before we can go ahead to drive the op injection work. | ||||
| // The finder can make updates to the InjectionPass object. | |||||
| InjectionPass::InjectionFinder finder(this); | |||||
| finder.Run(tree, modified); | |||||
| // The finder can make updates to the EpochInjectionPass object. | |||||
| EpochInjectionPass::InjectionFinder finder(tree->root()); | |||||
| RETURN_IF_NOT_OK(finder.Run(tree, modified)); | |||||
| // The first injection logic is to check if we should inject the epoch control op as the root node. | // The first injection logic is to check if we should inject the epoch control op as the root node. | ||||
| // Do not inject the op if the number of epochs is 1. | // Do not inject the op if the number of epochs is 1. | ||||
| int32_t num_epochs = tree->num_epochs(); | int32_t num_epochs = tree->num_epochs(); | ||||
| if (num_epochs != 1 && !epoch_ctrl_bypass_) { | |||||
| std::shared_ptr<DatasetOp> epoch_inject_node = finder.injection_point(); | |||||
| if (num_epochs != 1 && epoch_inject_node != nullptr) { | |||||
| std::shared_ptr<EpochCtrlOp> epoch_ctrl_op; | std::shared_ptr<EpochCtrlOp> epoch_ctrl_op; | ||||
| RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op)); | RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op)); | ||||
| RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op)); | RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op)); | ||||
| std::shared_ptr<DatasetOp> node = tree->root(); | |||||
| if (std::dynamic_pointer_cast<DeviceQueueOp>(node) == nullptr) { | |||||
| tree->root()->InsertAsParent(epoch_ctrl_op); | |||||
| } else { | |||||
| tree->root()->child(0)->InsertAsParent(epoch_ctrl_op); | |||||
| } | |||||
| epoch_inject_node->InsertAsParent(epoch_ctrl_op); | |||||
| } | } | ||||
| MS_LOG(INFO) << "Pre pass: Injection pass complete."; | MS_LOG(INFO) << "Pre pass: Injection pass complete."; | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_ | |||||
| #define DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_ | |||||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ | |||||
| #define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -26,10 +26,10 @@ namespace dataset { | |||||
| class DatasetOp; | class DatasetOp; | ||||
| /// \class InjectionPass injection_pass.h | |||||
| /// \class EpochInjectionPass epoch_injection_pass.h | |||||
| /// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api | /// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api | ||||
| /// parsing. | /// parsing. | ||||
| class InjectionPass : public TreePass { | |||||
| class EpochInjectionPass : public TreePass { | |||||
| /// \class InjectionFinder | /// \class InjectionFinder | ||||
| /// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for | /// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for | ||||
| /// operators that need to be injected. It is run first by the main injection pass to find out what operators | /// operators that need to be injected. It is run first by the main injection pass to find out what operators | ||||
| @@ -37,7 +37,10 @@ class InjectionPass : public TreePass { | |||||
| class InjectionFinder : public NodePass { | class InjectionFinder : public NodePass { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit InjectionFinder(InjectionPass *injection_pass); | |||||
| explicit InjectionFinder(std::shared_ptr<DatasetOp> node); | |||||
| /// \brief Destructor | |||||
| ~InjectionFinder() = default; | |||||
| /// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection. | /// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| @@ -58,24 +61,30 @@ class InjectionPass : public TreePass { | |||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | ||||
| /// \brief Register the DeviceQueueOp for further action. | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override; | |||||
| /// \brief Getter | |||||
| std::shared_ptr<DatasetOp> injection_point() { return injection_point_; } | |||||
| private: | private: | ||||
| InjectionPass *injection_pass_; | |||||
| std::shared_ptr<DatasetOp> injection_point_; | |||||
| }; | }; | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| InjectionPass(); | |||||
| EpochInjectionPass(); | |||||
| /// \brief Runs an injection pass to inject in operators needed at the pre pass stage | /// \brief Runs an injection pass to inject in operators needed at the pre pass stage | ||||
| /// \param[inout] tree The tree to operate on. | /// \param[inout] tree The tree to operate on. | ||||
| /// \param[inout] Indicate of the tree was modified. | /// \param[inout] Indicate of the tree was modified. | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | Status RunOnTree(ExecutionTree *tree, bool *modified) override; | ||||
| private: | |||||
| bool epoch_ctrl_bypass_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_ | |||||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ | |||||
| @@ -1,58 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "minddata/dataset/engine/opt/pre/removal_nodes.h" | |||||
| #include "minddata/dataset/engine/opt/pre/removal_pass.h" | |||||
| #include "minddata/dataset/engine/datasetops/shuffle_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {} | |||||
| // Identifies the subtree below this node as a cached descendant tree. | |||||
| Status RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; | |||||
| is_caching_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| // Resets the tracking of the cache within the tree | |||||
| Status RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; | |||||
| is_caching_ = false; | |||||
| return Status::OK(); | |||||
| } | |||||
| // Perform ShuffleOp removal check. | |||||
| Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| // If we are in a cache descendant tree, then this shuffle op needs to be removed | |||||
| if (is_caching_) { | |||||
| MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; | |||||
| if (removal_pass_) { | |||||
| removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } else { | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!"); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -1,64 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ | |||||
| #include <memory> | |||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/engine/opt/pre/removal_pass.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// \class RemovalNodes removal_nodes.h | |||||
| /// \brief This is a NodePass who's job is to identify which nodes should be removed. | |||||
| /// It works in conjunction with the removal_pass. | |||||
| class RemovalNodes : public NodePass { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| /// \param[in] removal_pass Raw pointer back to controlling tree pass | |||||
| explicit RemovalNodes(RemovalPass *removal_pass); | |||||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||||
| /// \brief Resets the tracking of the cache within the tree | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||||
| /// \brief Destructor | |||||
| ~RemovalNodes() = default; | |||||
| /// \brief Perform ShuffleOp removal check | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | |||||
| private: | |||||
| bool is_caching_; | |||||
| RemovalPass *removal_pass_; // Back pointer to the owning removal pass | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_ | |||||
| @@ -16,32 +16,58 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "minddata/dataset/engine/opt/pre/removal_nodes.h" | |||||
| #include "minddata/dataset/engine/opt/pre/removal_pass.h" | #include "minddata/dataset/engine/opt/pre/removal_pass.h" | ||||
| #include "minddata/dataset/engine/datasetops/shuffle_op.h" | |||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| RemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {} | |||||
| // Identifies the subtree below this node as a cached descendant tree. | |||||
| Status RemovalPass::RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; | |||||
| is_caching_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| // Resets the tracking of the cache within the tree | |||||
| Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; | |||||
| is_caching_ = false; | |||||
| return Status::OK(); | |||||
| } | |||||
| // Perform ShuffleOp removal check. | |||||
| Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| // If we are in a cache descendant tree, then this shuffle op needs to be removed | |||||
| if (is_caching_) { | |||||
| MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; | |||||
| nodes_to_remove_.push_back(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // constructor | // constructor | ||||
| RemovalPass::RemovalPass() {} | RemovalPass::RemovalPass() {} | ||||
| // Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | |||||
| // Walk the tree to collect the nodes to remove, then removes them. | |||||
| Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { | Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { | ||||
| MS_LOG(INFO) << "Pre pass: removal pass started."; | MS_LOG(INFO) << "Pre pass: removal pass started."; | ||||
| // Create the removal node pass which can identify which nodes need to be removed. | // Create the removal node pass which can identify which nodes need to be removed. | ||||
| std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this); | |||||
| std::unique_ptr<RemovalPass::RemovalNodes> removal_nodes = std::make_unique<RemovalPass::RemovalNodes>(); | |||||
| RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); | RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); | ||||
| // Then, execute the removal of any nodes that were set up for removal | // Then, execute the removal of any nodes that were set up for removal | ||||
| for (auto node : removal_nodes_) { | |||||
| for (auto node : removal_nodes->nodes_to_remove()) { | |||||
| node->Remove(); | node->Remove(); | ||||
| } | } | ||||
| MS_LOG(INFO) << "Pre pass: removal pass complete."; | MS_LOG(INFO) << "Pre pass: removal pass complete."; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Adds an operator to the list of operators to be removed | |||||
| void RemovalPass::AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op) { removal_nodes_.push_back(dataset_op); } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -30,6 +30,45 @@ class DatasetOp; | |||||
| /// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which | /// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which | ||||
| /// nodes should be removed, and then removes them. | /// nodes should be removed, and then removes them. | ||||
| class RemovalPass : public TreePass { | class RemovalPass : public TreePass { | ||||
| /// \class RemovalNodes | |||||
| /// \brief This is a NodePass who's job is to identify which nodes should be removed. | |||||
| /// It works in conjunction with the removal_pass. | |||||
| class RemovalNodes : public NodePass { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| /// \param[in] removal_pass Raw pointer back to controlling tree pass | |||||
| RemovalNodes(); | |||||
| /// \brief Destructor | |||||
| ~RemovalNodes() = default; | |||||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||||
| /// \brief Resets the tracking of the cache within the tree | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||||
| /// \brief Perform ShuffleOp removal check | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | |||||
| /// \brief Getter | |||||
| /// \return All the nodes to be removed | |||||
| std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove() { return nodes_to_remove_; } | |||||
| private: | |||||
| bool is_caching_; | |||||
| std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove_; | |||||
| }; | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| RemovalPass(); | RemovalPass(); | ||||
| @@ -42,13 +81,6 @@ class RemovalPass : public TreePass { | |||||
| /// \param[inout] Indicate of the tree was modified. | /// \param[inout] Indicate of the tree was modified. | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | Status RunOnTree(ExecutionTree *tree, bool *modified) override; | ||||
| /// \brief Adds an operator to the list of operators to be removed | |||||
| /// \param[in] dataset_op The operator to add to the removal list | |||||
| void AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op); | |||||
| private: | |||||
| std::vector<std::shared_ptr<DatasetOp>> removal_nodes_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -189,7 +189,7 @@ def test_minddataset_invalidate_num_shards(): | |||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(): | for _ in data_set.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info) | |||||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value) | |||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||
| @@ -203,7 +203,7 @@ def test_minddataset_invalidate_shard_id(): | |||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(): | for _ in data_set.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info) | |||||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value) | |||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||
| @@ -217,14 +217,14 @@ def test_minddataset_shard_id_bigger_than_num_shard(): | |||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(): | for _ in data_set.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info) | |||||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value) | |||||
| with pytest.raises(Exception) as error_info: | with pytest.raises(Exception) as error_info: | ||||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) | data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(): | for _ in data_set.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info) | |||||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value) | |||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||
| @@ -245,7 +245,7 @@ def test_cv_minddataset_partition_num_samples_equals_0(): | |||||
| num_iter += 1 | num_iter += 1 | ||||
| with pytest.raises(Exception) as error_info: | with pytest.raises(Exception) as error_info: | ||||
| partitions(5) | partitions(5) | ||||
| assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info) | |||||
| assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info.value) | |||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||