You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ascend_lenet.py 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import pytest
  16. import numpy as np
  17. import time, math
  18. import mindspore.nn as nn
  19. from mindspore import context, Tensor, ParameterTuple
  20. from mindspore.ops import operations as P
  21. from mindspore.common.initializer import TruncatedNormal
  22. from mindspore.ops import functional as F
  23. from mindspore.ops import composite as C
  24. from mindspore.common import dtype as mstype
  25. from mindspore.nn.wrap.cell_wrapper import WithLossCell
  26. from mindspore.nn.optim import Momentum
  27. np.random.seed(1)
  28. def weight_variable():
  29. """weight initial"""
  30. return TruncatedNormal(0.02)
  31. def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
  32. """weight initial for conv layer"""
  33. weight = weight_variable()
  34. return nn.Conv2d(in_channels, out_channels,
  35. kernel_size=kernel_size, stride=stride, padding=padding,
  36. weight_init=weight, has_bias=False, pad_mode="valid")
  37. def fc_with_initialize(input_channels, out_channels):
  38. """weight initial for fc layer"""
  39. weight = weight_variable()
  40. bias = weight_variable()
  41. return nn.Dense(input_channels, out_channels, weight, bias)
  42. class LeNet(nn.Cell):
  43. """
  44. Lenet network
  45. Args:
  46. num_class (int): Num classes, Default: 10.
  47. Returns:
  48. Tensor, output tensor
  49. Examples:
  50. >>> LeNet(num_class=10)
  51. """
  52. def __init__(self, num_class=10):
  53. super(LeNet, self).__init__()
  54. self.num_class = num_class
  55. self.batch_size = 32
  56. self.conv1 = conv(1, 6, 5)
  57. self.conv2 = conv(6, 16, 5)
  58. self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
  59. self.fc2 = fc_with_initialize(120, 84)
  60. self.fc3 = fc_with_initialize(84, self.num_class)
  61. self.relu = nn.ReLU()
  62. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  63. self.reshape = P.Reshape()
  64. def construct(self, x):
  65. x = self.conv1(x)
  66. x = self.relu(x)
  67. x = self.max_pool2d(x)
  68. x = self.conv2(x)
  69. x = self.relu(x)
  70. x = self.max_pool2d(x)
  71. x = self.reshape(x, (self.batch_size, -1))
  72. x = self.fc1(x)
  73. x = self.relu(x)
  74. x = self.fc2(x)
  75. x = self.relu(x)
  76. x = self.fc3(x)
  77. return x
  78. class CrossEntropyLoss(nn.Cell):
  79. """
  80. Define loss for network
  81. """
  82. def __init__(self):
  83. super(CrossEntropyLoss, self).__init__()
  84. self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
  85. self.mean = P.ReduceMean()
  86. self.one_hot = P.OneHot()
  87. self.on_value = Tensor(1.0, mstype.float32)
  88. self.off_value = Tensor(0.0, mstype.float32)
  89. self.num = Tensor(32.0, mstype.float32)
  90. def construct(self, logits, label):
  91. label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value)
  92. loss = self.cross_entropy(logits, label)[0]
  93. loss = P.RealDiv()(P.ReduceSum()(loss, -1), self.num)
  94. return loss
  95. class GradWrap(nn.Cell):
  96. """
  97. GradWrap definition
  98. """
  99. def __init__(self, network):
  100. super(GradWrap, self).__init__()
  101. self.network = network
  102. self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
  103. def construct(self, x, label):
  104. weights = self.weights
  105. return C.grad_by_list(self.network, weights)(x, label)
  106. @pytest.mark.level0
  107. @pytest.mark.platform_x86_ascend_training
  108. @pytest.mark.env_single
  109. def test_ascend_pynative_lenet():
  110. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  111. epoch_size = 20
  112. batch_size = 32
  113. inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32))
  114. labels = Tensor(np.ones([batch_size]).astype(np.int32))
  115. net = LeNet()
  116. criterion = CrossEntropyLoss()
  117. optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
  118. net_with_criterion = WithLossCell(net, criterion)
  119. train_network = GradWrap(net_with_criterion)
  120. train_network.set_train()
  121. total_time = 0
  122. for epoch in range(0, epoch_size):
  123. start_time = time.time()
  124. fw_output = net(inputs)
  125. loss_output = criterion(fw_output, labels)
  126. grads = train_network(inputs, labels)
  127. success = optimizer(grads)
  128. end_time = time.time()
  129. cost_time = end_time - start_time
  130. total_time = total_time + cost_time
  131. assert(total_time < 20.0)
  132. assert(loss_output.asnumpy() < 0.01)

MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.