You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_export.py 3.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import os
  2. import numpy as np
  3. import mindspore.nn as nn
  4. from mindspore import context
  5. from mindspore.common.tensor import Tensor
  6. from mindspore.common.initializer import TruncatedNormal
  7. from mindspore.common.parameter import ParameterTuple
  8. from mindspore.ops import operations as P
  9. from mindspore.ops import composite as C
  10. from mindspore.train.serialization import export
  11. def weight_variable():
  12. return TruncatedNormal(0.02)
  13. def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
  14. weight = weight_variable()
  15. return nn.Conv2d(in_channels, out_channels,
  16. kernel_size=kernel_size, stride=stride, padding=padding,
  17. weight_init=weight, has_bias=False, pad_mode="valid")
  18. def fc_with_initialize(input_channels, out_channels):
  19. weight = weight_variable()
  20. bias = weight_variable()
  21. return nn.Dense(input_channels, out_channels, weight, bias)
  22. class LeNet5(nn.Cell):
  23. def __init__(self):
  24. super(LeNet5, self).__init__()
  25. self.batch_size = 32
  26. self.conv1 = conv(1, 6, 5)
  27. self.conv2 = conv(6, 16, 5)
  28. self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
  29. self.fc2 = fc_with_initialize(120, 84)
  30. self.fc3 = fc_with_initialize(84, 10)
  31. self.relu = nn.ReLU()
  32. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  33. self.reshape = P.Reshape()
  34. def construct(self, x):
  35. x = self.conv1(x)
  36. x = self.relu(x)
  37. x = self.max_pool2d(x)
  38. x = self.conv2(x)
  39. x = self.relu(x)
  40. x = self.max_pool2d(x)
  41. x = self.reshape(x, (self.batch_size, -1))
  42. x = self.fc1(x)
  43. x = self.relu(x)
  44. x = self.fc2(x)
  45. x = self.relu(x)
  46. x = self.fc3(x)
  47. return x
  48. class WithLossCell(nn.Cell):
  49. def __init__(self, network):
  50. super(WithLossCell, self).__init__(auto_prefix=False)
  51. self.loss = nn.SoftmaxCrossEntropyWithLogits()
  52. self.network = network
  53. def construct(self, x, label):
  54. predict = self.network(x)
  55. return self.loss(predict, label)
  56. class TrainOneStepCell(nn.Cell):
  57. def __init__(self, network):
  58. super(TrainOneStepCell, self).__init__(auto_prefix=False)
  59. self.network = network
  60. self.network.set_train()
  61. self.weights = ParameterTuple(network.trainable_params())
  62. self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
  63. self.hyper_map = C.HyperMap()
  64. self.grad = C.GradOperation(get_by_list=True)
  65. def construct(self, x, label):
  66. weights = self.weights
  67. grads = self.grad(self.network, weights)(x, label)
  68. return self.optimizer(grads)
  69. def test_export_lenet_grad_mindir():
  70. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  71. network = LeNet5()
  72. network.set_train()
  73. predict = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
  74. label = Tensor(np.zeros([32, 10]).astype(np.float32))
  75. net = TrainOneStepCell(WithLossCell(network))
  76. file_name = "lenet_grad"
  77. export(net, predict, label, file_name=file_name, file_format='MINDIR')
  78. verify_name = file_name + ".mindir"
  79. assert os.path.exists(verify_name)
  80. os.remove(verify_name)