You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tree_adapter_test.cc 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/engine/tree_adapter.h"
  17. #include "common/common.h"
  18. #include "minddata/dataset/core/tensor_row.h"
  19. #include "minddata/dataset/include/datasets.h"
  20. #include "minddata/dataset/include/transforms.h"
  21. using namespace mindspore::dataset;
  22. using mindspore::dataset::Tensor;
  23. class MindDataTestTreeAdapter : public UT::DatasetOpTesting {
  24. protected:
  25. };
  26. TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) {
  27. MS_LOG(INFO) << "Doing MindDataTestTreeAdapter-TestSimpleTreeAdapter.";
  28. // Create a Mnist Dataset
  29. std::string folder_path = datasets_root_path_ + "/testMnistData/";
  30. std::shared_ptr<api::Dataset> ds = Mnist(folder_path, "all", api::SequentialSampler(0, 4));
  31. EXPECT_NE(ds, nullptr);
  32. ds = ds->Batch(2);
  33. EXPECT_NE(ds, nullptr);
  34. mindspore::dataset::TreeAdapter tree_adapter;
  35. Status rc = tree_adapter.BuildAndPrepare(ds, 1);
  36. EXPECT_TRUE(rc.IsOk());
  37. const std::unordered_map<std::string, int32_t> map = {{"label", 1}, {"image", 0}};
  38. EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
  39. std::vector<size_t> row_sizes = {2, 2, 0, 0};
  40. TensorRow row;
  41. for (size_t sz : row_sizes) {
  42. rc = tree_adapter.GetNext(&row);
  43. EXPECT_TRUE(rc.IsOk());
  44. EXPECT_EQ(row.size(), sz);
  45. }
  46. rc = tree_adapter.GetNext(&row);
  47. EXPECT_TRUE(rc.IsError());
  48. const std::string err_msg = rc.ToString();
  49. EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
  50. }
  51. TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {
  52. MS_LOG(INFO) << "Doing MindDataTestTreeAdapter-TestTreeAdapterWithRepeat.";
  53. // Create a Mnist Dataset
  54. std::string folder_path = datasets_root_path_ + "/testMnistData/";
  55. std::shared_ptr<api::Dataset> ds = Mnist(folder_path, "all", api::SequentialSampler(0, 3));
  56. EXPECT_NE(ds, nullptr);
  57. ds = ds->Batch(2, false);
  58. EXPECT_NE(ds, nullptr);
  59. mindspore::dataset::TreeAdapter tree_adapter;
  60. Status rc = tree_adapter.BuildAndPrepare(ds, 2);
  61. EXPECT_TRUE(rc.IsOk());
  62. const std::unordered_map<std::string, int32_t> map = tree_adapter.GetColumnNameMap();
  63. EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
  64. std::vector<size_t> row_sizes = {2, 2, 0, 2, 2, 0, 0};
  65. TensorRow row;
  66. for (size_t sz : row_sizes) {
  67. rc = tree_adapter.GetNext(&row);
  68. EXPECT_TRUE(rc.IsOk());
  69. EXPECT_EQ(row.size(), sz);
  70. }
  71. rc = tree_adapter.GetNext(&row);
  72. const std::string err_msg = rc.ToString();
  73. EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
  74. }
  75. TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
  76. MS_LOG(INFO) << "Doing MindDataTestPipeline-TestProjectMap.";
  77. // Create an ImageFolder Dataset
  78. std::string folder_path = datasets_root_path_ + "/testPK/data/";
  79. std::shared_ptr<api::Dataset> ds = ImageFolder(folder_path, true, api::SequentialSampler(0, 2));
  80. EXPECT_NE(ds, nullptr);
  81. // Create objects for the tensor ops
  82. std::shared_ptr<api::TensorOperation> one_hot = api::transforms::OneHot(10);
  83. EXPECT_NE(one_hot, nullptr);
  84. // Create a Map operation, this will automatically add a project after map
  85. ds = ds->Map({one_hot}, {"label"}, {"label"}, {"label"});
  86. EXPECT_NE(ds, nullptr);
  87. mindspore::dataset::TreeAdapter tree_adapter;
  88. Status rc = tree_adapter.BuildAndPrepare(ds, 2);
  89. EXPECT_TRUE(rc.IsOk());
  90. const std::unordered_map<std::string, int32_t> map = {{"label", 0}};
  91. EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
  92. std::vector<size_t> row_sizes = {1, 1, 0, 1, 1, 0, 0};
  93. TensorRow row;
  94. for (size_t sz : row_sizes) {
  95. rc = tree_adapter.GetNext(&row);
  96. EXPECT_TRUE(rc.IsOk());
  97. EXPECT_EQ(row.size(), sz);
  98. }
  99. rc = tree_adapter.GetNext(&row);
  100. const std::string err_msg = rc.ToString();
  101. EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
  102. }