You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reshape_unexpand.py 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import mindspore as ms
  16. import mindspore.nn as nn
  17. from mindspore import Tensor
  18. from mindspore import context
  19. from mindspore.common.api import _executor
  20. from mindspore.common.parameter import Parameter
  21. from mindspore.ops import composite as C
  22. from mindspore.ops import operations as P
  23. from tests.ut.python.ops.test_math_ops import VirtualLoss
  24. grad_all = C.GradOperation(get_all=True)
  25. class NetWithLoss(nn.Cell):
  26. def __init__(self, network):
  27. super(NetWithLoss, self).__init__()
  28. self.loss = VirtualLoss()
  29. self.network = network
  30. def construct(self, x):
  31. predict = self.network(x)
  32. return self.loss(predict)
  33. class GradWrap(nn.Cell):
  34. def __init__(self, network):
  35. super(GradWrap, self).__init__()
  36. self.network = network
  37. def construct(self, x):
  38. return grad_all(self.network)(x)
  39. def test_reshape_unexpand():
  40. class Net(nn.Cell):
  41. def __init__(self):
  42. super().__init__()
  43. self.reshape = P.Reshape()
  44. self.mul = P.Mul().shard(((1, 8), (1, 1, 8)))
  45. self.mul_weight = Parameter(Tensor(np.ones([96, 128]), dtype=ms.float32), name="weight")
  46. def construct(self, x):
  47. weight = self.reshape(self.mul_weight, (1, 128, 96))
  48. out = self.mul(x, weight)
  49. return out
  50. size = 8
  51. context.set_auto_parallel_context(device_num=size, global_rank=0)
  52. x = Tensor(np.ones([128, 96]), dtype=ms.float32)
  53. net = GradWrap(NetWithLoss(Net()))
  54. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  55. net.set_auto_parallel()
  56. net.set_train()
  57. _executor.compile(net, x)
  58. def test_reshape_unexpand_1():
  59. class Net(nn.Cell):
  60. def __init__(self):
  61. super().__init__()
  62. self.reshape = P.Reshape()
  63. self.mul = P.Mul().shard(((1, 1, 8), (1, 8)))
  64. self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
  65. def construct(self, data):
  66. x = self.reshape(self.mul_weight, (1, 128, 96))
  67. out = self.mul(x, self.mul_weight)
  68. return out
  69. size = 8
  70. context.set_auto_parallel_context(device_num=size, global_rank=0)
  71. x = Tensor(np.ones([128, 96]), dtype=ms.float32)
  72. net = GradWrap(NetWithLoss(Net()))
  73. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  74. net.set_auto_parallel()
  75. net.set_train()
  76. _executor.compile(net, x)
  77. def test_reshape_unexpand_2():
  78. class Net(nn.Cell):
  79. def __init__(self):
  80. super().__init__()
  81. self.reshape = P.Reshape()
  82. self.mul = P.Mul().shard(((1, 4, 2), (4, 2)))
  83. self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
  84. def construct(self, data):
  85. x = self.reshape(self.mul_weight, (1, 128, 96))
  86. out = self.mul(x, self.mul_weight)
  87. return out
  88. size = 8
  89. context.set_auto_parallel_context(device_num=size, global_rank=0)
  90. x = Tensor(np.ones([128, 96]), dtype=ms.float32)
  91. net = GradWrap(NetWithLoss(Net()))
  92. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  93. net.set_auto_parallel()
  94. net.set_train()
  95. _executor.compile(net, x)
  96. def test_reshape_unexpand_3():
  97. class Net(nn.Cell):
  98. def __init__(self):
  99. super().__init__()
  100. self.reshape = P.Reshape()
  101. self.relu1 = P.ReLU().shard(((4, 1),))
  102. self.relu2 = P.ReLU().shard(((1, 4),))
  103. def construct(self, data):
  104. x = self.relu1(data)
  105. x = self.reshape(x, (3, 4))
  106. x = self.relu2(x)
  107. return x
  108. size = 4
  109. context.set_auto_parallel_context(device_num=size, global_rank=0)
  110. x = Tensor(np.ones([4, 3]), dtype=ms.float32)
  111. net = GradWrap(NetWithLoss(Net()))
  112. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  113. net.set_auto_parallel()
  114. net.set_train()
  115. _executor.compile(net, x)
  116. def test_reshape_unexpand_4():
  117. class Net(nn.Cell):
  118. def __init__(self):
  119. super().__init__()
  120. self.reshape = P.Reshape()
  121. self.relu1 = P.ReLU().shard(((4, 1),))
  122. self.relu2 = P.ReLU().shard(((1, 2, 2),))
  123. def construct(self, data):
  124. x = self.relu1(data)
  125. x = self.reshape(x, (3, 2, 2))
  126. x = self.relu2(x)
  127. return x
  128. size = 4
  129. context.set_auto_parallel_context(device_num=size, global_rank=0)
  130. x = Tensor(np.ones([4, 3]), dtype=ms.float32)
  131. net = GradWrap(NetWithLoss(Net()))
  132. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  133. net.set_auto_parallel()
  134. net.set_train()
  135. _executor.compile(net, x)
  136. def test_reshape_unexpand_5():
  137. class Net(nn.Cell):
  138. def __init__(self):
  139. super().__init__()
  140. self.reshape = P.Reshape()
  141. self.relu1 = P.ReLU().shard(((2, 2, 1),))
  142. self.relu2 = P.ReLU().shard(((1, 4),))
  143. def construct(self, data):
  144. x = self.relu1(data)
  145. x = self.reshape(x, (3, 4))
  146. x = self.relu2(x)
  147. return x
  148. size = 4
  149. context.set_auto_parallel_context(device_num=size, global_rank=0)
  150. x = Tensor(np.ones([2, 2, 3]), dtype=ms.float32)
  151. net = GradWrap(NetWithLoss(Net()))
  152. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  153. net.set_auto_parallel()
  154. net.set_train()
  155. _executor.compile(net, x)
  156. def test_reshape_unexpand_6():
  157. class Net(nn.Cell):
  158. def __init__(self):
  159. super().__init__()
  160. self.reshape = P.Reshape()
  161. self.relu1 = P.ReLU().shard(((2, 1),))
  162. self.relu2 = P.ReLU().shard(((1, 1, 4),))
  163. def construct(self, data):
  164. x = self.relu1(data)
  165. x = self.reshape(x, (1, 3, 4))
  166. x = self.relu2(x)
  167. return x
  168. size = 4
  169. context.set_auto_parallel_context(device_num=size, global_rank=0)
  170. x = Tensor(np.ones([4, 3]), dtype=ms.float32)
  171. net = GradWrap(NetWithLoss(Net()))
  172. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  173. net.set_auto_parallel()
  174. net.set_train()
  175. _executor.compile(net, x)
  176. def test_reshape_unexpand_7():
  177. class Net(nn.Cell):
  178. def __init__(self, in_channel=3, out_channel=8, axis=1, input_shape=(32, 4, 110, -1),
  179. mul_size=(32, 1, 220, 220)):
  180. super().__init__()
  181. mul_np = np.full(mul_size, 0.5, dtype=np.float32)
  182. self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight")
  183. self.mul = P.Mul()
  184. self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
  185. kernel_size=5, has_bias=True, weight_init='ones',
  186. bias_init='ones', pad_mode='valid')
  187. self.softmax = nn.Softmax(axis=axis)
  188. self.relu = nn.ReLU()
  189. self.reshape = P.Reshape()
  190. self.input_shape = input_shape
  191. def construct(self, inputs):
  192. x = self.conv(inputs)
  193. x = self.softmax(x)
  194. x = self.relu(x)
  195. x = self.mul(x, self.mul_weight)
  196. x = self.reshape(x, self.input_shape)
  197. return x
  198. size = 8
  199. context.set_auto_parallel_context(device_num=size, global_rank=0)
  200. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  201. x = Tensor(np.ones([32, 3, 224, 224]), dtype=ms.float32)
  202. net = GradWrap(NetWithLoss(Net()))
  203. net.set_auto_parallel()
  204. net.set_train()
  205. _executor.compile(net, x)
  206. def test_reshape_unexpand_8():
  207. class Net(nn.Cell):
  208. def __init__(self):
  209. super().__init__()
  210. self.reshape = P.Reshape()
  211. self.mul = P.Mul().shard(((1, 4, 2), (4, 2)))
  212. self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
  213. def construct(self, data):
  214. x = self.reshape(self.mul_weight, (1, 128, 96))
  215. out = self.mul(x, self.mul_weight)
  216. return out
  217. size = 8
  218. context.set_auto_parallel_context(device_num=size, global_rank=0)
  219. x = Tensor(np.ones([128, 96]), dtype=ms.float32)
  220. net = GradWrap(NetWithLoss(Net()))
  221. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  222. net.set_auto_parallel()
  223. net.set_train()
  224. _executor.compile(net, x)