You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_resizebilinear.py 9.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. '''ResizeBilinear and ResizeNearestNeigbor ut'''
  15. import numpy as np
  16. import pytest
  17. import mindspore as ms
  18. from mindspore import context, Tensor, Parameter
  19. from mindspore.common.api import _cell_graph_executor
  20. from mindspore.nn import Cell, TrainOneStepCell, Momentum
  21. from mindspore.ops import operations as P
  22. class Net(Cell):
  23. '''
  24. create the test Net
  25. '''
  26. def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
  27. strategy1=None, strategy2=None):
  28. super(Net, self).__init__()
  29. self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
  30. pad_mode=pad_mode, stride=stride).shard(strategy1)
  31. self.conv2d_weight = Parameter(conv2d_weight, "w1")
  32. self.resize_bilinear = P.ResizeBilinear((16, 16)).shard(strategy2)
  33. def construct(self, x):
  34. out = self.conv2d(x, self.conv2d_weight)
  35. out = self.resize_bilinear(out)
  36. return out
  37. class Net2(Cell):
  38. '''
  39. create the test Net
  40. '''
  41. def __init__(self, conv2d_weight, mul_weight, out_channel, kernel_size, pad_mode, stride, align_corners=False,
  42. strategy1=None, strategy2=None, out_strategy=None):
  43. super(Net2, self).__init__()
  44. self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
  45. pad_mode=pad_mode, stride=stride).shard(strategy1)
  46. self.conv2d_weight = Parameter(conv2d_weight, "w1")
  47. self.resize_neighbor = P.ResizeNearestNeighbor((16, 16), align_corners).shard(strategy2, out_strategy)
  48. self.mul = P.Mul()
  49. self.mul_weight = Parameter(mul_weight, "w2")
  50. def construct(self, x):
  51. out = self.conv2d(x, self.conv2d_weight)
  52. out = self.resize_neighbor(out)
  53. out = self.mul(out, self.mul_weight)
  54. return out
  55. class Net3(Cell):
  56. '''
  57. create the test Net
  58. '''
  59. def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
  60. strategy1=None):
  61. super(Net3, self).__init__()
  62. self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
  63. pad_mode=pad_mode, stride=stride).shard(strategy1)
  64. self.conv2d_weight = Parameter(conv2d_weight, "w1")
  65. self.resize_bilinear = P.ResizeBilinear((16, 16))
  66. def construct(self, x):
  67. out = self.conv2d(x, self.conv2d_weight)
  68. out = self.resize_bilinear(out)
  69. return out
  70. _x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  71. _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
  72. _w2 = Tensor(np.ones([32, 8, 16, 16]), dtype=ms.float32)
  73. def compile_net(net, inputs=_x):
  74. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  75. train_net = TrainOneStepCell(net, optimizer)
  76. train_net.set_auto_parallel()
  77. train_net.set_train()
  78. _cell_graph_executor.compile(train_net, inputs)
  79. context.reset_auto_parallel_context()
  80. def test_bililear_data_parallel():
  81. """
  82. Feature: test ResizeBilinear data parallel strategy
  83. Description: only shard batch dimension
  84. Expectation: compile success
  85. """
  86. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  87. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  88. strategy2 = ((8, 1, 1, 1),)
  89. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
  90. strategy1=strategy1, strategy2=strategy2)
  91. compile_net(net)
  92. def test_bilinear_model_parallel1():
  93. """
  94. Feature: test ResizeBilinear model parallel strategy
  95. Description: shard N/C
  96. Expectation: compile success
  97. """
  98. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  99. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  100. strategy2 = ((4, 2, 1, 1),)
  101. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
  102. strategy1=strategy1, strategy2=strategy2)
  103. compile_net(net)
  104. def test_bilinear_repeated_calc():
  105. """
  106. Feature: test ResizeBilinear repeated calculation parallel strategy
  107. Description: only shard batch dimension, but shard num smaller than device num
  108. Expectation: compile success
  109. """
  110. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  111. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  112. strategy2 = ((2, 1, 1, 1),)
  113. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
  114. strategy1=strategy1, strategy2=strategy2)
  115. compile_net(net)
  116. def test_bilinear_auto_parallel():
  117. """
  118. Feature: test ResizeBilinear auto parallel
  119. Description:
  120. Expectation: compile success
  121. """
  122. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  123. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
  124. compile_net(net)
  125. def test_bilinear_no_strategy():
  126. """
  127. Feature: test ResizeBilinear semi auto parallel, and has not set strategy for it
  128. Description:
  129. Expectation: compile success
  130. """
  131. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  132. net = Net3(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
  133. compile_net(net)
  134. def test_neighbor_data_parallel():
  135. """
  136. Feature: test ResizeNearestNeighbor data parallel strategy
  137. Description: only shard batch dimension
  138. Expectation: compile success
  139. """
  140. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  141. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  142. strategy2 = ((8, 1, 1, 1),)
  143. net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
  144. strategy1=strategy1, strategy2=strategy2)
  145. compile_net(net)
  146. def test_neighbor_model_parallel_align_corners_shard_HW():
  147. """
  148. Feature: test ResizeNearestNeighbor model parallel strategy
  149. Description: the align_corners is True, and shard N/C/H/W
  150. Expectation: compile failed
  151. """
  152. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
  153. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  154. strategy2 = ((2, 2, 2, 2),)
  155. net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1, align_corners=True,
  156. strategy1=strategy1, strategy2=strategy2)
  157. with pytest.raises(RuntimeError):
  158. compile_net(net)
  159. def test_neighbor_out_strategy():
  160. """
  161. Feature: test ResizeNearestNeighbor to set output parallel strategy
  162. Description:
  163. Expectation: compile failed
  164. """
  165. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
  166. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  167. strategy2 = ((2, 2, 2, 2),)
  168. out_strategy = ((2, 2, 2, 2),)
  169. net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
  170. strategy1=strategy1, strategy2=strategy2, out_strategy=out_strategy)
  171. with pytest.raises(RuntimeError):
  172. compile_net(net)
  173. def test_neighbor_model_parallel_align_corners_shard_NC():
  174. """
  175. Feature: test ResizeNearestNeighbor model parallel strategy
  176. Description: the align_corners is True, and shard N/C
  177. Expectation: compile success
  178. """
  179. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
  180. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  181. strategy2 = ((4, 4, 1, 1),)
  182. net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1, align_corners=True,
  183. strategy1=strategy1, strategy2=strategy2)
  184. compile_net(net)
  185. def test_neighbor_model_parallel_align_corners_is_false():
  186. """
  187. Feature: test ResizeNearestNeighbor model parallel strategy
  188. Description: the align_corners is False, and shard N/C/H/W
  189. Expectation: compile success
  190. """
  191. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
  192. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  193. strategy2 = ((2, 2, 2, 2),)
  194. net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
  195. strategy1=strategy1, strategy2=strategy2)
  196. compile_net(net)
  197. def test_neighbor_auto_parallel():
  198. """
  199. Feature: test ResizeNearestNeighbor auto parallel
  200. Description:
  201. Expectation: compile success
  202. """
  203. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  204. net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
  205. compile_net(net)