You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

subset_sampler_test.cc 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common.h"
  17. #include "gtest/gtest.h"
  18. #include "minddata/dataset/include/constants.h"
  19. #include "minddata/dataset/core/tensor.h"
  20. #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
  21. #include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h"
  22. #include <vector>
  23. #include <unordered_set>
  24. using namespace mindspore::dataset;
  25. class MindDataTestSubsetSampler : public UT::Common {
  26. public:
  27. class DummyRandomAccessOp : public RandomAccessOp {
  28. public:
  29. DummyRandomAccessOp(int64_t num_rows) {
  30. num_rows_ = num_rows; // base class
  31. };
  32. };
  33. };
  34. TEST_F(MindDataTestSubsetSampler, TestAllAtOnce) {
  35. std::vector<int64_t> in({3, 1, 4, 0, 1});
  36. std::unordered_set<int64_t> in_set(in.begin(), in.end());
  37. int64_t num_samples = 0;
  38. SubsetSamplerRT sampler(num_samples, in);
  39. DummyRandomAccessOp dummyRandomAccessOp(5);
  40. sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  41. TensorRow row;
  42. std::vector<int64_t> out;
  43. ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
  44. for (const auto &t : row) {
  45. for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
  46. out.push_back(*it);
  47. }
  48. }
  49. ASSERT_EQ(in.size(), out.size());
  50. for (int i = 0; i < in.size(); i++) {
  51. ASSERT_EQ(in[i], out[i]);
  52. }
  53. ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
  54. ASSERT_EQ(row.eoe(), true);
  55. }
  56. TEST_F(MindDataTestSubsetSampler, TestGetNextSample) {
  57. int64_t total_samples = 100000 - 5;
  58. int64_t samples_per_tensor = 10;
  59. int64_t num_samples = 0;
  60. std::vector<int64_t> input(total_samples, 1);
  61. SubsetSamplerRT sampler(num_samples, input, samples_per_tensor);
  62. DummyRandomAccessOp dummyRandomAccessOp(total_samples);
  63. sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  64. TensorRow row;
  65. std::vector<int64_t> out;
  66. ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
  67. int epoch = 0;
  68. while (!row.eoe()) {
  69. epoch++;
  70. for (const auto &t : row) {
  71. for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
  72. out.push_back(*it);
  73. }
  74. }
  75. ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
  76. }
  77. ASSERT_EQ(epoch, (total_samples + samples_per_tensor - 1) / samples_per_tensor);
  78. ASSERT_EQ(input.size(), out.size());
  79. }
  80. TEST_F(MindDataTestSubsetSampler, TestReset) {
  81. std::vector<int64_t> in({0, 1, 2, 3, 4});
  82. std::unordered_set<int64_t> in_set(in.begin(), in.end());
  83. int64_t num_samples = 0;
  84. SubsetSamplerRT sampler(num_samples, in);
  85. DummyRandomAccessOp dummyRandomAccessOp(5);
  86. sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
  87. TensorRow row;
  88. std::vector<int64_t> out;
  89. ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
  90. for (const auto &t : row) {
  91. for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
  92. out.push_back(*it);
  93. }
  94. }
  95. ASSERT_EQ(in.size(), out.size());
  96. for (int i = 0; i < in.size(); i++) {
  97. ASSERT_EQ(in[i], out[i]);
  98. }
  99. sampler.ResetSampler();
  100. ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
  101. ASSERT_EQ(row.eoe(), false);
  102. out.clear();
  103. for (const auto &t : row) {
  104. for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
  105. out.push_back(*it);
  106. }
  107. }
  108. ASSERT_EQ(in.size(), out.size());
  109. for (int i = 0; i < in.size(); i++) {
  110. ASSERT_EQ(in[i], out[i]);
  111. }
  112. ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
  113. ASSERT_EQ(row.eoe(), true);
  114. }