You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_allreduce_fusion.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. import mindspore as ms
  17. import mindspore.nn as nn
  18. from mindspore import Tensor, context
  19. from mindspore.common.api import _cell_graph_executor
  20. from mindspore.nn import TrainOneStepCell, WithLossCell
  21. from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
  22. from mindspore.nn.optim import Lamb
  23. from mindspore.nn.optim.momentum import Momentum
  24. from mindspore.ops import operations as P
  25. from mindspore.parallel import _cost_model_context as cost_model_context
  26. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  27. from mindspore.train import Model
  28. from mindspore.context import ParallelMode
  29. from tests.dataset_mock import MindData
  30. context.set_context(mode=context.PYNATIVE_MODE)
  31. class Net(nn.Cell):
  32. """Net definition"""
  33. def __init__(self):
  34. super(Net, self).__init__()
  35. self.fc1 = nn.Dense(128, 768, activation='relu')
  36. self.fc2 = nn.Dense(128, 768, activation='relu')
  37. self.fc3 = nn.Dense(128, 768, activation='relu')
  38. self.fc4 = nn.Dense(768, 768, activation='relu')
  39. self.relu4 = nn.ReLU()
  40. self.relu5 = nn.ReLU()
  41. self.transpose = P.Transpose()
  42. self.matmul1 = P.MatMul()
  43. self.matmul2 = P.MatMul()
  44. def construct(self, x):
  45. q = self.fc1(x)
  46. k = self.fc2(x)
  47. v = self.fc3(x)
  48. k = self.transpose(k, (1, 0))
  49. c = self.relu4(self.matmul1(q, k))
  50. s = self.relu5(self.matmul2(c, v))
  51. s = self.fc4(s)
  52. return s
  53. class Dataset(MindData):
  54. def __init__(self, predict, label, length=3):
  55. super(Dataset, self).__init__(size=length)
  56. self.predict = predict
  57. self.label = label
  58. self.index = 0
  59. self.length = length
  60. def __iter__(self):
  61. return self
  62. def __next__(self):
  63. if self.index >= self.length:
  64. raise StopIteration
  65. self.index += 1
  66. return self.predict, self.label
  67. def reset(self):
  68. self.index = 0
  69. class DenseNet1(nn.Cell):
  70. def __init__(self, has_bias=True, activation='relu'):
  71. super(DenseNet1, self).__init__()
  72. self.fc1 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  73. self.fc2 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  74. self.fc3 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  75. self.fc4 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  76. def construct(self, x):
  77. q = self.fc1(x)
  78. k = self.fc2(q)
  79. v = self.fc3(k)
  80. s = self.fc4(v)
  81. return s
  82. class DenseNet2(nn.Cell):
  83. def __init__(self, has_bias=True, activation='relu'):
  84. super(DenseNet2, self).__init__()
  85. self.fc1 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  86. self.fc2 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  87. self.fc3 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  88. self.fc4 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  89. self.fc5 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  90. self.fc6 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  91. self.fc7 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  92. self.fc8 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  93. def construct(self, x):
  94. q = self.fc1(x)
  95. k = self.fc2(q)
  96. v = self.fc3(k)
  97. s = self.fc4(v)
  98. t = self.fc5(s)
  99. u = self.fc6(t)
  100. w = self.fc7(u)
  101. z = self.fc8(w)
  102. return z
  103. class DenseNet3(nn.Cell):
  104. def __init__(self, has_bias=True, activation='relu'):
  105. super(DenseNet3, self).__init__()
  106. self.fc1 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  107. def construct(self, x):
  108. q = self.fc1(x)
  109. return q
  110. class SimpleDMLNet(nn.Cell):
  111. def __init__(self, net1, net2):
  112. super(SimpleDMLNet, self).__init__()
  113. self.backbone1 = net1
  114. self.backbone2 = net2
  115. def construct(self, x):
  116. x1 = self.backbone1(x)
  117. x2 = self.backbone2(x)
  118. return x1 + x2
  119. def train_common(net):
  120. batch_size = 32
  121. learning_rate = 0.1
  122. momentum = 0.9
  123. epoch_size = 2
  124. device_num = 4
  125. context.set_auto_parallel_context(device_num=device_num, parameter_broadcast=False)
  126. context.set_context(mode=context.GRAPH_MODE)
  127. predict = Tensor(np.ones([batch_size, 128]), dtype=ms.float32)
  128. label = Tensor(np.ones([batch_size]), dtype=ms.int32)
  129. dataset = Dataset(predict, label, 2)
  130. loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  131. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  132. model = Model(net, loss, opt)
  133. model.train(epoch_size, dataset, dataset_sink_mode=False)
  134. allreduce_fusion_dict = _cell_graph_executor._get_allreduce_fusion(model._train_network)
  135. print(allreduce_fusion_dict)
  136. return allreduce_fusion_dict
  137. def test_allreduce_fusion_auto():
  138. """
  139. Feature: test_allreduce_fusion in auto mode
  140. Description: allreduce fusion in auto mode
  141. Expectation: success
  142. """
  143. comm_fusion_dict = {"allreduce": {"mode": "auto", "config": None}}
  144. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict)
  145. net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
  146. allreduce_fusion_dict = train_common(net)
  147. expect_dict = {'backbone2.fc8.weight': 1,
  148. 'backbone2.fc7.weight': 1,
  149. 'backbone2.fc6.weight': 1,
  150. 'backbone1.fc4.weight': 1,
  151. 'backbone1.fc3.weight': 1,
  152. 'backbone1.fc2.weight': 1,
  153. 'backbone2.fc5.weight': 1,
  154. 'backbone2.fc4.weight': 1,
  155. 'backbone2.fc3.weight': 1,
  156. 'backbone2.fc2.weight': 1,
  157. 'backbone2.fc1.weight': 1,
  158. 'backbone1.fc1.weight': 1}
  159. assert allreduce_fusion_dict == expect_dict
  160. def test_allreduce_fusion_size():
  161. """
  162. Feature: test_allreduce_fusion in size mode
  163. Description: allreduce fusion in size mode
  164. Expectation: success
  165. """
  166. comm_fusion_dict = {"allreduce": {"mode": "size", "config": 32}}
  167. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict)
  168. net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
  169. allreduce_fusion_dict = train_common(net)
  170. expect_dict = {'backbone2.fc8.weight': 1,
  171. 'backbone2.fc7.weight': 1,
  172. 'backbone2.fc6.weight': 1,
  173. 'backbone1.fc4.weight': 1,
  174. 'backbone1.fc3.weight': 1,
  175. 'backbone1.fc2.weight': 1,
  176. 'backbone2.fc5.weight': 1,
  177. 'backbone2.fc4.weight': 1,
  178. 'backbone2.fc3.weight': 1,
  179. 'backbone2.fc2.weight': 1,
  180. 'backbone2.fc1.weight': 1,
  181. 'backbone1.fc1.weight': 1}
  182. assert allreduce_fusion_dict == expect_dict
  183. cost_model_context.reset_cost_model_context()
  184. comm_fusion = auto_parallel_context().get_comm_fusion()
  185. assert comm_fusion_dict == comm_fusion
  186. def test_lamb_split_fusion_in_index():
  187. """
  188. Feature: test_allreduce_fusion in index mode
  189. Description: allreduce fusion in index mode
  190. Expectation: success
  191. """
  192. comm_fusion_dict = {"allreduce": {"mode": "index", "config": [2, 4, 6, 8]}}
  193. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True,
  194. comm_fusion=comm_fusion_dict)
  195. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  196. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  197. net = Net()
  198. net.set_train()
  199. loss = nn.SoftmaxCrossEntropyWithLogits()
  200. optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
  201. net_with_loss = WithLossCell(net, loss)
  202. train_network = TrainOneStepCell(net_with_loss, optimizer)
  203. _cell_graph_executor.compile(train_network, inputs, label)
  204. context.reset_auto_parallel_context()
  205. def test_allreduce_fusion_size_priority():
  206. """
  207. Feature: test priority of "enable_all_reduce_fusion" and "comm_fusion"
  208. Description: test priority of "enable_all_reduce_fusion" and "comm_fusion"
  209. Expectation: success
  210. """
  211. auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=False)
  212. comm_fusion_dict = {"allreduce": {"mode": "size", "config": 32}}
  213. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict)
  214. net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
  215. allreduce_fusion_dict = train_common(net)
  216. expect_dict = {}
  217. assert allreduce_fusion_dict == expect_dict
  218. auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True)
  219. allreduce_fusion_dict = train_common(net)
  220. expect_dict = {'backbone2.fc8.weight': 1,
  221. 'backbone2.fc7.weight': 1,
  222. 'backbone2.fc6.weight': 1,
  223. 'backbone1.fc4.weight': 1,
  224. 'backbone1.fc3.weight': 1,
  225. 'backbone1.fc2.weight': 1,
  226. 'backbone2.fc5.weight': 1,
  227. 'backbone2.fc4.weight': 1,
  228. 'backbone2.fc3.weight': 1,
  229. 'backbone2.fc2.weight': 1,
  230. 'backbone2.fc1.weight': 1,
  231. 'backbone1.fc1.weight': 1}
  232. assert allreduce_fusion_dict == expect_dict
  233. def test_allreduce_fusion_size_one_tensor():
  234. """
  235. Feature: test_allreduce_fusion in size mode with one tensor
  236. Description: test_allreduce_fusion in size mode with one tensor
  237. Expectation: success
  238. """
  239. comm_fusion_dict = {"allreduce": {"mode": "size", "config": 32}}
  240. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict)
  241. net = DenseNet3(has_bias=False, activation=None)
  242. allreduce_fusion_dict = train_common(net)
  243. expect_dict = {'fc1.weight': 1}
  244. assert allreduce_fusion_dict == expect_dict
  245. def test_fusion_invalid_value_failed():
  246. """
  247. Feature: test_allreduce_fusion with invalid value
  248. Description: test_allreduce_fusion with invalid value
  249. Expectation: throw TypeError
  250. """
  251. with pytest.raises(TypeError):
  252. comm_fusion_dict = {"allreduce": {"mode": "size", "config": "30.12"}}
  253. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, comm_fusion=comm_fusion_dict)
  254. def test_enable_invalid_value_failed():
  255. """
  256. Feature: enable_all_reduce_fusion with invalid value
  257. Description: enable_all_reduce_fusion with invalid value
  258. Expectation: throw TypeError
  259. """
  260. with pytest.raises(TypeError):
  261. auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion="fusion")