You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parallel_optimizer.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test adam """
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, Parameter
  20. from mindspore.common.api import _executor
  21. from mindspore.nn import TrainOneStepCell, WithLossCell
  22. from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
  23. from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb, Momentum
  24. from mindspore.ops import operations as P
  25. from mindspore import context
  26. class Net(nn.Cell):
  27. """Net definition"""
  28. def __init__(self):
  29. super(Net, self).__init__()
  30. self.fc1 = nn.Dense(128, 768, activation='relu')
  31. self.fc2 = nn.Dense(128, 768, activation='relu')
  32. self.fc3 = nn.Dense(128, 768, activation='relu')
  33. self.fc4 = nn.Dense(768, 768, activation='relu')
  34. self.relu4 = nn.ReLU()
  35. self.relu5 = nn.ReLU()
  36. self.transpose = P.Transpose()
  37. self.matmul1 = P.MatMul()
  38. self.matmul2 = P.MatMul()
  39. def construct(self, x):
  40. q = self.fc1(x)
  41. k = self.fc2(x)
  42. v = self.fc3(x)
  43. k = self.transpose(k, (1, 0))
  44. c = self.relu4(self.matmul1(q, k))
  45. s = self.relu5(self.matmul2(c, v))
  46. s = self.fc4(s)
  47. return s
  48. class Net2(nn.Cell):
  49. """Net definition"""
  50. def __init__(self, strategy1, strategy2):
  51. super(Net2, self).__init__()
  52. self.fc1 = P.MatMul().shard(strategy=strategy1)
  53. self.fc2 = P.MatMul().shard(strategy=strategy2)
  54. self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
  55. self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2")
  56. def construct(self, x, y):
  57. x = self.fc1(x, self.p1)
  58. x = self.fc2(x, self.p2)
  59. return x - y
  60. def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None):
  61. context.set_context(mode=context.GRAPH_MODE)
  62. context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True)
  63. inputs = Tensor(np.ones([32, 48]).astype(np.float32))
  64. label = Tensor(np.zeros([32, 16]).astype(np.float32))
  65. net = Net2(strategy1, strategy2)
  66. net = _VirtualDatasetCell(net)
  67. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  68. train_network = TrainOneStepCell(net, optimizer)
  69. train_network.set_auto_parallel()
  70. train_network.set_train()
  71. _executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True)
  72. context.reset_auto_parallel_context()
  73. return train_network
  74. def test_auto_parallel_momentum_1():
  75. auto_parallel_compile_net("auto_parallel", 8)
  76. def test_auto_parallel_momentum_2():
  77. # data parallel case
  78. auto_parallel_compile_net("auto_parallel", 8, ((8, 1), (1, 1)), ((8, 1), (1, 1)))
  79. def test_auto_parallel_momentum_3():
  80. # hybrid parallel case
  81. # weight1 could not be shard and weight2 is repeated
  82. train_network = auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
  83. param_dict = train_network.parameter_layout_dict
  84. # validate opt_shard_group
  85. assert not param_dict["weight1"][5]
  86. assert param_dict["weight2"][5].startswith("4")
  87. def test_auto_parallel_momentum_4():
  88. # hybrid parallel cases
  89. # devices are repeatedly used
  90. auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 4), (4, 1)), ((4, 4), (4, 2)))
  91. def test_AdamWeightDecay():
  92. """ test_AdamWeightDecay """
  93. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
  94. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  95. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  96. net = Net()
  97. net.set_train()
  98. loss = nn.SoftmaxCrossEntropyWithLogits()
  99. optimizer = AdamWeightDecay(net.trainable_params(), learning_rate=0.1)
  100. net_with_loss = WithLossCell(net, loss)
  101. train_network = TrainOneStepCell(net_with_loss, optimizer)
  102. _executor.compile(train_network, inputs, label)
  103. context.reset_auto_parallel_context()
  104. def test_lamb_compile():
  105. """ test_Lamb_compile """
  106. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
  107. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  108. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  109. net = Net()
  110. net.set_train()
  111. loss = nn.SoftmaxCrossEntropyWithLogits()
  112. optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
  113. net_with_loss = WithLossCell(net, loss)
  114. train_network = TrainOneStepCell(net_with_loss, optimizer)
  115. _executor.compile(train_network, inputs, label)
  116. context.reset_auto_parallel_context()
  117. def test_lamb_split_fusion():
  118. """ test_Lamb_split_fusion """
  119. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True,
  120. all_reduce_fusion_config=[2, 4, 6, 8])
  121. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  122. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  123. net = Net()
  124. net.set_train()
  125. loss = nn.SoftmaxCrossEntropyWithLogits()
  126. optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
  127. net_with_loss = WithLossCell(net, loss)
  128. train_network = TrainOneStepCell(net_with_loss, optimizer)
  129. _executor.compile(train_network, inputs, label)
  130. context.reset_auto_parallel_context()
  131. def test_edge_case():
  132. """ test_edge_case """
  133. context.set_auto_parallel_context(enable_parallel_optimizer=True)
  134. net = Net()
  135. with pytest.raises(RuntimeError):
  136. context.set_auto_parallel_context(parallel_mode="stand_alone")
  137. Lamb(net.trainable_params(), learning_rate=0.1)
  138. with pytest.raises(RuntimeError):
  139. Adam(net.trainable_params(), learning_rate=0.1)
  140. with pytest.raises(RuntimeError):
  141. context.set_auto_parallel_context(device_num=16)
  142. Lamb(net.trainable_params(), learning_rate=0.1)
  143. context.reset_auto_parallel_context()