You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pack.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore as ms
  17. import mindspore.context as context
  18. from mindspore import Tensor, Parameter
  19. import mindspore.nn as nn
  20. from mindspore.common.api import _executor
  21. from mindspore.nn import TrainOneStepCell, Momentum
  22. from mindspore.ops import operations as P
  23. class Net(nn.Cell):
  24. def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None, is_parameter=True):
  25. super(Net, self).__init__()
  26. self.pack = P.Pack(axis=axis).shard(strategy1)
  27. self.mul = P.Mul().shard(strategy2)
  28. if is_parameter:
  29. self.weight1 = Parameter(weight1, "w1")
  30. else:
  31. self.weight1 = weight1
  32. self.weight2 = Parameter(weight2, "w2")
  33. def construct(self, x):
  34. out = self.pack([self.weight1, self.weight2])
  35. out = self.mul(x, out)
  36. return out
  37. class Net1(nn.Cell):
  38. def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None):
  39. super(Net1, self).__init__()
  40. self.pack = P.Pack(axis=axis).shard(strategy1)
  41. self.mul = P.Mul().shard(strategy2)
  42. self.weight1 = Parameter(weight1, "w1")
  43. self.weight2 = Parameter(weight2, "w2")
  44. def construct(self, x):
  45. out = self.mul(x, self.weight1)
  46. out = self.pack([out, self.weight2])
  47. return out
  48. class Net2(nn.Cell):
  49. def __init__(self, weight1, weight2, weight3, axis=0, strategy1=None, strategy2=None, is_parameter=True):
  50. super(Net2, self).__init__()
  51. self.pack = P.Pack(axis=axis).shard(strategy1)
  52. self.mul = P.Mul().shard(strategy2)
  53. if is_parameter:
  54. self.weight1 = Parameter(weight1, "w1")
  55. else:
  56. self.weight1 = weight1
  57. self.weight2 = Parameter(weight2, "w2")
  58. self.weight3 = Parameter(weight2, "w3")
  59. def construct(self, x):
  60. out = self.pack([self.weight1, self.weight2, self.weight3])
  61. out = self.mul(x, out)
  62. return out
  63. _w1 = Tensor(np.ones([48, 64]), dtype=ms.float32)
  64. _w2 = Tensor(np.ones([48, 64]), dtype=ms.float32)
  65. _w3 = Tensor(np.ones([48, 64]), dtype=ms.float32)
  66. _x = Tensor(np.ones([2, 48, 64]), dtype=ms.float32)
  67. _x1 = Tensor(np.ones([48, 64]), dtype=ms.float32)
  68. _x2 = Tensor(np.ones([3, 48, 64]), dtype=ms.float32)
  69. def compile_net(net):
  70. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  71. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  72. train_net = TrainOneStepCell(net, optimizer)
  73. train_net.set_auto_parallel()
  74. _executor.compile(train_net, _x)
  75. context.reset_auto_parallel_context()
  76. def compile_net1(net):
  77. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  78. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  79. train_net = TrainOneStepCell(net, optimizer)
  80. train_net.set_auto_parallel()
  81. _executor.compile(train_net, _x1)
  82. context.reset_auto_parallel_context()
  83. def compile_net2(net):
  84. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  85. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  86. train_net = TrainOneStepCell(net, optimizer)
  87. train_net.set_auto_parallel()
  88. _executor.compile(train_net, _x2)
  89. context.reset_auto_parallel_context()
  90. def test_pack_parameter():
  91. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  92. strategy1 = ((4, 2), (4, 2))
  93. strategy2 = ((1, 4, 2), (1, 4, 2))
  94. net = Net(_w1, _w2, 0, strategy1, strategy2)
  95. compile_net(net)
  96. def test_pack_parameter_no_full_split():
  97. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  98. strategy1 = ((2, 2), (2, 2))
  99. strategy2 = ((1, 4, 2), (1, 4, 2))
  100. net = Net(_w1, _w2, 0, strategy1, strategy2)
  101. compile_net(net)
  102. def test_pack_tensor_and_parameter():
  103. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  104. strategy1 = ((4, 2), (4, 2))
  105. strategy2 = ((1, 4, 2), (1, 4, 2))
  106. net = Net(_w1, _w2, 0, strategy1, strategy2, False)
  107. compile_net(net)
  108. def test_pack_output():
  109. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  110. strategy1 = ((4, 2), (4, 2))
  111. strategy2 = ((4, 2), (4, 2))
  112. net = Net1(_w1, _w2, 0, strategy1, strategy2)
  113. compile_net1(net)
  114. def test_pack_output_axis1():
  115. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  116. strategy1 = ((4, 2), (4, 2))
  117. strategy2 = ((4, 2), (4, 2))
  118. net = Net1(_w1, _w2, 1, strategy1, strategy2)
  119. compile_net1(net)
  120. def test_pack_output_no_full_split():
  121. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  122. strategy1 = ((2, 2), (2, 2))
  123. strategy2 = ((4, 2), (4, 2))
  124. net = Net1(_w1, _w2, 0, strategy1, strategy2)
  125. compile_net1(net)
  126. def test_pack_no_strategy():
  127. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  128. strategy1 = None
  129. strategy2 = ((4, 2), (4, 2))
  130. net = Net1(_w1, _w2, 0, strategy1, strategy2)
  131. compile_net1(net)
  132. def test_pack_no_strategy_axis1():
  133. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  134. strategy1 = None
  135. strategy2 = ((4, 2), (4, 2))
  136. net = Net1(_w1, _w2, 1, strategy1, strategy2)
  137. compile_net1(net)
  138. def test_pack_auto_parallel():
  139. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  140. net = Net1(_w1, _w2, 0)
  141. compile_net1(net)
  142. def test_pack_auto_parallel_axis1():
  143. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  144. net = Net1(_w1, _w2, 1)
  145. compile_net1(net)
  146. def test_pack_auto_parallel_3_tensor():
  147. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  148. net = Net2(_w1, _w2, _w3)
  149. compile_net2(net)