You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_callback.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
4 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test callback function."""
  16. import os
  17. import platform
  18. import stat
  19. import secrets
  20. from unittest import mock
  21. import numpy as np
  22. import pytest
  23. import mindspore.common.dtype as mstype
  24. import mindspore.nn as nn
  25. from mindspore.common.api import ms_function
  26. from mindspore.common.tensor import Tensor
  27. from mindspore.nn import TrainOneStepCell, WithLossCell
  28. from mindspore.nn.optim import Momentum
  29. from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \
  30. _CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op, History, LambdaCallback
  31. from mindspore.train.callback._checkpoint import _chg_ckpt_file_name_if_same_exist
  32. class Net(nn.Cell):
  33. """Net definition."""
  34. def __init__(self):
  35. super(Net, self).__init__()
  36. self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
  37. self.bn = nn.BatchNorm2d(64)
  38. self.relu = nn.ReLU()
  39. self.flatten = nn.Flatten()
  40. self.fc = nn.Dense(64 * 222 * 222, 3)
  41. @ms_function
  42. def construct(self, x):
  43. x = self.conv(x)
  44. x = self.bn(x)
  45. x = self.relu(x)
  46. x = self.flatten(x)
  47. out = self.fc(x)
  48. return out
  49. class LossNet(nn.Cell):
  50. """ LossNet definition """
  51. def __init__(self):
  52. super(LossNet, self).__init__()
  53. self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
  54. self.bn = nn.BatchNorm2d(64)
  55. self.relu = nn.ReLU()
  56. self.flatten = nn.Flatten()
  57. self.fc = nn.Dense(64 * 222 * 222, 3) # padding=0
  58. self.loss = nn.SoftmaxCrossEntropyWithLogits()
  59. @ms_function
  60. def construct(self, x, y):
  61. x = self.conv(x)
  62. x = self.bn(x)
  63. x = self.relu(x)
  64. x = self.flatten(x)
  65. x = self.fc(x)
  66. out = self.loss(x, y)
  67. return out
  68. def test_model_checkpoint_prefix_invalid():
  69. """
  70. Feature: callback
  71. Description: Test ModelCheckpoint prefix invalid
  72. Expectation: run success
  73. """
  74. with pytest.raises(ValueError):
  75. ModelCheckpoint(123)
  76. ModelCheckpoint(directory="./")
  77. with pytest.raises(TypeError):
  78. ModelCheckpoint(config='type_error')
  79. ModelCheckpoint(config=CheckpointConfig())
  80. ModelCheckpoint(prefix="ckpt_2", directory="./test_files")
  81. def test_loss_monitor_sink_mode():
  82. """
  83. Feature: callback
  84. Description: Test loss monitor sink mode
  85. Expectation: run success
  86. """
  87. cb_params = _InternalCallbackParam()
  88. cb_params.cur_epoch_num = 4
  89. cb_params.epoch_num = 4
  90. cb_params.cur_step_num = 2
  91. cb_params.batch_num = 2
  92. cb_params.net_outputs = Tensor(2.0)
  93. run_context = RunContext(cb_params)
  94. loss_cb = LossMonitor(1)
  95. callbacks = [loss_cb]
  96. with _CallbackManager(callbacks) as callbacklist:
  97. callbacklist.begin(run_context)
  98. callbacklist.epoch_begin(run_context)
  99. callbacklist.step_begin(run_context)
  100. callbacklist.step_end(run_context)
  101. callbacklist.epoch_end(run_context)
  102. callbacklist.end(run_context)
  103. def test_loss_monitor_normal_mode():
  104. """
  105. Feature: callback
  106. Description: Test loss monitor normal(non-sink) mode
  107. Expectation: run success
  108. """
  109. cb_params = _InternalCallbackParam()
  110. run_context = RunContext(cb_params)
  111. loss_cb = LossMonitor(1)
  112. cb_params.cur_epoch_num = 4
  113. cb_params.epoch_num = 4
  114. cb_params.cur_step_num = 1
  115. cb_params.batch_num = 1
  116. cb_params.net_outputs = Tensor(2.0)
  117. loss_cb.begin(run_context)
  118. loss_cb.epoch_begin(run_context)
  119. loss_cb.step_begin(run_context)
  120. loss_cb.step_end(run_context)
  121. loss_cb.epoch_end(run_context)
  122. loss_cb.end(run_context)
  123. def test_loss_monitor_args():
  124. """
  125. Feature: callback
  126. Description: Test loss monitor illegal args
  127. Expectation: run success
  128. """
  129. with pytest.raises(ValueError):
  130. LossMonitor(per_print_times=-1)
  131. with pytest.raises(ValueError):
  132. LossMonitor(has_trained_epoch=-100)
  133. def test_loss_monitor_has_trained_epoch():
  134. """
  135. Feature: callback
  136. Description: Test loss monitor has_trained_epoch args
  137. Expectation: run success
  138. """
  139. cb_params = _InternalCallbackParam()
  140. run_context = RunContext(cb_params)
  141. loss_cb = LossMonitor(has_trained_epoch=10)
  142. cb_params.cur_epoch_num = 4
  143. cb_params.cur_step_num = 1
  144. cb_params.batch_num = 1
  145. cb_params.net_outputs = Tensor(2.0)
  146. cb_params.epoch_num = 4
  147. loss_cb.begin(run_context)
  148. loss_cb.epoch_begin(run_context)
  149. loss_cb.step_begin(run_context)
  150. loss_cb.step_end(run_context)
  151. loss_cb.epoch_end(run_context)
  152. loss_cb.end(run_context)
  153. def test_save_ckpt_and_test_chg_ckpt_file_name_if_same_exist():
  154. """
  155. Feature: Save checkpoint and check if there is a file with the same name.
  156. Description: Save checkpoint and check if there is a file with the same name.
  157. Expectation: Checkpoint is saved and checking is successful.
  158. """
  159. train_config = CheckpointConfig(
  160. save_checkpoint_steps=16,
  161. save_checkpoint_seconds=0,
  162. keep_checkpoint_max=5,
  163. keep_checkpoint_per_n_minutes=0)
  164. cb_params = _InternalCallbackParam()
  165. net = Net()
  166. loss = nn.SoftmaxCrossEntropyWithLogits()
  167. optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  168. network_ = WithLossCell(net, loss)
  169. _train_network = TrainOneStepCell(network_, optim)
  170. cb_params.train_network = _train_network
  171. cb_params.epoch_num = 10
  172. cb_params.cur_epoch_num = 5
  173. cb_params.cur_step_num = 0
  174. cb_params.batch_num = 32
  175. ckpoint_cb = ModelCheckpoint(prefix="test_ckpt", directory='./test_files', config=train_config)
  176. run_context = RunContext(cb_params)
  177. ckpoint_cb.begin(run_context)
  178. ckpoint_cb.step_end(run_context)
  179. if os.path.exists('./test_files/test_ckpt-model.pkl'):
  180. os.chmod('./test_files/test_ckpt-model.pkl', stat.S_IWRITE)
  181. os.remove('./test_files/test_ckpt-model.pkl')
  182. _chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt")
  183. def test_checkpoint_cb_for_save_op():
  184. """
  185. Feature: callback
  186. Description: Test checkpoint cb for save op
  187. Expectation: run success
  188. """
  189. parameter_list = []
  190. one_param = {}
  191. one_param['name'] = "conv1.weight"
  192. one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
  193. parameter_list.append(one_param)
  194. _checkpoint_cb_for_save_op(parameter_list)
  195. def test_checkpoint_cb_for_save_op_update_net():
  196. """
  197. Feature: callback
  198. Description: Test checkpoint cb for save op
  199. Expectation: run success
  200. """
  201. parameter_list = []
  202. one_param = {}
  203. one_param['name'] = "conv.weight"
  204. one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32)
  205. parameter_list.append(one_param)
  206. net = Net()
  207. _set_cur_net(net)
  208. _checkpoint_cb_for_save_op(parameter_list)
  209. assert net.conv.weight.data.asnumpy()[0][0][0][0] == 1
  210. def test_internal_callback_param():
  211. """
  212. Feature: callback
  213. Description: Test Internal CallbackParam
  214. Expectation: run success
  215. """
  216. cb_params = _InternalCallbackParam()
  217. cb_params.member1 = 1
  218. cb_params.member2 = "abc"
  219. assert cb_params.member1 == 1
  220. assert cb_params.member2 == "abc"
  221. def test_checkpoint_save_ckpt_steps():
  222. """
  223. Feature: callback
  224. Description: Test checkpoint save ckpt steps
  225. Expectation: run success
  226. """
  227. train_config = CheckpointConfig(
  228. save_checkpoint_steps=16,
  229. save_checkpoint_seconds=0,
  230. keep_checkpoint_max=5,
  231. keep_checkpoint_per_n_minutes=0)
  232. ckpt_cb = ModelCheckpoint(config=train_config)
  233. cb_params = _InternalCallbackParam()
  234. net = Net()
  235. loss = nn.SoftmaxCrossEntropyWithLogits()
  236. optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  237. network_ = WithLossCell(net, loss)
  238. _train_network = TrainOneStepCell(network_, optim)
  239. cb_params.train_network = _train_network
  240. cb_params.epoch_num = 10
  241. cb_params.cur_epoch_num = 5
  242. cb_params.cur_step_num = 160
  243. cb_params.batch_num = 32
  244. run_context = RunContext(cb_params)
  245. ckpt_cb.begin(run_context)
  246. ckpt_cb.step_end(run_context)
  247. ckpt_cb2 = ModelCheckpoint(config=train_config)
  248. cb_params.cur_epoch_num = 1
  249. cb_params.cur_step_num = 15
  250. ckpt_cb2.begin(run_context)
  251. ckpt_cb2.step_end(run_context)
  252. def test_checkpoint_save_ckpt_seconds():
  253. """
  254. Feature: callback
  255. Description: Test checkpoint save ckpt seconds
  256. Expectation: run success
  257. """
  258. train_config = CheckpointConfig(
  259. save_checkpoint_steps=16,
  260. save_checkpoint_seconds=100,
  261. keep_checkpoint_max=0,
  262. keep_checkpoint_per_n_minutes=1)
  263. ckpt_cb = ModelCheckpoint(config=train_config)
  264. cb_params = _InternalCallbackParam()
  265. net = Net()
  266. loss = nn.SoftmaxCrossEntropyWithLogits()
  267. optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  268. network_ = WithLossCell(net, loss)
  269. _train_network = TrainOneStepCell(network_, optim)
  270. cb_params.train_network = _train_network
  271. cb_params.epoch_num = 10
  272. cb_params.cur_epoch_num = 4
  273. cb_params.cur_step_num = 128
  274. cb_params.batch_num = 32
  275. run_context = RunContext(cb_params)
  276. ckpt_cb.begin(run_context)
  277. ckpt_cb.step_end(run_context)
  278. ckpt_cb2 = ModelCheckpoint(config=train_config)
  279. cb_params.cur_epoch_num = 1
  280. cb_params.cur_step_num = 16
  281. ckpt_cb2.begin(run_context)
  282. ckpt_cb2.step_end(run_context)
  283. def test_checkpoint_save_ckpt_with_encryption():
  284. """
  285. Feature: callback
  286. Description: Test checkpoint save ckpt with encryption
  287. Expectation: run success
  288. """
  289. train_config = CheckpointConfig(
  290. save_checkpoint_steps=16,
  291. save_checkpoint_seconds=0,
  292. keep_checkpoint_max=5,
  293. keep_checkpoint_per_n_minutes=0,
  294. enc_key=secrets.token_bytes(16),
  295. enc_mode="AES-GCM")
  296. ckpt_cb = ModelCheckpoint(config=train_config)
  297. cb_params = _InternalCallbackParam()
  298. net = Net()
  299. loss = nn.SoftmaxCrossEntropyWithLogits()
  300. optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  301. network_ = WithLossCell(net, loss)
  302. _train_network = TrainOneStepCell(network_, optim)
  303. cb_params.train_network = _train_network
  304. cb_params.epoch_num = 10
  305. cb_params.cur_epoch_num = 5
  306. cb_params.cur_step_num = 160
  307. cb_params.batch_num = 32
  308. run_context = RunContext(cb_params)
  309. ckpt_cb.begin(run_context)
  310. ckpt_cb.step_end(run_context)
  311. ckpt_cb2 = ModelCheckpoint(config=train_config)
  312. cb_params.cur_epoch_num = 1
  313. cb_params.cur_step_num = 15
  314. if platform.system().lower() == "windows":
  315. with pytest.raises(NotImplementedError):
  316. ckpt_cb2.begin(run_context)
  317. ckpt_cb2.step_end(run_context)
  318. else:
  319. ckpt_cb2.begin(run_context)
  320. ckpt_cb2.step_end(run_context)
  321. def test_CallbackManager():
  322. """
  323. Feature: callback
  324. Description: Test CallbackManager
  325. Expectation: run success
  326. """
  327. ck_obj = ModelCheckpoint()
  328. loss_cb_1 = LossMonitor(1)
  329. callbacks = [None]
  330. with pytest.raises(TypeError):
  331. _CallbackManager(callbacks)
  332. callbacks = ['Error']
  333. with pytest.raises(TypeError):
  334. _CallbackManager(callbacks)
  335. callbacks = [ck_obj, loss_cb_1, 'Error', None]
  336. with pytest.raises(TypeError):
  337. _CallbackManager(callbacks)
  338. def test_CallbackManager_exit_called():
  339. """
  340. Feature: callback
  341. Description: Test CallbackManager exit called
  342. Expectation: run success
  343. """
  344. with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit:
  345. cb1, cb2 = Callback(), Callback()
  346. with _CallbackManager([cb1, cb2]):
  347. pass
  348. for call_args in mock_exit.call_args_list:
  349. assert call_args == mock.call(mock.ANY, None, None, None)
  350. assert mock_exit.call_count == 2
  351. def test_CallbackManager_exit_called_when_raises():
  352. """
  353. Feature: callback
  354. Description: Test when CallbackManager exit called
  355. Expectation: run success
  356. """
  357. with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit:
  358. cb1, cb2 = Callback(), Callback()
  359. with pytest.raises(ValueError):
  360. with _CallbackManager([cb1, cb2]):
  361. raise ValueError()
  362. for call_args in mock_exit.call_args_list:
  363. assert call_args == mock.call(*[mock.ANY] * 4)
  364. assert mock_exit.call_count == 2
  365. def test_CallbackManager_begin_called():
  366. """
  367. Feature: callback
  368. Description: Test CallbackManager called begin
  369. Expectation: run success
  370. """
  371. context = dict()
  372. with mock.patch.object(Callback, 'begin', return_value=None) as mock_begin:
  373. cb1, cb2 = Callback(), Callback()
  374. with _CallbackManager([cb1, cb2]) as cm:
  375. cm.begin(context)
  376. for call_args in mock_begin.call_args_list:
  377. assert call_args == mock.call(context)
  378. assert mock_begin.call_count == 2
  379. def test_RunContext():
  380. """
  381. Feature: callback
  382. Description: Test RunContext init
  383. Expectation: run success
  384. """
  385. context_err = 666
  386. with pytest.raises(TypeError):
  387. RunContext(context_err)
  388. cb_params = _InternalCallbackParam()
  389. cb_params.member1 = 1
  390. cb_params.member2 = "abc"
  391. run_context = RunContext(cb_params)
  392. run_context.original_args()
  393. assert cb_params.member1 == 1
  394. assert cb_params.member2 == "abc"
  395. run_context.request_stop()
  396. should_stop = run_context.get_stop_requested()
  397. assert should_stop
  398. def test_Checkpoint_Config():
  399. """
  400. Feature: callback
  401. Description: Test checkpoint config error args
  402. Expectation: run success
  403. """
  404. with pytest.raises(ValueError):
  405. CheckpointConfig(0, 0, 0, 0, True)
  406. with pytest.raises(ValueError):
  407. CheckpointConfig(0, None, 0, 0, True)
  408. def test_step_end_save_graph():
  409. """
  410. Feature: callback
  411. Description: Test save graph at step end
  412. Expectation: run success
  413. """
  414. train_config = CheckpointConfig(
  415. save_checkpoint_steps=16,
  416. save_checkpoint_seconds=0,
  417. keep_checkpoint_max=5,
  418. keep_checkpoint_per_n_minutes=0)
  419. cb_params = _InternalCallbackParam()
  420. net = LossNet()
  421. input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  422. input_label = Tensor(np.random.randint(0, 3, [1, 3]).astype(np.float32))
  423. net(input_data, input_label)
  424. cb_params.train_network = net
  425. cb_params.epoch_num = 10
  426. cb_params.cur_epoch_num = 5
  427. cb_params.cur_step_num = 0
  428. cb_params.batch_num = 32
  429. ckpoint_cb = ModelCheckpoint(prefix="test", directory='./test_files', config=train_config)
  430. run_context = RunContext(cb_params)
  431. ckpoint_cb.begin(run_context)
  432. ckpoint_cb.step_end(run_context)
  433. assert os.path.exists('./test_files/test-graph.meta')
  434. if os.path.exists('./test_files/test-graph.meta'):
  435. os.chmod('./test_files/test-graph.meta', stat.S_IWRITE)
  436. os.remove('./test_files/test-graph.meta')
  437. ckpoint_cb.step_end(run_context)
  438. assert not os.path.exists('./test_files/test-graph.meta')
  439. def test_history():
  440. """
  441. Feature: callback.
  442. Description: Test history object saves epoch and history properties.
  443. Expectation: run success.
  444. """
  445. cb_params = _InternalCallbackParam()
  446. cb_params.cur_epoch_num = 4
  447. cb_params.epoch_num = 4
  448. cb_params.cur_step_num = 2
  449. cb_params.batch_num = 2
  450. cb_params.net_outputs = Tensor(2.0)
  451. cb_params.metrics = {'mae': 6.343789100646973, 'mse': 59.03999710083008}
  452. run_context = RunContext(cb_params)
  453. history_cb = History()
  454. callbacks = [history_cb]
  455. with _CallbackManager(callbacks) as callbacklist:
  456. callbacklist.begin(run_context)
  457. callbacklist.epoch_begin(run_context)
  458. callbacklist.step_begin(run_context)
  459. callbacklist.step_end(run_context)
  460. callbacklist.epoch_end(run_context)
  461. callbacklist.end(run_context)
  462. print(history_cb.epoch)
  463. print(history_cb.history)
  464. def test_lambda():
  465. """
  466. Feature: callback.
  467. Description: Test lambda callback.
  468. Expectation: run success.
  469. """
  470. cb_params = _InternalCallbackParam()
  471. cb_params.cur_epoch_num = 4
  472. cb_params.epoch_num = 4
  473. cb_params.cur_step_num = 2
  474. cb_params.batch_num = 2
  475. cb_params.net_outputs = Tensor(2.0)
  476. run_context = RunContext(cb_params)
  477. lambda_cb = LambdaCallback(
  478. epoch_end=lambda run_context: print("loss result: ", run_context.original_args().net_outputs))
  479. callbacks = [lambda_cb]
  480. with _CallbackManager(callbacks) as callbacklist:
  481. callbacklist.begin(run_context)
  482. callbacklist.epoch_begin(run_context)
  483. callbacklist.step_begin(run_context)
  484. callbacklist.step_end(run_context)
  485. callbacklist.epoch_end(run_context)
  486. callbacklist.end(run_context)