You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serialize.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ut for model serialize(save/load)"""
  16. import os
  17. import stat
  18. import time
  19. import numpy as np
  20. import pytest
  21. import mindspore.common.dtype as mstype
  22. import mindspore.nn as nn
  23. from mindspore import context
  24. from mindspore.common.parameter import Parameter
  25. from mindspore.common.tensor import Tensor
  26. from mindspore.nn import SoftmaxCrossEntropyWithLogits
  27. from mindspore.nn import WithLossCell, TrainOneStepCell
  28. from mindspore.nn.optim.momentum import Momentum
  29. from mindspore.ops import operations as P
  30. from mindspore.train.callback import _CheckpointManager
  31. from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
  32. _exec_save_checkpoint, export, _save_graph
  33. from ..ut_filter import run_on_onnxruntime, non_graph_engine
  34. context.set_context(mode=context.GRAPH_MODE)
  35. class Net(nn.Cell):
  36. """Net definition."""
  37. def __init__(self, num_classes=10):
  38. super(Net, self).__init__()
  39. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
  40. self.bn1 = nn.BatchNorm2d(64)
  41. self.relu = nn.ReLU()
  42. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
  43. self.flatten = nn.Flatten()
  44. self.fc = nn.Dense(int(224 * 224 * 64 / 16), num_classes)
  45. def construct(self, x):
  46. x = self.conv1(x)
  47. x = self.bn1(x)
  48. x = self.relu(x)
  49. x = self.maxpool(x)
  50. x = self.flatten(x)
  51. x = self.fc(x)
  52. return x
  53. _input_x = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  54. _cur_dir = os.path.dirname(os.path.realpath(__file__))
  55. def setup_module():
  56. import shutil
  57. if os.path.exists('./test_files'):
  58. shutil.rmtree('./test_files')
  59. def test_save_graph():
  60. """ test_exec_save_graph """
  61. class Net(nn.Cell):
  62. def __init__(self):
  63. super(Net, self).__init__()
  64. self.add = P.TensorAdd()
  65. def construct(self, x, y):
  66. z = self.add(x, y)
  67. return z
  68. net = Net()
  69. net.set_train()
  70. out_me_list = []
  71. x = Tensor(np.random.rand(2, 1, 2, 3).astype(np.float32))
  72. y = Tensor(np.array([1.2]).astype(np.float32))
  73. out_put = net(x, y)
  74. _save_graph(network=net, file_name="net-graph.meta")
  75. out_me_list.append(out_put)
  76. def test_save_checkpoint():
  77. """ test_save_checkpoint """
  78. parameter_list = []
  79. one_param = {}
  80. param1 = {}
  81. param2 = {}
  82. one_param['name'] = "param_test"
  83. one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
  84. param1['name'] = "param"
  85. param1['data'] = Tensor(np.random.randint(0, 255, [12, 1024]), dtype=mstype.float32)
  86. param2['name'] = "new_param"
  87. param2['data'] = Tensor(np.random.randint(0, 255, [12, 1024, 1]), dtype=mstype.float32)
  88. parameter_list.append(one_param)
  89. parameter_list.append(param1)
  90. parameter_list.append(param2)
  91. if os.path.exists('./parameters.ckpt'):
  92. os.chmod('./parameters.ckpt', stat.S_IWRITE)
  93. os.remove('./parameters.ckpt')
  94. ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt')
  95. save_checkpoint(parameter_list, ckpoint_file_name)
  96. def test_load_checkpoint_error_filename():
  97. ckpoint_file_name = 1
  98. with pytest.raises(ValueError):
  99. load_checkpoint(ckpoint_file_name)
  100. def test_load_checkpoint():
  101. ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt')
  102. par_dict = load_checkpoint(ckpoint_file_name)
  103. assert len(par_dict) == 3
  104. assert par_dict['param_test'].name == 'param_test'
  105. assert par_dict['param_test'].data.dtype() == mstype.float32
  106. assert par_dict['param_test'].data.shape() == (1, 3, 224, 224)
  107. assert isinstance(par_dict, dict)
  108. def test_checkpoint_manager():
  109. """ test_checkpoint_manager """
  110. ckp_mgr = _CheckpointManager()
  111. ckpoint_file_name = os.path.join(_cur_dir, './test1.ckpt')
  112. with open(ckpoint_file_name, 'w'):
  113. os.chmod(ckpoint_file_name, stat.S_IWUSR | stat.S_IRUSR)
  114. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  115. assert ckp_mgr.ckpoint_num == 1
  116. ckp_mgr.remove_ckpoint_file(ckpoint_file_name)
  117. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  118. assert ckp_mgr.ckpoint_num == 0
  119. assert not os.path.exists(ckpoint_file_name)
  120. another_file_name = os.path.join(_cur_dir, './test2.ckpt')
  121. another_file_name = os.path.realpath(another_file_name)
  122. with open(another_file_name, 'w'):
  123. os.chmod(another_file_name, stat.S_IWUSR | stat.S_IRUSR)
  124. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  125. assert ckp_mgr.ckpoint_num == 1
  126. ckp_mgr.remove_oldest_ckpoint_file()
  127. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  128. assert ckp_mgr.ckpoint_num == 0
  129. assert not os.path.exists(another_file_name)
  130. # test keep_one_ckpoint_per_minutes
  131. file1 = os.path.realpath(os.path.join(_cur_dir, './time_file1.ckpt'))
  132. file2 = os.path.realpath(os.path.join(_cur_dir, './time_file2.ckpt'))
  133. file3 = os.path.realpath(os.path.join(_cur_dir, './time_file3.ckpt'))
  134. with open(file1, 'w'):
  135. os.chmod(file1, stat.S_IWUSR | stat.S_IRUSR)
  136. with open(file2, 'w'):
  137. os.chmod(file2, stat.S_IWUSR | stat.S_IRUSR)
  138. with open(file3, 'w'):
  139. os.chmod(file3, stat.S_IWUSR | stat.S_IRUSR)
  140. time1 = time.time()
  141. ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file")
  142. assert ckp_mgr.ckpoint_num == 3
  143. ckp_mgr.keep_one_ckpoint_per_minutes(1, time1)
  144. ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file")
  145. assert ckp_mgr.ckpoint_num == 1
  146. if os.path.exists(_cur_dir + '/time_file1.ckpt'):
  147. os.chmod(_cur_dir + '/time_file1.ckpt', stat.S_IWRITE)
  148. os.remove(_cur_dir + '/time_file1.ckpt')
  149. def test_load_param_into_net_error_net():
  150. parameter_dict = {}
  151. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  152. name="conv1.weight")
  153. parameter_dict["conv1.weight"] = one_param
  154. with pytest.raises(TypeError):
  155. load_param_into_net('', parameter_dict)
  156. def test_load_param_into_net_error_dict():
  157. net = Net(10)
  158. with pytest.raises(TypeError):
  159. load_param_into_net(net, '')
  160. def test_load_param_into_net_erro_dict_param():
  161. net = Net(10)
  162. net.init_parameters_data()
  163. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  164. parameter_dict = {}
  165. one_param = ''
  166. parameter_dict["conv1.weight"] = one_param
  167. with pytest.raises(TypeError):
  168. load_param_into_net(net, parameter_dict)
  169. def test_load_param_into_net_has_more_param():
  170. """ test_load_param_into_net_has_more_param """
  171. net = Net(10)
  172. net.init_parameters_data()
  173. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  174. parameter_dict = {}
  175. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  176. name="conv1.weight")
  177. parameter_dict["conv1.weight"] = one_param
  178. two_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  179. name="conv1.weight")
  180. parameter_dict["conv1.w"] = two_param
  181. load_param_into_net(net, parameter_dict)
  182. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 1
  183. def test_load_param_into_net_param_type_and_shape_error():
  184. net = Net(10)
  185. net.init_parameters_data()
  186. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  187. parameter_dict = {}
  188. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7))), name="conv1.weight")
  189. parameter_dict["conv1.weight"] = one_param
  190. with pytest.raises(RuntimeError):
  191. load_param_into_net(net, parameter_dict)
  192. def test_load_param_into_net_param_type_error():
  193. net = Net(10)
  194. net.init_parameters_data()
  195. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  196. parameter_dict = {}
  197. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.int32),
  198. name="conv1.weight")
  199. parameter_dict["conv1.weight"] = one_param
  200. with pytest.raises(RuntimeError):
  201. load_param_into_net(net, parameter_dict)
  202. def test_load_param_into_net_param_shape_error():
  203. net = Net(10)
  204. net.init_parameters_data()
  205. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  206. parameter_dict = {}
  207. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7,)), dtype=mstype.int32),
  208. name="conv1.weight")
  209. parameter_dict["conv1.weight"] = one_param
  210. with pytest.raises(RuntimeError):
  211. load_param_into_net(net, parameter_dict)
  212. def test_load_param_into_net():
  213. net = Net(10)
  214. net.init_parameters_data()
  215. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  216. parameter_dict = {}
  217. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  218. name="conv1.weight")
  219. parameter_dict["conv1.weight"] = one_param
  220. load_param_into_net(net, parameter_dict)
  221. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 1
  222. def test_exec_save_checkpoint():
  223. net = Net()
  224. loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  225. opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
  226. loss_net = WithLossCell(net, loss)
  227. train_network = TrainOneStepCell(loss_net, opt)
  228. _exec_save_checkpoint(train_network, ckpoint_file_name="./new_ckpt.ckpt")
  229. load_checkpoint("new_ckpt.ckpt")
  230. def test_load_checkpoint_empty_file():
  231. os.mknod("empty.ckpt")
  232. with pytest.raises(ValueError):
  233. load_checkpoint("empty.ckpt")
  234. class MYNET(nn.Cell):
  235. """ NET definition """
  236. def __init__(self):
  237. super(MYNET, self).__init__()
  238. self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
  239. self.bn = nn.BatchNorm2d(64)
  240. self.relu = nn.ReLU()
  241. self.flatten = nn.Flatten()
  242. self.fc = nn.Dense(64 * 222 * 222, 3) # padding=0
  243. def construct(self, x):
  244. x = self.conv(x)
  245. x = self.bn(x)
  246. x = self.relu(x)
  247. x = self.flatten(x)
  248. out = self.fc(x)
  249. return out
  250. @non_graph_engine
  251. def test_export():
  252. net = MYNET()
  253. input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  254. export(net, input_data, file_name="./me_export.pb", file_format="GEIR")
  255. class BatchNormTester(nn.Cell):
  256. "used to test exporting network in training mode in onnx format"
  257. def __init__(self, num_features):
  258. super(BatchNormTester, self).__init__()
  259. self.bn = nn.BatchNorm2d(num_features)
  260. def construct(self, x):
  261. return self.bn(x)
  262. class DepthwiseConv2dAndReLU6(nn.Cell):
  263. "Net for testing DepthwiseConv2d and ReLU6"
  264. def __init__(self, input_channel, kernel_size):
  265. super(DepthwiseConv2dAndReLU6, self).__init__()
  266. weight_shape = [1, input_channel, kernel_size, kernel_size]
  267. from mindspore.common.initializer import initializer
  268. self.weight = Parameter(initializer('ones', weight_shape), name='weight')
  269. self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=(kernel_size, kernel_size))
  270. self.relu6 = nn.ReLU6()
  271. def construct(self, x):
  272. x = self.depthwise_conv(x, self.weight)
  273. x = self.relu6(x)
  274. return x
  275. def test_batchnorm_train_onnx_export():
  276. input = Tensor(np.ones([1, 3, 32, 32]).astype(np.float32) * 0.01)
  277. net = BatchNormTester(3)
  278. net.set_train()
  279. if not net.training:
  280. raise ValueError('netowrk is not in training mode')
  281. export(net, input, file_name='batch_norm.onnx', file_format='ONNX')
  282. if not net.training:
  283. raise ValueError('netowrk is not in training mode')
  284. class LeNet5(nn.Cell):
  285. """LeNet5 definition"""
  286. def __init__(self):
  287. super(LeNet5, self).__init__()
  288. self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
  289. self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
  290. self.fc1 = nn.Dense(16 * 5 * 5, 120)
  291. self.fc2 = nn.Dense(120, 84)
  292. self.fc3 = nn.Dense(84, 10)
  293. self.relu = nn.ReLU()
  294. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  295. self.flatten = P.Flatten()
  296. def construct(self, x):
  297. x = self.max_pool2d(self.relu(self.conv1(x)))
  298. x = self.max_pool2d(self.relu(self.conv2(x)))
  299. x = self.flatten(x)
  300. x = self.relu(self.fc1(x))
  301. x = self.relu(self.fc2(x))
  302. x = self.fc3(x)
  303. return x
  304. def test_lenet5_onnx_export():
  305. input = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
  306. net = LeNet5()
  307. export(net, input, file_name='lenet5.onnx', file_format='ONNX')
  308. class DefinedNet(nn.Cell):
  309. """simple Net definition with maxpoolwithargmax."""
  310. def __init__(self, num_classes=10):
  311. super(DefinedNet, self).__init__()
  312. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
  313. self.bn1 = nn.BatchNorm2d(64)
  314. self.relu = nn.ReLU()
  315. self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=2, strides=2)
  316. self.flatten = nn.Flatten()
  317. self.fc = nn.Dense(int(56 * 56 * 64), num_classes)
  318. def construct(self, x):
  319. x = self.conv1(x)
  320. x = self.bn1(x)
  321. x = self.relu(x)
  322. x, argmax = self.maxpool(x)
  323. x = self.flatten(x)
  324. x = self.fc(x)
  325. return x
  326. def test_net_onnx_maxpoolwithargmax_export():
  327. input = Tensor(np.ones([1, 3, 224, 224]).astype(np.float32) * 0.01)
  328. net = DefinedNet()
  329. export(net, input, file_name='definedNet.onnx', file_format='ONNX')
  330. @run_on_onnxruntime
  331. def test_lenet5_onnx_load_run():
  332. onnx_file = 'lenet5.onnx'
  333. input = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
  334. net = LeNet5()
  335. export(net, input, file_name=onnx_file, file_format='ONNX')
  336. import onnx
  337. import onnxruntime as ort
  338. print('--------------------- onnx load ---------------------')
  339. # Load the ONNX model
  340. model = onnx.load(onnx_file)
  341. # Check that the IR is well formed
  342. onnx.checker.check_model(model)
  343. # Print a human readable representation of the graph
  344. g = onnx.helper.printable_graph(model.graph)
  345. print(g)
  346. print('------------------ onnxruntime run ------------------')
  347. ort_session = ort.InferenceSession(onnx_file)
  348. input_map = {'x': input.asnumpy()}
  349. # provide only input x to run model
  350. outputs = ort_session.run(None, input_map)
  351. print(outputs[0])
  352. # overwrite default weight to run model
  353. for item in net.trainable_params():
  354. input_map[item.name] = np.ones(item.default_input.asnumpy().shape, dtype=np.float32)
  355. outputs = ort_session.run(None, input_map)
  356. print(outputs[0])
  357. @run_on_onnxruntime
  358. def test_depthwiseconv_relu6_onnx_load_run():
  359. onnx_file = 'depthwiseconv_relu6.onnx'
  360. input_channel = 3
  361. input = Tensor(np.ones([1, input_channel, 32, 32]).astype(np.float32) * 0.01)
  362. net = DepthwiseConv2dAndReLU6(input_channel, kernel_size=3)
  363. export(net, input, file_name=onnx_file, file_format='ONNX')
  364. import onnx
  365. import onnxruntime as ort
  366. print('--------------------- onnx load ---------------------')
  367. # Load the ONNX model
  368. model = onnx.load(onnx_file)
  369. # Check that the IR is well formed
  370. onnx.checker.check_model(model)
  371. # Print a human readable representation of the graph
  372. g = onnx.helper.printable_graph(model.graph)
  373. print(g)
  374. print('------------------ onnxruntime run ------------------')
  375. ort_session = ort.InferenceSession(onnx_file)
  376. input_map = {'x': input.asnumpy()}
  377. # provide only input x to run model
  378. outputs = ort_session.run(None, input_map)
  379. print(outputs[0])
  380. # overwrite default weight to run model
  381. for item in net.trainable_params():
  382. input_map[item.name] = np.ones(item.default_input.asnumpy().shape, dtype=np.float32)
  383. outputs = ort_session.run(None, input_map)
  384. print(outputs[0])
  385. def teardown_module():
  386. files = ['parameters.ckpt', 'new_ckpt.ckpt', 'lenet5.onnx', 'batch_norm.onnx', 'empty.ckpt']
  387. for item in files:
  388. file_name = './' + item
  389. if not os.path.exists(file_name):
  390. continue
  391. os.chmod(file_name, stat.S_IWRITE)
  392. os.remove(file_name)