You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

subset_sampler_test.cc 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common.h"
  17. #include "gtest/gtest.h"
  18. #include "minddata/dataset/include/constants.h"
  19. #include "minddata/dataset/core/tensor.h"
  20. #include "minddata/dataset/engine/data_buffer.h"
  21. #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
  22. #include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h"
  23. #include <vector>
  24. #include <unordered_set>
  25. using namespace mindspore::dataset;
  26. class MindDataTestSubsetSampler : public UT::Common {
  27. public:
  28. class DummyRandomAccessOp : public RandomAccessOp {
  29. public:
  30. DummyRandomAccessOp(int64_t num_rows) {
  31. num_rows_ = num_rows; // base class
  32. };
  33. };
  34. };
  35. TEST_F(MindDataTestSubsetSampler, TestAllAtOnce) {
  36. std::vector<int64_t> in({3, 1, 4, 0, 1});
  37. std::unordered_set<int64_t> in_set(in.begin(), in.end());
  38. int64_t num_samples = 0;
  39. SubsetSamplerRT sampler(num_samples, in);
  40. DummyRandomAccessOp dummyRandomAccessOp(5);
  41. sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  42. std::unique_ptr<DataBuffer> db;
  43. TensorRow row;
  44. std::vector<int64_t> out;
  45. ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
  46. db->PopRow(&row);
  47. for (const auto &t : row) {
  48. for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
  49. out.push_back(*it);
  50. }
  51. }
  52. ASSERT_EQ(in.size(), out.size());
  53. for (int i = 0; i < in.size(); i++) {
  54. ASSERT_EQ(in[i], out[i]);
  55. }
  56. ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
  57. ASSERT_EQ(db->eoe(), true);
  58. }
  59. TEST_F(MindDataTestSubsetSampler, TestGetNextBuffer) {
  60. int64_t total_samples = 100000 - 5;
  61. int64_t samples_per_buffer = 10;
  62. int64_t num_samples = 0;
  63. std::vector<int64_t> input(total_samples, 1);
  64. SubsetSamplerRT sampler(num_samples, input, samples_per_buffer);
  65. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  66. sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  67. std::unique_ptr<DataBuffer> db;
  68. TensorRow row;
  69. std::vector<int64_t> out;
  70. ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
  71. int epoch = 0;
  72. while (!db->eoe()) {
  73. epoch++;
  74. db->PopRow(&row);
  75. for (const auto &t : row) {
  76. for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
  77. out.push_back(*it);
  78. }
  79. }
  80. db.reset();
  81. ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
  82. }
  83. ASSERT_EQ(epoch, (total_samples + samples_per_buffer - 1) / samples_per_buffer);
  84. ASSERT_EQ(input.size(), out.size());
  85. }
  86. TEST_F(MindDataTestSubsetSampler, TestReset) {
  87. std::vector<int64_t> in({0, 1, 2, 3, 4});
  88. std::unordered_set<int64_t> in_set(in.begin(), in.end());
  89. int64_t num_samples = 0;
  90. SubsetSamplerRT sampler(num_samples, in);
  91. DummyRandomAccessOp dummyRandomAccessOp(5);
  92. sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  93. std::unique_ptr<DataBuffer> db;
  94. TensorRow row;
  95. std::vector<int64_t> out;
  96. ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
  97. db->PopRow(&row);
  98. for (const auto &t : row) {
  99. for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
  100. out.push_back(*it);
  101. }
  102. }
  103. ASSERT_EQ(in.size(), out.size());
  104. for (int i = 0; i < in.size(); i++) {
  105. ASSERT_EQ(in[i], out[i]);
  106. }
  107. sampler.ResetSampler();
  108. ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
  109. ASSERT_EQ(db->eoe(), false);
  110. db->PopRow(&row);
  111. out.clear();
  112. for (const auto &t : row) {
  113. for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
  114. out.push_back(*it);
  115. }
  116. }
  117. ASSERT_EQ(in.size(), out.size());
  118. for (int i = 0; i < in.size(); i++) {
  119. ASSERT_EQ(in[i], out[i]);
  120. }
  121. ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
  122. ASSERT_EQ(db->eoe(), true);
  123. }