You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pad_batch.py 9.2 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import time
  16. import numpy as np
  17. import mindspore.dataset as ds
  18. # This UT test tests the following cases
  19. # 1. padding: input_shape=[x] output_shape=[y] where y > x
  20. # 2. padding in one dimension and truncate in the other. input_shape=[x1,x2] output_shape=[y1,y2] y1>x1 and y2<x2
  21. # 3. automatic padding for a specific column
  22. # 4. default setting for all columns
  23. # 5. test None in different places
  24. # this generator function yield two columns
  25. # col1d: [0],[1], [2], [3]
  26. # col2d: [[100],[200]], [[101],[201]], [102],[202]], [103],[203]]
  27. def gen_2cols(num):
  28. for i in range(num):
  29. yield (np.array([i]), np.array([[i + 100], [i + 200]]))
  30. # this generator function yield one column of variable shapes
  31. # col: [0], [0,1], [0,1,2], [0,1,2,3]
  32. def gen_var_col(num):
  33. for i in range(num):
  34. yield (np.array([j for j in range(i + 1)]),)
  35. # this generator function yield two columns of variable shapes
  36. # col1: [0], [0,1], [0,1,2], [0,1,2,3]
  37. # col2: [100], [100,101], [100,101,102], [100,110,102,103]
  38. def gen_var_cols(num):
  39. for i in range(num):
  40. yield (np.array([j for j in range(i + 1)]), np.array([100 + j for j in range(i + 1)]))
  41. # this generator function yield two columns of variable shapes
  42. # col1: [[0]], [[0,1]], [[0,1,2]], [[0,1,2,3]]
  43. # col2: [[100]], [[100,101]], [[100,101,102]], [[100,110,102,103]]
  44. def gen_var_cols_2d(num):
  45. for i in range(num):
  46. yield (np.array([[j for j in range(i + 1)]]), np.array([[100 + j for j in range(i + 1)]]))
  47. def test_batch_padding_01():
  48. data1 = ds.GeneratorDataset((lambda: gen_2cols(2)), ["col1d", "col2d"])
  49. data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col2d": ([2, 2], -2), "col1d": ([2], -1)})
  50. data1 = data1.repeat(2)
  51. for data in data1.create_dict_iterator(num_epochs=1):
  52. np.testing.assert_array_equal([[0, -1], [1, -1]], data["col1d"])
  53. np.testing.assert_array_equal([[[100, -2], [200, -2]], [[101, -2], [201, -2]]], data["col2d"])
  54. def test_batch_padding_02():
  55. data1 = ds.GeneratorDataset((lambda: gen_2cols(2)), ["col1d", "col2d"])
  56. data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col2d": ([1, 2], -2)})
  57. data1 = data1.repeat(2)
  58. for data in data1.create_dict_iterator(num_epochs=1):
  59. np.testing.assert_array_equal([[0], [1]], data["col1d"])
  60. np.testing.assert_array_equal([[[100, -2]], [[101, -2]]], data["col2d"])
  61. def test_batch_padding_03():
  62. data1 = ds.GeneratorDataset((lambda: gen_var_col(4)), ["col"])
  63. data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col": (None, -1)}) # pad automatically
  64. data1 = data1.repeat(2)
  65. res = dict()
  66. for ind, data in enumerate(data1.create_dict_iterator(num_epochs=1)):
  67. res[ind] = data["col"].copy()
  68. np.testing.assert_array_equal(res[0], [[0, -1], [0, 1]])
  69. np.testing.assert_array_equal(res[1], [[0, 1, 2, -1], [0, 1, 2, 3]])
  70. np.testing.assert_array_equal(res[2], [[0, -1], [0, 1]])
  71. np.testing.assert_array_equal(res[3], [[0, 1, 2, -1], [0, 1, 2, 3]])
  72. def test_batch_padding_04():
  73. data1 = ds.GeneratorDataset((lambda: gen_var_cols(2)), ["col1", "col2"])
  74. data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={}) # pad automatically
  75. data1 = data1.repeat(2)
  76. for data in data1.create_dict_iterator(num_epochs=1):
  77. np.testing.assert_array_equal(data["col1"], [[0, 0], [0, 1]])
  78. np.testing.assert_array_equal(data["col2"], [[100, 0], [100, 101]])
  79. def test_batch_padding_05():
  80. data1 = ds.GeneratorDataset((lambda: gen_var_cols_2d(3)), ["col1", "col2"])
  81. data1 = data1.batch(batch_size=3, drop_remainder=False,
  82. pad_info={"col2": ([2, None], -2), "col1": (None, -1)}) # pad automatically
  83. for data in data1.create_dict_iterator(num_epochs=1):
  84. np.testing.assert_array_equal(data["col1"], [[[0, -1, -1]], [[0, 1, -1]], [[0, 1, 2]]])
  85. np.testing.assert_array_equal(data["col2"], [[[100, -2, -2], [-2, -2, -2]], [[100, 101, -2], [-2, -2, -2]],
  86. [[100, 101, 102], [-2, -2, -2]]])
  87. def batch_padding_performance_3d():
  88. cifar10_dir = "../data/dataset/testCifar10Data"
  89. data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
  90. data1 = data1.repeat(24)
  91. pad_info = {"image": ([36, 36, 3], 0)}
  92. # pad_info = None
  93. data1 = data1.batch(batch_size=24, drop_remainder=True, pad_info=pad_info)
  94. start_time = time.time()
  95. num_batches = 0
  96. for _ in data1.create_dict_iterator(num_epochs=1):
  97. num_batches += 1
  98. _ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
  99. # print(res)
  100. def batch_padding_performance_1d():
  101. cifar10_dir = "../data/dataset/testCifar10Data"
  102. data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
  103. data1 = data1.repeat(24)
  104. data1 = data1.map(operations=(lambda x: x.reshape(-1)), input_columns="image")
  105. pad_info = {"image": ([3888], 0)} # 3888 =36*36*3
  106. # pad_info = None
  107. data1 = data1.batch(batch_size=24, drop_remainder=True, pad_info=pad_info)
  108. start_time = time.time()
  109. num_batches = 0
  110. for _ in data1.create_dict_iterator(num_epochs=1):
  111. num_batches += 1
  112. _ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
  113. # print(res)
  114. def batch_pyfunc_padding_3d():
  115. cifar10_dir = "../data/dataset/testCifar10Data"
  116. data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
  117. data1 = data1.repeat(24)
  118. # pad_info = {"image": ([36, 36, 3], 0)}
  119. data1 = data1.map(operations=(lambda x: np.pad(x, ((0, 4), (0, 4), (0, 0)))), input_columns="image",
  120. python_multiprocessing=False)
  121. data1 = data1.batch(batch_size=24, drop_remainder=True)
  122. start_time = time.time()
  123. num_batches = 0
  124. for _ in data1.create_dict_iterator(num_epochs=1):
  125. num_batches += 1
  126. _ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
  127. # print(res)
  128. def batch_pyfunc_padding_1d():
  129. cifar10_dir = "../data/dataset/testCifar10Data"
  130. data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
  131. data1 = data1.repeat(24)
  132. data1 = data1.map(operations=(lambda x: x.reshape(-1)), input_columns="image")
  133. data1 = data1.map(operations=(lambda x: np.pad(x, (0, 816))), input_columns="image", python_multiprocessing=False)
  134. data1 = data1.batch(batch_size=24, drop_remainder=True)
  135. start_time = time.time()
  136. num_batches = 0
  137. for _ in data1.create_dict_iterator(num_epochs=1):
  138. num_batches += 1
  139. _ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
  140. # print(res)
  141. # this function runs pad_batch and numpy.pad then compare the results
  142. def test_pad_via_map():
  143. cifar10_dir = "../data/dataset/testCifar10Data"
  144. def pad_map_config():
  145. data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False, num_samples=1000) # shape = [32,32,3]
  146. data1 = data1.map(operations=(lambda x: x.reshape(-1)), input_columns="image") # reshape to 1d
  147. data1 = data1.map(operations=(lambda x: np.pad(x, (0, 816))), input_columns="image")
  148. data1 = data1.batch(batch_size=25, drop_remainder=True)
  149. res = []
  150. for data in data1.create_dict_iterator(num_epochs=1):
  151. res.append(data["image"])
  152. return res
  153. def pad_batch_config():
  154. data2 = ds.Cifar10Dataset(cifar10_dir, shuffle=False, num_samples=1000) # shape = [32,32,3]
  155. data2 = data2.map(operations=(lambda x: x.reshape(-1)), input_columns="image") # reshape to 1d
  156. data2 = data2.batch(batch_size=25, drop_remainder=True, pad_info={"image": ([3888], 0)})
  157. res = []
  158. for data in data2.create_dict_iterator(num_epochs=1):
  159. res.append(data["image"])
  160. return res
  161. res_from_map = pad_map_config()
  162. res_from_batch = pad_batch_config()
  163. assert len(res_from_batch) == len(res_from_batch)
  164. for i, _ in enumerate(res_from_map):
  165. np.testing.assert_array_equal(res_from_map[i], res_from_batch[i])
  166. if __name__ == '__main__':
  167. test_batch_padding_01()
  168. test_batch_padding_02()
  169. test_batch_padding_03()
  170. test_batch_padding_04()
  171. test_batch_padding_05()
  172. # batch_padding_performance_3d()
  173. # batch_padding_performance_1d()
  174. # batch_pyfunc_padding_3d()
  175. # batch_pyfunc_padding_1d()
  176. test_pad_via_map()