You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pack.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore as ms
  17. import mindspore.context as context
  18. from mindspore import Tensor, Parameter
  19. import mindspore.nn as nn
  20. from mindspore.common.api import _executor
  21. from mindspore.nn import TrainOneStepCell, Momentum
  22. from mindspore.ops import operations as P
  23. from mindspore.nn import Dense, Flatten
  24. class Net(nn.Cell):
  25. def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None, is_parameter=True):
  26. super(Net, self).__init__()
  27. self.pack = P.Pack(axis=axis).shard(strategy1)
  28. self.mul = P.Mul().shard(strategy2)
  29. if is_parameter:
  30. self.weight1 = Parameter(weight1, "w1")
  31. else:
  32. self.weight1 = weight1
  33. self.weight2 = Parameter(weight2, "w2")
  34. def construct(self, x):
  35. out = self.pack([self.weight1, self.weight2])
  36. out = self.mul(x, out)
  37. return out
  38. class Net1(nn.Cell):
  39. def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None):
  40. super(Net1, self).__init__()
  41. self.pack = P.Pack(axis=axis).shard(strategy1)
  42. self.mul = P.Mul().shard(strategy2)
  43. self.weight1 = Parameter(weight1, "w1")
  44. self.weight2 = Parameter(weight2, "w2")
  45. def construct(self, x):
  46. out = self.mul(x, self.weight1)
  47. out = self.pack([out, self.weight2])
  48. return out
  49. class Net2(nn.Cell):
  50. def __init__(self, weight1, weight2, weight3, axis=0, strategy1=None, strategy2=None, is_parameter=True):
  51. super(Net2, self).__init__()
  52. self.pack = P.Pack(axis=axis).shard(strategy1)
  53. self.mul = P.Mul().shard(strategy2)
  54. if is_parameter:
  55. self.weight1 = Parameter(weight1, "w1")
  56. else:
  57. self.weight1 = weight1
  58. self.weight2 = Parameter(weight2, "w2")
  59. self.weight3 = Parameter(weight2, "w3")
  60. def construct(self, x):
  61. out = self.pack([self.weight1, self.weight2, self.weight3])
  62. out = self.mul(x, out)
  63. return out
  64. class PackConstantNet1(nn.Cell):
  65. def __init__(self, dense_in_channel, dense_out_channel, axis=0, shape=None, strategy=None):
  66. super().__init__()
  67. weight_np = np.full((dense_out_channel, dense_in_channel), 0.01, dtype=np.float32)
  68. bias_np = np.full((dense_out_channel), 0.01, dtype=np.float32)
  69. self.pack_con = Tensor(np.full(shape, 0.01, dtype=np.float32))
  70. self.flat = Flatten()
  71. self.dense = Dense(in_channels=dense_in_channel,
  72. out_channels=dense_out_channel,
  73. weight_init=Tensor(weight_np),
  74. bias_init=Tensor(bias_np),
  75. has_bias=True)
  76. self.mul = P.Mul()
  77. self.pack = P.Pack(axis)
  78. if strategy is not None:
  79. self.pack.shard(strategy)
  80. def construct(self, inputs):
  81. x = self.pack([self.pack_con, self.pack_con, self.pack_con, self.pack_con,
  82. self.pack_con, self.pack_con, self.pack_con, self.pack_con])
  83. x1 = self.flat(x)
  84. x2 = self.flat(inputs)
  85. x = self.mul(x1, x2)
  86. x = self.dense(x)
  87. return x
  88. class PackConstantNet2(nn.Cell):
  89. def __init__(self, dense_in_channel, dense_out_channel, axis=0, shape=None, strategy=None):
  90. super().__init__()
  91. weight_np = np.full((dense_out_channel, dense_in_channel), 0.01, dtype=np.float32)
  92. bias_np = np.full((dense_out_channel), 0.01, dtype=np.float32)
  93. self.pack_con = Tensor(np.full(shape, 0.01, dtype=np.float32))
  94. self.flat = Flatten()
  95. self.dense = Dense(in_channels=dense_in_channel,
  96. out_channels=dense_out_channel,
  97. weight_init=Tensor(weight_np),
  98. bias_init=Tensor(bias_np),
  99. has_bias=True)
  100. self.mul = P.Mul()
  101. self.pack = P.Pack(axis)
  102. if strategy is not None:
  103. self.pack.shard(strategy)
  104. def construct(self, inputs):
  105. x = self.pack((self.pack_con, self.pack_con, self.pack_con, self.pack_con,
  106. self.pack_con, self.pack_con, self.pack_con, self.pack_con))
  107. x1 = self.flat(x)
  108. x2 = self.flat(inputs)
  109. x = self.mul(x1, x2)
  110. x = self.dense(x)
  111. return x
  112. _w1 = Tensor(np.ones([48, 64]), dtype=ms.float32)
  113. _w2 = Tensor(np.ones([48, 64]), dtype=ms.float32)
  114. _w3 = Tensor(np.ones([48, 64]), dtype=ms.float32)
  115. _x = Tensor(np.ones([2, 48, 64]), dtype=ms.float32)
  116. _x1 = Tensor(np.ones([48, 64]), dtype=ms.float32)
  117. _x2 = Tensor(np.ones([3, 48, 64]), dtype=ms.float32)
  118. _x_c = Tensor(np.ones([8, 8, 8]), dtype=ms.float32)
  119. def compile_net(net):
  120. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  121. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  122. train_net = TrainOneStepCell(net, optimizer)
  123. train_net.set_auto_parallel()
  124. train_net.set_train()
  125. _executor.compile(train_net, _x)
  126. context.reset_auto_parallel_context()
  127. def compile_net1(net):
  128. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  129. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  130. train_net = TrainOneStepCell(net, optimizer)
  131. train_net.set_auto_parallel()
  132. train_net.set_train()
  133. _executor.compile(train_net, _x1)
  134. context.reset_auto_parallel_context()
  135. def compile_net2(net):
  136. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  137. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  138. train_net = TrainOneStepCell(net, optimizer)
  139. train_net.set_auto_parallel()
  140. train_net.set_train()
  141. _executor.compile(train_net, _x2)
  142. context.reset_auto_parallel_context()
  143. def compile_net_con(net):
  144. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  145. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  146. train_net = TrainOneStepCell(net, optimizer)
  147. train_net.set_auto_parallel()
  148. _executor.compile(train_net, _x_c)
  149. context.reset_auto_parallel_context()
  150. def test_pack_parameter():
  151. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  152. strategy1 = ((4, 2), (4, 2))
  153. strategy2 = ((1, 4, 2), (1, 4, 2))
  154. net = Net(_w1, _w2, 0, strategy1, strategy2)
  155. compile_net(net)
  156. def test_pack_parameter_no_full_split():
  157. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  158. strategy1 = ((2, 2), (2, 2))
  159. strategy2 = ((1, 4, 2), (1, 4, 2))
  160. net = Net(_w1, _w2, 0, strategy1, strategy2)
  161. compile_net(net)
  162. def test_pack_tensor_and_parameter():
  163. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  164. strategy1 = ((4, 2), (4, 2))
  165. strategy2 = ((1, 4, 2), (1, 4, 2))
  166. net = Net(_w1, _w2, 0, strategy1, strategy2, False)
  167. compile_net(net)
  168. def test_pack_output():
  169. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  170. strategy1 = ((4, 2), (4, 2))
  171. strategy2 = ((4, 2), (4, 2))
  172. net = Net1(_w1, _w2, 0, strategy1, strategy2)
  173. compile_net1(net)
  174. def test_pack_output_axis1():
  175. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  176. strategy1 = ((4, 2), (4, 2))
  177. strategy2 = ((4, 2), (4, 2))
  178. net = Net1(_w1, _w2, 1, strategy1, strategy2)
  179. compile_net1(net)
  180. def test_pack_output_no_full_split():
  181. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  182. strategy1 = ((2, 2), (2, 2))
  183. strategy2 = ((4, 2), (4, 2))
  184. net = Net1(_w1, _w2, 0, strategy1, strategy2)
  185. compile_net1(net)
  186. def test_pack_no_strategy():
  187. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  188. strategy1 = None
  189. strategy2 = ((4, 2), (4, 2))
  190. net = Net1(_w1, _w2, 0, strategy1, strategy2)
  191. compile_net1(net)
  192. def test_pack_no_strategy_axis1():
  193. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  194. strategy1 = None
  195. strategy2 = ((4, 2), (4, 2))
  196. net = Net1(_w1, _w2, 1, strategy1, strategy2)
  197. compile_net1(net)
  198. def test_pack_auto_parallel():
  199. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  200. net = Net1(_w1, _w2, 0)
  201. compile_net1(net)
  202. def test_pack_auto_parallel_axis1():
  203. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  204. net = Net1(_w1, _w2, 1)
  205. compile_net1(net)
  206. def test_pack_auto_parallel_3_tensor():
  207. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  208. net = Net2(_w1, _w2, _w3)
  209. compile_net2(net)
  210. def test_pack_constant1():
  211. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  212. net = PackConstantNet1(dense_in_channel=64, dense_out_channel=4, axis=0, shape=(8, 8),
  213. strategy=((4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1)))
  214. compile_net_con(net)
  215. def test_pack_constant2():
  216. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  217. net = PackConstantNet2(dense_in_channel=64, dense_out_channel=4, axis=0, shape=(8, 8),
  218. strategy=((4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1)))
  219. compile_net_con(net)
  220. def test_pack_auto_constant():
  221. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  222. net = PackConstantNet1(dense_in_channel=64, dense_out_channel=4, axis=0, shape=(8, 8),
  223. strategy=((8, 1), (8, 1), (8, 1), (8, 1), (8, 1), (8, 1), (8, 1), (8, 1)))
  224. compile_net_con(net)