You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gpu_lenet.py 3.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import pytest
  16. import numpy as np
  17. import mindspore.nn as nn
  18. import mindspore.context as context
  19. from mindspore import Tensor
  20. from mindspore.nn.optim import Momentum
  21. from mindspore.ops import operations as P
  22. from mindspore.nn import TrainOneStepCell, WithLossCell
  23. from mindspore.nn import Dense
  24. from mindspore.common.initializer import initializer
  25. from mindspore.common import dtype as mstype
  26. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  27. class LeNet(nn.Cell):
  28. def __init__(self):
  29. super(LeNet, self).__init__()
  30. self.relu = P.ReLU()
  31. self.batch_size = 1
  32. weight1 = Tensor(np.ones([6, 3, 5, 5]).astype(np.float32) * 0.01)
  33. weight2 = Tensor(np.ones([16, 6, 5, 5]).astype(np.float32) * 0.01)
  34. self.conv1 = nn.Conv2d(3, 6, (5, 5), weight_init=weight1, stride=1, padding=0, pad_mode='valid')
  35. self.conv2 = nn.Conv2d(6, 16, (5, 5), weight_init=weight2, pad_mode='valid', stride=1, padding=0)
  36. self.pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="valid")
  37. self.reshape = P.Reshape()
  38. self.reshape1 = P.Reshape()
  39. self.fc1 = Dense(400, 120)
  40. self.fc2 = Dense(120, 84)
  41. self.fc3 = Dense(84, 10)
  42. def construct(self, input_x):
  43. output = self.conv1(input_x)
  44. output = self.relu(output)
  45. output = self.pool(output)
  46. output = self.conv2(output)
  47. output = self.relu(output)
  48. output = self.pool(output)
  49. output = self.reshape(output, (self.batch_size, -1))
  50. output = self.fc1(output)
  51. output = self.fc2(output)
  52. output = self.fc3(output)
  53. return output
  54. def multisteplr(total_steps, gap, base_lr=0.9, gamma=0.1, dtype=mstype.float32):
  55. lr = []
  56. for step in range(total_steps):
  57. lr_ = base_lr * gamma ** (step//gap)
  58. lr.append(lr_)
  59. return Tensor(np.array(lr), dtype)
  60. @pytest.mark.level0
  61. @pytest.mark.platform_x86_gpu_training
  62. @pytest.mark.env_onecard
  63. def test_train_lenet():
  64. epoch = 100
  65. net = LeNet()
  66. momentum = initializer(Tensor(np.array([0.9]).astype(np.float32)), [1])
  67. learning_rate = multisteplr(epoch, 30)
  68. optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
  69. criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  70. net_with_criterion = WithLossCell(net, criterion)
  71. train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
  72. train_network.set_train()
  73. losses = []
  74. for i in range(epoch):
  75. data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01)
  76. label = Tensor(np.ones([net.batch_size]).astype(np.int32))
  77. loss = train_network(data, label)
  78. losses.append(loss)
  79. print(losses)