From: @ziruiwu Reviewed-by: @robingrosman,@nsyca Signed-off-by: @nsycatags/v1.2.0-rc1
| @@ -211,7 +211,7 @@ class ExecutionTree { | |||||
| // Return the pointer to the TaskGroup | // Return the pointer to the TaskGroup | ||||
| // @return raw pointer to the TaskGroup | // @return raw pointer to the TaskGroup | ||||
| TaskGroup *AllTasks() const { return tg_.get(); } | |||||
| TaskGroup *const AllTasks() const { return tg_.get(); } | |||||
| // Return if the ExecutionTree is at end of epoch status | // Return if the ExecutionTree is at end of epoch status | ||||
| // @return bool - true is ExecutionTree is end of epoch status | // @return bool - true is ExecutionTree is end of epoch status | ||||
| @@ -65,7 +65,7 @@ Status CocoNode::ValidateParams() { | |||||
| // Function to build CocoNode | // Function to build CocoNode | ||||
| Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | ||||
| CocoOp::TaskType task_type; | |||||
| CocoOp::TaskType task_type = CocoOp::TaskType::Detection; | |||||
| if (task_ == "Detection") { | if (task_ == "Detection") { | ||||
| task_type = CocoOp::TaskType::Detection; | task_type = CocoOp::TaskType::Detection; | ||||
| } else if (task_ == "Stuff") { | } else if (task_ == "Stuff") { | ||||
| @@ -74,6 +74,10 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| task_type = CocoOp::TaskType::Keypoint; | task_type = CocoOp::TaskType::Keypoint; | ||||
| } else if (task_ == "Panoptic") { | } else if (task_ == "Panoptic") { | ||||
| task_type = CocoOp::TaskType::Panoptic; | task_type = CocoOp::TaskType::Panoptic; | ||||
| } else { | |||||
| std::string err_msg = "Task type:'" + task_ + "' is not supported."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | } | ||||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | ||||
| @@ -47,6 +47,9 @@ class AutoWorkerPass : public IRTreePass { | |||||
| max_num_workers_(8), | max_num_workers_(8), | ||||
| thread_cnt_(GlobalContext::Instance()->config_manager()->num_cpu_threads()) {} | thread_cnt_(GlobalContext::Instance()->config_manager()->num_cpu_threads()) {} | ||||
| /// \brief destructor, by doing "= default", compiler will automatically generate the correct destructor | |||||
| ~AutoWorkerPass() = default; | |||||
| Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *) override; | Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *) override; | ||||
| private: | private: | ||||
| @@ -54,6 +57,10 @@ class AutoWorkerPass : public IRTreePass { | |||||
| public: | public: | ||||
| explicit OpWeightPass(const std::map<std::string, float> &weight_profile) | explicit OpWeightPass(const std::map<std::string, float> &weight_profile) | ||||
| : IRNodePass(), weight_sum_(0), weight_profile_(weight_profile) {} | : IRNodePass(), weight_sum_(0), weight_profile_(weight_profile) {} | ||||
| /// \brief destructor, by doing "= default", compiler will automatically generate the correct destructor | |||||
| ~OpWeightPass() = default; | |||||
| // this is the base class function which contains the logic to handle most of the pipeline ops | // this is the base class function which contains the logic to handle most of the pipeline ops | ||||
| // pipeline ops although can't config num_workers it still runs 1 thread they need to be factored into weight | // pipeline ops although can't config num_workers it still runs 1 thread they need to be factored into weight | ||||
| Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override; | Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override; | ||||
| @@ -51,15 +51,15 @@ class TreeAdapter { | |||||
| // 2. GetNext will return empty row when eoe/eof is obtained | // 2. GetNext will return empty row when eoe/eof is obtained | ||||
| Status GetNext(TensorRow *); | Status GetNext(TensorRow *); | ||||
| // This function will return the root of the execution tree. | |||||
| std::weak_ptr<DatasetOp> GetRoot() { return tree_ != nullptr ? tree_->root() : nullptr; } | |||||
| // unique_ptr overloads operator bool(), will return false if it doesn't manage an object | |||||
| std::weak_ptr<DatasetOp> GetRoot() { return tree_ ? tree_->root() : nullptr; } | |||||
| // This function will return the column_name_map once BuildAndPrepare() is called | // This function will return the column_name_map once BuildAndPrepare() is called | ||||
| std::unordered_map<std::string, int32_t> GetColumnNameMap() const { return column_name_map_; } | std::unordered_map<std::string, int32_t> GetColumnNameMap() const { return column_name_map_; } | ||||
| // This function returns the TaskGroup associated with ExeTree. This is needed by DeviceQueueConsumer | // This function returns the TaskGroup associated with ExeTree. This is needed by DeviceQueueConsumer | ||||
| // to be able to launch a thread. BuildAndPrepare needs to be called before this function | // to be able to launch a thread. BuildAndPrepare needs to be called before this function | ||||
| TaskGroup *AllTasks() const { return tree_ != nullptr ? tree_->AllTasks() : nullptr; } | |||||
| TaskGroup *const AllTasks() const { return tree_ ? tree_->AllTasks() : nullptr; } | |||||
| Status Launch() const; | Status Launch() const; | ||||