You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_splitv.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore as ms
  18. import mindspore.context as context
  19. from mindspore import Tensor, Parameter
  20. import mindspore.nn as nn
  21. from mindspore.ops import operations as P
  22. from parallel.utils.utils import compile_net
  23. input_x_tensor_ = Tensor(np.ones([8, 8, 8]), ms.float32)
  24. input_x_parameter_ = Parameter(Tensor(np.ones([8, 8, 8]), ms.float32), "input_x")
  25. SIZE_SPLIT = [3, 3, 2]
  26. NUM_SPLIT = 3
  27. class Net(nn.Cell):
  28. def __init__(self, size_split, split_dim, num_split, strategy1=None, strategy2=None):
  29. super(Net, self).__init__()
  30. self.splitv = P.SplitV(size_split, split_dim, num_split).shard(strategy1)
  31. self.mul = P.Mul().shard(strategy2)
  32. self.weight = Parameter(np.array([1.0]), name="mul_weight")
  33. def construct(self, x):
  34. out = self.splitv(x)
  35. out = self.mul(out[0], self.weight)
  36. return out
  37. class NetWithParameter(nn.Cell):
  38. def __init__(self, size_split, split_dim, num_split, strategy1=None, strategy2=None):
  39. super(NetWithParameter, self).__init__()
  40. self.splitv = P.SplitV(size_split, split_dim, num_split).shard(strategy1)
  41. self.mul = P.Mul().shard(strategy2)
  42. self.weight = input_x_parameter_
  43. def construct(self, x):
  44. out = self.splitv(self.weight)
  45. out = self.mul(x, out[0])
  46. return out
  47. def test_splitv_auto_parallel():
  48. """
  49. Feature: test SplitV auto parallel
  50. Description: auto parallel
  51. Expectation: compile success
  52. """
  53. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  54. net = Net(SIZE_SPLIT, 0, NUM_SPLIT)
  55. compile_net(net, input_x_tensor_)
  56. def test_splitv_auto_parallel_with_parameter():
  57. """
  58. Feature: test SplitV auto parallel with parameter input
  59. Description: auto parallel with parameter input
  60. Expectation: compile success
  61. """
  62. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  63. net = NetWithParameter(SIZE_SPLIT, 2, NUM_SPLIT)
  64. x = Tensor(np.ones([8, 8, 3]), ms.float32)
  65. compile_net(net, x)
  66. def test_splitv_data_parallel():
  67. """
  68. Feature: test SplitV data parallel
  69. Description: data parallel
  70. Expectation: compile success
  71. """
  72. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  73. net = Net(SIZE_SPLIT, 1, NUM_SPLIT)
  74. compile_net(net, input_x_tensor_)
  75. def test_splitv_model_parallel():
  76. """
  77. Feature: test SplitV model parallel
  78. Description: model parallel
  79. Expectation: compile success
  80. """
  81. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  82. strategy1 = ((1, 2, 2),)
  83. strategy2 = ((1, 2, 2), (1,))
  84. net = Net(SIZE_SPLIT, 0, NUM_SPLIT, strategy1, strategy2)
  85. compile_net(net, input_x_tensor_)
  86. def test_splitv_strategy_error():
  87. """
  88. Feature: test SplitV parallel with invalid strategy
  89. Description: config invalid strategy
  90. Expectation: raise RuntimeError
  91. """
  92. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  93. strategy1 = ((2, 2, 2),)
  94. strategy2 = ((8, 1, 1),)
  95. net = Net(SIZE_SPLIT, 0, NUM_SPLIT, strategy1, strategy2)
  96. with pytest.raises(RuntimeError):
  97. compile_net(net, input_x_tensor_)
  98. context.reset_auto_parallel_context()