You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

weighted_random_sampler_test.cc 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common.h"
  17. #include "gtest/gtest.h"
  18. #include "dataset/core/constants.h"
  19. #include "dataset/core/tensor.h"
  20. #include "dataset/engine/data_buffer.h"
  21. #include "dataset/engine/datasetops/source/sampler/sampler.h"
  22. #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
  23. #include "utils/log_adapter.h"
  24. #include <vector>
  25. #include <unordered_set>
  26. using namespace mindspore::dataset;
  27. using mindspore::MsLogLevel::INFO;
  28. using mindspore::ExceptionType::NoExceptionType;
  29. using mindspore::LogStream;
  30. class MindDataTestWeightedRandomSampler : public UT::Common {
  31. public:
  32. class DummyRandomAccessOp : public RandomAccessOp {
  33. public:
  34. DummyRandomAccessOp(uint64_t num_rows) : num_rows_(num_rows) {};
  35. Status GetNumSamples(int64_t *num) const {
  36. *num = num_rows_;
  37. return Status::OK();
  38. }
  39. Status GetNumRowsInDataset(int64_t *num) const {
  40. *num = num_rows_;
  41. return Status::OK();
  42. }
  43. private:
  44. uint64_t num_rows_;
  45. };
  46. };
  47. TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
  48. // num samples to draw.
  49. uint64_t num_samples = 100;
  50. uint64_t total_samples = 1000;
  51. std::vector<double> weights(total_samples, std::rand() % 100);
  52. std::vector<uint64_t> freq(total_samples, 0);
  53. // create sampler with replacement = true
  54. WeightedRandomSampler m_sampler(weights, num_samples, true);
  55. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  56. m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  57. std::unique_ptr<DataBuffer> db;
  58. TensorRow row;
  59. std::vector<uint64_t> out;
  60. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  61. db->PopRow(&row);
  62. for (const auto &t : row) {
  63. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  64. out.push_back(*it);
  65. freq[*it]++;
  66. }
  67. }
  68. ASSERT_EQ(num_samples, out.size());
  69. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  70. ASSERT_EQ(db->eoe(), true);
  71. }
  72. TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
  73. // num samples to draw.
  74. uint64_t num_samples = 100;
  75. uint64_t total_samples = 1000;
  76. std::vector<double> weights(total_samples, std::rand() % 100);
  77. std::vector<uint64_t> freq(total_samples, 0);
  78. // create sampler with replacement = replacement
  79. WeightedRandomSampler m_sampler(weights, num_samples, false);
  80. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  81. m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  82. std::unique_ptr<DataBuffer> db;
  83. TensorRow row;
  84. std::vector<uint64_t> out;
  85. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  86. db->PopRow(&row);
  87. for (const auto &t : row) {
  88. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  89. out.push_back(*it);
  90. freq[*it]++;
  91. }
  92. }
  93. ASSERT_EQ(num_samples, out.size());
  94. // Without replacement, each sample only drawn once.
  95. for (int i = 0; i < total_samples; i++) {
  96. if (freq[i]) {
  97. ASSERT_EQ(freq[i], 1);
  98. }
  99. }
  100. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  101. ASSERT_EQ(db->eoe(), true);
  102. }
  103. TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
  104. // num samples to draw.
  105. uint64_t num_samples = 100;
  106. uint64_t total_samples = 1000;
  107. uint64_t samples_per_buffer = 10;
  108. std::vector<double> weights(total_samples, std::rand() % 100);
  109. // create sampler with replacement = replacement
  110. WeightedRandomSampler m_sampler(weights, num_samples, true, samples_per_buffer);
  111. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  112. m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  113. std::unique_ptr<DataBuffer> db;
  114. TensorRow row;
  115. std::vector<uint64_t> out;
  116. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  117. int epoch = 0;
  118. while (!db->eoe()) {
  119. epoch++;
  120. db->PopRow(&row);
  121. for (const auto &t : row) {
  122. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  123. out.push_back(*it);
  124. }
  125. }
  126. db.reset();
  127. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  128. }
  129. ASSERT_EQ(epoch, (num_samples + samples_per_buffer - 1) / samples_per_buffer);
  130. ASSERT_EQ(num_samples, out.size());
  131. }
  132. TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
  133. // num samples to draw.
  134. uint64_t num_samples = 100;
  135. uint64_t total_samples = 100;
  136. uint64_t samples_per_buffer = 10;
  137. std::vector<double> weights(total_samples, std::rand() % 100);
  138. weights[1] = 0;
  139. weights[2] = 0;
  140. std::vector<uint64_t> freq(total_samples, 0);
  141. // create sampler with replacement = replacement
  142. WeightedRandomSampler m_sampler(weights, num_samples, false, samples_per_buffer);
  143. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  144. m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  145. std::unique_ptr<DataBuffer> db;
  146. TensorRow row;
  147. std::vector<uint64_t> out;
  148. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  149. int epoch = 0;
  150. while (!db->eoe()) {
  151. epoch++;
  152. db->PopRow(&row);
  153. for (const auto &t : row) {
  154. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  155. out.push_back(*it);
  156. freq[*it]++;
  157. }
  158. }
  159. db.reset();
  160. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  161. }
  162. // Without replacement, each sample only drawn once.
  163. for (int i = 0; i < total_samples; i++) {
  164. if (freq[i]) {
  165. ASSERT_EQ(freq[i], 1);
  166. }
  167. }
  168. ASSERT_EQ(epoch, (num_samples + samples_per_buffer - 1) / samples_per_buffer);
  169. ASSERT_EQ(num_samples, out.size());
  170. }
  171. TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
  172. // num samples to draw.
  173. uint64_t num_samples = 1000000;
  174. uint64_t total_samples = 1000000;
  175. std::vector<double> weights(total_samples, std::rand() % 100);
  176. std::vector<uint64_t> freq(total_samples, 0);
  177. // create sampler with replacement = true
  178. WeightedRandomSampler m_sampler(weights, num_samples, true);
  179. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  180. m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  181. std::unique_ptr<DataBuffer> db;
  182. TensorRow row;
  183. std::vector<uint64_t> out;
  184. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  185. db->PopRow(&row);
  186. for (const auto &t : row) {
  187. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  188. out.push_back(*it);
  189. freq[*it]++;
  190. }
  191. }
  192. ASSERT_EQ(num_samples, out.size());
  193. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  194. ASSERT_EQ(db->eoe(), true);
  195. m_sampler.Reset();
  196. out.clear();
  197. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  198. db->PopRow(&row);
  199. for (const auto &t : row) {
  200. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  201. out.push_back(*it);
  202. freq[*it]++;
  203. }
  204. }
  205. ASSERT_EQ(num_samples, out.size());
  206. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  207. ASSERT_EQ(db->eoe(), true);
  208. }
  209. TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
  210. // num samples to draw.
  211. uint64_t num_samples = 1000000;
  212. uint64_t total_samples = 1000000;
  213. std::vector<double> weights(total_samples, std::rand() % 100);
  214. std::vector<uint64_t> freq(total_samples, 0);
  215. // create sampler with replacement = true
  216. WeightedRandomSampler m_sampler(weights, num_samples, false);
  217. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  218. m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  219. std::unique_ptr<DataBuffer> db;
  220. TensorRow row;
  221. std::vector<uint64_t> out;
  222. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  223. db->PopRow(&row);
  224. for (const auto &t : row) {
  225. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  226. out.push_back(*it);
  227. freq[*it]++;
  228. }
  229. }
  230. ASSERT_EQ(num_samples, out.size());
  231. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  232. ASSERT_EQ(db->eoe(), true);
  233. m_sampler.Reset();
  234. out.clear();
  235. freq.clear();
  236. freq.resize(total_samples, 0);
  237. MS_LOG(INFO) << "Resetting sampler";
  238. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  239. db->PopRow(&row);
  240. for (const auto &t : row) {
  241. for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
  242. out.push_back(*it);
  243. freq[*it]++;
  244. }
  245. }
  246. ASSERT_EQ(num_samples, out.size());
  247. // Without replacement, each sample only drawn once.
  248. for (int i = 0; i < total_samples; i++) {
  249. if (freq[i]) {
  250. ASSERT_EQ(freq[i], 1);
  251. }
  252. }
  253. ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
  254. ASSERT_EQ(db->eoe(), true);
  255. }