You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pynative_lenet.py 5.3 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import time
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import context, Tensor, ParameterTuple
  20. from mindspore.common import dtype as mstype
  21. from mindspore.common.initializer import TruncatedNormal
  22. from mindspore.nn.optim import Momentum
  23. from mindspore.nn.wrap.cell_wrapper import WithLossCell
  24. from mindspore.ops import composite as C
  25. from mindspore.ops import functional as F
  26. from mindspore.ops import operations as P
  27. np.random.seed(1)
  28. grad_by_list = C.GradOperation(get_by_list=True)
  29. def weight_variable():
  30. """weight initial"""
  31. return TruncatedNormal(0.02)
  32. def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
  33. """weight initial for conv layer"""
  34. weight = weight_variable()
  35. return nn.Conv2d(in_channels, out_channels,
  36. kernel_size=kernel_size, stride=stride, padding=padding,
  37. weight_init=weight, has_bias=False, pad_mode="valid")
  38. def fc_with_initialize(input_channels, out_channels):
  39. """weight initial for fc layer"""
  40. weight = weight_variable()
  41. bias = weight_variable()
  42. return nn.Dense(input_channels, out_channels, weight, bias)
  43. class LeNet(nn.Cell):
  44. """
  45. Lenet network
  46. Args:
  47. num_class (int): Num classes, Default: 10.
  48. Returns:
  49. Tensor, output tensor
  50. Examples:
  51. >>> LeNet(num_class=10)
  52. """
  53. def __init__(self, num_class=10):
  54. super(LeNet, self).__init__()
  55. self.num_class = num_class
  56. self.batch_size = 32
  57. self.conv1 = conv(1, 6, 5)
  58. self.conv2 = conv(6, 16, 5)
  59. self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
  60. self.fc2 = fc_with_initialize(120, 84)
  61. self.fc3 = fc_with_initialize(84, self.num_class)
  62. self.relu = nn.ReLU()
  63. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  64. self.reshape = P.Reshape()
  65. def construct(self, x):
  66. x = self.conv1(x)
  67. x = self.relu(x)
  68. x = self.max_pool2d(x)
  69. x = self.conv2(x)
  70. x = self.relu(x)
  71. x = self.max_pool2d(x)
  72. x = self.reshape(x, (self.batch_size, -1))
  73. x = self.fc1(x)
  74. x = self.relu(x)
  75. x = self.fc2(x)
  76. x = self.relu(x)
  77. x = self.fc3(x)
  78. return x
  79. class CrossEntropyLoss(nn.Cell):
  80. """
  81. Define loss for network
  82. """
  83. def __init__(self):
  84. super(CrossEntropyLoss, self).__init__()
  85. self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
  86. self.mean = P.ReduceMean()
  87. self.one_hot = P.OneHot()
  88. self.on_value = Tensor(1.0, mstype.float32)
  89. self.off_value = Tensor(0.0, mstype.float32)
  90. self.num = Tensor(32.0, mstype.float32)
  91. def construct(self, logits, label):
  92. label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value)
  93. loss = self.cross_entropy(logits, label)[0]
  94. loss = P.RealDiv()(P.ReduceSum()(loss, -1), self.num)
  95. return loss
  96. class GradWrap(nn.Cell):
  97. """
  98. GradWrap definition
  99. """
  100. def __init__(self, network):
  101. super(GradWrap, self).__init__()
  102. self.network = network
  103. self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
  104. def construct(self, x, label):
  105. weights = self.weights
  106. return grad_by_list(self.network, weights)(x, label)
  107. @pytest.mark.level0
  108. @pytest.mark.platform_arm_ascend_training
  109. @pytest.mark.platform_x86_ascend_training
  110. @pytest.mark.env_onecard
  111. def test_ascend_pynative_lenet():
  112. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  113. epoch_size = 20
  114. batch_size = 32
  115. inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32))
  116. labels = Tensor(np.ones([batch_size]).astype(np.int32))
  117. net = LeNet()
  118. criterion = CrossEntropyLoss()
  119. optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
  120. net_with_criterion = WithLossCell(net, criterion)
  121. train_network = GradWrap(net_with_criterion)
  122. train_network.set_train()
  123. total_time = 0
  124. for epoch in range(0, epoch_size):
  125. start_time = time.time()
  126. fw_output = net(inputs)
  127. loss_output = criterion(fw_output, labels)
  128. grads = train_network(inputs, labels)
  129. optimizer(grads)
  130. end_time = time.time()
  131. cost_time = end_time - start_time
  132. total_time = total_time + cost_time
  133. print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
  134. assert loss_output.asnumpy() < 0.004
  135. assert loss_output.asnumpy() > 0.003