You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_apply.py 7.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.dataset as ds
  17. import mindspore.dataset.vision.c_transforms as vision
  18. from mindspore import log as logger
  19. DATA_DIR = "../data/dataset/testPK/data"
  20. # Generate 1d int numpy array from 0 - 64
  21. def generator_1d():
  22. for i in range(64):
  23. yield (np.array([i]),)
  24. def test_apply_generator_case():
  25. # apply dataset operations
  26. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  27. data2 = ds.GeneratorDataset(generator_1d, ["data"])
  28. def dataset_fn(ds_):
  29. ds_ = ds_.repeat(2)
  30. return ds_.batch(4)
  31. data1 = data1.apply(dataset_fn)
  32. data2 = data2.repeat(2)
  33. data2 = data2.batch(4)
  34. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  35. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  36. np.testing.assert_array_equal(item1["data"], item2["data"])
  37. def test_apply_imagefolder_case():
  38. # apply dataset map operations
  39. data1 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=3)
  40. data2 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=3)
  41. decode_op = vision.Decode()
  42. normalize_op = vision.Normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0])
  43. def dataset_fn(ds_):
  44. ds_ = ds_.map(operations=decode_op)
  45. ds_ = ds_.map(operations=normalize_op)
  46. ds_ = ds_.repeat(2)
  47. return ds_
  48. data1 = data1.apply(dataset_fn)
  49. data2 = data2.map(operations=decode_op)
  50. data2 = data2.map(operations=normalize_op)
  51. data2 = data2.repeat(2)
  52. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  53. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  54. np.testing.assert_array_equal(item1["image"], item2["image"])
  55. def test_apply_flow_case_0(id_=0):
  56. # apply control flow operations
  57. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  58. def dataset_fn(ds_):
  59. if id_ == 0:
  60. ds_ = ds_.batch(4)
  61. elif id_ == 1:
  62. ds_ = ds_.repeat(2)
  63. elif id_ == 2:
  64. ds_ = ds_.batch(4)
  65. ds_ = ds_.repeat(2)
  66. else:
  67. ds_ = ds_.shuffle(buffer_size=4)
  68. return ds_
  69. data1 = data1.apply(dataset_fn)
  70. num_iter = 0
  71. for _ in data1.create_dict_iterator(num_epochs=1):
  72. num_iter = num_iter + 1
  73. if id_ == 0:
  74. assert num_iter == 16
  75. elif id_ == 1:
  76. assert num_iter == 128
  77. elif id_ == 2:
  78. assert num_iter == 32
  79. else:
  80. assert num_iter == 64
  81. def test_apply_flow_case_1(id_=1):
  82. # apply control flow operations
  83. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  84. def dataset_fn(ds_):
  85. if id_ == 0:
  86. ds_ = ds_.batch(4)
  87. elif id_ == 1:
  88. ds_ = ds_.repeat(2)
  89. elif id_ == 2:
  90. ds_ = ds_.batch(4)
  91. ds_ = ds_.repeat(2)
  92. else:
  93. ds_ = ds_.shuffle(buffer_size=4)
  94. return ds_
  95. data1 = data1.apply(dataset_fn)
  96. num_iter = 0
  97. for _ in data1.create_dict_iterator(num_epochs=1):
  98. num_iter = num_iter + 1
  99. if id_ == 0:
  100. assert num_iter == 16
  101. elif id_ == 1:
  102. assert num_iter == 128
  103. elif id_ == 2:
  104. assert num_iter == 32
  105. else:
  106. assert num_iter == 64
  107. def test_apply_flow_case_2(id_=2):
  108. # apply control flow operations
  109. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  110. def dataset_fn(ds_):
  111. if id_ == 0:
  112. ds_ = ds_.batch(4)
  113. elif id_ == 1:
  114. ds_ = ds_.repeat(2)
  115. elif id_ == 2:
  116. ds_ = ds_.batch(4)
  117. ds_ = ds_.repeat(2)
  118. else:
  119. ds_ = ds_.shuffle(buffer_size=4)
  120. return ds_
  121. data1 = data1.apply(dataset_fn)
  122. num_iter = 0
  123. for _ in data1.create_dict_iterator(num_epochs=1):
  124. num_iter = num_iter + 1
  125. if id_ == 0:
  126. assert num_iter == 16
  127. elif id_ == 1:
  128. assert num_iter == 128
  129. elif id_ == 2:
  130. assert num_iter == 32
  131. else:
  132. assert num_iter == 64
  133. def test_apply_flow_case_3(id_=3):
  134. # apply control flow operations
  135. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  136. def dataset_fn(ds_):
  137. if id_ == 0:
  138. ds_ = ds_.batch(4)
  139. elif id_ == 1:
  140. ds_ = ds_.repeat(2)
  141. elif id_ == 2:
  142. ds_ = ds_.batch(4)
  143. ds_ = ds_.repeat(2)
  144. else:
  145. ds_ = ds_.shuffle(buffer_size=4)
  146. return ds_
  147. data1 = data1.apply(dataset_fn)
  148. num_iter = 0
  149. for _ in data1.create_dict_iterator(num_epochs=1):
  150. num_iter = num_iter + 1
  151. if id_ == 0:
  152. assert num_iter == 16
  153. elif id_ == 1:
  154. assert num_iter == 128
  155. elif id_ == 2:
  156. assert num_iter == 32
  157. else:
  158. assert num_iter == 64
  159. def test_apply_exception_case():
  160. # apply exception operations
  161. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  162. def dataset_fn(ds_):
  163. ds_ = ds_.repeat(2)
  164. return ds_.batch(4)
  165. def exception_fn():
  166. return np.array([[0], [1], [3], [4], [5]])
  167. try:
  168. data1 = data1.apply("123")
  169. for _ in data1.create_dict_iterator(num_epochs=1):
  170. pass
  171. assert False
  172. except TypeError:
  173. pass
  174. try:
  175. data1 = data1.apply(exception_fn)
  176. for _ in data1.create_dict_iterator(num_epochs=1):
  177. pass
  178. assert False
  179. except TypeError:
  180. pass
  181. try:
  182. data2 = data1.apply(dataset_fn)
  183. _ = data1.apply(dataset_fn)
  184. for _, _ in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
  185. pass
  186. assert False
  187. except ValueError as e:
  188. logger.info("Got an exception in DE: {}".format(str(e)))
  189. if __name__ == '__main__':
  190. logger.info("Running test_apply.py test_apply_generator_case() function")
  191. test_apply_generator_case()
  192. logger.info("Running test_apply.py test_apply_imagefolder_case() function")
  193. test_apply_imagefolder_case()
  194. logger.info("Running test_apply.py test_apply_flow_case(id) function")
  195. test_apply_flow_case_0()
  196. test_apply_flow_case_1()
  197. test_apply_flow_case_2()
  198. test_apply_flow_case_3()
  199. logger.info("Running test_apply.py test_apply_exception_case() function")
  200. test_apply_exception_case()