You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parallel_optimizer.py 11 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # Copyright 2020-2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test adam """
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, Parameter
  20. from mindspore.common.api import _cell_graph_executor
  21. from mindspore.nn import TrainOneStepCell, WithLossCell
  22. from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
  23. from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb, Momentum
  24. from mindspore.ops import operations as P
  25. from mindspore import context
  26. class Net(nn.Cell):
  27. """Net definition"""
  28. def __init__(self):
  29. super(Net, self).__init__()
  30. self.fc1 = nn.Dense(128, 768, activation='relu')
  31. self.fc2 = nn.Dense(128, 768, activation='relu')
  32. self.fc3 = nn.Dense(128, 768, activation='relu')
  33. self.fc4 = nn.Dense(768, 768, activation='relu')
  34. self.relu4 = nn.ReLU()
  35. self.relu5 = nn.ReLU()
  36. self.transpose = P.Transpose()
  37. self.matmul1 = P.MatMul()
  38. self.matmul2 = P.MatMul()
  39. def construct(self, x):
  40. q = self.fc1(x)
  41. k = self.fc2(x)
  42. v = self.fc3(x)
  43. k = self.transpose(k, (1, 0))
  44. c = self.relu4(self.matmul1(q, k))
  45. s = self.relu5(self.matmul2(c, v))
  46. s = self.fc4(s)
  47. return s
  48. class Net2(nn.Cell):
  49. """Net definition"""
  50. def __init__(self, strategy1, strategy2):
  51. super(Net2, self).__init__()
  52. self.fc1 = P.MatMul().shard(strategy1)
  53. self.fc2 = P.MatMul().shard(strategy2)
  54. self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
  55. self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2")
  56. def construct(self, x, y):
  57. x = self.fc1(x, self.p1)
  58. x = self.fc2(x, self.p2)
  59. return x - y
  60. class Net3(nn.Cell):
  61. """Net definition"""
  62. def __init__(self, strategy1, strategy2):
  63. super(Net3, self).__init__()
  64. self.fc1 = P.MatMul().shard(strategy1)
  65. self.fc2 = P.MatMul().shard(strategy2)
  66. self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
  67. self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2", parallel_optimizer=False)
  68. def construct(self, x, y):
  69. x = self.fc1(x, self.p1)
  70. x = self.fc2(x, self.p2)
  71. return x - y
  72. class Net4(nn.Cell):
  73. """Net definition"""
  74. def __init__(self, strategy1, strategy2):
  75. super(Net4, self).__init__()
  76. self.fc1 = P.MatMul().shard(strategy1)
  77. self.fc2 = P.MatMul().shard(strategy2)
  78. self.p1 = Parameter(Tensor(np.ones([48, 1152]).astype(np.float32)), name="weight1")
  79. self.p2 = Parameter(Tensor(np.ones([1152, 16]).astype(np.float32)), name="weight2")
  80. def construct(self, x, y):
  81. x = self.fc1(x, self.p1)
  82. x = self.fc2(x, self.p2)
  83. return x - y
  84. def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None):
  85. context.set_context(mode=context.GRAPH_MODE)
  86. context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True)
  87. inputs = Tensor(np.ones([32, 48]).astype(np.float32))
  88. label = Tensor(np.zeros([32, 16]).astype(np.float32))
  89. net = net(strategy1, strategy2)
  90. net = _VirtualDatasetCell(net)
  91. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  92. train_network = TrainOneStepCell(net, optimizer).set_comm_fusion(4)
  93. train_network.set_auto_parallel()
  94. train_network.set_train()
  95. _cell_graph_executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True)
  96. context.reset_auto_parallel_context()
  97. return train_network
  98. def test_auto_parallel_momentum_1():
  99. auto_parallel_compile_net("auto_parallel", 8, Net2)
  100. def test_auto_parallel_momentum_2():
  101. # data parallel case
  102. auto_parallel_compile_net("auto_parallel", 8, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)))
  103. def test_auto_parallel_momentum_3():
  104. # hybrid parallel case
  105. # weight1 could not be shard and weight2 is repeated
  106. dp = 4
  107. context.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": 1})
  108. train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((dp, 8), (8, 1)), ((dp, 4), (4, 2)))
  109. param_dict = train_network.parameter_layout_dict
  110. # validate opt_shard_group
  111. assert not param_dict["weight1"][5]
  112. assert param_dict["weight2"][5].startswith(str(dp))
  113. def test_auto_parallel_momentum_4():
  114. # hybrid parallel cases
  115. # devices are repeatedly used
  116. auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 4), (4, 1)), ((4, 4), (4, 2)))
  117. def test_auto_parallel_momentum_5():
  118. # test parallel optimizer filter
  119. context.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": 1})
  120. train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net3, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
  121. param_dict = train_network.parameter_layout_dict
  122. # validate opt_shard_group
  123. assert not param_dict["weight1"][5]
  124. assert not param_dict["weight2"][5]
  125. def test_auto_parallel_momentum_6():
  126. # test not fully use parallel optimizer with optimizer_weight_shard_size
  127. # weight1 could not be shard and weight2 is repeated
  128. param_shard_group_size = 2
  129. context.set_auto_parallel_context(optimizer_weight_shard_size=param_shard_group_size)
  130. context.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": 1})
  131. train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
  132. param_dict = train_network.parameter_layout_dict
  133. # validate opt_shard_group
  134. assert param_dict["weight1"][5].startswith(str(param_shard_group_size))
  135. assert param_dict["weight2"][5].startswith(str(param_shard_group_size))
  136. def test_default_threshold():
  137. """
  138. Feature: auto-parallel-optimizer(I4S85V)
  139. Description: the memory size of weight2(72KB) is higher than the threshold(64KB).
  140. Expectation: weight2 being sharded with sharding group size equal to dp.
  141. """
  142. dp = 4
  143. train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net4, ((dp, 8), (8, 1)), ((dp, 4), (4, 2)))
  144. param_dict = train_network.parameter_layout_dict
  145. # validate opt_shard_group
  146. assert param_dict["weight2"][5]
  147. def test_user_define_threshold():
  148. """
  149. Feature: auto-parallel-optimizer(I4S85V)
  150. Description: the memory size of weight2(72KB) is lower than the threshold(100KB).
  151. Expectation: weight2 being not sharded.
  152. """
  153. dp = 4
  154. context.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": 100})
  155. train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net4, ((dp, 8), (8, 1)), ((dp, 4), (4, 2)))
  156. param_dict = train_network.parameter_layout_dict
  157. # validate opt_shard_group
  158. assert not param_dict["weight2"][5]
  159. def test_AdamWeightDecay():
  160. """ test_AdamWeightDecay """
  161. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True,
  162. parallel_optimizer_config={"parallel_optimizer_threshold": 1})
  163. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  164. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  165. net = Net()
  166. net.set_train()
  167. loss = nn.SoftmaxCrossEntropyWithLogits()
  168. optimizer = AdamWeightDecay(net.trainable_params(), learning_rate=0.1)
  169. net_with_loss = WithLossCell(net, loss)
  170. train_network = TrainOneStepCell(net_with_loss, optimizer)
  171. _cell_graph_executor.compile(train_network, inputs, label)
  172. context.reset_auto_parallel_context()
  173. def test_lamb_compile():
  174. """ test_Lamb_compile """
  175. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True,
  176. parallel_optimizer_config={"parallel_optimizer_threshold": 2})
  177. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  178. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  179. net = Net()
  180. net.set_train()
  181. loss = nn.SoftmaxCrossEntropyWithLogits()
  182. optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
  183. net_with_loss = WithLossCell(net, loss)
  184. train_network = TrainOneStepCell(net_with_loss, optimizer)
  185. _cell_graph_executor.compile(train_network, inputs, label)
  186. context.reset_auto_parallel_context()
  187. def test_lamb_split_fusion():
  188. """ test_Lamb_split_fusion """
  189. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True,
  190. all_reduce_fusion_config=[2, 4, 6, 8],
  191. parallel_optimizer_config={"parallel_optimizer_threshold": 1})
  192. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  193. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  194. net = Net()
  195. net.set_train()
  196. loss = nn.SoftmaxCrossEntropyWithLogits()
  197. optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
  198. net_with_loss = WithLossCell(net, loss)
  199. train_network = TrainOneStepCell(net_with_loss, optimizer)
  200. _cell_graph_executor.compile(train_network, inputs, label)
  201. context.reset_auto_parallel_context()
  202. def test_edge_case():
  203. """ test_edge_case """
  204. context.set_auto_parallel_context(enable_parallel_optimizer=True)
  205. net = Net()
  206. with pytest.raises(RuntimeError):
  207. context.set_auto_parallel_context(parallel_mode="stand_alone")
  208. Lamb(net.trainable_params(), learning_rate=0.1)
  209. with pytest.raises(RuntimeError):
  210. context.set_context(device_target="GPU")
  211. context.set_auto_parallel_context(parallel_mode="data_parallel")
  212. Lamb(net.trainable_params(), learning_rate=0.1)
  213. with pytest.raises(RuntimeError):
  214. context.set_context(device_target="Ascend")
  215. context.set_auto_parallel_context(parallel_mode="data_parallel")
  216. Adam(net.trainable_params(), learning_rate=0.1)
  217. with pytest.raises(RuntimeError):
  218. context.set_auto_parallel_context(device_num=16)
  219. Lamb(net.trainable_params(), learning_rate=0.1)
  220. with pytest.raises(ValueError):
  221. context.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": -1})
  222. Lamb(net.trainable_params(), learning_rate=0.1)
  223. context.reset_auto_parallel_context()