You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_callbacks.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. from builtins import range, super
  16. import time
  17. import pytest
  18. from mindspore import context
  19. from mindspore import log as logger
  20. from mindspore.dataset.callback import DSCallback, WaitedDSCallback
  21. from mindspore.train import Model
  22. from mindspore.train.callback import Callback
  23. import mindspore.dataset as ds
  24. import mindspore.nn as nn
  25. context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
  26. class BaseCallback(DSCallback):
  27. def __init__(self, step_size=1, events=None, cb_id=0):
  28. super().__init__(step_size)
  29. self.events = events
  30. self.cb_id = cb_id
  31. def append(self, event_name, ds_run_context):
  32. event = [event_name, ds_run_context.cur_epoch_num,
  33. ds_run_context.cur_step_num_in_epoch, ds_run_context.cur_step_num]
  34. event = '_'.join([str(e) for e in event])
  35. index = -1
  36. for i, e in enumerate(self.events):
  37. if e[0] == event:
  38. index = i
  39. break
  40. if index != -1:
  41. self.events[index][1].append(self.cb_id)
  42. else:
  43. self.events.append((event, [self.cb_id]))
  44. class Begin(BaseCallback):
  45. def ds_begin(self, ds_run_context):
  46. self.append("begin", ds_run_context)
  47. class EpochBegin(BaseCallback):
  48. def ds_epoch_begin(self, ds_run_context):
  49. self.append("epoch_begin", ds_run_context)
  50. class EpochEnd(BaseCallback):
  51. def ds_epoch_end(self, ds_run_context):
  52. self.append("epoch_end", ds_run_context)
  53. class StepBegin(BaseCallback):
  54. def ds_step_begin(self, ds_run_context):
  55. self.append("step_begin", ds_run_context)
  56. class StepEnd(BaseCallback):
  57. def ds_step_end(self, ds_run_context):
  58. self.append("step_end", ds_run_context)
  59. class MyDSCallback(Begin, EpochBegin, EpochEnd, StepBegin, StepEnd):
  60. pass
  61. def verify_events(events, epoch_num, step_num, step_size=1, map_num=1, repeat=1):
  62. '''
  63. Make sure that the events are in correct order.
  64. * begin is the first
  65. * epoch x begin before epoch x end
  66. * epoch x end before epoch x+1 begin
  67. * step x begin before step x end
  68. * step x begin before step x+1 begin
  69. * step x end before step x+1 end
  70. '''
  71. assert events[0][0] == "begin_0_0_0"
  72. epochs = list(filter(lambda e: 'epoch' in e[0], events))
  73. i = 0
  74. while i < len(epochs):
  75. epoch_num = epochs[i][0].split('_')[2]
  76. e_type = epochs[i][0].split('_')[1]
  77. assert str(i // 2 + 1) == epoch_num
  78. assert e_type == "begin"
  79. i += 1
  80. epoch_num = epochs[i][0].split('_')[2]
  81. e_type = epochs[i][0].split('_')[1]
  82. assert str(i // 2 + 1) == epoch_num
  83. assert e_type == "end"
  84. i += 1
  85. steps = list(filter(lambda e: 'step' in e[0], events))
  86. steps = [(s[0].split('_')[1], s[0].split('_')[-1]) for s in steps]
  87. steps_map = {}
  88. max_step = 0
  89. for s in steps:
  90. if s[1] in steps_map:
  91. assert steps_map[s[1]] == 'begin'
  92. assert s[0] == 'end'
  93. else:
  94. assert s[0] == 'begin'
  95. steps_map[s[1]] = 'begin'
  96. assert int(s[1]) > max_step
  97. max_step = max(max_step, int(s[1]))
  98. def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1):
  99. events = []
  100. cb_id = list(range(map_num))
  101. def append(name, e, s):
  102. event = [name, e + 1, s + 1, e * step_num * repeat + s + 1]
  103. event = '_'.join([str(ev) for ev in event])
  104. events.append((event, cb_id))
  105. events.append(("begin_0_0_0", cb_id))
  106. for e in range(epoch_num):
  107. append("epoch_begin", e, -1)
  108. for s in range(step_num * repeat):
  109. if s % step_size == 0:
  110. append("step_begin", e, s)
  111. append("step_end", e, s)
  112. append("epoch_end", e, step_num * repeat - 1)
  113. return events
  114. def build_test_case_1cb(epochs, steps, step_size=1, repeat=1):
  115. events = []
  116. arr = list(range(1, steps + 1))
  117. data = ds.NumpySlicesDataset(arr, shuffle=False)
  118. my_cb = MyDSCallback(step_size=step_size, events=events)
  119. data = data.map(operations=(lambda x: x), callbacks=my_cb)
  120. if repeat != 1:
  121. if repeat % 2 == 0 and repeat != 2:
  122. data = data.repeat(2)
  123. data = data.map(operations=(lambda x: x))
  124. data = data.repeat(repeat // 2)
  125. else:
  126. data = data.repeat(repeat)
  127. itr = data.create_tuple_iterator(num_epochs=epochs)
  128. for _ in range(epochs):
  129. for _ in itr:
  130. pass
  131. expected_events = generate_expected(epochs, steps, step_size, 1, repeat)
  132. expected_events = [e[0] for e in expected_events]
  133. verify_events(events, epochs, steps, step_size, repeat)
  134. events = [e[0] for e in events]
  135. expected_events.sort()
  136. events.sort()
  137. assert expected_events == events
  138. def build_test_case_2cbs(epochs, steps):
  139. events1 = []
  140. events2 = []
  141. my_cb1 = MyDSCallback(events=events1)
  142. my_cb2 = MyDSCallback(events=events2)
  143. arr = list(range(1, steps + 1))
  144. data = ds.NumpySlicesDataset(arr, shuffle=False)
  145. data = data.map(operations=(lambda x: x), callbacks=[my_cb1, my_cb2])
  146. itr = data.create_tuple_iterator(num_epochs=epochs)
  147. for _ in range(epochs):
  148. for _ in itr:
  149. pass
  150. expected_events = generate_expected(epochs, steps)
  151. expected_events.sort()
  152. events1.sort()
  153. events2.sort()
  154. assert expected_events == events1
  155. assert expected_events == events2
  156. def build_test_case_2maps(epochs, steps):
  157. events = []
  158. my_cb1 = MyDSCallback(events=events, cb_id=0)
  159. my_cb2 = MyDSCallback(events=events, cb_id=1)
  160. arr = list(range(1, steps + 1))
  161. data = ds.NumpySlicesDataset(arr, shuffle=False)
  162. data = data.map(operations=(lambda x: x), callbacks=my_cb1)
  163. data = data.map(operations=(lambda x: x), callbacks=my_cb2)
  164. itr = data.create_tuple_iterator(num_epochs=epochs)
  165. for _ in range(epochs):
  166. for _ in itr:
  167. pass
  168. expected_events = generate_expected(epochs, steps, map_num=2)
  169. assert expected_events[1:] == events[1:]
  170. for event in events:
  171. assert len(event) == 2
  172. event, cb_ids = event
  173. if event != "begin_0_0_0":
  174. assert cb_ids[0] == 0
  175. assert cb_ids[1] == 1
  176. def test_callbacks_all_methods():
  177. logger.info("test_callbacks_all_methods")
  178. build_test_case_1cb(1, 1)
  179. build_test_case_1cb(1, 2)
  180. build_test_case_1cb(1, 3)
  181. build_test_case_1cb(1, 4)
  182. build_test_case_1cb(2, 1)
  183. build_test_case_1cb(2, 2)
  184. build_test_case_1cb(2, 3)
  185. build_test_case_1cb(2, 4)
  186. build_test_case_1cb(3, 1)
  187. build_test_case_1cb(3, 2)
  188. build_test_case_1cb(3, 3)
  189. build_test_case_1cb(3, 4)
  190. def test_callbacks_var_step_size():
  191. logger.info("test_callbacks_var_step_size")
  192. build_test_case_1cb(1, 2, 2)
  193. build_test_case_1cb(1, 3, 2)
  194. build_test_case_1cb(1, 4, 2)
  195. build_test_case_1cb(2, 2, 2)
  196. build_test_case_1cb(2, 3, 2)
  197. build_test_case_1cb(2, 4, 2)
  198. build_test_case_1cb(3, 2, 2)
  199. build_test_case_1cb(3, 3, 2)
  200. build_test_case_1cb(3, 4, 2)
  201. def test_callbacks_all_2cbs():
  202. logger.info("test_callbacks_all_2cbs")
  203. build_test_case_2cbs(4, 1)
  204. build_test_case_2cbs(4, 2)
  205. build_test_case_2cbs(4, 3)
  206. build_test_case_2cbs(4, 4)
  207. class MyWaitedCallback(WaitedDSCallback):
  208. def __init__(self, events, step_size=1):
  209. super().__init__(step_size)
  210. self.events = events
  211. def sync_epoch_begin(self, train_run_context, ds_run_context):
  212. event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
  213. self.events.append(event)
  214. def sync_step_begin(self, train_run_context, ds_run_context):
  215. event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
  216. self.events.append(event)
  217. class MyMSCallback(Callback):
  218. def __init__(self, events):
  219. self.events = events
  220. def epoch_end(self, run_context):
  221. cb_params = run_context.original_args()
  222. event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
  223. self.events.append(event)
  224. def step_end(self, run_context):
  225. cb_params = run_context.original_args()
  226. event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
  227. self.events.append(event)
  228. class Net(nn.Cell):
  229. def construct(self, x, y):
  230. return x
  231. def test_callbacks_non_sink():
  232. logger.info("test_callbacks_non_sink")
  233. events = []
  234. my_cb1 = MyWaitedCallback(events, 1)
  235. my_cb2 = MyMSCallback(events)
  236. arr = [1, 2, 3, 4]
  237. data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
  238. data = data.map(operations=(lambda x: x), callbacks=my_cb1)
  239. net = Net()
  240. model = Model(net)
  241. model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
  242. expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3',
  243. 'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4',
  244. 'ms_epoch_end_1_4', 'ds_epoch_begin_2_4',
  245. 'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6',
  246. 'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8',
  247. 'ms_step_end_2_8', 'ms_epoch_end_2_8']
  248. assert events[:18] == expected_synced_events
  249. def test_callbacks_non_sink_batch_size2():
  250. logger.info("test_callbacks_non_sink_batch_size2")
  251. events = []
  252. my_cb1 = MyWaitedCallback(events, 2)
  253. my_cb2 = MyMSCallback(events)
  254. arr = [1, 2, 3, 4]
  255. data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
  256. data = data.map(operations=(lambda x: x), callbacks=my_cb1)
  257. data = data.batch(2)
  258. net = Net()
  259. model = Model(net)
  260. model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
  261. expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_3',
  262. 'ms_step_end_1_2',
  263. 'ms_epoch_end_1_2', 'ds_epoch_begin_2_4',
  264. 'ds_step_begin_2_5', 'ms_step_end_2_3', 'ds_step_begin_2_7',
  265. 'ms_step_end_2_4', 'ms_epoch_end_2_4']
  266. assert events[:10] == expected_synced_events
  267. def test_callbacks_non_sink_mismatch_size():
  268. logger.info("test_callbacks_non_sink_mismatch_size")
  269. default_timeout = ds.config.get_callback_timeout()
  270. ds.config.set_callback_timeout(1)
  271. events = []
  272. my_cb1 = MyWaitedCallback(events, 2)
  273. my_cb2 = MyMSCallback(events)
  274. arr = [1, 2, 3, 4]
  275. data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
  276. data = data.map(operations=(lambda x: x), callbacks=my_cb1)
  277. data = data.batch(3)
  278. net = Net()
  279. model = Model(net)
  280. with pytest.raises(Exception) as err:
  281. model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
  282. assert "RuntimeError: ds_step_begin timed out after 1 second(s)" in str(err.value)
  283. ds.config.set_callback_timeout(default_timeout)
  284. def test_callbacks_validations():
  285. logger.info("test_callbacks_validations")
  286. with pytest.raises(Exception) as err:
  287. data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
  288. data.map(operations=(lambda x: x), callbacks=0)
  289. assert "Argument callbacks with value 0 is not " in str(err.value)
  290. with pytest.raises(Exception) as err:
  291. my_cb1 = MyDSCallback()
  292. data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
  293. data.map(operations=(lambda x: x), callbacks=[my_cb1, 0])
  294. assert "Argument callbacks[1] with value 0 is not " in str(err.value)
  295. with pytest.raises(Exception) as err:
  296. class BadCB(DSCallback):
  297. pass
  298. my_cb = BadCB()
  299. data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
  300. data = data.map(operations=(lambda x: x), callbacks=my_cb)
  301. for _ in data:
  302. pass
  303. assert "Provided Callback class did not override any of the 6 callback methods." in str(err.value)
  304. def test_callbacks_sink_simulation():
  305. logger.info("test_callback_sink_simulation")
  306. events = []
  307. epochs = 2
  308. my_cb = MyWaitedCallback(events, 1)
  309. data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
  310. data = data.map(operations=(lambda x: x), callbacks=my_cb)
  311. data = data.to_device()
  312. data.send(num_epochs=epochs)
  313. for e in range(epochs):
  314. for s in range(4):
  315. time.sleep(0.5)
  316. events.append(f"ms_step_end_{e + 1}_{e * 4 + s + 1}")
  317. my_cb.step_end(run_context=0)
  318. events.append(f"ms_epoch_end_{e + 1}_{(e + 1) * 4}")
  319. my_cb.epoch_end(run_context=0)
  320. expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3',
  321. 'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4',
  322. 'ms_epoch_end_1_4', 'ds_epoch_begin_2_4',
  323. 'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6',
  324. 'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8',
  325. 'ms_step_end_2_8', 'ms_epoch_end_2_8']
  326. assert events == expected_synced_events
  327. def test_callbacks_repeat():
  328. logger.info("test_callbacks_repeat")
  329. build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2)
  330. build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=3)
  331. build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3)
  332. build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3)
  333. build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2)
  334. build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=4)
  335. build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=8)
  336. build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=16)
  337. def test_callbacks_exceptions():
  338. logger.info("test_callbacks_exceptions")
  339. class BadCB(DSCallback):
  340. def ds_begin(self, ds_run_context):
  341. raise RuntimeError("Bad begin")
  342. with pytest.raises(Exception) as err:
  343. data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
  344. data = data.map(operations=(lambda x: x), callbacks=BadCB())
  345. for _ in data:
  346. pass
  347. assert "RuntimeError: Bad begin" in str(err.value)
  348. def test_callbacks_train_end():
  349. logger.info("test_callback_sink_simulation")
  350. # No asserts are needed, just test there is no deadlock or exceptions
  351. events = []
  352. epochs = 2
  353. my_cb = MyWaitedCallback(events, 1)
  354. data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
  355. data = data.map(operations=(lambda x: x), callbacks=[my_cb])
  356. data = data.to_device()
  357. data.send(num_epochs=epochs)
  358. time.sleep(0.5)
  359. my_cb.end(run_context={})
  360. time.sleep(0.5)
  361. def test_callbacks_one_cb():
  362. logger.info("test_callbacks_one_cb")
  363. data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
  364. events1 = []
  365. events2 = []
  366. events3 = []
  367. my_begin = Begin(events=events1, cb_id=1)
  368. my_epoch_begin = EpochBegin(events=events2, cb_id=2)
  369. my_epoch_end = EpochEnd(events=events3, cb_id=3)
  370. my_step_begin = StepBegin(events=events3, cb_id=3)
  371. my_step_end = StepEnd(events=events2, cb_id=2)
  372. data = data.map(operations=(lambda x: x), callbacks=my_begin)
  373. data = data.map(operations=(lambda x: x), callbacks=[my_epoch_begin, my_step_end])
  374. data = data.map(operations=(lambda x: x), callbacks=[my_epoch_end, my_step_begin])
  375. itr = data.create_tuple_iterator(num_epochs=2)
  376. for _ in range(2):
  377. for _ in itr:
  378. pass
  379. expected_events1 = [('begin_0_0_0', [1])]
  380. expected_events2 = [('epoch_begin_1_0_0', [2]), ('step_end_1_1_1', [2]), ('step_end_1_2_2', [2]),
  381. ('step_end_1_3_3', [2]), ('step_end_1_4_4', [2]), ('epoch_begin_2_0_4', [2]),
  382. ('step_end_2_1_5', [2]), ('step_end_2_2_6', [2]), ('step_end_2_3_7', [2]),
  383. ('step_end_2_4_8', [2])]
  384. expected_events3 = [('step_begin_1_1_1', [3]), ('step_begin_1_2_2', [3]), ('step_begin_1_3_3', [3]),
  385. ('step_begin_1_4_4', [3]), ('epoch_end_1_4_4', [3]), ('step_begin_2_1_5', [3]),
  386. ('step_begin_2_2_6', [3]), ('step_begin_2_3_7', [3]), ('step_begin_2_4_8', [3]),
  387. ('epoch_end_2_4_8', [3])]
  388. events1.sort()
  389. events2.sort()
  390. events3.sort()
  391. expected_events1.sort()
  392. expected_events2.sort()
  393. expected_events3.sort()
  394. assert events1 == expected_events1
  395. assert events2 == expected_events2
  396. assert events3 == expected_events3
  397. def test_clear_callback():
  398. logger.info("test_clear_callback")
  399. # this test case will test that callback is removed for get_dataset_size and output_shape/type
  400. class FlagCallback(DSCallback):
  401. def __init__(self):
  402. super().__init__(step_size=1)
  403. self.flag = False
  404. self.row_cnt = 0
  405. def ds_begin(self, ds_run_context):
  406. # if callback isn't removed in getter pass, this function will be called
  407. self.flag = True
  408. def ds_step_begin(self, ds_run_context):
  409. self.row_cnt += 1
  410. data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
  411. cb = FlagCallback()
  412. # make sure variables are properly initialized before testing
  413. assert not cb.flag and cb.row_cnt == 0
  414. data = data.map(operations=(lambda x: x), callbacks=cb)
  415. assert data.get_dataset_size() == 4
  416. assert data.output_shapes() == [[]]
  417. # make sure callback is never called by checking flag and row_cnt
  418. assert not cb.flag and cb.row_cnt == 0
  419. for _ in data.create_dict_iterator(num_epochs=1):
  420. pass
  421. # this ensure that callback is indeed called
  422. assert cb.flag and cb.row_cnt == 4
  423. if __name__ == '__main__':
  424. test_callbacks_all_2cbs()
  425. test_callbacks_all_methods()
  426. test_callbacks_exceptions()
  427. test_callbacks_repeat()
  428. test_callbacks_sink_simulation()
  429. test_callbacks_validations()
  430. test_callbacks_var_step_size()
  431. test_callbacks_non_sink_batch_size2()
  432. test_callbacks_non_sink()
  433. test_callbacks_one_cb()
  434. test_callbacks_non_sink_mismatch_size()
  435. test_callbacks_train_end()
  436. test_clear_callback()