You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_callback.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test callback function."""
  16. import os
  17. import stat
  18. import numpy as np
  19. import pytest
  20. import mindspore.nn as nn
  21. import mindspore.common.dtype as mstype
  22. from mindspore import context
  23. from mindspore.common.tensor import Tensor
  24. from mindspore.nn.optim import Momentum
  25. from mindspore.nn import TrainOneStepCell, WithLossCell
  26. from mindspore.train.callback import ModelCheckpoint, _check_file_name_prefix, RunContext,_checkpoint_cb_for_save_op,\
  27. LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist,\
  28. _build_callbacks, CheckpointConfig, _set_cur_net
  29. from mindspore.common.api import ms_function
  30. class Net(nn.Cell):
  31. """Net definition."""
  32. def __init__(self):
  33. super(Net, self).__init__()
  34. self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
  35. self.bn = nn.BatchNorm2d(64)
  36. self.relu = nn.ReLU()
  37. self.flatten = nn.Flatten()
  38. self.fc = nn.Dense(64 * 222 * 222, 3)
  39. @ms_function
  40. def construct(self, x):
  41. x = self.conv(x)
  42. x = self.bn(x)
  43. x = self.relu(x)
  44. x = self.flatten(x)
  45. out = self.fc(x)
  46. return out
  47. class LossNet(nn.Cell):
  48. """ LossNet definition """
  49. def __init__(self):
  50. super(LossNet, self).__init__()
  51. self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
  52. self.bn = nn.BatchNorm2d(64)
  53. self.relu = nn.ReLU()
  54. self.flatten = nn.Flatten()
  55. self.fc = nn.Dense(64 * 222 * 222, 3) # padding=0
  56. self.loss = nn.SoftmaxCrossEntropyWithLogits()
  57. @ms_function
  58. def construct(self, x, y):
  59. x = self.conv(x)
  60. x = self.bn(x)
  61. x = self.relu(x)
  62. x = self.flatten(x)
  63. x = self.fc(x)
  64. out = self.loss(x, y)
  65. return out
  66. def test_Model_Checkpoint_prefix_invalid():
  67. """Test ModelCheckpoint prefix invalid."""
  68. with pytest.raises(ValueError):
  69. ModelCheckpoint(123)
  70. ModelCheckpoint(directory="./")
  71. with pytest.raises(TypeError):
  72. ModelCheckpoint(config='type_error')
  73. ModelCheckpoint(config=CheckpointConfig())
  74. ModelCheckpoint(prefix="ckpt_2", directory="./test_files")
  75. def test_save_checkpoint():
  76. """Test save checkpoint."""
  77. train_config = CheckpointConfig(
  78. save_checkpoint_steps=16,
  79. save_checkpoint_seconds=0,
  80. keep_checkpoint_max=5,
  81. keep_checkpoint_per_n_minutes=0)
  82. cb_params = _InternalCallbackParam()
  83. net = Net()
  84. loss = nn.SoftmaxCrossEntropyWithLogits()
  85. optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  86. network_ = WithLossCell(net, loss)
  87. _train_network = TrainOneStepCell(network_, optim)
  88. cb_params.train_network = _train_network
  89. cb_params.epoch_num = 10
  90. cb_params.cur_epoch_num = 5
  91. cb_params.cur_step_num = 0
  92. cb_params.batch_num = 32
  93. ckpoint_cb = ModelCheckpoint(prefix="test_ckpt", directory='./test_files', config=train_config)
  94. run_context = RunContext(cb_params)
  95. ckpoint_cb.begin(run_context)
  96. ckpoint_cb.step_end(run_context)
  97. if os.path.exists('./test_files/test_ckpt-model.pkl'):
  98. os.chmod('./test_files/test_ckpt-model.pkl', stat.S_IWRITE)
  99. os.remove('./test_files/test_ckpt-model.pkl')
  100. def test_loss_monitor_graph_model():
  101. """Test lossmonitor Graph model."""
  102. cb_params = _InternalCallbackParam()
  103. cb_params.cur_epoch_num = 4
  104. cb_params.cur_step_num = 2
  105. cb_params.batch_num = 2
  106. cb_params.net_outputs = Tensor(2.0)
  107. run_context = RunContext(cb_params)
  108. loss_cb = LossMonitor(1)
  109. callbacks = [loss_cb]
  110. callbacklist = _build_callbacks(callbacks)
  111. callbacklist.begin(run_context)
  112. callbacklist.epoch_begin(run_context)
  113. callbacklist.step_begin(run_context)
  114. callbacklist.step_end(run_context)
  115. callbacklist.epoch_end(run_context)
  116. callbacklist.end(run_context)
  117. def test_Loss_Monitor_feed_feed_model():
  118. """Test Loss Monitor feed feed mode."""
  119. cb_params = _InternalCallbackParam()
  120. run_context = RunContext(cb_params)
  121. loss_cb = LossMonitor(1)
  122. cb_params.cur_epoch_num = 4
  123. cb_params.cur_step_num = 1
  124. cb_params.batch_num = 1
  125. cb_params.net_outputs = Tensor(2.0)
  126. loss_cb.begin(run_context)
  127. loss_cb.epoch_begin(run_context)
  128. loss_cb.step_begin(run_context)
  129. loss_cb.step_end(run_context)
  130. loss_cb.epoch_end(run_context)
  131. loss_cb.end(run_context)
  132. def test_check_file_name_not_str():
  133. """Test check file name not str."""
  134. ret = _check_file_name_prefix(1)
  135. assert not ret
  136. def test_check_file_name_back_err():
  137. """Test check file name back err."""
  138. ret = _check_file_name_prefix('abc.')
  139. assert ret
  140. def test_check_file_name_one_alpha():
  141. """Test check file name one alpha."""
  142. ret = _check_file_name_prefix('a')
  143. assert ret
  144. ret = _check_file_name_prefix('_')
  145. assert ret
  146. def test_check_file_name_err():
  147. """Test check file name err."""
  148. ret = _check_file_name_prefix('_123')
  149. assert ret
  150. def test_chg_ckpt_file_name_if_same_exist():
  151. """Test chg ckpt file name if same exist."""
  152. _chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt")
  153. def test_checkpoint_cb_for_save_op():
  154. """Test checkpoint cb for save op."""
  155. parameter_list = []
  156. one_param = {}
  157. one_param['name'] = "conv1.weight"
  158. one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
  159. parameter_list.append(one_param)
  160. _checkpoint_cb_for_save_op(parameter_list)
  161. def test_checkpoint_cb_for_save_op_update_net():
  162. """Test checkpoint cb for save op."""
  163. parameter_list = []
  164. one_param = {}
  165. one_param['name'] = "conv.weight"
  166. one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32)
  167. parameter_list.append(one_param)
  168. net = Net()
  169. _set_cur_net(net)
  170. _checkpoint_cb_for_save_op(parameter_list)
  171. assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1
  172. def test_internal_callback_param():
  173. """Test Internal CallbackParam."""
  174. cb_params = _InternalCallbackParam()
  175. cb_params.member1 = 1
  176. cb_params.member2 = "abc"
  177. assert cb_params.member1 == 1
  178. assert cb_params.member2 == "abc"
  179. def test_checkpoint_save_ckpt_steps():
  180. """Test checkpoint save ckpt steps."""
  181. train_config = CheckpointConfig(
  182. save_checkpoint_steps=16,
  183. save_checkpoint_seconds=0,
  184. keep_checkpoint_max=5,
  185. keep_checkpoint_per_n_minutes=0)
  186. ckpt_cb = ModelCheckpoint(config=train_config)
  187. cb_params = _InternalCallbackParam()
  188. net = Net()
  189. loss = nn.SoftmaxCrossEntropyWithLogits()
  190. optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  191. network_ = WithLossCell(net, loss)
  192. _train_network = TrainOneStepCell(network_, optim)
  193. cb_params.train_network = _train_network
  194. cb_params.epoch_num = 10
  195. cb_params.cur_epoch_num = 5
  196. cb_params.cur_step_num = 160
  197. cb_params.batch_num = 32
  198. run_context = RunContext(cb_params)
  199. ckpt_cb.begin(run_context)
  200. ckpt_cb.step_end(run_context)
  201. ckpt_cb2 = ModelCheckpoint(config=train_config)
  202. cb_params.cur_epoch_num = 1
  203. cb_params.cur_step_num = 15
  204. ckpt_cb2.begin(run_context)
  205. ckpt_cb2.step_end(run_context)
  206. def test_checkpoint_save_ckpt_seconds():
  207. """Test checkpoint save ckpt seconds."""
  208. train_config = CheckpointConfig(
  209. save_checkpoint_steps=16,
  210. save_checkpoint_seconds=100,
  211. keep_checkpoint_max=0,
  212. keep_checkpoint_per_n_minutes=1)
  213. ckpt_cb = ModelCheckpoint(config=train_config)
  214. cb_params = _InternalCallbackParam()
  215. net = Net()
  216. loss = nn.SoftmaxCrossEntropyWithLogits()
  217. optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  218. network_ = WithLossCell(net, loss)
  219. _train_network = TrainOneStepCell(network_, optim)
  220. cb_params.train_network = _train_network
  221. cb_params.epoch_num = 10
  222. cb_params.cur_epoch_num = 4
  223. cb_params.cur_step_num = 128
  224. cb_params.batch_num = 32
  225. run_context = RunContext(cb_params)
  226. ckpt_cb.begin(run_context)
  227. ckpt_cb.step_end(run_context)
  228. ckpt_cb2 = ModelCheckpoint(config=train_config)
  229. cb_params.cur_epoch_num = 1
  230. cb_params.cur_step_num = 16
  231. ckpt_cb2.begin(run_context)
  232. ckpt_cb2.step_end(run_context)
  233. def test_build_callbacks():
  234. """Test_build_callbacks."""
  235. ck_obj = ModelCheckpoint()
  236. loss_cb_1 = LossMonitor(1)
  237. callbacks = [None]
  238. with pytest.raises(TypeError):
  239. callbacks = _build_callbacks(callbacks)
  240. callbacks = ['Error']
  241. with pytest.raises(TypeError):
  242. callbacks = _build_callbacks(callbacks)
  243. callbacks = [ck_obj, loss_cb_1, 'Error', None]
  244. with pytest.raises(TypeError):
  245. callback_list = _build_callbacks(callbacks)
  246. def test_RunContext():
  247. """Test RunContext."""
  248. context_err = 666
  249. with pytest.raises(TypeError):
  250. context = RunContext(context_err)
  251. cb_params = _InternalCallbackParam()
  252. cb_params.member1 = 1
  253. cb_params.member2 = "abc"
  254. run_context = RunContext(cb_params)
  255. run_context.original_args()
  256. assert cb_params.member1 == 1
  257. assert cb_params.member2 == "abc"
  258. run_context.request_stop()
  259. should_stop = run_context.get_stop_requested()
  260. assert should_stop
  261. def test_Checkpoint_Config():
  262. """Test CheckpointConfig all None or 0."""
  263. with pytest.raises(ValueError):
  264. CheckpointConfig(0, 0, 0, 0)
  265. with pytest.raises(ValueError):
  266. CheckpointConfig(0, None, 0, 0)
  267. def test_step_end_save_graph():
  268. """Test save checkpoint."""
  269. train_config = CheckpointConfig(
  270. save_checkpoint_steps=16,
  271. save_checkpoint_seconds=0,
  272. keep_checkpoint_max=5,
  273. keep_checkpoint_per_n_minutes=0)
  274. cb_params = _InternalCallbackParam()
  275. net = LossNet()
  276. input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  277. input_label = Tensor(np.random.randint(0, 3, [1, 3]).astype(np.float32))
  278. net(input_data, input_label)
  279. cb_params.train_network = net
  280. cb_params.epoch_num = 10
  281. cb_params.cur_epoch_num = 5
  282. cb_params.cur_step_num = 0
  283. cb_params.batch_num = 32
  284. ckpoint_cb = ModelCheckpoint(prefix="test", directory='./test_files', config=train_config)
  285. run_context = RunContext(cb_params)
  286. ckpoint_cb.begin(run_context)
  287. # import pdb;pdb.set_trace()
  288. ckpoint_cb.step_end(run_context)
  289. assert os.path.exists('./test_files/test-graph.meta') == True
  290. if os.path.exists('./test_files/test-graph.meta'):
  291. os.chmod('./test_files/test-graph.meta', stat.S_IWRITE)
  292. os.remove('./test_files/test-graph.meta')
  293. ckpoint_cb.step_end(run_context)
  294. assert os.path.exists('./test_files/test-graph.meta') == False