You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

weighted_random_sampler_test.cc 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common.h"
  17. #include "gtest/gtest.h"
  18. #include "minddata/dataset/core/constants.h"
  19. #include "minddata/dataset/core/tensor.h"
  20. #include "minddata/dataset/engine/data_buffer.h"
  21. #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
  22. #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
  23. #include "utils/log_adapter.h"
  24. #include <vector>
  25. #include <unordered_set>
  26. using namespace mindspore::dataset;
  27. using mindspore::MsLogLevel::INFO;
  28. using mindspore::ExceptionType::NoExceptionType;
  29. using mindspore::LogStream;
  30. class MindDataTestWeightedRandomSampler : public UT::Common {
  31. public:
  32. class DummyRandomAccessOp : public RandomAccessOp {
  33. public:
  34. DummyRandomAccessOp(uint64_t num_rows) {
  35. // row count is in base class as protected member
  36. // GetNumRowsInDataset does not need an override, the default from base class is fine.
  37. num_rows_ = num_rows;
  38. }
  39. };
  40. };
  41. TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
  42. // num samples to draw.
  43. uint64_t num_samples = 100;
  44. uint64_t total_samples = 1000;
  45. std::vector<double> weights(total_samples, std::rand() % 100);
  46. std::vector<uint64_t> freq(total_samples, 0);
  47. // create sampler with replacement = true
  48. WeightedRandomSampler m_sampler(num_samples, weights, true);
  49. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  50. m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  51. std::unique_ptr<DataBuffer> db;
  52. TensorRow row;
  53. std::vector<uint64_t> out;
  54. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  55. db->PopRow(&row);
  56. for (const auto &t : row) {
  57. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  58. out.push_back(*it);
  59. freq[*it]++;
  60. }
  61. }
  62. ASSERT_EQ(num_samples, out.size());
  63. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  64. ASSERT_EQ(db->eoe(), true);
  65. }
  66. TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
  67. // num samples to draw.
  68. uint64_t num_samples = 100;
  69. uint64_t total_samples = 1000;
  70. std::vector<double> weights(total_samples, std::rand() % 100);
  71. std::vector<uint64_t> freq(total_samples, 0);
  72. // create sampler with replacement = replacement
  73. WeightedRandomSampler m_sampler(num_samples, weights, false);
  74. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  75. m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  76. std::unique_ptr<DataBuffer> db;
  77. TensorRow row;
  78. std::vector<uint64_t> out;
  79. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  80. db->PopRow(&row);
  81. for (const auto &t : row) {
  82. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  83. out.push_back(*it);
  84. freq[*it]++;
  85. }
  86. }
  87. ASSERT_EQ(num_samples, out.size());
  88. // Without replacement, each sample only drawn once.
  89. for (int i = 0; i < total_samples; i++) {
  90. if (freq[i]) {
  91. ASSERT_EQ(freq[i], 1);
  92. }
  93. }
  94. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  95. ASSERT_EQ(db->eoe(), true);
  96. }
  97. TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
  98. // num samples to draw.
  99. uint64_t num_samples = 100;
  100. uint64_t total_samples = 1000;
  101. uint64_t samples_per_buffer = 10;
  102. std::vector<double> weights(total_samples, std::rand() % 100);
  103. // create sampler with replacement = replacement
  104. WeightedRandomSampler m_sampler(num_samples, weights, true, samples_per_buffer);
  105. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  106. m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  107. std::unique_ptr<DataBuffer> db;
  108. TensorRow row;
  109. std::vector<uint64_t> out;
  110. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  111. int epoch = 0;
  112. while (!db->eoe()) {
  113. epoch++;
  114. db->PopRow(&row);
  115. for (const auto &t : row) {
  116. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  117. out.push_back(*it);
  118. }
  119. }
  120. db.reset();
  121. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  122. }
  123. ASSERT_EQ(epoch, (num_samples + samples_per_buffer - 1) / samples_per_buffer);
  124. ASSERT_EQ(num_samples, out.size());
  125. }
  126. TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
  127. // num samples to draw.
  128. uint64_t num_samples = 100;
  129. uint64_t total_samples = 100;
  130. uint64_t samples_per_buffer = 10;
  131. std::vector<double> weights(total_samples, std::rand() % 100);
  132. weights[1] = 0;
  133. weights[2] = 0;
  134. std::vector<uint64_t> freq(total_samples, 0);
  135. // create sampler with replacement = replacement
  136. WeightedRandomSampler m_sampler(num_samples, weights, false, samples_per_buffer);
  137. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  138. m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  139. std::unique_ptr<DataBuffer> db;
  140. TensorRow row;
  141. std::vector<uint64_t> out;
  142. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  143. int epoch = 0;
  144. while (!db->eoe()) {
  145. epoch++;
  146. db->PopRow(&row);
  147. for (const auto &t : row) {
  148. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  149. out.push_back(*it);
  150. freq[*it]++;
  151. }
  152. }
  153. db.reset();
  154. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  155. }
  156. // Without replacement, each sample only drawn once.
  157. for (int i = 0; i < total_samples; i++) {
  158. if (freq[i]) {
  159. ASSERT_EQ(freq[i], 1);
  160. }
  161. }
  162. ASSERT_EQ(epoch, (num_samples + samples_per_buffer - 1) / samples_per_buffer);
  163. ASSERT_EQ(num_samples, out.size());
  164. }
  165. TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
  166. // num samples to draw.
  167. uint64_t num_samples = 1000000;
  168. uint64_t total_samples = 1000000;
  169. std::vector<double> weights(total_samples, std::rand() % 100);
  170. std::vector<uint64_t> freq(total_samples, 0);
  171. // create sampler with replacement = true
  172. WeightedRandomSampler m_sampler(num_samples, weights, true);
  173. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  174. m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  175. std::unique_ptr<DataBuffer> db;
  176. TensorRow row;
  177. std::vector<uint64_t> out;
  178. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  179. db->PopRow(&row);
  180. for (const auto &t : row) {
  181. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  182. out.push_back(*it);
  183. freq[*it]++;
  184. }
  185. }
  186. ASSERT_EQ(num_samples, out.size());
  187. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  188. ASSERT_EQ(db->eoe(), true);
  189. m_sampler.ResetSampler();
  190. out.clear();
  191. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  192. db->PopRow(&row);
  193. for (const auto &t : row) {
  194. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  195. out.push_back(*it);
  196. freq[*it]++;
  197. }
  198. }
  199. ASSERT_EQ(num_samples, out.size());
  200. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  201. ASSERT_EQ(db->eoe(), true);
  202. }
  203. TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
  204. // num samples to draw.
  205. uint64_t num_samples = 1000000;
  206. uint64_t total_samples = 1000000;
  207. std::vector<double> weights(total_samples, std::rand() % 100);
  208. std::vector<uint64_t> freq(total_samples, 0);
  209. // create sampler with replacement = true
  210. WeightedRandomSampler m_sampler(num_samples, weights, false);
  211. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  212. m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  213. std::unique_ptr<DataBuffer> db;
  214. TensorRow row;
  215. std::vector<uint64_t> out;
  216. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  217. db->PopRow(&row);
  218. for (const auto &t : row) {
  219. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  220. out.push_back(*it);
  221. freq[*it]++;
  222. }
  223. }
  224. ASSERT_EQ(num_samples, out.size());
  225. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  226. ASSERT_EQ(db->eoe(), true);
  227. m_sampler.ResetSampler();
  228. out.clear();
  229. freq.clear();
  230. freq.resize(total_samples, 0);
  231. MS_LOG(INFO) << "Resetting sampler";
  232. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  233. db->PopRow(&row);
  234. for (const auto &t : row) {
  235. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  236. out.push_back(*it);
  237. freq[*it]++;
  238. }
  239. }
  240. ASSERT_EQ(num_samples, out.size());
  241. // Without replacement, each sample only drawn once.
  242. for (int i = 0; i < total_samples; i++) {
  243. if (freq[i]) {
  244. ASSERT_EQ(freq[i], 1);
  245. }
  246. }
  247. ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
  248. ASSERT_EQ(db->eoe(), true);
  249. }