You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ir_node_test.cc 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include <string>
  18. #include "common/common.h"
  19. #include "gtest/gtest.h"
  20. #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
  21. #include "minddata/dataset/engine/opt/pre/getter_pass.h"
  22. using namespace mindspore::dataset;
  23. class MindDataTestIRNodes : public UT::DatasetOpTesting {
  24. public:
  25. MindDataTestIRNodes() = default;
  26. // compare the ptr of the nodes in two trees, used to test the deep copy of nodes, will return error code
  27. // if (ptr1 == ptr2) does not equal to flag or the two tree has different structures (or node names are not the same)
  28. Status CompareTwoTrees(std::shared_ptr<DatasetNode> root1, std::shared_ptr<DatasetNode> root2, bool flag) {
  29. CHECK_FAIL_RETURN_UNEXPECTED(root1 != nullptr && root2 != nullptr, "Error in Compare, nullptr.");
  30. if (((root1.get() == root2.get()) != flag) || (root1->Name() != root2->Name())) {
  31. std::string err_msg =
  32. "Expect node ptr " + root1->Name() + (flag ? "==" : "!=") + root2->Name() + " but they aren't!";
  33. RETURN_STATUS_UNEXPECTED(err_msg);
  34. }
  35. size_t num_child = root1->Children().size();
  36. CHECK_FAIL_RETURN_UNEXPECTED(num_child == root2->Children().size(),
  37. root1->Name() + " has " + std::to_string(num_child) + "child, node #2 has " +
  38. std::to_string(root2->Children().size()) + " child.");
  39. for (size_t ind = 0; ind < num_child; ind++) {
  40. RETURN_IF_NOT_OK(CompareTwoTrees(root1->Children()[ind], root2->Children()[ind], flag));
  41. }
  42. return Status::OK();
  43. }
  44. // print the node's name in post order
  45. Status PostOrderPrintTree(std::shared_ptr<DatasetNode> ir, std::string &names) {
  46. RETURN_UNEXPECTED_IF_NULL(ir);
  47. for (auto child : ir->Children()) {
  48. RETURN_IF_NOT_OK(PostOrderPrintTree(child, names));
  49. }
  50. names += (ir->Name() + "->");
  51. return Status::OK();
  52. }
  53. };
  54. TEST_F(MindDataTestIRNodes, MindDataTestSimpleDeepCopy) {
  55. MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestSimpleDeepCopy.";
  56. auto tree1 = RandomData(44)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2)->IRNode();
  57. auto tree2 = tree1->DeepCopy();
  58. std::string tree_1_names, tree_2_names;
  59. ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names));
  60. ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names));
  61. // expected output for the 2 names:
  62. // RandomDataset->Repeat->Project->Shuffle->Batch->
  63. EXPECT_EQ(tree_1_names, tree_2_names);
  64. ASSERT_OK(CompareTwoTrees(tree1, tree1, true));
  65. ASSERT_OK(CompareTwoTrees(tree1, tree2, false));
  66. // verify compare function is correct
  67. EXPECT_TRUE(CompareTwoTrees(tree2, tree2, false).IsError());
  68. }
  69. TEST_F(MindDataTestIRNodes, MindDataTestZipDeepCopy) {
  70. MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestZipDeepCopy.";
  71. auto branch1 = RandomData(44)->Project({"label"});
  72. auto branch2 = RandomData(44)->Shuffle(10);
  73. auto tree1 = Zip({branch1, branch2})->Batch(2)->IRNode();
  74. auto tree2 = tree1->DeepCopy();
  75. std::string tree_1_names, tree_2_names;
  76. ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names));
  77. ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names));
  78. // expected output for the 2 names:
  79. // RandomDataset->Project->RandomDataset->Shuffle->Zip->Batch->
  80. EXPECT_EQ(tree_1_names, tree_2_names);
  81. // verify the pointer within the same tree are the same
  82. ASSERT_OK(CompareTwoTrees(tree1, tree1, true));
  83. // verify two trees
  84. ASSERT_OK(CompareTwoTrees(tree1, tree2, false));
  85. }
  86. TEST_F(MindDataTestIRNodes, MindDataTestNodeRemove) {
  87. MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestNodeRemove.";
  88. auto branch1 = RandomData(44)->Project({"label"});
  89. auto branch2 = ImageFolder("path");
  90. auto tree = Zip({branch1, branch2})->IRNode();
  91. /***
  92. tree looks like this, we will remove node and test its functionalities
  93. Zip
  94. / \
  95. Project ImageFolder
  96. /
  97. RandomData
  98. ***/
  99. auto tree_copy_1 = tree->DeepCopy();
  100. ASSERT_EQ(tree_copy_1->Children().size(), 2);
  101. // remove the project in the tree and test
  102. ASSERT_OK(tree_copy_1->Children()[0]->Remove()); // remove Project from tree
  103. ASSERT_OK(CompareTwoTrees(tree_copy_1, Zip({RandomData(44), ImageFolder("path")})->IRNode(), false));
  104. // remove the ImageFolder, a leaf node from the tree
  105. std::string tree_1_names, tree_2_names;
  106. ASSERT_OK(PostOrderPrintTree(tree_copy_1, tree_1_names));
  107. EXPECT_EQ(tree_1_names, "RandomDataset->ImageFolderDataset->Zip->");
  108. auto tree_copy_2 = tree->DeepCopy();
  109. ASSERT_EQ(tree_copy_2->Children().size(), 2);
  110. tree_copy_2->Children()[1]->Remove();
  111. ASSERT_OK(PostOrderPrintTree(tree_copy_2, tree_2_names));
  112. EXPECT_EQ(tree_2_names, "RandomDataset->Project->Zip->");
  113. }