You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_lerp.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import numpy as np
  2. import pytest
  3. from mindspore import Tensor, context
  4. from mindspore.nn import Cell
  5. import mindspore.ops as ops
  6. from parallel.utils.utils import compile_net
  7. input_start_ = Tensor(np.random.normal(size=[8, 8, 8]).astype(np.float32))
  8. input_end_ = Tensor(np.random.normal(size=[8]).astype(np.float32))
  9. input_weight_tensor_ = Tensor(np.random.normal(size=[8, 8]).astype(np.float32))
  10. input_weight_float_ = 0.5
  11. class Net(Cell):
  12. def __init__(self, strategy=None):
  13. super(Net, self).__init__()
  14. self.lerp = ops.Lerp().shard(strategy)
  15. def construct(self, *inputs):
  16. output = self.lerp(*inputs)
  17. return output
  18. def test_lerp_auto_parallel_with_weight_tensor():
  19. """
  20. Feature: test Lerp auto parallel
  21. Description: auto parallel when 'weight' is tensor
  22. Expectation: compile success
  23. """
  24. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  25. net = Net()
  26. compile_net(net, input_start_, input_end_, input_weight_tensor_)
  27. def test_lerp_auto_parallel_with_weight_float():
  28. """
  29. Feature: test Lerp auto parallel
  30. Description: auto parallel when 'weight' is float
  31. Expectation: compile success
  32. """
  33. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  34. net = Net()
  35. compile_net(net, input_start_, input_end_, input_weight_float_)
  36. def test_lerp_model_parallel_with_weight_tensor():
  37. """
  38. Feature: test Lerp model parallel
  39. Description: model parallel when 'weight' is tensor
  40. Expectation: compile success
  41. """
  42. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  43. strategy = ((2, 2, 2), (2,), (2, 2))
  44. net = Net(strategy)
  45. compile_net(net, input_start_, input_end_, input_weight_tensor_)
  46. def test_lerp_model_parallel_with_weight_float():
  47. """
  48. Feature: test Lerp model parallel
  49. Description: model parallel when 'weight' is float
  50. Expectation: compile success
  51. """
  52. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  53. strategy = ((2, 2, 2), (2,))
  54. net = Net(strategy)
  55. compile_net(net, input_start_, input_end_, input_weight_float_)
  56. def test_lerp_model_parallel_repeated_cal_with_weight_tensor():
  57. """
  58. Feature: test Lerp model parallel with repeated calculation
  59. Description: model parallel when 'weight' is tensor
  60. Expectation: compile success
  61. """
  62. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  63. strategy = ((1, 2, 2), (2,), (2, 2))
  64. net = Net(strategy)
  65. compile_net(net, input_start_, input_end_, input_weight_tensor_)
  66. def test_lerp_model_parallel_repeated_cal_with_weight_float():
  67. """
  68. Feature: test Lerp model parallel with repeated calculation
  69. Description: model parallel when 'weight' is float
  70. Expectation: compile success
  71. """
  72. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  73. strategy = ((1, 2, 2), (2,))
  74. net = Net(strategy)
  75. compile_net(net, input_start_, input_end_, input_weight_float_)
  76. def test_lerp_data_parallel_with_weight_tensor():
  77. """
  78. Feature: test Lerp data parallel
  79. Description: data parallel when 'weight' is tensor
  80. Expectation: compile success
  81. """
  82. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  83. strategy = ((8, 1, 1), (1,), (1, 1))
  84. net = Net(strategy)
  85. compile_net(net, input_start_, input_end_, input_weight_tensor_)
  86. def test_lerp_data_parallel_with_weight_float():
  87. """
  88. Feature: test Lerp data parallel
  89. Description: data parallel when 'weight' is float
  90. Expectation: compile success
  91. """
  92. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  93. strategy = ((8, 1, 1), (1,))
  94. net = Net(strategy)
  95. compile_net(net, input_start_, input_end_, input_weight_float_)
  96. def test_lerp_strategy_error_with_weight_tensor():
  97. """
  98. Feature: test invalid strategy for Lerp
  99. Description: illegal strategy when 'weight' is tensor
  100. Expectation: raise RuntimeError
  101. """
  102. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  103. strategy = ((4, 2, 1), (1,), (1, 2))
  104. net = Net(strategy)
  105. with pytest.raises(RuntimeError):
  106. compile_net(net, input_start_, input_end_, input_weight_tensor_)
  107. def test_lerp_strategy_error_with_weight_float():
  108. """
  109. Feature: test invalid strategy for Lerp
  110. Description: illegal strategy when 'weight' is float
  111. Expectation: raise RuntimeError
  112. """
  113. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  114. strategy = ((4, 1, 2), (1,))
  115. net = Net(strategy)
  116. with pytest.raises(RuntimeError):
  117. compile_net(net, input_start_, input_end_, input_weight_float_)