You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_export.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import os
  2. import numpy as np
  3. import mindspore.nn as nn
  4. import mindspore.dataset as ds
  5. import mindspore.dataset.vision.c_transforms as CV
  6. import mindspore.dataset.transforms.c_transforms as CT
  7. from mindspore.dataset.vision import Inter
  8. from mindspore import context
  9. from mindspore.common import dtype as mstype
  10. from mindspore.common.tensor import Tensor
  11. from mindspore.common.initializer import TruncatedNormal
  12. from mindspore.common.parameter import ParameterTuple
  13. from mindspore.ops import operations as P
  14. from mindspore.ops import composite as C
  15. from mindspore.train.serialization import export
  16. def weight_variable():
  17. return TruncatedNormal(0.02)
  18. def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
  19. weight = weight_variable()
  20. return nn.Conv2d(in_channels, out_channels,
  21. kernel_size=kernel_size, stride=stride, padding=padding,
  22. weight_init=weight, has_bias=False, pad_mode="valid")
  23. def fc_with_initialize(input_channels, out_channels):
  24. weight = weight_variable()
  25. bias = weight_variable()
  26. return nn.Dense(input_channels, out_channels, weight, bias)
  27. def create_dataset():
  28. # define dataset
  29. mnist_ds = ds.MnistDataset("../data/dataset/testMnistData")
  30. resize_height, resize_width = 32, 32
  31. rescale = 1.0 / 255.0
  32. shift = 0.0
  33. rescale_nml = 1 / 0.3081
  34. shift_nml = -1 * 0.1307 / 0.3081
  35. # define map operations
  36. resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
  37. rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
  38. rescale_op = CV.Rescale(rescale, shift)
  39. hwc2chw_op = CV.HWC2CHW()
  40. type_cast_op = CT.TypeCast(mstype.int32)
  41. # apply map operations on images
  42. mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label")
  43. mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image")
  44. mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image")
  45. mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image")
  46. mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image")
  47. # apply DatasetOps
  48. mnist_ds = mnist_ds.batch(batch_size=32, drop_remainder=True)
  49. return mnist_ds
  50. class LeNet5(nn.Cell):
  51. def __init__(self):
  52. super(LeNet5, self).__init__()
  53. self.batch_size = 32
  54. self.conv1 = conv(1, 6, 5)
  55. self.conv2 = conv(6, 16, 5)
  56. self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
  57. self.fc2 = fc_with_initialize(120, 84)
  58. self.fc3 = fc_with_initialize(84, 10)
  59. self.relu = nn.ReLU()
  60. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  61. self.reshape = P.Reshape()
  62. def construct(self, x):
  63. x = self.conv1(x)
  64. x = self.relu(x)
  65. x = self.max_pool2d(x)
  66. x = self.conv2(x)
  67. x = self.relu(x)
  68. x = self.max_pool2d(x)
  69. x = self.reshape(x, (self.batch_size, -1))
  70. x = self.fc1(x)
  71. x = self.relu(x)
  72. x = self.fc2(x)
  73. x = self.relu(x)
  74. x = self.fc3(x)
  75. return x
  76. class WithLossCell(nn.Cell):
  77. def __init__(self, network):
  78. super(WithLossCell, self).__init__(auto_prefix=False)
  79. self.loss = nn.SoftmaxCrossEntropyWithLogits()
  80. self.network = network
  81. def construct(self, x, label):
  82. predict = self.network(x)
  83. return self.loss(predict, label)
  84. class TrainOneStepCell(nn.Cell):
  85. def __init__(self, network):
  86. super(TrainOneStepCell, self).__init__(auto_prefix=False)
  87. self.network = network
  88. self.network.set_train()
  89. self.weights = ParameterTuple(network.trainable_params())
  90. self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
  91. self.hyper_map = C.HyperMap()
  92. self.grad = C.GradOperation(get_by_list=True)
  93. def construct(self, x, label):
  94. weights = self.weights
  95. grads = self.grad(self.network, weights)(x, label)
  96. return self.optimizer(grads)
  97. def test_export_lenet_grad_mindir():
  98. """
  99. Feature: Export LeNet to MindIR
  100. Description: Test export API to save network into MindIR
  101. Expectation: save successfully
  102. """
  103. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  104. network = LeNet5()
  105. network.set_train()
  106. predict = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
  107. label = Tensor(np.zeros([32, 10]).astype(np.float32))
  108. net = TrainOneStepCell(WithLossCell(network))
  109. file_name = "lenet_grad"
  110. export(net, predict, label, file_name=file_name, file_format='MINDIR')
  111. verify_name = file_name + ".mindir"
  112. assert os.path.exists(verify_name)
  113. os.remove(verify_name)
  114. def test_export_lenet_with_dataset():
  115. """
  116. Feature: Export LeNet with data preprocess to MindIR
  117. Description: Test export API to save network and dataset into MindIR
  118. Expectation: save successfully
  119. """
  120. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  121. network = LeNet5()
  122. network.set_train()
  123. dataset = create_dataset()
  124. file_name = "lenet_preprocess"
  125. export(network, dataset, file_name=file_name, file_format='MINDIR')
  126. verify_name = file_name + ".mindir"
  127. assert os.path.exists(verify_name)
  128. os.remove(verify_name)