You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_predict_save_model.py 3.3 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Function:
  17. test network
  18. Usage:
  19. python test_predict_save_model.py --path ./
  20. """
  21. import argparse
  22. import os
  23. import numpy as np
  24. import mindspore.context as context
  25. import mindspore.nn as nn
  26. import mindspore.ops.operations as P
  27. from mindspore.common.tensor import Tensor
  28. from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
  29. class LeNet(nn.Cell):
  30. def __init__(self):
  31. super(LeNet, self).__init__()
  32. self.relu = P.ReLU()
  33. self.batch_size = 32
  34. self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
  35. self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
  36. self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
  37. self.reshape = P.Reshape()
  38. self.fc1 = nn.Dense(400, 120)
  39. self.fc2 = nn.Dense(120, 84)
  40. self.fc3 = nn.Dense(84, 10)
  41. def construct(self, input_x):
  42. output = self.conv1(input_x)
  43. output = self.relu(output)
  44. output = self.pool(output)
  45. output = self.conv2(output)
  46. output = self.relu(output)
  47. output = self.pool(output)
  48. output = self.reshape(output, (self.batch_size, -1))
  49. output = self.fc1(output)
  50. output = self.relu(output)
  51. output = self.fc2(output)
  52. output = self.relu(output)
  53. output = self.fc3(output)
  54. return output
  55. parser = argparse.ArgumentParser(description='MindSpore Model Save')
  56. parser.add_argument('--path', default='./lenet_model.ms', type=str, help='model save path')
  57. if __name__ == '__main__':
  58. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  59. print("test lenet predict start")
  60. seed = 0
  61. np.random.seed(seed)
  62. batch = 32
  63. channel = 1
  64. input_h = 32
  65. input_w = 32
  66. origin_data = np.random.uniform(low=0, high=255, size=(batch, channel, input_h, input_w)).astype(np.float32)
  67. origin_data.tofile("lenet_input_data.bin")
  68. input_data = Tensor(origin_data)
  69. print(input_data.asnumpy())
  70. net = LeNet()
  71. ckpt_file_path = "./tests/ut/python/predict/checkpoint_lenet.ckpt"
  72. predict_args = parser.parse_args()
  73. model_path_name = predict_args.path
  74. is_ckpt_exist = os.path.exists(ckpt_file_path)
  75. if is_ckpt_exist:
  76. param_dict = load_checkpoint(ckpt_file_name=ckpt_file_path)
  77. load_param_into_net(net, param_dict)
  78. export(net, input_data, file_name=model_path_name, file_format='LITE')
  79. print("test lenet predict success.")
  80. else:
  81. print("checkpoint file is not exist.")