You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sync_wait.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import mindspore.dataset as ds
  16. from mindspore import log as logger
  17. import time
  18. import numpy as np
  19. def gen():
  20. for i in range(100):
  21. yield np.array(i),
  22. class Augment:
  23. def __init__(self, loss):
  24. self.loss = loss
  25. def preprocess(self, input):
  26. return input
  27. def update(self, data):
  28. self.loss = data["loss"]
  29. def test_simple_sync_wait():
  30. """
  31. Test simple sync wait: test sync in dataset pipeline
  32. """
  33. logger.info("test_simple_sync_wait")
  34. batch_size = 4
  35. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  36. aug = Augment(0)
  37. dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  38. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  39. dataset = dataset.batch(batch_size)
  40. count = 0
  41. for data in dataset.create_dict_iterator():
  42. assert (data["input"][0] == count)
  43. count += batch_size
  44. data = {"loss": count}
  45. dataset.sync_update(condition_name="policy", data=data)
  46. def test_simple_shuffle_sync():
  47. """
  48. Test simple shuffle sync: test shuffle before sync
  49. """
  50. logger.info("test_simple_shuffle_sync")
  51. shuffle_size = 4
  52. batch_size = 10
  53. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  54. aug = Augment(0)
  55. dataset = dataset.shuffle(shuffle_size)
  56. dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  57. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  58. dataset = dataset.batch(batch_size)
  59. count = 0
  60. for data in dataset.create_dict_iterator():
  61. count += 1
  62. #time.sleep(0.5)
  63. data = {"loss": count}
  64. dataset.sync_update(condition_name="policy", data=data)
  65. def test_two_sync():
  66. """
  67. Test two sync: dataset pipeline with with two sync_operators
  68. """
  69. logger.info("test_two_sync")
  70. batch_size = 6
  71. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  72. aug = Augment(0)
  73. # notice that with our design, we need to have step_size = shuffle size
  74. dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
  75. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  76. dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches")
  77. dataset = dataset.batch(batch_size)
  78. count = 0
  79. for data in dataset.create_dict_iterator():
  80. count += 1
  81. data = {"loss": count}
  82. dataset.sync_update(condition_name="every batch", data=data)
  83. if count % 2 == 0:
  84. dataset.sync_update(condition_name="every 2 batches")
  85. def test_sync_epoch():
  86. """
  87. Test sync wait with epochs: test sync with epochs in dataset pipeline
  88. """
  89. logger.info("test_sync_epoch")
  90. batch_size = 30
  91. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  92. aug = Augment(0)
  93. dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  94. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  95. dataset = dataset.batch(batch_size, drop_remainder=True)
  96. for epochs in range(3):
  97. aug.update({"loss": 0})
  98. count = 0
  99. for data in dataset.create_dict_iterator():
  100. assert (data["input"][0] == count)
  101. count += batch_size
  102. data = {"loss": count}
  103. dataset.sync_update(condition_name="policy", data=data)
  104. def test_sync_exception_01():
  105. """
  106. Test sync: with shuffle in sync mode
  107. """
  108. logger.info("test_sync_exception_01")
  109. shuffle_size = 4
  110. batch_size = 10
  111. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  112. aug = Augment(0)
  113. dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  114. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  115. try:
  116. dataset = dataset.shuffle(shuffle_size)
  117. except BaseException as e:
  118. assert "shuffle" in str(e)
  119. dataset = dataset.batch(batch_size)
  120. def test_sync_exception_02():
  121. """
  122. Test sync: with duplicated condition name
  123. """
  124. logger.info("test_sync_exception_02")
  125. batch_size = 6
  126. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  127. aug = Augment(0)
  128. # notice that with our design, we need to have step_size = shuffle size
  129. dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
  130. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  131. try:
  132. dataset = dataset.sync_wait(num_batch=2, condition_name="every batch")
  133. except BaseException as e:
  134. assert "name" in str(e)
  135. dataset = dataset.batch(batch_size)
  136. if __name__ == "__main__":
  137. test_simple_sync_wait()
  138. test_simple_shuffle_sync()
  139. test_two_sync()
  140. test_sync_exception_01()
  141. test_sync_exception_02()
  142. test_sync_epoch()