You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_callback.py 13 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test callback function."""
  16. import os
  17. import platform
  18. import stat
  19. import secrets
  20. from unittest import mock
  21. import numpy as np
  22. import pytest
  23. import mindspore.common.dtype as mstype
  24. import mindspore.nn as nn
  25. from mindspore.common.api import ms_function
  26. from mindspore.common.tensor import Tensor
  27. from mindspore.nn import TrainOneStepCell, WithLossCell
  28. from mindspore.nn.optim import Momentum
  29. from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \
  30. _CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op
  31. from mindspore.train.callback._checkpoint import _chg_ckpt_file_name_if_same_exist
  32. class Net(nn.Cell):
  33. """Net definition."""
  34. def __init__(self):
  35. super(Net, self).__init__()
  36. self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
  37. self.bn = nn.BatchNorm2d(64)
  38. self.relu = nn.ReLU()
  39. self.flatten = nn.Flatten()
  40. self.fc = nn.Dense(64 * 222 * 222, 3)
  41. @ms_function
  42. def construct(self, x):
  43. x = self.conv(x)
  44. x = self.bn(x)
  45. x = self.relu(x)
  46. x = self.flatten(x)
  47. out = self.fc(x)
  48. return out
  49. class LossNet(nn.Cell):
  50. """ LossNet definition """
  51. def __init__(self):
  52. super(LossNet, self).__init__()
  53. self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
  54. self.bn = nn.BatchNorm2d(64)
  55. self.relu = nn.ReLU()
  56. self.flatten = nn.Flatten()
  57. self.fc = nn.Dense(64 * 222 * 222, 3) # padding=0
  58. self.loss = nn.SoftmaxCrossEntropyWithLogits()
  59. @ms_function
  60. def construct(self, x, y):
  61. x = self.conv(x)
  62. x = self.bn(x)
  63. x = self.relu(x)
  64. x = self.flatten(x)
  65. x = self.fc(x)
  66. out = self.loss(x, y)
  67. return out
  68. def test_model_checkpoint_prefix_invalid():
  69. """Test ModelCheckpoint prefix invalid."""
  70. with pytest.raises(ValueError):
  71. ModelCheckpoint(123)
  72. ModelCheckpoint(directory="./")
  73. with pytest.raises(TypeError):
  74. ModelCheckpoint(config='type_error')
  75. ModelCheckpoint(config=CheckpointConfig())
  76. ModelCheckpoint(prefix="ckpt_2", directory="./test_files")
  77. def test_save_checkpoint():
  78. """Test save checkpoint."""
  79. train_config = CheckpointConfig(
  80. save_checkpoint_steps=16,
  81. save_checkpoint_seconds=0,
  82. keep_checkpoint_max=5,
  83. keep_checkpoint_per_n_minutes=0)
  84. cb_params = _InternalCallbackParam()
  85. net = Net()
  86. loss = nn.SoftmaxCrossEntropyWithLogits()
  87. optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  88. network_ = WithLossCell(net, loss)
  89. _train_network = TrainOneStepCell(network_, optim)
  90. cb_params.train_network = _train_network
  91. cb_params.epoch_num = 10
  92. cb_params.cur_epoch_num = 5
  93. cb_params.cur_step_num = 0
  94. cb_params.batch_num = 32
  95. ckpoint_cb = ModelCheckpoint(prefix="test_ckpt", directory='./test_files', config=train_config)
  96. run_context = RunContext(cb_params)
  97. ckpoint_cb.begin(run_context)
  98. ckpoint_cb.step_end(run_context)
  99. if os.path.exists('./test_files/test_ckpt-model.pkl'):
  100. os.chmod('./test_files/test_ckpt-model.pkl', stat.S_IWRITE)
  101. os.remove('./test_files/test_ckpt-model.pkl')
  102. def test_loss_monitor_sink_mode():
  103. """Test loss monitor sink mode."""
  104. cb_params = _InternalCallbackParam()
  105. cb_params.cur_epoch_num = 4
  106. cb_params.epoch_num = 4
  107. cb_params.cur_step_num = 2
  108. cb_params.batch_num = 2
  109. cb_params.net_outputs = Tensor(2.0)
  110. run_context = RunContext(cb_params)
  111. loss_cb = LossMonitor(1)
  112. callbacks = [loss_cb]
  113. with _CallbackManager(callbacks) as callbacklist:
  114. callbacklist.begin(run_context)
  115. callbacklist.epoch_begin(run_context)
  116. callbacklist.step_begin(run_context)
  117. callbacklist.step_end(run_context)
  118. callbacklist.epoch_end(run_context)
  119. callbacklist.end(run_context)
  120. def test_loss_monitor_normal_mode():
  121. """Test loss monitor normal(non-sink) mode."""
  122. cb_params = _InternalCallbackParam()
  123. run_context = RunContext(cb_params)
  124. loss_cb = LossMonitor(1)
  125. cb_params.cur_epoch_num = 4
  126. cb_params.epoch_num = 4
  127. cb_params.cur_step_num = 1
  128. cb_params.batch_num = 1
  129. cb_params.net_outputs = Tensor(2.0)
  130. loss_cb.begin(run_context)
  131. loss_cb.epoch_begin(run_context)
  132. loss_cb.step_begin(run_context)
  133. loss_cb.step_end(run_context)
  134. loss_cb.epoch_end(run_context)
  135. loss_cb.end(run_context)
  136. def test_chg_ckpt_file_name_if_same_exist():
  137. """Test chg ckpt file name if same exist."""
  138. _chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt")
  139. def test_checkpoint_cb_for_save_op():
  140. """Test checkpoint cb for save op."""
  141. parameter_list = []
  142. one_param = {}
  143. one_param['name'] = "conv1.weight"
  144. one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
  145. parameter_list.append(one_param)
  146. _checkpoint_cb_for_save_op(parameter_list)
  147. def test_checkpoint_cb_for_save_op_update_net():
  148. """Test checkpoint cb for save op."""
  149. parameter_list = []
  150. one_param = {}
  151. one_param['name'] = "conv.weight"
  152. one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32)
  153. parameter_list.append(one_param)
  154. net = Net()
  155. _set_cur_net(net)
  156. _checkpoint_cb_for_save_op(parameter_list)
  157. assert net.conv.weight.data.asnumpy()[0][0][0][0] == 1
  158. def test_internal_callback_param():
  159. """Test Internal CallbackParam."""
  160. cb_params = _InternalCallbackParam()
  161. cb_params.member1 = 1
  162. cb_params.member2 = "abc"
  163. assert cb_params.member1 == 1
  164. assert cb_params.member2 == "abc"
  165. def test_checkpoint_save_ckpt_steps():
  166. """Test checkpoint save ckpt steps."""
  167. train_config = CheckpointConfig(
  168. save_checkpoint_steps=16,
  169. save_checkpoint_seconds=0,
  170. keep_checkpoint_max=5,
  171. keep_checkpoint_per_n_minutes=0)
  172. ckpt_cb = ModelCheckpoint(config=train_config)
  173. cb_params = _InternalCallbackParam()
  174. net = Net()
  175. loss = nn.SoftmaxCrossEntropyWithLogits()
  176. optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  177. network_ = WithLossCell(net, loss)
  178. _train_network = TrainOneStepCell(network_, optim)
  179. cb_params.train_network = _train_network
  180. cb_params.epoch_num = 10
  181. cb_params.cur_epoch_num = 5
  182. cb_params.cur_step_num = 160
  183. cb_params.batch_num = 32
  184. run_context = RunContext(cb_params)
  185. ckpt_cb.begin(run_context)
  186. ckpt_cb.step_end(run_context)
  187. ckpt_cb2 = ModelCheckpoint(config=train_config)
  188. cb_params.cur_epoch_num = 1
  189. cb_params.cur_step_num = 15
  190. ckpt_cb2.begin(run_context)
  191. ckpt_cb2.step_end(run_context)
  192. def test_checkpoint_save_ckpt_seconds():
  193. """Test checkpoint save ckpt seconds."""
  194. train_config = CheckpointConfig(
  195. save_checkpoint_steps=16,
  196. save_checkpoint_seconds=100,
  197. keep_checkpoint_max=0,
  198. keep_checkpoint_per_n_minutes=1)
  199. ckpt_cb = ModelCheckpoint(config=train_config)
  200. cb_params = _InternalCallbackParam()
  201. net = Net()
  202. loss = nn.SoftmaxCrossEntropyWithLogits()
  203. optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  204. network_ = WithLossCell(net, loss)
  205. _train_network = TrainOneStepCell(network_, optim)
  206. cb_params.train_network = _train_network
  207. cb_params.epoch_num = 10
  208. cb_params.cur_epoch_num = 4
  209. cb_params.cur_step_num = 128
  210. cb_params.batch_num = 32
  211. run_context = RunContext(cb_params)
  212. ckpt_cb.begin(run_context)
  213. ckpt_cb.step_end(run_context)
  214. ckpt_cb2 = ModelCheckpoint(config=train_config)
  215. cb_params.cur_epoch_num = 1
  216. cb_params.cur_step_num = 16
  217. ckpt_cb2.begin(run_context)
  218. ckpt_cb2.step_end(run_context)
  219. def test_checkpoint_save_ckpt_with_encryption():
  220. """Test checkpoint save ckpt with encryption."""
  221. train_config = CheckpointConfig(
  222. save_checkpoint_steps=16,
  223. save_checkpoint_seconds=0,
  224. keep_checkpoint_max=5,
  225. keep_checkpoint_per_n_minutes=0,
  226. enc_key=secrets.token_bytes(16),
  227. enc_mode="AES-GCM")
  228. ckpt_cb = ModelCheckpoint(config=train_config)
  229. cb_params = _InternalCallbackParam()
  230. net = Net()
  231. loss = nn.SoftmaxCrossEntropyWithLogits()
  232. optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  233. network_ = WithLossCell(net, loss)
  234. _train_network = TrainOneStepCell(network_, optim)
  235. cb_params.train_network = _train_network
  236. cb_params.epoch_num = 10
  237. cb_params.cur_epoch_num = 5
  238. cb_params.cur_step_num = 160
  239. cb_params.batch_num = 32
  240. run_context = RunContext(cb_params)
  241. ckpt_cb.begin(run_context)
  242. ckpt_cb.step_end(run_context)
  243. ckpt_cb2 = ModelCheckpoint(config=train_config)
  244. cb_params.cur_epoch_num = 1
  245. cb_params.cur_step_num = 15
  246. if platform.system().lower() == "windows":
  247. with pytest.raises(NotImplementedError):
  248. ckpt_cb2.begin(run_context)
  249. ckpt_cb2.step_end(run_context)
  250. else:
  251. ckpt_cb2.begin(run_context)
  252. ckpt_cb2.step_end(run_context)
  253. def test_CallbackManager():
  254. """TestCallbackManager."""
  255. ck_obj = ModelCheckpoint()
  256. loss_cb_1 = LossMonitor(1)
  257. callbacks = [None]
  258. with pytest.raises(TypeError):
  259. _CallbackManager(callbacks)
  260. callbacks = ['Error']
  261. with pytest.raises(TypeError):
  262. _CallbackManager(callbacks)
  263. callbacks = [ck_obj, loss_cb_1, 'Error', None]
  264. with pytest.raises(TypeError):
  265. _CallbackManager(callbacks)
  266. def test_CallbackManager_exit_called():
  267. with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit:
  268. cb1, cb2 = Callback(), Callback()
  269. with _CallbackManager([cb1, cb2]):
  270. pass
  271. for call_args in mock_exit.call_args_list:
  272. assert call_args == mock.call(mock.ANY, None, None, None)
  273. assert mock_exit.call_count == 2
  274. def test_CallbackManager_exit_called_when_raises():
  275. with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit:
  276. cb1, cb2 = Callback(), Callback()
  277. with pytest.raises(ValueError):
  278. with _CallbackManager([cb1, cb2]):
  279. raise ValueError()
  280. for call_args in mock_exit.call_args_list:
  281. assert call_args == mock.call(*[mock.ANY] * 4)
  282. assert mock_exit.call_count == 2
  283. def test_CallbackManager_begin_called():
  284. context = dict()
  285. with mock.patch.object(Callback, 'begin', return_value=None) as mock_begin:
  286. cb1, cb2 = Callback(), Callback()
  287. with _CallbackManager([cb1, cb2]) as cm:
  288. cm.begin(context)
  289. for call_args in mock_begin.call_args_list:
  290. assert call_args == mock.call(context)
  291. assert mock_begin.call_count == 2
  292. def test_RunContext():
  293. """Test RunContext."""
  294. context_err = 666
  295. with pytest.raises(TypeError):
  296. RunContext(context_err)
  297. cb_params = _InternalCallbackParam()
  298. cb_params.member1 = 1
  299. cb_params.member2 = "abc"
  300. run_context = RunContext(cb_params)
  301. run_context.original_args()
  302. assert cb_params.member1 == 1
  303. assert cb_params.member2 == "abc"
  304. run_context.request_stop()
  305. should_stop = run_context.get_stop_requested()
  306. assert should_stop
  307. def test_Checkpoint_Config():
  308. """Test CheckpointConfig all None or 0."""
  309. with pytest.raises(ValueError):
  310. CheckpointConfig(0, 0, 0, 0, True)
  311. with pytest.raises(ValueError):
  312. CheckpointConfig(0, None, 0, 0, True)
  313. def test_step_end_save_graph():
  314. """Test save checkpoint."""
  315. train_config = CheckpointConfig(
  316. save_checkpoint_steps=16,
  317. save_checkpoint_seconds=0,
  318. keep_checkpoint_max=5,
  319. keep_checkpoint_per_n_minutes=0)
  320. cb_params = _InternalCallbackParam()
  321. net = LossNet()
  322. input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  323. input_label = Tensor(np.random.randint(0, 3, [1, 3]).astype(np.float32))
  324. net(input_data, input_label)
  325. cb_params.train_network = net
  326. cb_params.epoch_num = 10
  327. cb_params.cur_epoch_num = 5
  328. cb_params.cur_step_num = 0
  329. cb_params.batch_num = 32
  330. ckpoint_cb = ModelCheckpoint(prefix="test", directory='./test_files', config=train_config)
  331. run_context = RunContext(cb_params)
  332. ckpoint_cb.begin(run_context)
  333. ckpoint_cb.step_end(run_context)
  334. assert os.path.exists('./test_files/test-graph.meta')
  335. if os.path.exists('./test_files/test-graph.meta'):
  336. os.chmod('./test_files/test-graph.meta', stat.S_IWRITE)
  337. os.remove('./test_files/test-graph.meta')
  338. ckpoint_cb.step_end(run_context)
  339. assert not os.path.exists('./test_files/test-graph.meta')