You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cell_shard_check.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore as ms
  18. from mindspore import nn, context, Tensor
  19. import mindspore.ops as ops
  20. def set_context():
  21. context.set_context(mode=context.PYNATIVE_MODE)
  22. context.reset_auto_parallel_context()
  23. context.set_auto_parallel_context(device_num=8, parallel_mode="auto_parallel", search_mode="sharding_propagation")
  24. class NetMul(nn.Cell):
  25. def __init__(self):
  26. super().__init__()
  27. self.mul = ops.Mul()
  28. def construct(self, x, y):
  29. return self.mul(x, y)
  30. class NetMatMul(nn.Cell):
  31. def __init__(self):
  32. super().__init__()
  33. self.matmul = ops.MatMul()
  34. def construct(self, x, y):
  35. return self.matmul(x, y)
  36. class Net(nn.Cell):
  37. def __init__(self, in_axes, out_axes):
  38. super().__init__()
  39. self.mul_net = NetMul()
  40. self.matmul_net = NetMatMul()
  41. self.mul_net.shard(in_axes=in_axes, out_axes=out_axes)
  42. def construct(self, x, y):
  43. out1 = self.matmul_net(x, y)
  44. out2 = self.matmul_net(x, y)
  45. return self.mul_net(out1, out2)
  46. def cell_shard_execution(in_axes, out_axes, error_log):
  47. net = Net(in_axes, out_axes)
  48. x = Tensor(np.ones([128, 128]), dtype=ms.float32)
  49. y = Tensor(np.ones([128, 128]), dtype=ms.float32)
  50. with pytest.raises(Exception) as err:
  51. _ = net(x, y)
  52. assert error_log in str(err.value)
  53. def test_in_axes_numbers_check():
  54. """
  55. Feature: shard function for cell
  56. Description: inconsistent input number and in_axes number
  57. Expectation: throw an exception indicating inconsistent input number and in_axes number
  58. """
  59. set_context()
  60. in_axes = ((8, 1), None, (1, 8))
  61. out_axes = (None,)
  62. error_log = "Input numbers: 2 is not equal to in_axes numbers: 3"
  63. cell_shard_execution(in_axes, out_axes, error_log)
  64. def test_out_axes_numbers_check():
  65. """
  66. Feature: shard function for cell
  67. Description: inconsistent output number and out_axes number
  68. Expectation: throw an exception indicating inconsistent output number and out_axes number
  69. """
  70. set_context()
  71. in_axes = ((8, 1), None)
  72. out_axes = (None, (8, 1))
  73. error_log = "Output number: 1 is not equal to out_axes number: 2"
  74. cell_shard_execution(in_axes, out_axes, error_log)
  75. def test_in_axes_dimension_check():
  76. """
  77. Feature: shard function for cell
  78. Description: inconsistent input dimension and in_axes dimension
  79. Expectation: throw an exception indicating inconsistent input_dimension and in_axes dimension
  80. """
  81. set_context()
  82. in_axes = ((8, 1, 1), None)
  83. out_axes = (None, (8, 1))
  84. error_log = "Input dimension: 2 is not equal to in_axes dimension: 3 at index 0"
  85. cell_shard_execution(in_axes, out_axes, error_log)
  86. def test_out_axes_dimension_check():
  87. """
  88. Feature: shard function for cell
  89. Description: inconsistent output dimension and out_axes dimension
  90. Expectation: throw an exception indicating inconsistent output_dimension and out_axes dimension
  91. """
  92. set_context()
  93. in_axes = ((8, 1), None)
  94. out_axes = ((8,),)
  95. error_log = "Output dimension: 2 is not equal to out_axes dimension: 1 at index 0"
  96. cell_shard_execution(in_axes, out_axes, error_log)
  97. def test_in_axes_format_check():
  98. """
  99. Feature: shard function for cell
  100. Description: unsupported in_axes format
  101. Expectation: throw an exception indicating an supported in_axes format
  102. """
  103. set_context()
  104. in_axes = ([8, 1], None)
  105. out_axes = (None,)
  106. error_log = "in_axes should be a two-dimension tuple"
  107. cell_shard_execution(in_axes, out_axes, error_log)