You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_allreduce_fusion.py 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. import mindspore as ms
  17. import mindspore.nn as nn
  18. from mindspore import Tensor, context
  19. from mindspore.common.api import _executor
  20. from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
  21. from mindspore.nn.optim.momentum import Momentum
  22. from mindspore.parallel import _cost_model_context as cost_model_context
  23. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  24. from mindspore.train import Model
  25. from mindspore.context import ParallelMode
  26. from tests.dataset_mock import MindData
  27. class Dataset(MindData):
  28. def __init__(self, predict, label, length=3):
  29. super(Dataset, self).__init__(size=length)
  30. self.predict = predict
  31. self.label = label
  32. self.index = 0
  33. self.length = length
  34. def __iter__(self):
  35. return self
  36. def __next__(self):
  37. if self.index >= self.length:
  38. raise StopIteration
  39. self.index += 1
  40. return self.predict, self.label
  41. def reset(self):
  42. self.index = 0
  43. class DenseNet1(nn.Cell):
  44. def __init__(self, has_bias=True, activation='relu'):
  45. super(DenseNet1, self).__init__()
  46. self.fc1 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  47. self.fc2 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  48. self.fc3 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  49. self.fc4 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  50. def construct(self, x):
  51. q = self.fc1(x)
  52. k = self.fc2(q)
  53. v = self.fc3(k)
  54. s = self.fc4(v)
  55. return s
  56. class DenseNet2(nn.Cell):
  57. def __init__(self, has_bias=True, activation='relu'):
  58. super(DenseNet2, self).__init__()
  59. self.fc1 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  60. self.fc2 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  61. self.fc3 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  62. self.fc4 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  63. self.fc5 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  64. self.fc6 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  65. self.fc7 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  66. self.fc8 = nn.Dense(128, 128, has_bias=has_bias, activation=activation)
  67. def construct(self, x):
  68. q = self.fc1(x)
  69. k = self.fc2(q)
  70. v = self.fc3(k)
  71. s = self.fc4(v)
  72. t = self.fc5(s)
  73. u = self.fc6(t)
  74. w = self.fc7(u)
  75. z = self.fc8(w)
  76. return z
  77. class SimpleDMLNet(nn.Cell):
  78. def __init__(self, net1, net2):
  79. super(SimpleDMLNet, self).__init__()
  80. self.backbone1 = net1
  81. self.backbone2 = net2
  82. def construct(self, x):
  83. x1 = self.backbone1(x)
  84. x2 = self.backbone2(x)
  85. return x1 + x2
  86. def train_common(net):
  87. batch_size = 32
  88. learning_rate = 0.1
  89. momentum = 0.9
  90. epoch_size = 2
  91. device_num = 4
  92. auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True)
  93. context.set_auto_parallel_context(device_num=device_num, parameter_broadcast=False)
  94. context.set_context(mode=context.GRAPH_MODE)
  95. predict = Tensor(np.ones([batch_size, 128]), dtype=ms.float32)
  96. label = Tensor(np.ones([batch_size]), dtype=ms.int32)
  97. dataset = Dataset(predict, label, 2)
  98. loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  99. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  100. model = Model(net, loss, opt)
  101. model.train(epoch_size, dataset, dataset_sink_mode=False)
  102. allreduce_fusion_dict = _executor._get_allreduce_fusion(model._train_network)
  103. print(allreduce_fusion_dict)
  104. return allreduce_fusion_dict
  105. @pytest.mark.skip(reason="depreciated feature")
  106. def test_allreduce_fusion_parameters():
  107. cost_model_context.reset_cost_model_context()
  108. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
  109. algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
  110. assert algorithm == 2
  111. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
  112. algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
  113. assert algorithm == 1
  114. cost_model_context.reset_cost_model_context()
  115. algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
  116. assert algorithm == 0
  117. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
  118. fusion_times = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_times')
  119. assert fusion_times == 2
  120. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.2)
  121. tail_percent = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_percent')
  122. assert tail_percent == 0.2
  123. cost_model_context.reset_cost_model_context()
  124. tail_percent = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_percent')
  125. assert tail_percent == 0.1
  126. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.2)
  127. tail_time = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_time')
  128. assert tail_time == 0.2
  129. cost_model_context.reset_cost_model_context()
  130. tail_time = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_time')
  131. assert tail_time == 0.1
  132. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.2)
  133. allreduce_inherent_time = cost_model_context.get_cost_model_context(
  134. 'costmodel_allreduce_fusion_allreduce_inherent_time')
  135. assert allreduce_inherent_time == 0.2
  136. cost_model_context.reset_cost_model_context()
  137. allreduce_inherent_time = cost_model_context.get_cost_model_context(
  138. 'costmodel_allreduce_fusion_allreduce_inherent_time')
  139. assert allreduce_inherent_time == 0.1
  140. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.2)
  141. allreduce_bandwidth = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_allreduce_bandwidth')
  142. assert allreduce_bandwidth == 0.2
  143. cost_model_context.reset_cost_model_context()
  144. allreduce_bandwidth = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_allreduce_bandwidth')
  145. assert allreduce_bandwidth == 0.1
  146. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.2)
  147. computation_time_parameter = cost_model_context.get_cost_model_context(
  148. 'costmodel_allreduce_fusion_computation_time_parameter')
  149. assert computation_time_parameter == 0.2
  150. cost_model_context.reset_cost_model_context()
  151. computation_time_parameter = cost_model_context.get_cost_model_context(
  152. 'costmodel_allreduce_fusion_computation_time_parameter')
  153. assert computation_time_parameter == 0.1
  154. @pytest.mark.skip(reason="depreciated feature")
  155. def test_allreduce_fusion1():
  156. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
  157. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
  158. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
  159. context.reset_auto_parallel_context()
  160. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  161. net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
  162. allreduce_fusion_dict = train_common(net)
  163. expect_dict = {'backbone2.fc8.weight': 2,
  164. 'backbone2.fc7.weight': 2,
  165. 'backbone2.fc6.weight': 2,
  166. 'backbone1.fc4.weight': 2,
  167. 'backbone1.fc3.weight': 2,
  168. 'backbone1.fc2.weight': 2,
  169. 'backbone2.fc5.weight': 1,
  170. 'backbone2.fc4.weight': 1,
  171. 'backbone2.fc3.weight': 1,
  172. 'backbone2.fc2.weight': 1,
  173. 'backbone2.fc1.weight': 1,
  174. 'backbone1.fc1.weight': 1}
  175. assert allreduce_fusion_dict == expect_dict
  176. cost_model_context.reset_cost_model_context()
  177. @pytest.mark.skip(reason="depreciated feature")
  178. # reset_cost_model_context is called, the default value of costmodel_allreduce_fusion_times is 0, step_allreduce_fusion
  179. # is bypassed.
  180. def test_allreduce_fusion2():
  181. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
  182. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
  183. cost_model_context.reset_cost_model_context()
  184. context.reset_auto_parallel_context()
  185. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  186. net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
  187. allreduce_fusion_dict = train_common(net)
  188. expect_dict = {}
  189. assert allreduce_fusion_dict == expect_dict
  190. cost_model_context.reset_cost_model_context()
  191. @pytest.mark.skip(reason="depreciated feature")
  192. def test_allreduce_fusion3():
  193. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
  194. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=3)
  195. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.3333333)
  196. context.reset_auto_parallel_context()
  197. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  198. net = SimpleDMLNet(DenseNet1(has_bias=True, activation='relu'), DenseNet2(has_bias=False, activation='relu'))
  199. allreduce_fusion_dict = train_common(net)
  200. expect_dict = {'backbone2.fc8.weight': 3,
  201. 'backbone2.fc7.weight': 3,
  202. 'backbone2.fc6.weight': 2,
  203. 'backbone2.fc5.weight': 2,
  204. 'backbone2.fc4.weight': 2,
  205. 'backbone2.fc3.weight': 1,
  206. 'backbone2.fc2.weight': 1,
  207. 'backbone2.fc1.weight': 1,
  208. 'backbone1.fc4.bias': 3,
  209. 'backbone1.fc4.weight': 3,
  210. 'backbone1.fc3.bias': 3,
  211. 'backbone1.fc3.weight': 2,
  212. 'backbone1.fc2.bias': 2,
  213. 'backbone1.fc2.weight': 2,
  214. 'backbone1.fc1.bias': 2,
  215. 'backbone1.fc1.weight': 2}
  216. assert allreduce_fusion_dict == expect_dict
  217. cost_model_context.reset_cost_model_context()
  218. @pytest.mark.skip(reason="depreciated feature")
  219. def test_allreduce_fusion4():
  220. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
  221. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
  222. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
  223. context.reset_auto_parallel_context()
  224. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  225. net = SimpleDMLNet(DenseNet2(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
  226. allreduce_fusion_dict = train_common(net)
  227. expect_dict = {'backbone2.fc8.weight': 2,
  228. 'backbone2.fc7.weight': 2,
  229. 'backbone2.fc6.weight': 2,
  230. 'backbone1.fc8.weight': 2,
  231. 'backbone1.fc7.weight': 2,
  232. 'backbone1.fc6.weight': 2,
  233. 'backbone2.fc5.weight': 1,
  234. 'backbone2.fc4.weight': 1,
  235. 'backbone2.fc3.weight': 1,
  236. 'backbone2.fc2.weight': 1,
  237. 'backbone2.fc1.weight': 1,
  238. 'backbone1.fc5.weight': 1,
  239. 'backbone1.fc4.weight': 1,
  240. 'backbone1.fc3.weight': 1,
  241. 'backbone1.fc2.weight': 1,
  242. 'backbone1.fc1.weight': 1}
  243. assert allreduce_fusion_dict == expect_dict
  244. cost_model_context.reset_cost_model_context()
  245. @pytest.mark.skip(reason="depreciated feature")
  246. def test_allreduce_fusion5():
  247. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
  248. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1)
  249. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.05)
  250. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.000001)
  251. cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.0000015)
  252. context.reset_auto_parallel_context()
  253. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  254. net = SimpleDMLNet(DenseNet2(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
  255. allreduce_fusion_dict = train_common(net)
  256. expect_dict = {'backbone2.fc8.weight': 3,
  257. 'backbone2.fc7.weight': 3,
  258. 'backbone2.fc6.weight': 3,
  259. 'backbone2.fc5.weight': 3,
  260. 'backbone2.fc4.weight': 2,
  261. 'backbone2.fc3.weight': 2,
  262. 'backbone2.fc2.weight': 1,
  263. 'backbone2.fc1.weight': 1,
  264. 'backbone1.fc8.weight': 3,
  265. 'backbone1.fc7.weight': 3,
  266. 'backbone1.fc6.weight': 3,
  267. 'backbone1.fc5.weight': 3,
  268. 'backbone1.fc4.weight': 2,
  269. 'backbone1.fc3.weight': 2,
  270. 'backbone1.fc2.weight': 1,
  271. 'backbone1.fc1.weight': 1,}
  272. assert allreduce_fusion_dict == expect_dict
  273. cost_model_context.reset_cost_model_context()