You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parallel_optimizer.py 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test adam """
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, Parameter
  20. from mindspore.common.api import _executor
  21. from mindspore.nn import TrainOneStepCell, WithLossCell
  22. from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
  23. from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb, Momentum
  24. from mindspore.ops import operations as P
  25. from mindspore import context
  26. class Net(nn.Cell):
  27. """Net definition"""
  28. def __init__(self):
  29. super(Net, self).__init__()
  30. self.fc1 = nn.Dense(128, 768, activation='relu')
  31. self.fc2 = nn.Dense(128, 768, activation='relu')
  32. self.fc3 = nn.Dense(128, 768, activation='relu')
  33. self.fc4 = nn.Dense(768, 768, activation='relu')
  34. self.relu4 = nn.ReLU()
  35. self.relu5 = nn.ReLU()
  36. self.transpose = P.Transpose()
  37. self.matmul1 = P.MatMul()
  38. self.matmul2 = P.MatMul()
  39. def construct(self, x):
  40. q = self.fc1(x)
  41. k = self.fc2(x)
  42. v = self.fc3(x)
  43. k = self.transpose(k, (1, 0))
  44. c = self.relu4(self.matmul1(q, k))
  45. s = self.relu5(self.matmul2(c, v))
  46. s = self.fc4(s)
  47. return s
  48. class Net2(nn.Cell):
  49. """Net definition"""
  50. def __init__(self, strategy1, strategy2):
  51. super(Net2, self).__init__()
  52. self.fc1 = P.MatMul().shard(strategy=strategy1)
  53. self.fc2 = P.MatMul().shard(strategy=strategy2)
  54. self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
  55. self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2")
  56. def construct(self, x, y):
  57. x = self.fc1(x, self.p1)
  58. x = self.fc2(x, self.p2)
  59. return x - y
  60. class Net3(nn.Cell):
  61. """Net definition"""
  62. def __init__(self, strategy1, strategy2):
  63. super(Net3, self).__init__()
  64. self.fc1 = P.MatMul().shard(strategy=strategy1)
  65. self.fc2 = P.MatMul().shard(strategy=strategy2)
  66. self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
  67. self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2", parallel_optimizer=False)
  68. def construct(self, x, y):
  69. x = self.fc1(x, self.p1)
  70. x = self.fc2(x, self.p2)
  71. return x - y
  72. def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None):
  73. context.set_context(mode=context.GRAPH_MODE)
  74. context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True)
  75. inputs = Tensor(np.ones([32, 48]).astype(np.float32))
  76. label = Tensor(np.zeros([32, 16]).astype(np.float32))
  77. net = net(strategy1, strategy2)
  78. net = _VirtualDatasetCell(net)
  79. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  80. train_network = TrainOneStepCell(net, optimizer).set_comm_fusion(4)
  81. train_network.set_auto_parallel()
  82. train_network.set_train()
  83. _executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True)
  84. context.reset_auto_parallel_context()
  85. return train_network
  86. def test_auto_parallel_momentum_1():
  87. auto_parallel_compile_net("auto_parallel", 8, Net2)
  88. def test_auto_parallel_momentum_2():
  89. # data parallel case
  90. auto_parallel_compile_net("auto_parallel", 8, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)))
  91. def test_auto_parallel_momentum_3():
  92. # hybrid parallel case
  93. # weight1 could not be shard and weight2 is repeated
  94. train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
  95. param_dict = train_network.parameter_layout_dict
  96. # validate opt_shard_group
  97. assert not param_dict["weight1"][5]
  98. assert param_dict["weight2"][5].startswith("4")
  99. def test_auto_parallel_momentum_4():
  100. # hybrid parallel cases
  101. # devices are repeatedly used
  102. auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 4), (4, 1)), ((4, 4), (4, 2)))
  103. def test_auto_parallel_momentum_5():
  104. # test parallel optimizer filter
  105. train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net3, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
  106. param_dict = train_network.parameter_layout_dict
  107. # validate opt_shard_group
  108. assert not param_dict["weight1"][5]
  109. assert not param_dict["weight2"][5]
  110. def test_AdamWeightDecay():
  111. """ test_AdamWeightDecay """
  112. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
  113. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  114. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  115. net = Net()
  116. net.set_train()
  117. loss = nn.SoftmaxCrossEntropyWithLogits()
  118. optimizer = AdamWeightDecay(net.trainable_params(), learning_rate=0.1)
  119. net_with_loss = WithLossCell(net, loss)
  120. train_network = TrainOneStepCell(net_with_loss, optimizer)
  121. _executor.compile(train_network, inputs, label)
  122. context.reset_auto_parallel_context()
  123. def test_lamb_compile():
  124. """ test_Lamb_compile """
  125. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
  126. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  127. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  128. net = Net()
  129. net.set_train()
  130. loss = nn.SoftmaxCrossEntropyWithLogits()
  131. optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
  132. net_with_loss = WithLossCell(net, loss)
  133. train_network = TrainOneStepCell(net_with_loss, optimizer)
  134. _executor.compile(train_network, inputs, label)
  135. context.reset_auto_parallel_context()
  136. def test_lamb_split_fusion():
  137. """ test_Lamb_split_fusion """
  138. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True,
  139. all_reduce_fusion_config=[2, 4, 6, 8])
  140. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  141. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  142. net = Net()
  143. net.set_train()
  144. loss = nn.SoftmaxCrossEntropyWithLogits()
  145. optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
  146. net_with_loss = WithLossCell(net, loss)
  147. train_network = TrainOneStepCell(net_with_loss, optimizer)
  148. _executor.compile(train_network, inputs, label)
  149. context.reset_auto_parallel_context()
  150. def test_edge_case():
  151. """ test_edge_case """
  152. context.set_auto_parallel_context(enable_parallel_optimizer=True)
  153. net = Net()
  154. with pytest.raises(RuntimeError):
  155. context.set_auto_parallel_context(parallel_mode="stand_alone")
  156. Lamb(net.trainable_params(), learning_rate=0.1)
  157. with pytest.raises(RuntimeError):
  158. context.set_context(device_target="GPU")
  159. context.set_auto_parallel_context(parallel_mode="data_parallel")
  160. Lamb(net.trainable_params(), learning_rate=0.1)
  161. with pytest.raises(RuntimeError):
  162. context.set_context(device_target="Ascend")
  163. context.set_auto_parallel_context(parallel_mode="data_parallel")
  164. Adam(net.trainable_params(), learning_rate=0.1)
  165. with pytest.raises(RuntimeError):
  166. context.set_auto_parallel_context(device_num=16)
  167. Lamb(net.trainable_params(), learning_rate=0.1)
  168. context.reset_auto_parallel_context()