You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_callbacks.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. from builtins import range, super
  16. import time
  17. import pytest
  18. from mindspore import context
  19. from mindspore import log as logger
  20. from mindspore.dataset.callback import DSCallback, WaitedDSCallback
  21. from mindspore.train import Model
  22. from mindspore.train.callback import Callback
  23. import mindspore.dataset as ds
  24. import mindspore.nn as nn
  25. context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
  26. class MyDSCallback(DSCallback):
  27. def __init__(self, step_size=1, events=None, cb_id=0):
  28. super().__init__(step_size)
  29. self.events = events
  30. self.cb_id = cb_id
  31. def append(self, event_name, ds_run_context):
  32. event = [event_name, ds_run_context.cur_epoch_num,
  33. ds_run_context.cur_step_num_in_epoch, ds_run_context.cur_step_num]
  34. event = '_'.join([str(e) for e in event])
  35. index = -1
  36. for i, e in enumerate(self.events):
  37. if e[0] == event:
  38. index = i
  39. break
  40. if index != -1:
  41. self.events[index][1].append(self.cb_id)
  42. else:
  43. self.events.append((event, [self.cb_id]))
  44. def ds_begin(self, ds_run_context):
  45. self.append("begin", ds_run_context)
  46. def ds_end(self, ds_run_context):
  47. self.append("end", ds_run_context)
  48. def ds_epoch_begin(self, ds_run_context):
  49. self.append("epoch_begin", ds_run_context)
  50. def ds_epoch_end(self, ds_run_context):
  51. self.append("epoch_end", ds_run_context)
  52. def ds_step_begin(self, ds_run_context):
  53. self.append("step_begin", ds_run_context)
  54. def ds_step_end(self, ds_run_context):
  55. self.append("step_end", ds_run_context)
  56. def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1):
  57. events = []
  58. cb_id = list(range(map_num))
  59. def append(name, e, s):
  60. event = [name, e + 1, s + 1, e * step_num * repeat + s + 1]
  61. event = '_'.join([str(ev) for ev in event])
  62. events.append((event, cb_id))
  63. events.append(("begin_0_0_0", cb_id))
  64. for e in range(epoch_num):
  65. append("epoch_begin", e, -1)
  66. for s in range(step_num * repeat):
  67. if s % step_size == 0:
  68. append("step_begin", e, s)
  69. append("step_end", e, s)
  70. append("epoch_end", e, step_num * repeat - 1)
  71. return events
  72. def build_test_case_1cb(epochs, steps, step_size=1, repeat=1):
  73. events = []
  74. arr = list(range(1, steps + 1))
  75. data = ds.NumpySlicesDataset(arr, shuffle=False)
  76. my_cb = MyDSCallback(step_size=step_size, events=events)
  77. data = data.map(operations=(lambda x: x), callbacks=my_cb)
  78. if repeat != 1:
  79. data = data.repeat(repeat)
  80. itr = data.create_tuple_iterator(num_epochs=epochs)
  81. for _ in range(epochs):
  82. for _ in itr:
  83. pass
  84. expected_events = generate_expected(epochs, steps, step_size, 1, repeat)
  85. assert expected_events == events
  86. def build_test_case_2cbs(epochs, steps):
  87. events1 = []
  88. events2 = []
  89. my_cb1 = MyDSCallback(events=events1)
  90. my_cb2 = MyDSCallback(events=events2)
  91. arr = list(range(1, steps + 1))
  92. data = ds.NumpySlicesDataset(arr, shuffle=False)
  93. data = data.map(operations=(lambda x: x), callbacks=[my_cb1, my_cb2])
  94. itr = data.create_tuple_iterator(num_epochs=epochs)
  95. for _ in range(epochs):
  96. for _ in itr:
  97. pass
  98. expected_events = generate_expected(epochs, steps)
  99. assert expected_events == events1
  100. assert expected_events == events2
  101. def build_test_case_2maps(epochs, steps):
  102. events = []
  103. my_cb1 = MyDSCallback(events=events, cb_id=0)
  104. my_cb2 = MyDSCallback(events=events, cb_id=1)
  105. arr = list(range(1, steps + 1))
  106. data = ds.NumpySlicesDataset(arr, shuffle=False)
  107. data = data.map(operations=(lambda x: x), callbacks=my_cb1)
  108. data = data.map(operations=(lambda x: x), callbacks=my_cb2)
  109. itr = data.create_tuple_iterator(num_epochs=epochs)
  110. for _ in range(epochs):
  111. for _ in itr:
  112. pass
  113. expected_events = generate_expected(epochs, steps, map_num=2)
  114. assert expected_events[1:] == events[1:]
  115. for event in events:
  116. assert len(event) == 2
  117. event, cb_ids = event
  118. if event != "begin_0_0_0":
  119. assert cb_ids[0] == 0
  120. assert cb_ids[1] == 1
  121. def test_callbacks_all_methods():
  122. logger.info("test_callbacks_all_methods")
  123. build_test_case_1cb(1, 1)
  124. build_test_case_1cb(1, 2)
  125. build_test_case_1cb(1, 3)
  126. build_test_case_1cb(1, 4)
  127. build_test_case_1cb(2, 1)
  128. build_test_case_1cb(2, 2)
  129. build_test_case_1cb(2, 3)
  130. build_test_case_1cb(2, 4)
  131. build_test_case_1cb(3, 1)
  132. build_test_case_1cb(3, 2)
  133. build_test_case_1cb(3, 3)
  134. build_test_case_1cb(3, 4)
  135. def test_callbacks_var_step_size():
  136. logger.info("test_callbacks_var_step_size")
  137. build_test_case_1cb(1, 2, 2)
  138. build_test_case_1cb(1, 3, 2)
  139. build_test_case_1cb(1, 4, 2)
  140. build_test_case_1cb(2, 2, 2)
  141. build_test_case_1cb(2, 3, 2)
  142. build_test_case_1cb(2, 4, 2)
  143. build_test_case_1cb(3, 2, 2)
  144. build_test_case_1cb(3, 3, 2)
  145. build_test_case_1cb(3, 4, 2)
  146. def test_callbacks_all_2cbs():
  147. logger.info("test_callbacks_all_2cbs")
  148. build_test_case_2cbs(4, 1)
  149. build_test_case_2cbs(4, 2)
  150. build_test_case_2cbs(4, 3)
  151. build_test_case_2cbs(4, 4)
  152. def test_callbacks_2maps():
  153. logger.info("test_callbacks_2maps")
  154. build_test_case_2maps(5, 10)
  155. build_test_case_2maps(6, 9)
  156. class MyWaitedCallback(WaitedDSCallback):
  157. def __init__(self, events, step_size=1):
  158. super().__init__(step_size)
  159. self.events = events
  160. def sync_epoch_begin(self, train_run_context, ds_run_context):
  161. event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
  162. self.events.append(event)
  163. def sync_step_begin(self, train_run_context, ds_run_context):
  164. event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
  165. self.events.append(event)
  166. class MyMSCallback(Callback):
  167. def __init__(self, events):
  168. self.events = events
  169. def epoch_end(self, run_context):
  170. cb_params = run_context.original_args()
  171. event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
  172. self.events.append(event)
  173. def step_end(self, run_context):
  174. cb_params = run_context.original_args()
  175. event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
  176. self.events.append(event)
  177. class Net(nn.Cell):
  178. def construct(self, x, y):
  179. return x
  180. def test_train_non_sink():
  181. logger.info("test_train_non_sink")
  182. events = []
  183. my_cb1 = MyWaitedCallback(events, 1)
  184. my_cb2 = MyMSCallback(events)
  185. arr = [1, 2, 3, 4]
  186. data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
  187. data = data.map(operations=(lambda x: x), callbacks=my_cb1)
  188. net = Net()
  189. model = Model(net)
  190. model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
  191. expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3',
  192. 'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4',
  193. 'ms_epoch_end_1_4', 'ds_epoch_begin_2_4',
  194. 'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6',
  195. 'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8',
  196. 'ms_step_end_2_8', 'ms_epoch_end_2_8']
  197. assert events == expected_synced_events
  198. def test_train_batch_size2():
  199. logger.info("test_train_batch_size2")
  200. events = []
  201. my_cb1 = MyWaitedCallback(events, 2)
  202. my_cb2 = MyMSCallback(events)
  203. arr = [1, 2, 3, 4]
  204. data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
  205. data = data.map(operations=(lambda x: x), callbacks=my_cb1)
  206. data = data.batch(2)
  207. net = Net()
  208. model = Model(net)
  209. model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
  210. expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_3',
  211. 'ms_step_end_1_2',
  212. 'ms_epoch_end_1_2', 'ds_epoch_begin_2_4',
  213. 'ds_step_begin_2_5', 'ms_step_end_2_3', 'ds_step_begin_2_7',
  214. 'ms_step_end_2_4', 'ms_epoch_end_2_4']
  215. assert events == expected_synced_events
  216. def test_callbacks_validations():
  217. logger.info("test_callbacks_validations")
  218. with pytest.raises(Exception) as err:
  219. data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
  220. data.map(operations=(lambda x: x), callbacks=0)
  221. assert "Argument callbacks with value 0 is not " in str(err.value)
  222. with pytest.raises(Exception) as err:
  223. my_cb1 = MyDSCallback()
  224. data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
  225. data.map(operations=(lambda x: x), callbacks=[my_cb1, 0])
  226. assert "Argument callbacks[1] with value 0 is not " in str(err.value)
  227. with pytest.raises(Exception) as err:
  228. class BadCB(DSCallback):
  229. pass
  230. my_cb = BadCB()
  231. data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
  232. data = data.map(operations=(lambda x: x), callbacks=my_cb)
  233. for _ in data:
  234. pass
  235. assert "Provided Callback class did not override any of the 6 callback methods." in str(err.value)
  236. def test_callback_sink_simulation():
  237. logger.info("test_callback_sink_simulation")
  238. events = []
  239. epochs = 2
  240. my_cb = MyWaitedCallback(events, 1)
  241. data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
  242. data = data.map(operations=(lambda x: x), callbacks=my_cb)
  243. data = data.to_device()
  244. data.send(num_epochs=epochs)
  245. for e in range(epochs):
  246. for s in range(4):
  247. time.sleep(0.5)
  248. events.append(f"ms_step_end_{e + 1}_{e * 4 + s + 1}")
  249. my_cb.step_end(run_context=0)
  250. events.append(f"ms_epoch_end_{e + 1}_{(e + 1) * 4}")
  251. my_cb.epoch_end(run_context=0)
  252. expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3',
  253. 'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4',
  254. 'ms_epoch_end_1_4', 'ds_epoch_begin_2_4',
  255. 'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6',
  256. 'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8',
  257. 'ms_step_end_2_8', 'ms_epoch_end_2_8']
  258. assert events == expected_synced_events
  259. def test_callbacks_repeat():
  260. logger.info("test_callbacks_repeat")
  261. build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2)
  262. build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=3)
  263. build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3)
  264. build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3)
  265. if __name__ == '__main__':
  266. test_callbacks_all_methods()
  267. test_callbacks_all_2cbs()
  268. test_callbacks_2maps()
  269. test_callbacks_validations()
  270. test_callbacks_var_step_size()
  271. test_train_batch_size2()
  272. test_callback_sink_simulation()
  273. test_callbacks_repeat()