From: @ziruiwu Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -227,6 +227,8 @@ Status PreBuiltOperation::ValidateParams() { return Status::OK(); } | |||||
| std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; } | std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; } | ||||
| std::string PreBuiltOperation::Name() const { return op_ ? op_->Name() : kPreBuiltOperation; } | |||||
| // RandomApplyOperation | // RandomApplyOperation | ||||
| RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob) | RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob) | ||||
| : TensorOperation(true), transforms_(transforms), prob_(prob) {} | : TensorOperation(true), transforms_(transforms), prob_(prob) {} | ||||
| @@ -1264,72 +1264,7 @@ std::shared_ptr<TensorOp> RandomCropOperation::Build() { | |||||
| RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale, | RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale, | ||||
| std::vector<float> ratio, | std::vector<float> ratio, | ||||
| InterpolationMode interpolation, int32_t max_attempts) | InterpolationMode interpolation, int32_t max_attempts) | ||||
| : TensorOperation(true), | |||||
| size_(size), | |||||
| scale_(scale), | |||||
| ratio_(ratio), | |||||
| interpolation_(interpolation), | |||||
| max_attempts_(max_attempts) {} | |||||
| Status RandomCropDecodeResizeOperation::ValidateParams() { | |||||
| // size | |||||
| if (size_.empty() || size_.size() > 2) { | |||||
| std::string err_msg = "RandomCropDecodeResize: size vector has incorrect size: " + std::to_string(size_.size()); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| RETURN_IF_NOT_OK(ValidateVectorPositive("RandomCropDecodeResize", size_)); | |||||
| // rescale | |||||
| if (scale_.empty() || scale_.size() != 2) { | |||||
| std::string err_msg = "RandomCropDecodeResize: scale vector has incorrect size: " + std::to_string(scale_.size()); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (scale_[0] < 0) { | |||||
| std::string err_msg = "RandomCropDecodeResize: invalid scale, min scale must be greater than or equal to 0, got: " + | |||||
| std::to_string(scale_[0]); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (scale_[1] <= 0) { | |||||
| std::string err_msg = | |||||
| "RandomCropDecodeResize: invalid scale, max scale must be greater than 0, got: " + std::to_string(scale_[1]); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (scale_[0] > scale_[1]) { | |||||
| std::string err_msg = "RandomCropDecodeResize: scale should be in (min,max) format. Got (max,min)."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| // ratio | |||||
| if (ratio_.empty() || ratio_.size() != 2) { | |||||
| std::string err_msg = "RandomCropDecodeResize: ratio vector has incorrect size: " + std::to_string(ratio_.size()); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| for (int32_t i = 0; i < ratio_.size(); ++i) { | |||||
| if (ratio_[i] <= 0) { | |||||
| std::string err_msg = | |||||
| "RandomCropDecodeResize: invalid ratio, ratio must be greater than 0, got: " + std::to_string(ratio_[i]); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } | |||||
| if (ratio_[0] > ratio_[1]) { | |||||
| std::string err_msg = "RandomCropDecodeResize: ratio should be in (min,max) format. Got (max,min)."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| // max_attempts | |||||
| if (max_attempts_ < 1) { | |||||
| std::string err_msg = | |||||
| "RandomCropDecodeResize: max_attempts must be greater than or equal to 1, got: " + std::to_string(max_attempts_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| : RandomResizedCropOperation(size, scale, ratio, interpolation, max_attempts) {} | |||||
| std::shared_ptr<TensorOp> RandomCropDecodeResizeOperation::Build() { | std::shared_ptr<TensorOp> RandomCropDecodeResizeOperation::Build() { | ||||
| int32_t crop_height = size_[0]; | int32_t crop_height = size_[0]; | ||||
| @@ -1352,6 +1287,9 @@ std::shared_ptr<TensorOp> RandomCropDecodeResizeOperation::Build() { | |||||
| return tensor_op; | return tensor_op; | ||||
| } | } | ||||
| RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(const RandomResizedCropOperation &base) | |||||
| : RandomResizedCropOperation(base) {} | |||||
| // RandomCropWithBBoxOperation | // RandomCropWithBBoxOperation | ||||
| RandomCropWithBBoxOperation::RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding, | RandomCropWithBBoxOperation::RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding, | ||||
| bool pad_if_needed, std::vector<uint8_t> fill_value, | bool pad_if_needed, std::vector<uint8_t> fill_value, | ||||
| @@ -1574,62 +1512,56 @@ RandomResizedCropOperation::RandomResizedCropOperation(std::vector<int32_t> size | |||||
| Status RandomResizedCropOperation::ValidateParams() { | Status RandomResizedCropOperation::ValidateParams() { | ||||
| // size | // size | ||||
| if (size_.size() != 2 && size_.size() != 1) { | if (size_.size() != 2 && size_.size() != 1) { | ||||
| std::string err_msg = | |||||
| "RandomResizedCrop: size must be a vector of one or two values, got: " + std::to_string(size_.size()); | |||||
| std::string err_msg = Name() + ": size must be a vector of one or two values, got: " + std::to_string(size_.size()); | |||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| if (size_[0] <= 0 || (size_.size() == 2 && size_[1] <= 0)) { | if (size_[0] <= 0 || (size_.size() == 2 && size_[1] <= 0)) { | ||||
| std::string err_msg = "RandomResizedCrop: size must only contain positive integers."; | |||||
| MS_LOG(ERROR) << "RandomResizedCrop: size must only contain positive integers, got: " << size_; | |||||
| std::string err_msg = Name() + ": size must only contain positive integers."; | |||||
| MS_LOG(ERROR) << Name() + ": size must only contain positive integers, got: " << size_; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| // scale | // scale | ||||
| if (scale_.size() != 2) { | if (scale_.size() != 2) { | ||||
| std::string err_msg = | |||||
| "RandomResizedCrop: scale must be a vector of two values, got: " + std::to_string(scale_.size()); | |||||
| std::string err_msg = Name() + ": scale must be a vector of two values, got: " + std::to_string(scale_.size()); | |||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| if (scale_[0] < 0) { | if (scale_[0] < 0) { | ||||
| std::string err_msg = "RandomResizedCrop: min scale must be greater than or equal to 0."; | |||||
| MS_LOG(ERROR) << "RandomResizedCrop: min scale must be greater than or equal to 0, got: " + | |||||
| std::to_string(scale_[0]); | |||||
| std::string err_msg = Name() + ": min scale must be greater than or equal to 0."; | |||||
| MS_LOG(ERROR) << Name() + ": min scale must be greater than or equal to 0, got: " + std::to_string(scale_[0]); | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| if (scale_[1] <= 0) { | if (scale_[1] <= 0) { | ||||
| std::string err_msg = "RandomResizedCrop: max scale must be greater than 0."; | |||||
| MS_LOG(ERROR) << "RandomResizedCrop: max scale must be greater than 0, got: " + std::to_string(scale_[1]); | |||||
| std::string err_msg = Name() + ": max scale must be greater than 0."; | |||||
| MS_LOG(ERROR) << Name() + ": max scale must be greater than 0, got: " + std::to_string(scale_[1]); | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| if (scale_[1] < scale_[0]) { | if (scale_[1] < scale_[0]) { | ||||
| std::string err_msg = "RandomResizedCrop: scale must have a size of two in the format of (min, max)."; | |||||
| MS_LOG(ERROR) << "RandomResizedCrop: scale must have a size of two in the format of (min, max), but got: " | |||||
| << scale_; | |||||
| std::string err_msg = Name() + ": scale must have a size of two in the format of (min, max)."; | |||||
| MS_LOG(ERROR) << Name() + ": scale must have a size of two in the format of (min, max), but got: " << scale_; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| // ratio | // ratio | ||||
| if (ratio_.size() != 2) { | if (ratio_.size() != 2) { | ||||
| std::string err_msg = | |||||
| "RandomResizedCrop: ratio must be a vector of two values, got: " + std::to_string(ratio_.size()); | |||||
| std::string err_msg = Name() + ": ratio must be a vector of two values, got: " + std::to_string(ratio_.size()); | |||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| if (ratio_[0] <= 0 || ratio_[1] <= 0) { | if (ratio_[0] <= 0 || ratio_[1] <= 0) { | ||||
| std::string err_msg = "RandomResizedCrop: ratio must be greater than 0."; | |||||
| MS_LOG(ERROR) << "RandomResizedCrop: ratio must be greater than 0, got: " << ratio_; | |||||
| std::string err_msg = Name() + ": ratio must be greater than 0."; | |||||
| MS_LOG(ERROR) << Name() + ": ratio must be greater than 0, got: " << ratio_; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| if (ratio_[1] < ratio_[0]) { | if (ratio_[1] < ratio_[0]) { | ||||
| std::string err_msg = "RandomResizedCrop: ratio must have a size of two in the format of (min, max)."; | |||||
| MS_LOG(ERROR) << "RandomResizedCrop: ratio must have a size of two in the format of (min, max), but got: " | |||||
| << ratio_; | |||||
| std::string err_msg = Name() + ": ratio must have a size of two in the format of (min, max)."; | |||||
| MS_LOG(ERROR) << Name() + ": ratio must have a size of two in the format of (min, max), but got: " << ratio_; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| // max_attempts | // max_attempts | ||||
| if (max_attempts_ < 1) { | if (max_attempts_ < 1) { | ||||
| std::string err_msg = | std::string err_msg = | ||||
| "RandomResizedCrop: max_attempts must be greater than or equal to 1, got: " + std::to_string(max_attempts_); | |||||
| Name() + ": max_attempts must be greater than or equal to 1, got: " + std::to_string(max_attempts_); | |||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| @@ -515,17 +515,6 @@ Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vec | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TreeGetters::InternalInit(int8_t type) { | |||||
| if (init_flag_) return Status::OK(); | |||||
| tree_adapter_->SetPrePassOverride([&type](OptPass pre) { | |||||
| pre.push_back(std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(type))); | |||||
| return pre; | |||||
| }); | |||||
| Status s = tree_adapter_->Compile(std::move(root_), 1); | |||||
| if (s.IsOk()) init_flag_ = true; | |||||
| return s; | |||||
| } | |||||
| Status TreeGetters::InternalInit() { | Status TreeGetters::InternalInit() { | ||||
| if (init_flag_) return Status::OK(); | if (init_flag_) return Status::OK(); | ||||
| Status s = tree_adapter_->Compile(std::move(root_), 1); | Status s = tree_adapter_->Compile(std::move(root_), 1); | ||||
| @@ -535,7 +524,7 @@ Status TreeGetters::InternalInit() { | |||||
| Status TreeGetters::GetFirstRowShapeAndType() { | Status TreeGetters::GetFirstRowShapeAndType() { | ||||
| RETURN_OK_IF_TRUE(first_row_obtained_); | RETURN_OK_IF_TRUE(first_row_obtained_); | ||||
| RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType))); | |||||
| RETURN_IF_NOT_OK(InternalInit()); | |||||
| TensorRow first_row; | TensorRow first_row; | ||||
| RETURN_IF_NOT_OK(GetRow(&first_row)); | RETURN_IF_NOT_OK(GetRow(&first_row)); | ||||
| std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_type_), | std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_type_), | ||||
| @@ -572,11 +561,6 @@ Status DatasetSizeGetter::Init(std::shared_ptr<DatasetNode> d) { | |||||
| Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size) { | Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size) { | ||||
| std::shared_ptr<TreeAdapter> tree_adapter = std::make_shared<TreeAdapter>(TreeAdapter::UsageFlag::kDeGetter); | std::shared_ptr<TreeAdapter> tree_adapter = std::make_shared<TreeAdapter>(TreeAdapter::UsageFlag::kDeGetter); | ||||
| tree_adapters_.push_back(tree_adapter); | tree_adapters_.push_back(tree_adapter); | ||||
| tree_adapter->SetPrePassOverride([](OptPass pre) { | |||||
| pre.push_back( | |||||
| std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(GetterPass::GetterType::kDatasetSize))); | |||||
| return pre; | |||||
| }); | |||||
| RETURN_IF_NOT_OK(tree_adapter->Compile(ir_node, 1)); | RETURN_IF_NOT_OK(tree_adapter->Compile(ir_node, 1)); | ||||
| TensorRow row; | TensorRow row; | ||||
| RETURN_IF_NOT_OK(GetRow(tree_adapter, &row)); | RETURN_IF_NOT_OK(GetRow(tree_adapter, &row)); | ||||
| @@ -199,7 +199,6 @@ class TreeGetters : public TreeConsumer { | |||||
| bool first_row_obtained_; // whether first row (which could be empty) is obtained by TreeGetter | bool first_row_obtained_; // whether first row (which could be empty) is obtained by TreeGetter | ||||
| bool init_flag_; // indicate whether the tree has initialized | bool init_flag_; // indicate whether the tree has initialized | ||||
| Status InternalInit(int8_t type); | |||||
| Status InternalInit(); | Status InternalInit(); | ||||
| }; | }; | ||||
| @@ -40,12 +40,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor | // Constructor | ||||
| ExecutionTree::ExecutionTree() : id_count_(0), pre_pass_override_(nullptr) { | |||||
| ExecutionTree::ExecutionTree() : id_count_(0), tree_state_(kDeTStateInit), prepare_flags_(kDePrepNone) { | |||||
| tg_ = std::make_unique<TaskGroup>(); | tg_ = std::make_unique<TaskGroup>(); | ||||
| tree_state_ = kDeTStateInit; | |||||
| prepare_flags_ = kDePrepNone; | |||||
| profiling_manager_ = std::make_unique<ProfilingManager>(this); | profiling_manager_ = std::make_unique<ProfilingManager>(this); | ||||
| optimize_ = common::GetEnv("OPTIMIZE") == "true" ? true : false; | |||||
| #if defined(NUMA_ENABLED) && (defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE)) | #if defined(NUMA_ENABLED) && (defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE)) | ||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | ||||
| rank_id_ = cfg->rank_id(); | rank_id_ = cfg->rank_id(); | ||||
| @@ -275,10 +272,6 @@ Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) { | |||||
| // Pre optimization compulsory transformation | // Pre optimization compulsory transformation | ||||
| RETURN_IF_NOT_OK(this->PreAction()); | RETURN_IF_NOT_OK(this->PreAction()); | ||||
| // If optional optimizations are enabled | |||||
| if (optimize_) { | |||||
| RETURN_IF_NOT_OK(this->Optimize()); | |||||
| } | |||||
| // Post optimization compulsory transformation | // Post optimization compulsory transformation | ||||
| RETURN_IF_NOT_OK(this->PostAction()); | RETURN_IF_NOT_OK(this->PostAction()); | ||||
| @@ -302,14 +295,6 @@ Status ExecutionTree::PreAction() { | |||||
| pre_actions.push_back(std::make_unique<RemovalPass>()); | pre_actions.push_back(std::make_unique<RemovalPass>()); | ||||
| } | } | ||||
| // this offers a way to override the preset optimization pass with customized ones | |||||
| // this is used when certain nodes are removed for tree getters | |||||
| if (pre_pass_override_) { | |||||
| MS_LOG(INFO) << "Default pre optimization passes is being overridden," | |||||
| << " number of passes before the override:" << pre_actions.size() << "."; | |||||
| pre_actions = pre_pass_override_(std::move(pre_actions)); | |||||
| } | |||||
| MS_LOG(INFO) << "Running " << pre_actions.size() << " pre pass loops."; | MS_LOG(INFO) << "Running " << pre_actions.size() << " pre pass loops."; | ||||
| // Apply pre action passes | // Apply pre action passes | ||||
| @@ -343,22 +328,6 @@ Status ExecutionTree::PostAction() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status ExecutionTree::Optimize() { | |||||
| // Vector of optimizations, currently only 1, add more as necessary | |||||
| OptPass optimizations; | |||||
| #ifndef ENABLE_ANDROID | |||||
| optimizations.push_back(std::make_unique<TensorOpFusionPass>()); | |||||
| #endif | |||||
| // vector of flags for each optimization | |||||
| std::vector<bool> modified(optimizations.size(), false); | |||||
| for (auto i = 0; i < optimizations.size(); i++) { | |||||
| auto m = false; | |||||
| optimizations[i]->Run(this, &m); | |||||
| modified[i] = m; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // The driver of the prepare phase of the execution tree. The prepare phase will recursively | // The driver of the prepare phase of the execution tree. The prepare phase will recursively | ||||
| // walk the tree to perform modifications to the tree or specific nodes within the tree to get | // walk the tree to perform modifications to the tree or specific nodes within the tree to get | ||||
| // it ready for execution. | // it ready for execution. | ||||
| @@ -192,10 +192,6 @@ class ExecutionTree { | |||||
| // @return Status The status code returned | // @return Status The status code returned | ||||
| Status PostAction(); | Status PostAction(); | ||||
| // Optimization transformation/action, optional. | |||||
| // @return Status The status code returned | |||||
| Status Optimize(); | |||||
| // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively | // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively | ||||
| // walk the tree to perform modifications to the tree or specific nodes within the tree to get | // walk the tree to perform modifications to the tree or specific nodes within the tree to get | ||||
| // it ready for execution. | // it ready for execution. | ||||
| @@ -240,29 +236,10 @@ class ExecutionTree { | |||||
| // Getter for profiling manager, no ownership | // Getter for profiling manager, no ownership | ||||
| ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); } | ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); } | ||||
| // Set optional optimization if tree has not been prepared yet | |||||
| Status SetOptimize(bool value) { | |||||
| if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) { | |||||
| std::string optimize = (optimize_ == true) ? "true" : "false"; | |||||
| std::string msg = "Tree has already been prepared with OPTIMIZE set to " + optimize; | |||||
| RETURN_STATUS_UNEXPECTED(msg); | |||||
| } else { | |||||
| optimize_ = value; | |||||
| return Status::OK(); | |||||
| } | |||||
| } | |||||
| // Optional optimizations status | |||||
| bool OptimizationEnabled() const { return optimize_; } | |||||
| // Getter function to get the total number of epochs to be run on this tree. | // Getter function to get the total number of epochs to be run on this tree. | ||||
| // @return total number of epochs | // @return total number of epochs | ||||
| int32_t num_epochs() { return num_epochs_; } | int32_t num_epochs() { return num_epochs_; } | ||||
| // set the function ptr that overrides the pre-pass which allows caller to adjust the existing pre_pass and | |||||
| // introduce new passes. E.g. caller can override the num_epoch in EpochInjectionPass | |||||
| void SetPrePassOverride(std::function<OptPass(OptPass)> pre_pass_override) { pre_pass_override_ = pre_pass_override; } | |||||
| private: | private: | ||||
| // A helper functions for doing the recursive printing | // A helper functions for doing the recursive printing | ||||
| // @param dataset_op - The dataset op to print | // @param dataset_op - The dataset op to print | ||||
| @@ -279,8 +256,6 @@ class ExecutionTree { | |||||
| TreeState tree_state_; // Tracking the current tree state | TreeState tree_state_; // Tracking the current tree state | ||||
| int32_t num_epochs_; // Total number of epochs to run for this tree | int32_t num_epochs_; // Total number of epochs to run for this tree | ||||
| std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager | std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager | ||||
| bool optimize_; // Flag to enable optional optimizations | |||||
| std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare() | |||||
| bool partially_prepare_; // Temp: during migration to IR, if true, run remaining passes. | bool partially_prepare_; // Temp: during migration to IR, if true, run remaining passes. | ||||
| #if defined(NUMA_ENABLED) && (defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE)) | #if defined(NUMA_ENABLED) && (defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE)) | ||||
| // This rank_id is for numa and device_queue, one process work with only one rank_id, | // This rank_id is for numa and device_queue, one process work with only one rank_id, | ||||
| @@ -115,5 +115,10 @@ Status MapNode::AcceptAfter(IRNodePass *const p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->VisitAfter(shared_from_base<MapNode>(), modified); | return p->VisitAfter(shared_from_base<MapNode>(), modified); | ||||
| } | } | ||||
| void MapNode::setOperations(const std::vector<std::shared_ptr<TensorOperation>> &operations) { | |||||
| operations_ = operations; | |||||
| } | |||||
| std::vector<std::shared_ptr<TensorOperation>> MapNode::operations() { return operations_; } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -74,8 +74,19 @@ class MapNode : public DatasetNode { | |||||
| /// \return Status of the node visit | /// \return Status of the node visit | ||||
| Status AcceptAfter(IRNodePass *p, bool *modified) override; | Status AcceptAfter(IRNodePass *p, bool *modified) override; | ||||
| /// \brief clear all callbacks | |||||
| void ClearCallbacks() { callbacks_.clear(); } | |||||
| /// \brief getter to get all tensor operations | |||||
| std::vector<std::shared_ptr<TensorOperation>> operations(); | |||||
| /// \brief setter to set all tensor operations | |||||
| void setOperations(const std::vector<std::shared_ptr<TensorOperation>> &operations); | |||||
| private: | private: | ||||
| std::vector<std::shared_ptr<TensorOperation>> operations_; | std::vector<std::shared_ptr<TensorOperation>> operations_; | ||||
| private: | |||||
| std::vector<std::string> input_columns_; | std::vector<std::string> input_columns_; | ||||
| std::vector<std::string> output_columns_; | std::vector<std::string> output_columns_; | ||||
| std::vector<std::string> project_columns_; | std::vector<std::string> project_columns_; | ||||
| @@ -13,45 +13,53 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <algorithm> | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | #include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | ||||
| #include "minddata/dataset/kernels/image/decode_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | |||||
| #include "minddata/dataset/include/transforms.h" | |||||
| #include "minddata/dataset/include/vision.h" | |||||
| #include "minddata/dataset/include/vision_lite.h" | |||||
| #include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" | |||||
| #include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h" | #include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| Status TensorOpFusionPass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) { | |||||
| // Most primitive pattern: DecodeOp immediately followed by RandomCropAndResizeOp | |||||
| // Abstract into a more general member function that can find any pattern, expressed | |||||
| // by regular expressions, for instance. | |||||
| // Add a list of optimisation policies. For now, just this lambda | |||||
| auto FindPattern = [](auto &tfuncs) { | |||||
| auto it = | |||||
| std::find_if(tfuncs.begin(), tfuncs.end(), [](const auto &tf) -> bool { return tf->Name() == kDecodeOp; }); | |||||
| auto next = it + 1; | |||||
| if (it != tfuncs.end() && next != tfuncs.end() && (*next)->Name() == kRandomCropAndResizeOp) { | |||||
| return it; | |||||
| } else { | |||||
| return tfuncs.end(); | |||||
| } | |||||
| }; | |||||
| Status TensorOpFusionPass::Visit(std::shared_ptr<MapNode> node, bool *modified) { | |||||
| std::vector<std::shared_ptr<TensorOperation>> ops = node->operations(); | |||||
| auto &tfuncs = node->TFuncs(); | |||||
| auto it = FindPattern(tfuncs); | |||||
| if (it != tfuncs.end()) { | |||||
| auto next = it + 1; | |||||
| auto op = static_cast<RandomCropAndResizeOp *>(next->get()); | |||||
| *it = std::static_pointer_cast<TensorOp>(std::make_shared<RandomCropDecodeResizeOp>(*op)); | |||||
| tfuncs.erase(next); | |||||
| } | |||||
| if (modified != nullptr) { | |||||
| // start temporary code, to deal with pre-built TensorOperation | |||||
| std::vector<std::string> pattern = {kDecodeOp, kRandomCropAndResizeOp}; | |||||
| auto itr = std::search(ops.begin(), ops.end(), pattern.begin(), pattern.end(), | |||||
| [](auto op, const std::string &nm) { return op->Name() == nm; }); | |||||
| if (itr != ops.end()) { | |||||
| MS_LOG(WARNING) << "Fusing pre-build Decode and RandomCropResize into one pre-build."; | |||||
| auto op = dynamic_cast<RandomCropAndResizeOp *>((*(itr + 1))->Build().get()); | |||||
| (*itr) = std::make_shared<transforms::PreBuiltOperation>(std::make_shared<RandomCropDecodeResizeOp>(*op)); | |||||
| ops.erase(itr + 1); | |||||
| node->setOperations(ops); | |||||
| *modified = true; | *modified = true; | ||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("modified is nullptr"); | |||||
| } | |||||
| return Status::OK(); | |||||
| } // end of temporary code, needs to be deleted when tensorOperation's pybind completes | |||||
| // logic below is for non-prebuilt TensorOperation | |||||
| pattern = {vision::kDecodeOperation, vision::kRandomResizedCropOperation}; | |||||
| itr = std::search(ops.begin(), ops.end(), pattern.begin(), pattern.end(), | |||||
| [](auto op, const std::string &nm) { return op->Name() == nm; }); | |||||
| // return here if no pattern is found | |||||
| RETURN_OK_IF_TRUE(itr == ops.end()); | |||||
| auto *op = dynamic_cast<vision::RandomResizedCropOperation *>((itr + 1)->get()); | |||||
| RETURN_UNEXPECTED_IF_NULL(op); | |||||
| // fuse the two ops | |||||
| (*itr) = std::make_shared<vision::RandomCropDecodeResizeOperation>(*op); | |||||
| ops.erase(itr + 1); | |||||
| node->setOperations(ops); | |||||
| *modified = true; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -25,12 +25,12 @@ namespace dataset { | |||||
| /// \class TensorOpFusionPass tensor_op_fusion_pass.h | /// \class TensorOpFusionPass tensor_op_fusion_pass.h | ||||
| /// \brief And optional optimization pass identifying and fusing | /// \brief And optional optimization pass identifying and fusing | ||||
| /// tensor ops within MapOp | /// tensor ops within MapOp | ||||
| class TensorOpFusionPass : public NodePass { | |||||
| class TensorOpFusionPass : public IRNodePass { | |||||
| /// \brief Identifies and fuses tensor ops within MapOp | /// \brief Identifies and fuses tensor ops within MapOp | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] *modified indicates whether the node has been visited | /// \param[inout] *modified indicates whether the node has been visited | ||||
| /// \return Status The status code returned | /// \return Status The status code returned | ||||
| Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | |||||
| Status Visit(std::shared_ptr<MapNode> node, bool *modified) override; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -15,52 +15,13 @@ | |||||
| */ | */ | ||||
| #include "minddata/dataset/engine/opt/pre/getter_pass.h" | #include "minddata/dataset/engine/opt/pre/getter_pass.h" | ||||
| #include "minddata/dataset/engine/execution_tree.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) { | |||||
| nodes_to_remove_.push_back(node); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||||
| if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) { | |||||
| if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) { | |||||
| if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) { | |||||
| nodes_to_clear_callback_.push_back(node); | |||||
| return Status::OK(); | |||||
| } | |||||
| #ifdef ENABLE_PYTHON | |||||
| Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | |||||
| if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node); | |||||
| Status GetterPass::Visit(std::shared_ptr<MapNode> node, bool *modified) { | |||||
| node->ClearCallbacks(); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| #endif | |||||
| Status GetterPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||||
| RETURN_IF_NOT_OK(pass_.Run(tree, modified)); | |||||
| // currently the getter pass only disables call_back from the execution tree | |||||
| // clear the callback for selected ops (map when its GetOutputType/Shape) | |||||
| for (auto node : pass_.nodes_to_clear_callback_) node->ClearCallbacks(); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,7 +19,6 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <list> | #include <list> | ||||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | |||||
| #include "minddata/dataset/engine/opt/pass.h" | #include "minddata/dataset/engine/opt/pass.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -28,48 +27,16 @@ namespace dataset { | |||||
| class DatasetOp; | class DatasetOp; | ||||
| /// \class GetterPass | /// \class GetterPass | ||||
| /// \brief This is a tree pass that will remove nodes or clears the callback in MapOp | |||||
| class GetterPass : public TreePass { | |||||
| /// \brief This is a tree pass that will for now only clear the callback in MapOp to prevent hang | |||||
| class GetterPass : public IRNodePass { | |||||
| public: | public: | ||||
| enum GetterType { kDatasetSize = 1, kOutputShapeAndType = 2 }; | |||||
| /// \brief Constructor | |||||
| explicit GetterPass(GetterType tp) : pass_(tp) {} | |||||
| /// \brief Default Constructor | |||||
| GetterPass() = default; | |||||
| /// \brief default copy Constructor | |||||
| explicit GetterPass(const GetterPass &) = default; | |||||
| /// \brief Destructor | |||||
| /// \brief Default Destructor | |||||
| ~GetterPass() = default; | ~GetterPass() = default; | ||||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | |||||
| private: | |||||
| /// \class GetterNodes, this is a nested class which is owned via composition by the outter class to identify nodes | |||||
| /// \brief This is a NodePass who's job is to identify which nodes should be removed. | |||||
| class GetterNodes : public NodePass { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit GetterNodes(GetterType tp) : type_(tp) {} | |||||
| ~GetterNodes() = default; | |||||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override { return Status::OK(); } | |||||
| Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | |||||
| #ifdef ENABLE_PYTHON | |||||
| Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override; | |||||
| #endif | |||||
| GetterType type_; | |||||
| std::list<std::shared_ptr<DatasetOp>> nodes_to_clear_callback_; | |||||
| std::list<std::shared_ptr<DatasetOp>> nodes_to_remove_; | |||||
| }; | |||||
| // outer class needs only to own the inner class object since it automatically has access to its private variables | |||||
| GetterNodes pass_; | |||||
| Status Visit(std::shared_ptr<MapNode> node, bool *modified) override; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,7 +27,7 @@ namespace dataset { | |||||
| /// \class InputValidationPass | /// \class InputValidationPass | ||||
| /// \brief This is a parse pass that validates input parameters of the IR tree. | /// \brief This is a parse pass that validates input parameters of the IR tree. | ||||
| class InputValidationPass : public IRNodePass { | class InputValidationPass : public IRNodePass { | ||||
| /// \brief Runs a validatation pass to check input parameters | |||||
| /// \brief Runs a validation pass to check input parameters | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] *modified indicates whether the node has been visited | /// \param[inout] *modified indicates whether the node has been visited | ||||
| /// \return Status code | /// \return Status code | ||||
| @@ -18,11 +18,13 @@ | |||||
| #include "minddata/dataset/core/client.h" | #include "minddata/dataset/core/client.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/root_node.h" | #include "minddata/dataset/engine/ir/datasetops/root_node.h" | ||||
| #include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | |||||
| #include "minddata/dataset/engine/opt/pass.h" | #include "minddata/dataset/engine/opt/pass.h" | ||||
| #include "minddata/dataset/engine/opt/post/auto_worker_pass.h" | #include "minddata/dataset/engine/opt/post/auto_worker_pass.h" | ||||
| #include "minddata/dataset/engine/opt/pre/cache_validation_pass.h" | #include "minddata/dataset/engine/opt/pre/cache_validation_pass.h" | ||||
| #include "minddata/dataset/engine/opt/pre/deep_copy_pass.h" | #include "minddata/dataset/engine/opt/pre/deep_copy_pass.h" | ||||
| #include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h" | #include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h" | ||||
| #include "minddata/dataset/engine/opt/pre/getter_pass.h" | |||||
| #include "minddata/dataset/engine/opt/pre/input_validation_pass.h" | #include "minddata/dataset/engine/opt/pre/input_validation_pass.h" | ||||
| #include "minddata/dataset/engine/opt/pre/node_removal_pass.h" | #include "minddata/dataset/engine/opt/pre/node_removal_pass.h" | ||||
| @@ -43,11 +45,11 @@ Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) { | |||||
| std::vector<std::unique_ptr<IRPass>> actions; | std::vector<std::unique_ptr<IRPass>> actions; | ||||
| MS_LOG(INFO) << "Running pre pass loops."; | MS_LOG(INFO) << "Running pre pass loops."; | ||||
| actions.push_back(std::make_unique<InputValidationPass>()); | |||||
| actions.push_back(std::make_unique<CacheValidationPass>()); | |||||
| actions.push_back(std::make_unique<NodeRemovalPass>()); | |||||
| actions.push_back(std::make_unique<EpochCtrlPass>()); | |||||
| actions.emplace_back(std::make_unique<InputValidationPass>()); | |||||
| actions.emplace_back(std::make_unique<CacheValidationPass>()); | |||||
| actions.emplace_back(std::make_unique<NodeRemovalPass>()); | |||||
| actions.emplace_back(std::make_unique<EpochCtrlPass>()); | |||||
| if (usage_ == kDeGetter) actions.emplace_back(std::make_unique<GetterPass>()); | |||||
| // Vector of flags for each action | // Vector of flags for each action | ||||
| std::vector<bool> modified(actions.size(), false); | std::vector<bool> modified(actions.size(), false); | ||||
| // Apply pre-pass actions | // Apply pre-pass actions | ||||
| @@ -64,16 +66,11 @@ Status TreeAdapter::Optimize(std::shared_ptr<DatasetNode> ir) { | |||||
| // Vector of optimizations | // Vector of optimizations | ||||
| std::vector<std::unique_ptr<IRNodePass>> optimizations; | std::vector<std::unique_ptr<IRNodePass>> optimizations; | ||||
| MS_LOG(INFO) << "Running optimization pass loops"; | MS_LOG(INFO) << "Running optimization pass loops"; | ||||
| // We will gradually move TensorOpFusionPass from ExecutionTree::Optimize to here. | |||||
| // Vector of flags for each optimization | |||||
| std::vector<bool> modified(optimizations.size(), false); | |||||
| optimizations.emplace_back(std::make_unique<TensorOpFusionPass>()); | |||||
| // Apply optimization pass actions | // Apply optimization pass actions | ||||
| for (auto i = 0; i < optimizations.size(); i++) { | for (auto i = 0; i < optimizations.size(); i++) { | ||||
| auto m = false; | |||||
| RETURN_IF_NOT_OK(optimizations[i]->Run(ir, &m)); | |||||
| modified[i] = m; | |||||
| bool modified = false; | |||||
| RETURN_IF_NOT_OK(optimizations[i]->Run(ir, &modified)); | |||||
| } | } | ||||
| MS_LOG(INFO) << "Optimization pass complete."; | MS_LOG(INFO) << "Optimization pass complete."; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -138,8 +135,6 @@ Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epoc | |||||
| RETURN_IF_NOT_OK(BuildExecutionTreeRecur(root_ir->Children()[0], &root_op)); | RETURN_IF_NOT_OK(BuildExecutionTreeRecur(root_ir->Children()[0], &root_op)); | ||||
| RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); | RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); | ||||
| if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_); | |||||
| // Note: We will gradually move the pre pass, optimizer pass, and post pass | // Note: We will gradually move the pre pass, optimizer pass, and post pass | ||||
| // on ExecutionTree to perform on IR tree. | // on ExecutionTree to perform on IR tree. | ||||
| // Prepare the tree | // Prepare the tree | ||||
| @@ -66,9 +66,6 @@ class TreeAdapter { | |||||
| // Set optional optimization pass | // Set optional optimization pass | ||||
| void SetOptimize(bool value) { optimize_ = value; } | void SetOptimize(bool value) { optimize_ = value; } | ||||
| // function to override override the pre-pass | |||||
| void SetPrePassOverride(std::function<OptPass(OptPass)> pre_pass_override) { pre_pass_override_ = pre_pass_override; } | |||||
| // Optional optimizations status | // Optional optimizations status | ||||
| bool OptimizationEnabled() const { return optimize_; } | bool OptimizationEnabled() const { return optimize_; } | ||||
| @@ -90,14 +87,13 @@ class TreeAdapter { | |||||
| std::unique_ptr<DataBuffer> cur_db_; | std::unique_ptr<DataBuffer> cur_db_; | ||||
| std::unordered_map<std::string, int32_t> column_name_map_; | std::unordered_map<std::string, int32_t> column_name_map_; | ||||
| std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling | |||||
| bool optimize_; // Flag to enable optional optimization pass | |||||
| std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data | |||||
| int32_t cur_batch_num_; // current batch number, used for profiling | |||||
| int32_t cur_connector_size_; // current connector size of root op, used for profiling | |||||
| int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling | |||||
| std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare() | |||||
| UsageFlag usage_; // usage of this tree adapter (type of consumer) | |||||
| std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling | |||||
| bool optimize_; // Flag to enable optional optimization pass | |||||
| std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data | |||||
| int32_t cur_batch_num_; // current batch number, used for profiling | |||||
| int32_t cur_connector_size_; // current connector size of root op, used for profiling | |||||
| int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling | |||||
| UsageFlag usage_; // usage of this tree adapter (type of consumer) | |||||
| // State flags for the lifecycle of the tree | // State flags for the lifecycle of the tree | ||||
| enum CompileState { | enum CompileState { | ||||
| kCompileStateInit = 0, // The freshly initialized state | kCompileStateInit = 0, // The freshly initialized state | ||||
| @@ -204,7 +204,7 @@ class PreBuiltOperation : public TensorOperation { | |||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| std::string Name() const override { return kPreBuiltOperation; } | |||||
| std::string Name() const override; | |||||
| private: | private: | ||||
| std::shared_ptr<TensorOp> op_; | std::shared_ptr<TensorOp> op_; | ||||
| @@ -758,20 +758,25 @@ class RandomCropOperation : public TensorOperation { | |||||
| BorderType padding_mode_; | BorderType padding_mode_; | ||||
| }; | }; | ||||
| class RandomCropDecodeResizeOperation : public TensorOperation { | |||||
| class RandomResizedCropOperation : public TensorOperation { | |||||
| public: | public: | ||||
| RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale, std::vector<float> ratio, | |||||
| InterpolationMode interpolation, int32_t max_attempts); | |||||
| RandomResizedCropOperation(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0}, | |||||
| std::vector<float> ratio = {3. / 4., 4. / 3.}, | |||||
| InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, | |||||
| int32_t max_attempts = 10); | |||||
| ~RandomCropDecodeResizeOperation() = default; | |||||
| /// \brief default copy constructor | |||||
| explicit RandomResizedCropOperation(const RandomResizedCropOperation &) = default; | |||||
| ~RandomResizedCropOperation() = default; | |||||
| std::shared_ptr<TensorOp> Build() override; | std::shared_ptr<TensorOp> Build() override; | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| std::string Name() const override { return kRandomCropDecodeResizeOperation; } | |||||
| std::string Name() const override { return kRandomResizedCropOperation; } | |||||
| private: | |||||
| protected: | |||||
| std::vector<int32_t> size_; | std::vector<int32_t> size_; | ||||
| std::vector<float> scale_; | std::vector<float> scale_; | ||||
| std::vector<float> ratio_; | std::vector<float> ratio_; | ||||
| @@ -779,6 +784,20 @@ class RandomCropDecodeResizeOperation : public TensorOperation { | |||||
| int32_t max_attempts_; | int32_t max_attempts_; | ||||
| }; | }; | ||||
| class RandomCropDecodeResizeOperation : public RandomResizedCropOperation { | |||||
| public: | |||||
| RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale, std::vector<float> ratio, | |||||
| InterpolationMode interpolation, int32_t max_attempts); | |||||
| explicit RandomCropDecodeResizeOperation(const RandomResizedCropOperation &base); | |||||
| ~RandomCropDecodeResizeOperation() = default; | |||||
| std::shared_ptr<TensorOp> Build() override; | |||||
| std::string Name() const override { return kRandomCropDecodeResizeOperation; } | |||||
| }; | |||||
| class RandomCropWithBBoxOperation : public TensorOperation { | class RandomCropWithBBoxOperation : public TensorOperation { | ||||
| public: | public: | ||||
| RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0}, | RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0}, | ||||
| @@ -881,29 +900,6 @@ class RandomResizeWithBBoxOperation : public TensorOperation { | |||||
| std::vector<int32_t> size_; | std::vector<int32_t> size_; | ||||
| }; | }; | ||||
| class RandomResizedCropOperation : public TensorOperation { | |||||
| public: | |||||
| explicit RandomResizedCropOperation(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0}, | |||||
| std::vector<float> ratio = {3. / 4., 4. / 3.}, | |||||
| InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, | |||||
| int32_t max_attempts = 10); | |||||
| ~RandomResizedCropOperation() = default; | |||||
| std::shared_ptr<TensorOp> Build() override; | |||||
| Status ValidateParams() override; | |||||
| std::string Name() const override { return kRandomResizedCropOperation; } | |||||
| private: | |||||
| std::vector<int32_t> size_; | |||||
| std::vector<float> scale_; | |||||
| std::vector<float> ratio_; | |||||
| InterpolationMode interpolation_; | |||||
| int32_t max_attempts_; | |||||
| }; | |||||
| class RandomResizedCropWithBBoxOperation : public TensorOperation { | class RandomResizedCropWithBBoxOperation : public TensorOperation { | ||||
| public: | public: | ||||
| explicit RandomResizedCropWithBBoxOperation(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0}, | explicit RandomResizedCropWithBBoxOperation(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0}, | ||||
| @@ -16,14 +16,17 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include "minddata/dataset/core/client.h" | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "gtest/gtest.h" | #include "gtest/gtest.h" | ||||
| #include "minddata/dataset/engine/execution_tree.h" | |||||
| #include "minddata/dataset/core/client.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | |||||
| #include "minddata/dataset/engine/opt/post/auto_worker_pass.h" | #include "minddata/dataset/engine/opt/post/auto_worker_pass.h" | ||||
| #include "minddata/dataset/engine/opt/pre/getter_pass.h" | |||||
| #include "minddata/dataset/include/transforms.h" | |||||
| #include "minddata/dataset/include/vision.h" | |||||
| #include "minddata/dataset/include/vision_lite.h" | |||||
| using namespace mindspore::dataset; | using namespace mindspore::dataset; | ||||
| using mindspore::LogStream; | using mindspore::LogStream; | ||||
| @@ -31,7 +34,6 @@ using mindspore::MsLogLevel::INFO; | |||||
| class MindDataTestOptimizationPass : public UT::DatasetOpTesting {}; | class MindDataTestOptimizationPass : public UT::DatasetOpTesting {}; | ||||
| TEST_F(MindDataTestOptimizationPass, MindDataTestAutoWorkerPass) { | TEST_F(MindDataTestOptimizationPass, MindDataTestAutoWorkerPass) { | ||||
| MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestAutoWorkerPass."; | MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestAutoWorkerPass."; | ||||
| @@ -63,3 +65,41 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestAutoWorkerPass) { | |||||
| MS_LOG(DEBUG) << batch->IRNode()->Name() << ": num_worker=" << batch->IRNode()->num_workers(); | MS_LOG(DEBUG) << batch->IRNode()->Name() << ": num_worker=" << batch->IRNode()->num_workers(); | ||||
| MS_LOG(DEBUG) << map->IRNode()->Name() << ": num_worker=" << map->IRNode()->num_workers(); | MS_LOG(DEBUG) << map->IRNode()->Name() << ": num_worker=" << map->IRNode()->num_workers(); | ||||
| } | } | ||||
| TEST_F(MindDataTestOptimizationPass, MindDataTestTensorFusionPass) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestTensorFusionPass."; | |||||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||||
| std::shared_ptr<Dataset> root = | |||||
| ImageFolder(folder_path, false)->Map({vision::Decode(), vision::RandomResizedCrop({100})}, {"image"}); | |||||
| TensorOpFusionPass fusion_pass; | |||||
| bool modified = false; | |||||
| std::shared_ptr<MapNode> map_node = std::dynamic_pointer_cast<MapNode>(root->IRNode()); | |||||
| // no deepcopy is performed because this doesn't go through tree_adapter | |||||
| fusion_pass.Run(root->IRNode(), &modified); | |||||
| EXPECT_EQ(modified, true); | |||||
| ASSERT_NE(map_node, nullptr); | |||||
| auto fused_ops = map_node->operations(); | |||||
| ASSERT_EQ(fused_ops.size(), 1); | |||||
| ASSERT_EQ(fused_ops[0]->Name(), vision::kRandomCropDecodeResizeOperation); | |||||
| } | |||||
| TEST_F(MindDataTestOptimizationPass, MindDataTestTensorFusionPassPreBuiltTensorOperation) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestTensorFusionPassPreBuiltTensorOperation."; | |||||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||||
| // make prebuilt tensor operation | |||||
| auto decode = std::make_shared<transforms::PreBuiltOperation>(vision::Decode()->Build()); | |||||
| auto resize = std::make_shared<transforms::PreBuiltOperation>(vision::RandomResizedCrop({100})->Build()); | |||||
| std::shared_ptr<Dataset> root = ImageFolder(folder_path, false)->Map({decode, resize}, {"image"}); | |||||
| TensorOpFusionPass fusion_pass; | |||||
| bool modified = false; | |||||
| std::shared_ptr<MapNode> map_node = std::dynamic_pointer_cast<MapNode>(root->IRNode()); | |||||
| // no deepcopy is performed because this doesn't go through tree_adapter | |||||
| fusion_pass.Run(root->IRNode(), &modified); | |||||
| EXPECT_EQ(modified, true); | |||||
| ASSERT_NE(map_node, nullptr); | |||||
| auto fused_ops = map_node->operations(); | |||||
| ASSERT_EQ(fused_ops.size(), 1); | |||||
| ASSERT_EQ(fused_ops[0]->Name(), kRandomCropDecodeResizeOp); | |||||
| } | |||||
| @@ -454,6 +454,38 @@ def test_callbacks_one_cb(): | |||||
| assert events3 == expected_events3 | assert events3 == expected_events3 | ||||
| def test_clear_callback(): | |||||
| logger.info("test_clear_callback") | |||||
| # this test case will test that callback is removed for get_dataset_size and output_shape/type | |||||
| class FlagCallback(DSCallback): | |||||
| def __init__(self): | |||||
| super().__init__(step_size=1) | |||||
| self.flag = False | |||||
| self.row_cnt = 0 | |||||
| def ds_begin(self, ds_run_context): | |||||
| # if callback isn't removed in getter pass, this function will be called | |||||
| self.flag = True | |||||
| def ds_step_begin(self, ds_run_context): | |||||
| self.row_cnt += 1 | |||||
| data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) | |||||
| cb = FlagCallback() | |||||
| # make sure variables are properly initialized before testing | |||||
| assert not cb.flag and cb.row_cnt == 0 | |||||
| data = data.map(operations=(lambda x: x), callbacks=cb) | |||||
| assert data.get_dataset_size() == 4 | |||||
| assert data.output_shapes() == [[]] | |||||
| # make sure callback is never called by checking flag and row_cnt | |||||
| assert not cb.flag and cb.row_cnt == 0 | |||||
| for _ in data.create_dict_iterator(num_epochs=1): | |||||
| pass | |||||
| # this ensure that callback is indeed called | |||||
| assert cb.flag and cb.row_cnt == 4 | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_callbacks_all_2cbs() | test_callbacks_all_2cbs() | ||||
| test_callbacks_all_methods() | test_callbacks_all_methods() | ||||
| @@ -467,3 +499,4 @@ if __name__ == '__main__': | |||||
| test_callbacks_one_cb() | test_callbacks_one_cb() | ||||
| test_callbacks_non_sink_mismatch_size() | test_callbacks_non_sink_mismatch_size() | ||||
| test_callbacks_train_end() | test_callbacks_train_end() | ||||
| test_clear_callback() | |||||