You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_utils_test.cc 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <vector>
  17. #include "common/common_test.h"
  18. #include "backend/kernel_compiler/common_utils.h"
  19. namespace mindspore {
  20. namespace kernel {
  21. class CommonUtilTest : public UT::Common {
  22. public:
  23. CommonUtilTest() = default;
  24. };
  25. TEST_F(CommonUtilTest, BucketReduceSparseGradient1) {
  26. // The indices is a vector and the grad is a tensor with shape (6, 2)
  27. /* 0
  28. * 0
  29. * 1
  30. * 1
  31. * 0
  32. * 3
  33. */
  34. std::vector<int> indices{0, 0, 1, 1, 0, 3};
  35. /* 0 1
  36. * 2 3
  37. * 4 5
  38. * 6 7
  39. * 8 9
  40. * 10 11
  41. */
  42. std::vector<float> grad;
  43. for (int i = 0; i < 6 * 2; i++) {
  44. grad.push_back(i);
  45. }
  46. std::vector<int> unique_indices(6);
  47. std::vector<float> summed_grad(12);
  48. std::vector<int> tmp_indices(6);
  49. std::vector<float> tmp_grad(12);
  50. SparseGradient unique_grad({summed_grad.data(), unique_indices.data(), 6});
  51. SparseGradient workspace_grad({tmp_grad.data(), tmp_indices.data(), 6});
  52. SparseGradient input_grad({grad.data(), indices.data(), 6});
  53. ReduceSparseGradientParam param;
  54. param.input_grad_ = &input_grad;
  55. param.workspace_grad_ = &workspace_grad;
  56. param.output_grad_ = &unique_grad;
  57. param.max_index_ = 6;
  58. param.value_stride_ = 2;
  59. BucketReduceSparseGradient(param);
  60. EXPECT_EQ(unique_grad.indices_size_, 3);
  61. std::vector<int> expect_indices({0, 1, 3});
  62. for (size_t i = 0; i < unique_grad.indices_size_; ++i) {
  63. EXPECT_EQ(unique_grad.indices_[i], expect_indices[i]);
  64. }
  65. /* 10 13
  66. * 10 12
  67. * 10 11
  68. */
  69. std::vector<int> expect_value({10, 13, 10, 12, 10, 11});
  70. for (size_t i = 0; i < unique_grad.indices_size_ * 2; ++i) {
  71. EXPECT_EQ(unique_grad.value_[i], expect_value[i]);
  72. }
  73. }
  74. TEST_F(CommonUtilTest, BucketReduceSparseGradient2) {
  75. // The indices is a vector and the grad is a tensor with shape (6, 2)
  76. /* 0
  77. * 0
  78. * 1
  79. * 1
  80. * 0
  81. * 6
  82. */
  83. std::vector<int> indices{0, 0, 1, 1, 0, 6};
  84. /* 0 1
  85. * 2 3
  86. * 4 5
  87. * 6 7
  88. * 8 9
  89. * 10 11
  90. */
  91. std::vector<float> grad;
  92. for (int i = 0; i < 6 * 2; i++) {
  93. grad.push_back(i);
  94. }
  95. std::vector<int> unique_indices(6);
  96. std::vector<float> summed_grad(12);
  97. std::vector<int> tmp_indices(6);
  98. std::vector<float> tmp_grad(12);
  99. SparseGradient unique_grad({summed_grad.data(), unique_indices.data(), 6});
  100. SparseGradient workspace_grad({tmp_grad.data(), tmp_indices.data(), 6});
  101. SparseGradient input_grad({grad.data(), indices.data(), 6});
  102. ReduceSparseGradientParam param;
  103. param.input_grad_ = &input_grad;
  104. param.workspace_grad_ = &workspace_grad;
  105. param.output_grad_ = &unique_grad;
  106. param.max_index_ = 6;
  107. param.value_stride_ = 2;
  108. BucketReduceSparseGradient(param);
  109. EXPECT_EQ(unique_grad.indices_size_, 2);
  110. std::vector<int> expect_indices({0, 1});
  111. for (size_t i = 0; i < unique_grad.indices_size_; ++i) {
  112. EXPECT_EQ(unique_grad.indices_[i], expect_indices[i]);
  113. }
  114. /* 10 13
  115. * 10 12
  116. */
  117. std::vector<int> expect_value({10, 13, 10, 12});
  118. for (size_t i = 0; i < unique_grad.indices_size_ * 2; ++i) {
  119. EXPECT_EQ(unique_grad.value_[i], expect_value[i]);
  120. }
  121. }
  122. } // namespace kernel
  123. } // namespace mindspore