You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gpu_lenet.py 9.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import os
  16. import numpy as np
  17. import pytest
  18. import mindspore.context as context
  19. import mindspore.dataset as ds
  20. import mindspore.dataset.transforms.c_transforms as C
  21. import mindspore.dataset.vision.c_transforms as CV
  22. import mindspore.nn as nn
  23. from mindspore import Tensor, ParameterTuple
  24. from mindspore.common import dtype as mstype
  25. from mindspore.dataset.vision import Inter
  26. from mindspore.nn import Dense, TrainOneStepCell, WithLossCell, ForwardValueAndGrad
  27. from mindspore.nn.metrics import Accuracy
  28. from mindspore.nn.optim import Momentum
  29. from mindspore.ops import operations as P
  30. from mindspore.ops import functional as F
  31. from mindspore.train import Model
  32. from mindspore.train.callback import LossMonitor
  33. from mindspore.common.initializer import TruncatedNormal
  34. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  35. def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
  36. """weight initial for conv layer"""
  37. weight = weight_variable()
  38. return nn.Conv2d(in_channels, out_channels,
  39. kernel_size=kernel_size, stride=stride, padding=padding,
  40. weight_init=weight, has_bias=False, pad_mode="valid")
  41. def fc_with_initialize(input_channels, out_channels):
  42. """weight initial for fc layer"""
  43. weight = weight_variable()
  44. bias = weight_variable()
  45. return nn.Dense(input_channels, out_channels, weight, bias)
  46. def weight_variable():
  47. """weight initial"""
  48. return TruncatedNormal(0.02)
  49. class LeNet5(nn.Cell):
  50. def __init__(self, num_class=10, channel=1):
  51. super(LeNet5, self).__init__()
  52. self.num_class = num_class
  53. self.conv1 = conv(channel, 6, 5)
  54. self.conv2 = conv(6, 16, 5)
  55. self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
  56. self.fc2 = fc_with_initialize(120, 84)
  57. self.fc3 = fc_with_initialize(84, self.num_class)
  58. self.relu = nn.ReLU()
  59. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  60. self.flatten = nn.Flatten()
  61. def construct(self, x):
  62. x = self.conv1(x)
  63. x = self.relu(x)
  64. x = self.max_pool2d(x)
  65. x = self.conv2(x)
  66. x = self.relu(x)
  67. x = self.max_pool2d(x)
  68. x = self.flatten(x)
  69. x = self.fc1(x)
  70. x = self.relu(x)
  71. x = self.fc2(x)
  72. x = self.relu(x)
  73. x = self.fc3(x)
  74. return x
  75. class LeNet(nn.Cell):
  76. def __init__(self):
  77. super(LeNet, self).__init__()
  78. self.relu = P.ReLU()
  79. self.batch_size = 1
  80. weight1 = Tensor(np.ones([6, 3, 5, 5]).astype(np.float32) * 0.01)
  81. weight2 = Tensor(np.ones([16, 6, 5, 5]).astype(np.float32) * 0.01)
  82. self.conv1 = nn.Conv2d(3, 6, (5, 5), weight_init=weight1, stride=1, padding=0, pad_mode='valid')
  83. self.conv2 = nn.Conv2d(6, 16, (5, 5), weight_init=weight2, pad_mode='valid', stride=1, padding=0)
  84. self.pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="valid")
  85. self.reshape = P.Reshape()
  86. self.reshape1 = P.Reshape()
  87. self.fc1 = Dense(400, 120)
  88. self.fc2 = Dense(120, 84)
  89. self.fc3 = Dense(84, 10)
  90. def construct(self, input_x):
  91. output = self.conv1(input_x)
  92. output = self.relu(output)
  93. output = self.pool(output)
  94. output = self.conv2(output)
  95. output = self.relu(output)
  96. output = self.pool(output)
  97. output = self.reshape(output, (self.batch_size, -1))
  98. output = self.fc1(output)
  99. output = self.fc2(output)
  100. output = self.fc3(output)
  101. return output
  102. def multisteplr(total_steps, gap, base_lr=0.9, gamma=0.1, dtype=mstype.float32):
  103. lr = []
  104. for step in range(total_steps):
  105. lr_ = base_lr * gamma ** (step // gap)
  106. lr.append(lr_)
  107. return Tensor(np.array(lr), dtype)
  108. @pytest.mark.level0
  109. @pytest.mark.platform_x86_gpu_training
  110. @pytest.mark.env_onecard
  111. def test_train_lenet():
  112. epoch = 100
  113. net = LeNet()
  114. momentum = 0.9
  115. learning_rate = multisteplr(epoch, 30)
  116. optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
  117. criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  118. net_with_criterion = WithLossCell(net, criterion)
  119. train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
  120. train_network.set_train()
  121. losses = []
  122. for i in range(epoch):
  123. data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01)
  124. label = Tensor(np.ones([net.batch_size]).astype(np.int32))
  125. loss = train_network(data, label).asnumpy()
  126. losses.append(loss)
  127. assert losses[-1] < 0.01
  128. def create_dataset(data_path, batch_size=32, repeat_size=1,
  129. num_parallel_workers=1):
  130. """
  131. create dataset for train or test
  132. """
  133. # define dataset
  134. mnist_ds = ds.MnistDataset(data_path)
  135. resize_height, resize_width = 32, 32
  136. rescale = 1.0 / 255.0
  137. shift = 0.0
  138. rescale_nml = 1 / 0.3081
  139. shift_nml = -1 * 0.1307 / 0.3081
  140. # define map operations
  141. resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
  142. rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
  143. rescale_op = CV.Rescale(rescale, shift)
  144. hwc2chw_op = CV.HWC2CHW()
  145. type_cast_op = C.TypeCast(mstype.int32)
  146. # apply map operations on images
  147. mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
  148. mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  149. mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  150. mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  151. mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  152. # apply DatasetOps
  153. buffer_size = 10000
  154. mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
  155. mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
  156. mnist_ds = mnist_ds.repeat(repeat_size)
  157. return mnist_ds
  158. @pytest.mark.level0
  159. @pytest.mark.platform_x86_gpu_training
  160. @pytest.mark.env_onecard
  161. def test_train_and_eval_lenet():
  162. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  163. network = LeNet5(10)
  164. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  165. net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
  166. model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
  167. print("============== Starting Training ==============")
  168. ds_train = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "train"), 32, 1)
  169. model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=True)
  170. print("============== Starting Testing ==============")
  171. ds_eval = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "test"), 32, 1)
  172. acc = model.eval(ds_eval, dataset_sink_mode=True)
  173. print("============== {} ==============".format(acc))
  174. @pytest.mark.level0
  175. @pytest.mark.platform_x86_gpu_training
  176. @pytest.mark.env_onecard
  177. def test_train_lenet_with_new_interface(num_classes=10, epoch=20, batch_size=32):
  178. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  179. network = LeNet5(num_classes)
  180. criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  181. net_with_criterion = WithLossCell(network, criterion)
  182. net_with_criterion.set_train()
  183. weights = ParameterTuple(network.trainable_params())
  184. optimizer = nn.Momentum(weights, 0.1, 0.9)
  185. train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
  186. losses = []
  187. for i in range(0, epoch):
  188. data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
  189. label = Tensor(np.ones([batch_size]).astype(np.int32))
  190. sens = Tensor(np.ones([1]).astype(np.float32))
  191. loss, grads = train_network(data, label, sens)
  192. grads = F.identity(grads)
  193. optimizer(grads)
  194. losses.append(loss)
  195. assert losses[-1].asnumpy() < 0.008
  196. assert losses[-1].asnumpy() > 0.001