You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sync_wait.py 8.8 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.dataset as ds
  17. from mindspore import log as logger
  18. def gen():
  19. for i in range(100):
  20. yield (np.array(i),)
  21. class Augment:
  22. def __init__(self, loss):
  23. self.loss = loss
  24. def preprocess(self, input_):
  25. return input_
  26. def update(self, data):
  27. self.loss = data["loss"]
  28. def test_simple_sync_wait():
  29. """
  30. Test simple sync wait: test sync in dataset pipeline
  31. """
  32. logger.info("test_simple_sync_wait")
  33. batch_size = 4
  34. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  35. aug = Augment(0)
  36. dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  37. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  38. dataset = dataset.batch(batch_size)
  39. count = 0
  40. for data in dataset.create_dict_iterator():
  41. assert data["input"][0] == count
  42. count += batch_size
  43. data = {"loss": count}
  44. dataset.sync_update(condition_name="policy", data=data)
  45. def test_simple_shuffle_sync():
  46. """
  47. Test simple shuffle sync: test shuffle before sync
  48. """
  49. logger.info("test_simple_shuffle_sync")
  50. shuffle_size = 4
  51. batch_size = 10
  52. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  53. aug = Augment(0)
  54. dataset = dataset.shuffle(shuffle_size)
  55. dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  56. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  57. dataset = dataset.batch(batch_size)
  58. count = 0
  59. for data in dataset.create_dict_iterator():
  60. count += 1
  61. data = {"loss": count}
  62. dataset.sync_update(condition_name="policy", data=data)
  63. def test_two_sync():
  64. """
  65. Test two sync: dataset pipeline with with two sync_operators
  66. """
  67. logger.info("test_two_sync")
  68. batch_size = 6
  69. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  70. aug = Augment(0)
  71. # notice that with our design, we need to have step_size = shuffle size
  72. dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
  73. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  74. dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches")
  75. dataset = dataset.batch(batch_size)
  76. count = 0
  77. for data in dataset.create_dict_iterator():
  78. count += 1
  79. data = {"loss": count}
  80. dataset.sync_update(condition_name="every batch", data=data)
  81. if count % 2 == 0:
  82. dataset.sync_update(condition_name="every 2 batches")
  83. def test_sync_epoch():
  84. """
  85. Test sync wait with epochs: test sync with epochs in dataset pipeline
  86. """
  87. logger.info("test_sync_epoch")
  88. batch_size = 30
  89. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  90. aug = Augment(0)
  91. dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  92. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  93. dataset = dataset.batch(batch_size, drop_remainder=True)
  94. for _ in range(3):
  95. aug.update({"loss": 0})
  96. count = 0
  97. for data in dataset.create_dict_iterator():
  98. assert data["input"][0] == count
  99. count += batch_size
  100. data = {"loss": count}
  101. dataset.sync_update(condition_name="policy", data=data)
  102. def test_multiple_iterators():
  103. """
  104. Test sync wait with multiple iterators: will start multiple
  105. """
  106. logger.info("test_sync_epoch")
  107. batch_size = 30
  108. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  109. aug = Augment(0)
  110. dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  111. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  112. dataset = dataset.batch(batch_size, drop_remainder=True)
  113. # 2nd dataset
  114. dataset2 = ds.GeneratorDataset(gen, column_names=["input"])
  115. aug = Augment(0)
  116. dataset2 = dataset2.sync_wait(condition_name="policy", callback=aug.update)
  117. dataset2 = dataset2.map(input_columns=["input"], operations=[aug.preprocess])
  118. dataset2 = dataset2.batch(batch_size, drop_remainder=True)
  119. for item1, item2 in zip(dataset.create_dict_iterator(), dataset2.create_dict_iterator()):
  120. assert item1["input"][0] == item2["input"][0]
  121. data1 = {"loss": item1["input"][0]}
  122. data2 = {"loss": item2["input"][0]}
  123. dataset.sync_update(condition_name="policy", data=data1)
  124. dataset2.sync_update(condition_name="policy", data=data2)
  125. def test_sync_exception_01():
  126. """
  127. Test sync: with shuffle in sync mode
  128. """
  129. logger.info("test_sync_exception_01")
  130. shuffle_size = 4
  131. batch_size = 10
  132. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  133. aug = Augment(0)
  134. dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  135. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  136. try:
  137. dataset = dataset.shuffle(shuffle_size)
  138. except Exception as e:
  139. assert "shuffle" in str(e)
  140. dataset = dataset.batch(batch_size)
  141. def test_sync_exception_02():
  142. """
  143. Test sync: with duplicated condition name
  144. """
  145. logger.info("test_sync_exception_02")
  146. batch_size = 6
  147. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  148. aug = Augment(0)
  149. dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
  150. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  151. try:
  152. dataset = dataset.sync_wait(num_batch=2, condition_name="every batch")
  153. except Exception as e:
  154. assert "name" in str(e)
  155. dataset = dataset.batch(batch_size)
  156. def test_sync_exception_03():
  157. """
  158. Test sync: with wrong batch size
  159. """
  160. logger.info("test_sync_exception_03")
  161. batch_size = 6
  162. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  163. aug = Augment(0)
  164. # try to create dataset with batch_size < 0
  165. try:
  166. dataset = dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update)
  167. except Exception as e:
  168. assert "num_batch" in str(e)
  169. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  170. def test_sync_exception_04():
  171. """
  172. Test sync: with negative batch size in update
  173. """
  174. logger.info("test_sync_exception_04")
  175. batch_size = 6
  176. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  177. aug = Augment(0)
  178. # try to create dataset with batch_size < 0
  179. dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
  180. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  181. count = 0
  182. try:
  183. for item in dataset.create_dict_iterator():
  184. count += 1
  185. data = {"loss": count}
  186. # dataset.disable_sync()
  187. dataset.sync_update(condition_name="every batch", num_batch=-1, data=data)
  188. except Exception as e:
  189. assert "batch" in str(e)
  190. def test_sync_exception_05():
  191. """
  192. Test sync: with wrong batch size in update
  193. """
  194. logger.info("test_sync_exception_05")
  195. batch_size = 6
  196. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  197. count = 0
  198. aug = Augment(0)
  199. # try to create dataset with batch_size < 0
  200. dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
  201. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  202. try:
  203. for item in dataset.create_dict_iterator():
  204. dataset.disable_sync()
  205. count += 1
  206. data = {"loss": count}
  207. dataset.disable_sync()
  208. dataset.sync_update(condition_name="every", data=data)
  209. except Exception as e:
  210. assert "name" in str(e)
  211. if __name__ == "__main__":
  212. test_simple_sync_wait()
  213. test_simple_shuffle_sync()
  214. test_two_sync()
  215. test_sync_exception_01()
  216. test_sync_exception_02()
  217. test_sync_exception_03()
  218. test_sync_exception_04()
  219. test_sync_exception_05()
  220. test_sync_epoch()
  221. test_multiple_iterators()