You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_auto_parallel_reshape.py 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import mindspore as ms
  16. import mindspore.nn as nn
  17. from mindspore import Tensor
  18. from mindspore import context
  19. from mindspore.common.api import _executor
  20. from mindspore.common.parameter import Parameter
  21. from mindspore.ops import composite as C
  22. from mindspore.ops import operations as P
  23. from tests.ut.python.ops.test_math_ops import VirtualLoss
  24. grad_all = C.GradOperation(get_all=True)
  25. class NetWithLoss(nn.Cell):
  26. def __init__(self, network):
  27. super(NetWithLoss, self).__init__()
  28. self.loss = VirtualLoss()
  29. self.network = network
  30. def construct(self, x):
  31. predict = self.network(x)
  32. return self.loss(predict)
  33. class GradWrap(nn.Cell):
  34. def __init__(self, network):
  35. super(GradWrap, self).__init__()
  36. self.network = network
  37. def construct(self, x):
  38. return grad_all(self.network)(x)
  39. def test_reshape_matmul():
  40. class Net(nn.Cell):
  41. def __init__(self):
  42. super().__init__()
  43. self.reshape = P.Reshape()
  44. self.matmul = P.MatMul()
  45. self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
  46. def construct(self, x):
  47. out = self.reshape(x, (64, 28))
  48. out = self.matmul(out, self.matmul_weight)
  49. return out
  50. size = 8
  51. context.set_auto_parallel_context(device_num=size, global_rank=0)
  52. x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32)
  53. net = GradWrap(NetWithLoss(Net()))
  54. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  55. net.set_auto_parallel()
  56. _executor.compile(net, x)
  57. def test_reshape_reshape():
  58. class Net(nn.Cell):
  59. def __init__(self):
  60. super().__init__()
  61. self.reshape = P.Reshape()
  62. self.relu = P.ReLU()
  63. def construct(self, x):
  64. x = self.relu(x)
  65. out = self.reshape(x, (64, 28))
  66. out = self.reshape(out, (64, 28, 1))
  67. return out
  68. size = 8
  69. context.set_auto_parallel_context(device_num=size, global_rank=0)
  70. x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32)
  71. net = GradWrap(NetWithLoss(Net()))
  72. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  73. net.set_auto_parallel()
  74. _executor.compile(net, x)
  75. def test_reshape_auto_1():
  76. class Net(nn.Cell):
  77. def __init__(self):
  78. super().__init__()
  79. self.relu = P.ReLU()
  80. self.reshape = P.Reshape()
  81. self.matmul = P.MatMul()
  82. self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
  83. def construct(self, x):
  84. out = self.relu(x)
  85. out = self.reshape(out, (64, 28))
  86. out = self.matmul(out, self.matmul_weight)
  87. return out
  88. size = 8
  89. context.set_auto_parallel_context(device_num=size, global_rank=0)
  90. x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32)
  91. net = GradWrap(NetWithLoss(Net()))
  92. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  93. net.set_auto_parallel()
  94. _executor.compile(net, x)
  95. def test_reshape_auto_2():
  96. class Net(nn.Cell):
  97. def __init__(self):
  98. super().__init__()
  99. self.relu = P.ReLU()
  100. self.reshape = P.Reshape()
  101. self.matmul = P.MatMul()
  102. self.add_weight = Parameter(Tensor(np.ones([128, 32]), dtype=ms.float32), name="weight1")
  103. self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
  104. def construct(self, x):
  105. out = self.relu(x)
  106. out = self.reshape(out, (64, 28))
  107. out = self.matmul(out, self.matmul_weight)
  108. out = self.reshape(out, (128, 32))
  109. out = out + self.add_weight
  110. return out
  111. size = 8
  112. context.set_auto_parallel_context(device_num=size, global_rank=0)
  113. x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32)
  114. net = GradWrap(NetWithLoss(Net()))
  115. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  116. net.set_auto_parallel()
  117. _executor.compile(net, x)
  118. def test_reshape_auto_3():
  119. class Net(nn.Cell):
  120. def __init__(self):
  121. super().__init__()
  122. self.relu = P.ReLU()
  123. self.reshape = P.Reshape()
  124. self.matmul = P.MatMul()
  125. self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
  126. def construct(self, x):
  127. out = self.relu(x)
  128. out = self.matmul(out, self.matmul_weight)
  129. out = self.reshape(out, (8, 8, 8, 8))
  130. return out
  131. size = 8
  132. context.set_auto_parallel_context(device_num=size, global_rank=0)
  133. x = Tensor(np.ones([8 * size, 28]), dtype=ms.float32)
  134. net = GradWrap(NetWithLoss(Net()))
  135. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  136. net.set_auto_parallel()
  137. _executor.compile(net, x)
  138. def test_reshape_auto_4():
  139. class Net(nn.Cell):
  140. def __init__(self):
  141. super().__init__()
  142. self.relu = P.ReLU()
  143. self.reshape = P.Reshape()
  144. self.matmul = P.MatMul()
  145. self.matmul_weight = Parameter(Tensor(np.ones([28 * 64]), dtype=ms.float32), name="weight")
  146. def construct(self, x):
  147. out = self.relu(x)
  148. out = self.reshape(out, (64, 28))
  149. w = self.reshape(self.matmul_weight, (28, 64))
  150. out = self.matmul(out, w)
  151. return out
  152. size = 8
  153. context.set_auto_parallel_context(device_num=size, global_rank=0)
  154. x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32)
  155. net = GradWrap(NetWithLoss(Net()))
  156. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  157. net.set_auto_parallel()
  158. _executor.compile(net, x)
  159. def test_reshape_auto_5():
  160. class NetWithLoss5(nn.Cell):
  161. def __init__(self, network):
  162. super(NetWithLoss5, self).__init__()
  163. self.loss = VirtualLoss()
  164. self.network = network
  165. def construct(self, x, y):
  166. predict = self.network(x, y)
  167. return self.loss(predict)
  168. class GradWrap5(nn.Cell):
  169. def __init__(self, network):
  170. super(GradWrap5, self).__init__()
  171. self.network = network
  172. def construct(self, x, y):
  173. return grad_all(self.network)(x, y)
  174. class Net(nn.Cell):
  175. def __init__(self):
  176. super().__init__()
  177. self.relu = P.ReLU()
  178. self.mul = P.Mul()
  179. self.reshape = P.Reshape()
  180. self.reduce_sum = P.ReduceSum()
  181. self.wide_w = Parameter(Tensor(np.ones([4, 1024 * 8, 64]), dtype=ms.float32), name="weight")
  182. def construct(self, x, y):
  183. mask = self.reshape(y, (4, 1024 * 8, 1))
  184. w_id = self.relu(x)
  185. wx = self.mul(w_id, mask)
  186. wide_out = self.reshape(self.reduce_sum(wx, 1), (-1, 1))
  187. deep_id = x + self.wide_w
  188. vx = self.mul(deep_id, mask)
  189. deep_in = self.reshape(vx, (-1, 1024 * 8 * 64))
  190. out = wide_out + deep_in
  191. return out
  192. size = 8
  193. context.set_auto_parallel_context(device_num=size, global_rank=0)
  194. x = Tensor(np.ones([4, 1024 * size, 1]), dtype=ms.float32)
  195. y = Tensor(np.ones([4, 1024 * size,]), dtype=ms.float32)
  196. net = GradWrap5(NetWithLoss5(Net()))
  197. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  198. net.set_auto_parallel()
  199. _executor.compile(net, x, y)
  200. def test_reshape_auto_6():
  201. class NetWithLoss6(nn.Cell):
  202. def __init__(self, network):
  203. super(NetWithLoss6, self).__init__()
  204. self.loss = VirtualLoss()
  205. self.network = network
  206. def construct(self, x, y):
  207. predict = self.network(x, y)
  208. return self.loss(predict)
  209. class GradWrap6(nn.Cell):
  210. def __init__(self, network):
  211. super(GradWrap6, self).__init__()
  212. self.network = network
  213. def construct(self, x, y):
  214. return grad_all(self.network)(x, y)
  215. class Net(nn.Cell):
  216. def __init__(self):
  217. super().__init__()
  218. self.relu = P.ReLU()
  219. self.mul = P.Mul()
  220. self.reshape = P.Reshape()
  221. self.reduce_mean = P.ReduceMean()
  222. self.wide_w = Parameter(Tensor(np.ones([4, 1024, 1]), dtype=ms.float32), name="weight")
  223. def construct(self, x, y):
  224. out1 = x + self.wide_w
  225. w = self.reshape(self.wide_w, (4, 1024))
  226. out1 = self.reduce_mean(out1, 1)
  227. out1 = out1 - w
  228. out2 = self.mul(y, w)
  229. out = out1 + out2
  230. return out
  231. size = 8
  232. context.set_auto_parallel_context(device_num=size, global_rank=0)
  233. x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32)
  234. y = Tensor(np.ones([4, 1024,]), dtype=ms.float32)
  235. net = GradWrap6(NetWithLoss6(Net()))
  236. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  237. net.set_auto_parallel()
  238. _executor.compile(net, x, y)
  239. def test_reshape_auto_7():
  240. class Net(nn.Cell):
  241. def __init__(self):
  242. super().__init__()
  243. self.reshape = P.Reshape()
  244. self.mul = P.Mul().shard(((1, 2, 4), (2, 4)))
  245. self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
  246. def construct(self, x):
  247. weight = self.reshape(self.mul_weight, (1, 128, 96))
  248. out = self.mul(weight, self.mul_weight)
  249. return out
  250. size = 8
  251. context.set_auto_parallel_context(device_num=size, global_rank=0)
  252. x = Tensor(np.ones([128, 28]), dtype=ms.float32)
  253. net = GradWrap(NetWithLoss(Net()))
  254. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  255. net.set_auto_parallel()
  256. _executor.compile(net, x)