You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serialize.py 19 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ut for model serialize(save/load)"""
  16. import os
  17. import platform
  18. import stat
  19. import time
  20. import secrets
  21. import numpy as np
  22. import pytest
  23. import mindspore.common.dtype as mstype
  24. import mindspore.nn as nn
  25. from mindspore import context
  26. from mindspore.common.parameter import Parameter
  27. from mindspore.common.tensor import Tensor
  28. from mindspore.nn import SoftmaxCrossEntropyWithLogits
  29. from mindspore.nn import WithLossCell, TrainOneStepCell
  30. from mindspore.nn.optim.momentum import Momentum
  31. from mindspore.ops import operations as P
  32. from mindspore.train.callback import _CheckpointManager
  33. from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
  34. export, _save_graph, load
  35. from tests.security_utils import security_off_wrap
  36. from ..ut_filter import non_graph_engine
  37. class Net(nn.Cell):
  38. """Net definition."""
  39. def __init__(self, num_classes=10):
  40. super(Net, self).__init__()
  41. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
  42. self.bn1 = nn.BatchNorm2d(64)
  43. self.relu = nn.ReLU()
  44. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
  45. self.flatten = nn.Flatten()
  46. self.fc = nn.Dense(int(224 * 224 * 64 / 16), num_classes)
  47. def construct(self, x):
  48. x = self.conv1(x)
  49. x = self.bn1(x)
  50. x = self.relu(x)
  51. x = self.maxpool(x)
  52. x = self.flatten(x)
  53. x = self.fc(x)
  54. return x
  55. _input_x = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  56. _cur_dir = os.path.dirname(os.path.realpath(__file__))
  57. def setup_module():
  58. import shutil
  59. if os.path.exists('./test_files'):
  60. shutil.rmtree('./test_files')
  61. def test_save_graph():
  62. """ test_exec_save_graph """
  63. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  64. class Net1(nn.Cell):
  65. def __init__(self):
  66. super(Net1, self).__init__()
  67. self.add = P.Add()
  68. def construct(self, x, y):
  69. z = self.add(x, y)
  70. return z
  71. net = Net1()
  72. net.set_train()
  73. out_me_list = []
  74. x = Tensor(np.random.rand(2, 1, 2, 3).astype(np.float32))
  75. y = Tensor(np.array([1.2]).astype(np.float32))
  76. out_put = net(x, y)
  77. output_file = "net-graph.meta"
  78. _save_graph(network=net, file_name=output_file)
  79. out_me_list.append(out_put)
  80. assert os.path.exists(output_file)
  81. os.chmod(output_file, stat.S_IWRITE)
  82. os.remove(output_file)
  83. def test_save_checkpoint_for_list():
  84. """ test save_checkpoint for list"""
  85. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  86. parameter_list = []
  87. one_param = {}
  88. param1 = {}
  89. param2 = {}
  90. one_param['name'] = "param_test"
  91. one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
  92. param1['name'] = "param"
  93. param1['data'] = Tensor(np.random.randint(0, 255, [12, 1024]), dtype=mstype.float32)
  94. param2['name'] = "new_param"
  95. param2['data'] = Tensor(np.random.randint(0, 255, [12, 1024, 1]), dtype=mstype.float32)
  96. parameter_list.append(one_param)
  97. parameter_list.append(param1)
  98. parameter_list.append(param2)
  99. if os.path.exists('./parameters.ckpt'):
  100. os.chmod('./parameters.ckpt', stat.S_IWRITE)
  101. os.remove('./parameters.ckpt')
  102. ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
  103. save_checkpoint(parameter_list, ckpt_file_name)
  104. def test_load_checkpoint_error_filename():
  105. """
  106. Feature: Load checkpoint.
  107. Description: Load checkpoint with error filename.
  108. Expectation: Raise value error for error filename.
  109. """
  110. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  111. ckpt_file_name = 1
  112. with pytest.raises(TypeError):
  113. load_checkpoint(ckpt_file_name)
  114. def test_save_checkpoint_for_list_append_info_and_load_checkpoint():
  115. """
  116. Feature: Save checkpoint for list append info and load checkpoint.
  117. Description: Save checkpoint for list append info and load checkpoint with list append info.
  118. Expectation: Checkpoint for list append info can be saved and reloaded.
  119. """
  120. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  121. parameter_list = []
  122. one_param = {}
  123. param1 = {}
  124. param2 = {}
  125. one_param['name'] = "param_test"
  126. one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
  127. param1['name'] = "param"
  128. param1['data'] = Tensor(np.random.randint(0, 255, [12, 1024]), dtype=mstype.float32)
  129. param2['name'] = "new_param"
  130. param2['data'] = Tensor(np.random.randint(0, 255, [12, 1024, 1]), dtype=mstype.float32)
  131. parameter_list.append(one_param)
  132. parameter_list.append(param1)
  133. parameter_list.append(param2)
  134. append_dict = {"lr": 0.01, "epoch": 20, "train": True}
  135. if os.path.exists('./parameters.ckpt'):
  136. os.chmod('./parameters.ckpt', stat.S_IWRITE)
  137. os.remove('./parameters.ckpt')
  138. ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
  139. save_checkpoint(parameter_list, ckpt_file_name, append_dict=append_dict)
  140. par_dict = load_checkpoint(ckpt_file_name)
  141. assert len(par_dict) == 6
  142. assert par_dict['param_test'].name == 'param_test'
  143. assert par_dict['param_test'].data.dtype == mstype.float32
  144. assert par_dict['param_test'].data.shape == (1, 3, 224, 224)
  145. assert isinstance(par_dict, dict)
  146. def test_checkpoint_manager():
  147. """ test_checkpoint_manager """
  148. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  149. ckp_mgr = _CheckpointManager()
  150. ckpt_file_name = os.path.join(_cur_dir, './test-1_1.ckpt')
  151. with open(ckpt_file_name, 'w'):
  152. os.chmod(ckpt_file_name, stat.S_IWUSR | stat.S_IRUSR)
  153. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  154. assert ckp_mgr.ckpoint_num == 1
  155. ckp_mgr.remove_ckpoint_file(ckpt_file_name)
  156. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  157. assert ckp_mgr.ckpoint_num == 0
  158. assert not os.path.exists(ckpt_file_name)
  159. another_file_name = os.path.join(_cur_dir, './test-2_1.ckpt')
  160. another_file_name = os.path.realpath(another_file_name)
  161. with open(another_file_name, 'w'):
  162. os.chmod(another_file_name, stat.S_IWUSR | stat.S_IRUSR)
  163. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  164. assert ckp_mgr.ckpoint_num == 1
  165. ckp_mgr.remove_oldest_ckpoint_file()
  166. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  167. assert ckp_mgr.ckpoint_num == 0
  168. assert not os.path.exists(another_file_name)
  169. # test keep_one_ckpoint_per_minutes
  170. file1 = os.path.realpath(os.path.join(_cur_dir, './time_file-1_1.ckpt'))
  171. file2 = os.path.realpath(os.path.join(_cur_dir, './time_file-2_1.ckpt'))
  172. file3 = os.path.realpath(os.path.join(_cur_dir, './time_file-3_1.ckpt'))
  173. with open(file1, 'w'):
  174. os.chmod(file1, stat.S_IWUSR | stat.S_IRUSR)
  175. with open(file2, 'w'):
  176. os.chmod(file2, stat.S_IWUSR | stat.S_IRUSR)
  177. with open(file3, 'w'):
  178. os.chmod(file3, stat.S_IWUSR | stat.S_IRUSR)
  179. time1 = time.time()
  180. ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file")
  181. assert ckp_mgr.ckpoint_num == 3
  182. ckp_mgr.keep_one_ckpoint_per_minutes(1, time1)
  183. ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file")
  184. assert ckp_mgr.ckpoint_num == 1
  185. if os.path.exists(_cur_dir + '/time_file-1_1.ckpt'):
  186. os.chmod(_cur_dir + '/time_file-1_1.ckpt', stat.S_IWRITE)
  187. os.remove(_cur_dir + '/time_file-1_1.ckpt')
  188. def test_load_param_into_net_error_net():
  189. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  190. parameter_dict = {}
  191. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  192. name="conv1.weight")
  193. parameter_dict["conv1.weight"] = one_param
  194. with pytest.raises(TypeError):
  195. load_param_into_net('', parameter_dict)
  196. def test_load_param_into_net_error_dict():
  197. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  198. net = Net(10)
  199. with pytest.raises(TypeError):
  200. load_param_into_net(net, '')
  201. def test_load_param_into_net_erro_dict_param():
  202. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  203. net = Net(10)
  204. net.init_parameters_data()
  205. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
  206. parameter_dict = {}
  207. one_param = ''
  208. parameter_dict["conv1.weight"] = one_param
  209. with pytest.raises(TypeError):
  210. load_param_into_net(net, parameter_dict)
  211. def test_load_param_into_net_has_more_param():
  212. """ test_load_param_into_net_has_more_param """
  213. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  214. net = Net(10)
  215. net.init_parameters_data()
  216. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
  217. parameter_dict = {}
  218. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  219. name="conv1.weight")
  220. parameter_dict["conv1.weight"] = one_param
  221. two_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  222. name="conv1.weight")
  223. parameter_dict["conv1.w"] = two_param
  224. load_param_into_net(net, parameter_dict)
  225. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 1
  226. def test_load_param_into_net_param_type_and_shape_error():
  227. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  228. net = Net(10)
  229. net.init_parameters_data()
  230. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
  231. parameter_dict = {}
  232. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.int32), name="conv1.weight")
  233. parameter_dict["conv1.weight"] = one_param
  234. with pytest.raises(RuntimeError):
  235. load_param_into_net(net, parameter_dict)
  236. def test_load_param_into_net_param_type_error():
  237. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  238. net = Net(10)
  239. net.init_parameters_data()
  240. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
  241. parameter_dict = {}
  242. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.int32),
  243. name="conv1.weight")
  244. parameter_dict["conv1.weight"] = one_param
  245. with pytest.raises(RuntimeError):
  246. load_param_into_net(net, parameter_dict)
  247. def test_load_param_into_net_param_shape_error():
  248. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  249. net = Net(10)
  250. net.init_parameters_data()
  251. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
  252. parameter_dict = {}
  253. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7,)), dtype=mstype.int32),
  254. name="conv1.weight")
  255. parameter_dict["conv1.weight"] = one_param
  256. with pytest.raises(RuntimeError):
  257. load_param_into_net(net, parameter_dict)
  258. def test_load_param_into_net():
  259. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  260. net = Net(10)
  261. net.init_parameters_data()
  262. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
  263. parameter_dict = {}
  264. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  265. name="conv1.weight")
  266. parameter_dict["conv1.weight"] = one_param
  267. load_param_into_net(net, parameter_dict)
  268. assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 1
  269. def test_save_checkpoint_for_network():
  270. """ test save_checkpoint for network"""
  271. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  272. net = Net()
  273. loss = SoftmaxCrossEntropyWithLogits(sparse=True)
  274. opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
  275. loss_net = WithLossCell(net, loss)
  276. train_network = TrainOneStepCell(loss_net, opt)
  277. save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")
  278. load_checkpoint("new_ckpt.ckpt")
  279. def test_load_checkpoint_empty_file():
  280. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  281. os.mknod("empty.ckpt")
  282. with pytest.raises(ValueError):
  283. load_checkpoint("empty.ckpt")
  284. def test_save_and_load_checkpoint_for_network_with_encryption():
  285. """ test save and checkpoint for network with encryption"""
  286. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  287. net = Net()
  288. loss = SoftmaxCrossEntropyWithLogits(sparse=True)
  289. opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
  290. loss_net = WithLossCell(net, loss)
  291. train_network = TrainOneStepCell(loss_net, opt)
  292. key = secrets.token_bytes(16)
  293. mode = "AES-GCM"
  294. ckpt_path = "./encrypt_ckpt.ckpt"
  295. if platform.system().lower() == "windows":
  296. with pytest.raises(NotImplementedError):
  297. save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode)
  298. param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM")
  299. load_param_into_net(net, param_dict)
  300. else:
  301. save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode)
  302. param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM")
  303. load_param_into_net(net, param_dict)
  304. if os.path.exists(ckpt_path):
  305. os.remove(ckpt_path)
  306. class MYNET(nn.Cell):
  307. """ NET definition """
  308. def __init__(self):
  309. super(MYNET, self).__init__()
  310. self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
  311. self.bn = nn.BatchNorm2d(64)
  312. self.relu = nn.ReLU()
  313. self.flatten = nn.Flatten()
  314. self.fc = nn.Dense(64 * 222 * 222, 3) # padding=0
  315. def construct(self, x):
  316. x = self.conv(x)
  317. x = self.bn(x)
  318. x = self.relu(x)
  319. x = self.flatten(x)
  320. out = self.fc(x)
  321. return out
  322. @non_graph_engine
  323. def test_export():
  324. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  325. net = MYNET()
  326. input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  327. with pytest.raises(ValueError):
  328. export(net, input_data, file_name="./me_export.pb", file_format="AIR")
  329. @non_graph_engine
  330. def test_mindir_export():
  331. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  332. net = MYNET()
  333. input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  334. export(net, input_data, file_name="./me_binary_export", file_format="MINDIR")
  335. @non_graph_engine
  336. def test_mindir_export_and_load_with_encryption():
  337. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  338. net = MYNET()
  339. input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  340. key = secrets.token_bytes(16)
  341. export(net, input_data, file_name="./me_cipher_binary_export.mindir", file_format="MINDIR", enc_key=key)
  342. load("./me_cipher_binary_export.mindir", dec_key=key)
  343. class PrintNet(nn.Cell):
  344. def __init__(self):
  345. super(PrintNet, self).__init__()
  346. self.print = P.Print()
  347. def construct(self, int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_,
  348. scale1, scale2):
  349. self.print('============tensor int8:==============', int8)
  350. self.print('============tensor int8:==============', int8)
  351. self.print('============tensor uint8:==============', uint8)
  352. self.print('============tensor int16:==============', int16)
  353. self.print('============tensor uint16:==============', uint16)
  354. self.print('============tensor int32:==============', int32)
  355. self.print('============tensor uint32:==============', uint32)
  356. self.print('============tensor int64:==============', int64)
  357. self.print('============tensor uint64:==============', uint64)
  358. self.print('============tensor float16:==============', flt16)
  359. self.print('============tensor float32:==============', flt32)
  360. self.print('============tensor float64:==============', flt64)
  361. self.print('============tensor bool:==============', bool_)
  362. self.print('============tensor scale1:==============', scale1)
  363. self.print('============tensor scale2:==============', scale2)
  364. return int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_, scale1, scale2
  365. @security_off_wrap
  366. def test_print():
  367. context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
  368. print_net = PrintNet()
  369. int8 = Tensor(np.random.randint(100, size=(10, 10), dtype="int8"))
  370. uint8 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint8"))
  371. int16 = Tensor(np.random.randint(100, size=(10, 10), dtype="int16"))
  372. uint16 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint16"))
  373. int32 = Tensor(np.random.randint(100, size=(10, 10), dtype="int32"))
  374. uint32 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint32"))
  375. int64 = Tensor(np.random.randint(100, size=(10, 10), dtype="int64"))
  376. uint64 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint64"))
  377. float16 = Tensor(np.random.rand(224, 224).astype(np.float16))
  378. float32 = Tensor(np.random.rand(224, 224).astype(np.float32))
  379. float64 = Tensor(np.random.rand(224, 224).astype(np.float64))
  380. bool_ = Tensor(np.arange(-10, 10, 2).astype(np.bool_))
  381. scale1 = Tensor(np.array(1))
  382. scale2 = Tensor(np.array(0.1))
  383. print_net(int8, uint8, int16, uint16, int32, uint32, int64, uint64, float16, float32, float64, bool_, scale1,
  384. scale2)
  385. def teardown_module():
  386. files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt']
  387. for item in files:
  388. file_name = './' + item
  389. if not os.path.exists(file_name):
  390. continue
  391. os.chmod(file_name, stat.S_IWRITE)
  392. os.remove(file_name)
  393. import shutil
  394. if os.path.exists('./print'):
  395. shutil.rmtree('./print')