You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trans_test.cc 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <vector>
  17. #include "common/common_test.h"
  18. #include "common/trans.h"
  19. #include "utils/utils.h"
  20. using namespace std;
  21. namespace mindspore {
  22. namespace trans {
  23. class FormatTransTest : public UT::Common {
  24. public:
  25. FormatTransTest() = default;
  26. void SetUp() override {}
  27. void TearDown() override {}
  28. };
  29. TEST_F(FormatTransTest, nchw_to_hwcn) {
  30. uint16_t data[2 * 2 * 2 * 2] = {12581, 14220, 14937, 14302, 15004, 14951, 14694, 14564,
  31. 14069, 14554, 10507, 14787, 13016, 15263, 14872, 10838};
  32. uint16_t res[2 * 2 * 2 * 2] = {12581, 14069, 15004, 13016, 14220, 14554, 14951, 15263,
  33. 14937, 10507, 14694, 14872, 14302, 14787, 14564, 10838};
  34. size_t device_size = 32;
  35. auto trans_tmp = std::vector<uint8_t>(device_size);
  36. FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN,
  37. {2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16};
  38. EXPECT_EQ(trans::TransFormat(format_args, trans_tmp.data()), true);
  39. for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) {
  40. EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
  41. }
  42. }
  43. TEST_F(FormatTransTest, hwcn_to_nchw) {
  44. uint16_t data[2 * 2 * 2 * 2] = {12581, 14069, 15004, 13016, 14220, 14554, 14951, 15263,
  45. 14937, 10507, 14694, 14872, 14302, 14787, 14564, 10838};
  46. uint16_t res[2 * 2 * 2 * 2] = {12581, 14220, 14937, 14302, 15004, 14951, 14694, 14564,
  47. 14069, 14554, 10507, 14787, 13016, 15263, 14872, 10838};
  48. size_t device_size = 32;
  49. auto trans_tmp = std::vector<uint8_t>(device_size);
  50. FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN,
  51. {2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16};
  52. EXPECT_EQ(trans::TransFormatFromDeviceToHost(format_args, trans_tmp.data()), true);
  53. for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) {
  54. EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
  55. }
  56. }
  57. TEST_F(FormatTransTest, nchw_to_nhwc) {
  58. uint16_t data[2 * 2 * 2 * 2] = {11750, 13778, 15007, 15321, 15163, 13446, 15063, 14467,
  59. 15056, 13284, 15219, 14797, 12684, 14288, 14855, 14799};
  60. uint16_t res[2 * 2 * 2 * 2] = {11750, 15163, 13778, 13446, 15007, 15063, 15321, 14467,
  61. 15056, 12684, 13284, 14288, 15219, 14855, 14797, 14799};
  62. size_t device_size = 32;
  63. auto trans_tmp = std::vector<uint8_t>(device_size);
  64. FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC,
  65. {2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16};
  66. EXPECT_EQ(trans::TransFormat(format_args, trans_tmp.data()), true);
  67. for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) {
  68. EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
  69. }
  70. }
  71. TEST_F(FormatTransTest, nhwc_to_nchw) {
  72. uint16_t data[2 * 2 * 2 * 2] = {11750, 15163, 13778, 13446, 15007, 15063, 15321, 14467,
  73. 15056, 12684, 13284, 14288, 15219, 14855, 14797, 14799};
  74. uint16_t res[2 * 2 * 2 * 2] = {11750, 13778, 15007, 15321, 15163, 13446, 15063, 14467,
  75. 15056, 13284, 15219, 14797, 12684, 14288, 14855, 14799};
  76. size_t device_size = 32;
  77. auto trans_tmp = std::vector<uint8_t>(device_size);
  78. FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC,
  79. {2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16};
  80. EXPECT_EQ(trans::TransFormatFromDeviceToHost(format_args, trans_tmp.data()), true);
  81. for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) {
  82. EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
  83. }
  84. }
  85. class ShapeTransTest : public UT::Common {
  86. public:
  87. ShapeTransTest() = default;
  88. void SetUp() override {}
  89. void TearDown() override {}
  90. };
  91. TEST_F(ShapeTransTest, fraczn_rnn_device_shape) {
  92. std::vector<size_t> host_shape = {43, 120};
  93. std::string format = kOpFormat_FRACTAL_ZN_RNN;
  94. std::vector<int64_t> input_hidden_size = {13, 30};
  95. auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size);
  96. const std::vector<size_t> expect_shape = {3, 8, 16, 16};
  97. EXPECT_EQ(trans_shape.size(), expect_shape.size());
  98. for (size_t i = 0; i < expect_shape.size(); i++) {
  99. EXPECT_EQ(trans_shape[i], expect_shape[i]);
  100. }
  101. }
  102. TEST_F(ShapeTransTest, nd_rnn_bias_device_shape) {
  103. std::vector<size_t> host_shape = {120};
  104. std::string format = kOpFormat_ND_RNN_BIAS;
  105. std::vector<int64_t> input_hidden_size = {13, 30};
  106. auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size);
  107. std::vector<size_t> expect_shape = {128};
  108. EXPECT_EQ(trans_shape.size(), expect_shape.size());
  109. for (size_t i = 0; i < expect_shape.size(); i++) {
  110. EXPECT_EQ(trans_shape[i], expect_shape[i]);
  111. }
  112. }
  113. TEST_F(ShapeTransTest, fraczn_rnn_dynamic_device_shape) {
  114. std::vector<int64_t> host_shape = {-1, -1};
  115. std::string format = kOpFormat_FRACTAL_ZN_RNN;
  116. std::vector<int64_t> input_hidden_size = {13, 30};
  117. auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size);
  118. const std::vector<int64_t> expect_shape = {-1, -1, 16, 16};
  119. EXPECT_EQ(trans_shape.size(), expect_shape.size());
  120. for (size_t i = 0; i < expect_shape.size(); i++) {
  121. EXPECT_EQ(trans_shape[i], expect_shape[i]);
  122. }
  123. }
  124. TEST_F(ShapeTransTest, nd_rnn_bias_dynamic_device_shape) {
  125. std::vector<int64_t> host_shape = {-1};
  126. std::string format = kOpFormat_ND_RNN_BIAS;
  127. std::vector<int64_t> input_hidden_size = {13, 30};
  128. auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size);
  129. std::vector<int64_t> expect_shape = {-1};
  130. EXPECT_EQ(trans_shape.size(), expect_shape.size());
  131. for (size_t i = 0; i < expect_shape.size(); i++) {
  132. EXPECT_EQ(trans_shape[i], expect_shape[i]);
  133. }
  134. }
  135. } // namespace trans
  136. } // namespace mindspore