You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serialize.py 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ut for model serialize(save/load)"""
  16. import os
  17. import platform
  18. import stat
  19. import time
  20. import numpy as np
  21. import pytest
  22. import mindspore.common.dtype as mstype
  23. import mindspore.nn as nn
  24. from mindspore import context
  25. from mindspore.common.parameter import Parameter
  26. from mindspore.common.tensor import Tensor
  27. from mindspore.nn import SoftmaxCrossEntropyWithLogits
  28. from mindspore.nn import WithLossCell, TrainOneStepCell
  29. from mindspore.nn.optim.momentum import Momentum
  30. from mindspore.ops import operations as P
  31. from mindspore.train.callback import _CheckpointManager
  32. from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
  33. export, _save_graph
  34. from ..ut_filter import non_graph_engine
  35. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  36. class Net(nn.Cell):
  37. """Net definition."""
  38. def __init__(self, num_classes=10):
  39. super(Net, self).__init__()
  40. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
  41. self.bn1 = nn.BatchNorm2d(64)
  42. self.relu = nn.ReLU()
  43. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
  44. self.flatten = nn.Flatten()
  45. self.fc = nn.Dense(int(224 * 224 * 64 / 16), num_classes)
  46. def construct(self, x):
  47. x = self.conv1(x)
  48. x = self.bn1(x)
  49. x = self.relu(x)
  50. x = self.maxpool(x)
  51. x = self.flatten(x)
  52. x = self.fc(x)
  53. return x
  54. _input_x = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  55. _cur_dir = os.path.dirname(os.path.realpath(__file__))
  56. def setup_module():
  57. import shutil
  58. if os.path.exists('./test_files'):
  59. shutil.rmtree('./test_files')
  60. def test_save_graph():
  61. """ test_exec_save_graph """
  62. class Net1(nn.Cell):
  63. def __init__(self):
  64. super(Net1, self).__init__()
  65. self.add = P.Add()
  66. def construct(self, x, y):
  67. z = self.add(x, y)
  68. return z
  69. net = Net1()
  70. net.set_train()
  71. out_me_list = []
  72. x = Tensor(np.random.rand(2, 1, 2, 3).astype(np.float32))
  73. y = Tensor(np.array([1.2]).astype(np.float32))
  74. out_put = net(x, y)
  75. output_file = "net-graph.meta"
  76. _save_graph(network=net, file_name=output_file)
  77. out_me_list.append(out_put)
  78. assert os.path.exists(output_file)
  79. os.chmod(output_file, stat.S_IWRITE)
  80. os.remove(output_file)
  81. def test_save_checkpoint_for_list():
  82. """ test save_checkpoint for list"""
  83. parameter_list = []
  84. one_param = {}
  85. param1 = {}
  86. param2 = {}
  87. one_param['name'] = "param_test"
  88. one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
  89. param1['name'] = "param"
  90. param1['data'] = Tensor(np.random.randint(0, 255, [12, 1024]), dtype=mstype.float32)
  91. param2['name'] = "new_param"
  92. param2['data'] = Tensor(np.random.randint(0, 255, [12, 1024, 1]), dtype=mstype.float32)
  93. parameter_list.append(one_param)
  94. parameter_list.append(param1)
  95. parameter_list.append(param2)
  96. if os.path.exists('./parameters.ckpt'):
  97. os.chmod('./parameters.ckpt', stat.S_IWRITE)
  98. os.remove('./parameters.ckpt')
  99. ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
  100. save_checkpoint(parameter_list, ckpt_file_name)
  101. def test_load_checkpoint_error_filename():
  102. ckpt_file_name = 1
  103. with pytest.raises(ValueError):
  104. load_checkpoint(ckpt_file_name)
  105. def test_load_checkpoint():
  106. ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
  107. par_dict = load_checkpoint(ckpt_file_name)
  108. assert len(par_dict) == 3
  109. assert par_dict['param_test'].name == 'param_test'
  110. assert par_dict['param_test'].data.dtype == mstype.float32
  111. assert par_dict['param_test'].data.shape == (1, 3, 224, 224)
  112. assert isinstance(par_dict, dict)
  113. def test_checkpoint_manager():
  114. """ test_checkpoint_manager """
  115. ckp_mgr = _CheckpointManager()
  116. ckpt_file_name = os.path.join(_cur_dir, './test1.ckpt')
  117. with open(ckpt_file_name, 'w'):
  118. os.chmod(ckpt_file_name, stat.S_IWUSR | stat.S_IRUSR)
  119. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  120. assert ckp_mgr.ckpoint_num == 1
  121. ckp_mgr.remove_ckpoint_file(ckpt_file_name)
  122. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  123. assert ckp_mgr.ckpoint_num == 0
  124. assert not os.path.exists(ckpt_file_name)
  125. another_file_name = os.path.join(_cur_dir, './test2.ckpt')
  126. another_file_name = os.path.realpath(another_file_name)
  127. with open(another_file_name, 'w'):
  128. os.chmod(another_file_name, stat.S_IWUSR | stat.S_IRUSR)
  129. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  130. assert ckp_mgr.ckpoint_num == 1
  131. ckp_mgr.remove_oldest_ckpoint_file()
  132. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  133. assert ckp_mgr.ckpoint_num == 0
  134. assert not os.path.exists(another_file_name)
  135. # test keep_one_ckpoint_per_minutes
  136. file1 = os.path.realpath(os.path.join(_cur_dir, './time_file1.ckpt'))
  137. file2 = os.path.realpath(os.path.join(_cur_dir, './time_file2.ckpt'))
  138. file3 = os.path.realpath(os.path.join(_cur_dir, './time_file3.ckpt'))
  139. with open(file1, 'w'):
  140. os.chmod(file1, stat.S_IWUSR | stat.S_IRUSR)
  141. with open(file2, 'w'):
  142. os.chmod(file2, stat.S_IWUSR | stat.S_IRUSR)
  143. with open(file3, 'w'):
  144. os.chmod(file3, stat.S_IWUSR | stat.S_IRUSR)
  145. time1 = time.time()
  146. ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file")
  147. assert ckp_mgr.ckpoint_num == 3
  148. ckp_mgr.keep_one_ckpoint_per_minutes(1, time1)
  149. ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file")
  150. assert ckp_mgr.ckpoint_num == 1
  151. if os.path.exists(_cur_dir + '/time_file1.ckpt'):
  152. os.chmod(_cur_dir + '/time_file1.ckpt', stat.S_IWRITE)
  153. os.remove(_cur_dir + '/time_file1.ckpt')
  154. def test_load_param_into_net_error_net():
  155. parameter_dict = {}
  156. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  157. name="conv1.weight")
  158. parameter_dict["conv1.weight"] = one_param
  159. with pytest.raises(TypeError):
  160. load_param_into_net('', parameter_dict)
  161. def test_load_param_into_net_error_dict():
  162. net = Net(10)
  163. with pytest.raises(TypeError):
  164. load_param_into_net(net, '')
  165. def test_load_param_into_net_erro_dict_param():
  166. net = Net(10)
  167. net.init_parameters_data()
  168. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
  169. parameter_dict = {}
  170. one_param = ''
  171. parameter_dict["conv1.weight"] = one_param
  172. with pytest.raises(TypeError):
  173. load_param_into_net(net, parameter_dict)
  174. def test_load_param_into_net_has_more_param():
  175. """ test_load_param_into_net_has_more_param """
  176. net = Net(10)
  177. net.init_parameters_data()
  178. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
  179. parameter_dict = {}
  180. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  181. name="conv1.weight")
  182. parameter_dict["conv1.weight"] = one_param
  183. two_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  184. name="conv1.weight")
  185. parameter_dict["conv1.w"] = two_param
  186. load_param_into_net(net, parameter_dict)
  187. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 1
  188. def test_load_param_into_net_param_type_and_shape_error():
  189. net = Net(10)
  190. net.init_parameters_data()
  191. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
  192. parameter_dict = {}
  193. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.int32), name="conv1.weight")
  194. parameter_dict["conv1.weight"] = one_param
  195. with pytest.raises(RuntimeError):
  196. load_param_into_net(net, parameter_dict)
  197. def test_load_param_into_net_param_type_error():
  198. net = Net(10)
  199. net.init_parameters_data()
  200. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
  201. parameter_dict = {}
  202. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.int32),
  203. name="conv1.weight")
  204. parameter_dict["conv1.weight"] = one_param
  205. with pytest.raises(RuntimeError):
  206. load_param_into_net(net, parameter_dict)
  207. def test_load_param_into_net_param_shape_error():
  208. net = Net(10)
  209. net.init_parameters_data()
  210. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
  211. parameter_dict = {}
  212. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7,)), dtype=mstype.int32),
  213. name="conv1.weight")
  214. parameter_dict["conv1.weight"] = one_param
  215. with pytest.raises(RuntimeError):
  216. load_param_into_net(net, parameter_dict)
  217. def test_load_param_into_net():
  218. net = Net(10)
  219. net.init_parameters_data()
  220. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
  221. parameter_dict = {}
  222. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  223. name="conv1.weight")
  224. parameter_dict["conv1.weight"] = one_param
  225. load_param_into_net(net, parameter_dict)
  226. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 1
  227. def test_save_checkpoint_for_network():
  228. """ test save_checkpoint for network"""
  229. net = Net()
  230. loss = SoftmaxCrossEntropyWithLogits(sparse=True)
  231. opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
  232. loss_net = WithLossCell(net, loss)
  233. train_network = TrainOneStepCell(loss_net, opt)
  234. save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")
  235. load_checkpoint("new_ckpt.ckpt")
  236. def test_load_checkpoint_empty_file():
  237. os.mknod("empty.ckpt")
  238. with pytest.raises(ValueError):
  239. load_checkpoint("empty.ckpt")
  240. def test_save_and_load_checkpoint_for_network_with_encryption():
  241. """ test save and checkpoint for network with encryption"""
  242. net = Net()
  243. loss = SoftmaxCrossEntropyWithLogits(sparse=True)
  244. opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
  245. loss_net = WithLossCell(net, loss)
  246. train_network = TrainOneStepCell(loss_net, opt)
  247. key = os.urandom(16)
  248. mode = "AES-GCM"
  249. ckpt_path = "./encrypt_ckpt.ckpt"
  250. if platform.system().lower() == "windows":
  251. with pytest.raises(NotImplementedError):
  252. save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode)
  253. param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM")
  254. load_param_into_net(net, param_dict)
  255. else:
  256. save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode)
  257. param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM")
  258. load_param_into_net(net, param_dict)
  259. if os.path.exists(ckpt_path):
  260. os.remove(ckpt_path)
  261. class MYNET(nn.Cell):
  262. """ NET definition """
  263. def __init__(self):
  264. super(MYNET, self).__init__()
  265. self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
  266. self.bn = nn.BatchNorm2d(64)
  267. self.relu = nn.ReLU()
  268. self.flatten = nn.Flatten()
  269. self.fc = nn.Dense(64 * 222 * 222, 3) # padding=0
  270. def construct(self, x):
  271. x = self.conv(x)
  272. x = self.bn(x)
  273. x = self.relu(x)
  274. x = self.flatten(x)
  275. out = self.fc(x)
  276. return out
  277. @non_graph_engine
  278. def test_export():
  279. net = MYNET()
  280. input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  281. with pytest.raises(ValueError):
  282. export(net, input_data, file_name="./me_export.pb", file_format="AIR")
  283. @non_graph_engine
  284. def test_mindir_export():
  285. net = MYNET()
  286. input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  287. export(net, input_data, file_name="./me_binary_export", file_format="MINDIR")
  288. class PrintNet(nn.Cell):
  289. def __init__(self):
  290. super(PrintNet, self).__init__()
  291. self.print = P.Print()
  292. def construct(self, int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_,
  293. scale1, scale2):
  294. self.print('============tensor int8:==============', int8)
  295. self.print('============tensor int8:==============', int8)
  296. self.print('============tensor uint8:==============', uint8)
  297. self.print('============tensor int16:==============', int16)
  298. self.print('============tensor uint16:==============', uint16)
  299. self.print('============tensor int32:==============', int32)
  300. self.print('============tensor uint32:==============', uint32)
  301. self.print('============tensor int64:==============', int64)
  302. self.print('============tensor uint64:==============', uint64)
  303. self.print('============tensor float16:==============', flt16)
  304. self.print('============tensor float32:==============', flt32)
  305. self.print('============tensor float64:==============', flt64)
  306. self.print('============tensor bool:==============', bool_)
  307. self.print('============tensor scale1:==============', scale1)
  308. self.print('============tensor scale2:==============', scale2)
  309. return int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_, scale1, scale2
  310. def test_print():
  311. print_net = PrintNet()
  312. int8 = Tensor(np.random.randint(100, size=(10, 10), dtype="int8"))
  313. uint8 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint8"))
  314. int16 = Tensor(np.random.randint(100, size=(10, 10), dtype="int16"))
  315. uint16 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint16"))
  316. int32 = Tensor(np.random.randint(100, size=(10, 10), dtype="int32"))
  317. uint32 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint32"))
  318. int64 = Tensor(np.random.randint(100, size=(10, 10), dtype="int64"))
  319. uint64 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint64"))
  320. float16 = Tensor(np.random.rand(224, 224).astype(np.float16))
  321. float32 = Tensor(np.random.rand(224, 224).astype(np.float32))
  322. float64 = Tensor(np.random.rand(224, 224).astype(np.float64))
  323. bool_ = Tensor(np.arange(-10, 10, 2).astype(np.bool_))
  324. scale1 = Tensor(np.array(1))
  325. scale2 = Tensor(np.array(0.1))
  326. print_net(int8, uint8, int16, uint16, int32, uint32, int64, uint64, float16, float32, float64, bool_, scale1,
  327. scale2)
  328. def teardown_module():
  329. files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt']
  330. for item in files:
  331. file_name = './' + item
  332. if not os.path.exists(file_name):
  333. continue
  334. os.chmod(file_name, stat.S_IWRITE)
  335. os.remove(file_name)
  336. import shutil
  337. if os.path.exists('./print'):
  338. shutil.rmtree('./print')