You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_nccl_lenet.py 3.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. from mindspore.nn import Dense
  17. import mindspore.nn as nn
  18. import datetime
  19. import mindspore.context as context
  20. from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size
  21. from mindspore.nn.optim import Momentum
  22. from mindspore.nn import TrainOneStepCell, WithLossCell
  23. from mindspore.ops import operations as P
  24. from mindspore.common.tensor import Tensor
  25. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  26. init('nccl')
  27. epoch = 2
  28. total = 5000
  29. batch_size = 32
  30. mini_batch = total // batch_size
  31. class LeNet(nn.Cell):
  32. def __init__(self):
  33. super(LeNet, self).__init__()
  34. self.relu = P.ReLU()
  35. self.batch_size = 32
  36. weight1 = Tensor(np.ones([6, 3, 5, 5]).astype(np.float32) * 0.01)
  37. weight2 = Tensor(np.ones([16, 6, 5, 5]).astype(np.float32) * 0.01)
  38. self.conv1 = nn.Conv2d(3, 6, (5, 5), weight_init=weight1, stride=1, padding=0, pad_mode='valid')
  39. self.conv2 = nn.Conv2d(6, 16, (5, 5), weight_init=weight2, pad_mode='valid', stride=1, padding=0)
  40. self.pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="valid")
  41. self.reshape = P.Reshape()
  42. weight1 = Tensor(np.ones([120, 400]).astype(np.float32) * 0.01)
  43. self.fc1 = Dense(400, 120, weight_init=weight1)
  44. weight2 = Tensor(np.ones([84, 120]).astype(np.float32) * 0.01)
  45. self.fc2 = Dense(120, 84, weight_init=weight2)
  46. weight3 = Tensor(np.ones([10, 84]).astype(np.float32) * 0.01)
  47. self.fc3 = Dense(84, 10, weight_init=weight3)
  48. def construct(self, input_x):
  49. output = self.conv1(input_x)
  50. output = self.relu(output)
  51. output = self.pool(output)
  52. output = self.conv2(output)
  53. output = self.relu(output)
  54. output = self.pool(output)
  55. output = self.reshape(output, (self.batch_size, -1))
  56. output = self.fc1(output)
  57. output = self.fc2(output)
  58. output = self.fc3(output)
  59. return output
  60. def test_lenet_nccl():
  61. net = LeNet()
  62. net.set_train()
  63. learning_rate = 0.01
  64. momentum = 0.9
  65. mom_optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
  66. criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  67. net_with_criterion = WithLossCell(net, criterion)
  68. context.set_auto_parallel_context(parallel_mode="data_parallel", mirror_mean=True, device_num=get_group_size())
  69. train_network = TrainOneStepCell(net_with_criterion, mom_optimizer)
  70. train_network.set_train()
  71. losses = []
  72. data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01)
  73. label = Tensor(np.ones([net.batch_size]).astype(np.int32))
  74. start = datetime.datetime.now()
  75. for i in range(epoch):
  76. for step in range(mini_batch):
  77. loss = train_network(data, label)
  78. losses.append(loss.asnumpy())
  79. end = datetime.datetime.now()
  80. with open("ms_time.txt", "w") as fo1:
  81. fo1.write("time:")
  82. fo1.write(str(end - start))
  83. with open("ms_loss.txt", "w") as fo2:
  84. fo2.write("loss:")
  85. fo2.write(str(losses[-5:]))
  86. assert(losses[-1] < 0.01)