You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

subset_random_sampler_test.cc 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common.h"
  17. #include "gtest/gtest.h"
  18. #include "dataset/core/constants.h"
  19. #include "dataset/core/tensor.h"
  20. #include "dataset/engine/data_buffer.h"
  21. #include "dataset/engine/datasetops/source/sampler/sampler.h"
  22. #include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
  23. #include <vector>
  24. #include <unordered_set>
  25. using namespace mindspore::dataset;
  26. class MindDataTestSubsetRandomSampler : public UT::Common {
  27. public:
  28. class DummyRandomAccessOp : public RandomAccessOp {
  29. public:
  30. DummyRandomAccessOp(int64_t num_rows) : num_rows_(num_rows) {};
  31. Status GetNumSamples(int64_t *num) const {
  32. *num = num_rows_;
  33. return Status::OK();
  34. }
  35. Status GetNumRowsInDataset(int64_t *num) const {
  36. *num = num_rows_;
  37. return Status::OK();
  38. }
  39. private:
  40. int64_t num_rows_;
  41. };
  42. };
  43. TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
  44. std::vector<int64_t> in({0, 1, 2, 3, 4});
  45. std::unordered_set<int64_t> in_set(in.begin(), in.end());
  46. SubsetRandomSampler sampler(in);
  47. DummyRandomAccessOp dummyRandomAccessOp(5);
  48. sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  49. std::unique_ptr<DataBuffer> db;
  50. TensorRow row;
  51. std::vector<int64_t> out;
  52. ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
  53. db->PopRow(&row);
  54. for (const auto &t : row) {
  55. for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
  56. out.push_back(*it);
  57. }
  58. }
  59. ASSERT_EQ(in.size(), out.size());
  60. for (int i = 0; i < in.size(); i++) {
  61. ASSERT_NE(in_set.find(out[i]), in_set.end());
  62. }
  63. ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
  64. ASSERT_EQ(db->eoe(), true);
  65. }
  66. TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
  67. int64_t total_samples = 100000 - 5;
  68. int64_t samples_per_buffer = 10;
  69. std::vector<int64_t> input(total_samples, 1);
  70. SubsetRandomSampler sampler(input, samples_per_buffer);
  71. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  72. sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  73. std::unique_ptr<DataBuffer> db;
  74. TensorRow row;
  75. std::vector<int64_t> out;
  76. ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
  77. int epoch = 0;
  78. while (!db->eoe()) {
  79. epoch++;
  80. db->PopRow(&row);
  81. for (const auto &t : row) {
  82. for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
  83. out.push_back(*it);
  84. }
  85. }
  86. db.reset();
  87. ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
  88. }
  89. ASSERT_EQ(epoch, (total_samples + samples_per_buffer - 1) / samples_per_buffer);
  90. ASSERT_EQ(input.size(), out.size());
  91. }
  92. TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
  93. std::vector<int64_t> in({0, 1, 2, 3, 4});
  94. std::unordered_set<int64_t> in_set(in.begin(), in.end());
  95. SubsetRandomSampler sampler(in);
  96. DummyRandomAccessOp dummyRandomAccessOp(5);
  97. sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  98. std::unique_ptr<DataBuffer> db;
  99. TensorRow row;
  100. std::vector<int64_t> out;
  101. ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
  102. db->PopRow(&row);
  103. for (const auto &t : row) {
  104. for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
  105. out.push_back(*it);
  106. }
  107. }
  108. ASSERT_EQ(in.size(), out.size());
  109. for (int i = 0; i < in.size(); i++) {
  110. ASSERT_NE(in_set.find(out[i]), in_set.end());
  111. }
  112. sampler.Reset();
  113. ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
  114. ASSERT_EQ(db->eoe(), false);
  115. db->PopRow(&row);
  116. out.clear();
  117. for (const auto &t : row) {
  118. for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
  119. out.push_back(*it);
  120. }
  121. }
  122. ASSERT_EQ(in.size(), out.size());
  123. for (int i = 0; i < in.size(); i++) {
  124. ASSERT_NE(in_set.find(out[i]), in_set.end());
  125. }
  126. ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
  127. ASSERT_EQ(db->eoe(), true);
  128. }