| @@ -233,7 +233,7 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) { | |||
| return shared_from_this(); | |||
| } | |||
| DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}) { | |||
| DatasetNode::DatasetNode() : cache_(nullptr), parent_({}), children_({}) { | |||
| // Fetch some default value from config manager | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| num_workers_ = cfg->num_parallel_workers(); | |||
| @@ -294,30 +294,33 @@ void DatasetNode::PrintNode(std::ostream &out, int *level) const { | |||
| // Add a node as a child, node's parent needs to be nullptr | |||
| // this function will allow child to be a nullptr, in which case it will simply skip | |||
| void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) { | |||
| if (child != nullptr && child->parent_ == nullptr) { | |||
| if (child != nullptr && !child->parent_.size()) { | |||
| children_.push_back(child); | |||
| child->parent_ = this; | |||
| child->parent_.push_back(this); | |||
| } else if (child != nullptr) { | |||
| MS_LOG(WARNING) << "DatasetNode::AddChild() Fail" + child->Name() + "'s parent isn't a nullptr."; | |||
| MS_LOG(WARNING) << "DatasetNode::AddChild() failed: " + child->Name() + "'s parent isn't a nullptr."; | |||
| children_.push_back(child); | |||
| child->parent_.push_back(this); | |||
| } | |||
| } | |||
| // Remove this node from its parent. Add the child of this node to its parent. | |||
| // for now, this remove is limited to node with a single child or no child | |||
| Status DatasetNode::Remove() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "Cannot remove root or a node without parent."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(parent_.size() != 0, "Cannot remove root or a node without parent."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(children_.size() < 2, "Cannot remove node with more than 1 child."); | |||
| if (children_.empty()) { // I am a leaf node, remove me from my parent's children list | |||
| parent_->children_.erase(std::remove(parent_->children_.begin(), parent_->children_.end(), shared_from_this()), | |||
| parent_->children_.end()); // removal using "erase remove idiom" | |||
| } else { // replace my position in my parent's children list with my single child | |||
| auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list."); | |||
| parent_[0]->children_.erase( | |||
| std::remove(parent_[0]->children_.begin(), parent_[0]->children_.end(), shared_from_this()), | |||
| parent_[0]->children_.end()); // removal using "erase remove idiom" | |||
| } else { // replace my position in my parent's children list with my single child | |||
| auto itr = std::find(parent_[0]->children_.begin(), parent_[0]->children_.end(), shared_from_this()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_[0]->children_.end(), "I am not in my parent's children list."); | |||
| children_[0]->parent_ = parent_; // set my single child's parent ptr to my parent | |||
| *itr = std::move(children_[0]); // replace me in my parent's children list with my single child | |||
| children_.clear(); // release my single child from my children list | |||
| } | |||
| parent_ = nullptr; | |||
| parent_[0] = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| @@ -183,6 +183,10 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| /// \return Child nodes | |||
| const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; } | |||
| /// \brief Getter function for parents nodes | |||
| /// \return Parent nodes | |||
| const std::vector<DatasetNode *> Parent() const { return parent_; } | |||
| /// \brief Establish the parent-child relationship between this node and its child. | |||
| void AddChild(std::shared_ptr<DatasetNode> child); | |||
| @@ -233,7 +237,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| protected: | |||
| std::vector<std::shared_ptr<DatasetNode>> children_; | |||
| DatasetNode *parent_; | |||
| std::vector<DatasetNode *> parent_; | |||
| std::shared_ptr<DatasetCache> cache_; | |||
| int64_t dataset_size_ = -1; | |||
| int32_t num_workers_; | |||
| @@ -52,7 +52,7 @@ Status RootNode::ValidateParams() { | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (parent_ != nullptr) { | |||
| if (parent_.size() != 0) { | |||
| std::string err_msg = "Internal error: root node should not have a parent"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| @@ -96,6 +96,10 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) { | |||
| } | |||
| Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) { | |||
| // Check if pipeline is valid or not | |||
| CHECK_FAIL_RETURN_UNEXPECTED(ir->Parent().size() <= 1, | |||
| "The data pipeline is not a tree (i.e. one node has two consumers)"); | |||
| // Build the DatasetOp ExecutionTree from the optimized IR tree | |||
| std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build(); | |||
| RETURN_IF_NOT_OK(ir->BuildStatus()); // remove me after changing return val of Build() | |||
| @@ -4511,7 +4511,7 @@ class Schema: | |||
| Parse the columns and add it to self. | |||
| Args: | |||
| columns (Union[dict, list[dict]]): Dataset attribute information, decoded from schema file. | |||
| columns (Union[dict, list[dict], tuple[dict]]): Dataset attribute information, decoded from schema file. | |||
| - list[dict], 'name' and 'type' must be in keys, 'shape' optional. | |||
| @@ -4519,7 +4519,6 @@ class Schema: | |||
| Raises: | |||
| RuntimeError: If failed to parse columns. | |||
| RuntimeError: If unknown items in columns. | |||
| RuntimeError: If column's name field is missing. | |||
| RuntimeError: If column's type field is missing. | |||
| @@ -415,6 +415,35 @@ TEST_F(MindDataTestPipeline, TestConcatFail4) { | |||
| EXPECT_EQ(iter, nullptr); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestConcatFail5) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatFail5."; | |||
| // This case is expected to fail because the dataset concat itself which causes ProjectNode has two parent nodes | |||
| // Create an ImageFolder Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||
| std::shared_ptr<Dataset> ds1 = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||
| EXPECT_NE(ds1, nullptr); | |||
| std::shared_ptr<Dataset> ds2 = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create a Project operation on ds | |||
| ds1 = ds1->Project({"image"}); | |||
| EXPECT_NE(ds1, nullptr); | |||
| ds2 = ds2->Project({"image"}); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create a Concat operation on the ds | |||
| // Input dataset is the dataset itself | |||
| ds1 = ds1 + ds1 + ds2; | |||
| EXPECT_NE(ds1, nullptr); | |||
| // Create an iterator over the result of the above dataset | |||
| std::shared_ptr<Iterator> iter = ds1->CreateIterator(); | |||
| // Expect failure: The data pipeline is not a tree | |||
| EXPECT_EQ(iter, nullptr); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestConcatSuccess) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatSuccess."; | |||