|
- import numpy as np
- import pytest
-
- from mindspore import Tensor, context
- from mindspore.nn import Cell
- import mindspore.ops as ops
-
- from parallel.utils.utils import compile_net
-
- input_start_ = Tensor(np.random.normal(size=[8, 8, 8]).astype(np.float32))
- input_end_ = Tensor(np.random.normal(size=[8]).astype(np.float32))
- input_weight_tensor_ = Tensor(np.random.normal(size=[8, 8]).astype(np.float32))
- input_weight_float_ = 0.5
-
-
- class Net(Cell):
- def __init__(self, strategy=None):
- super(Net, self).__init__()
- self.lerp = ops.Lerp().shard(strategy)
-
- def construct(self, *inputs):
- output = self.lerp(*inputs)
- return output
-
-
- def test_lerp_auto_parallel_with_weight_tensor():
- """
- Feature: test Lerp auto parallel
- Description: auto parallel when 'weight' is tensor
- Expectation: compile success
- """
- context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
- net = Net()
- compile_net(net, input_start_, input_end_, input_weight_tensor_)
-
-
- def test_lerp_auto_parallel_with_weight_float():
- """
- Feature: test Lerp auto parallel
- Description: auto parallel when 'weight' is float
- Expectation: compile success
- """
- context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
- net = Net()
- compile_net(net, input_start_, input_end_, input_weight_float_)
-
-
- def test_lerp_model_parallel_with_weight_tensor():
- """
- Feature: test Lerp model parallel
- Description: model parallel when 'weight' is tensor
- Expectation: compile success
- """
- context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
- strategy = ((2, 2, 2), (2,), (2, 2))
- net = Net(strategy)
- compile_net(net, input_start_, input_end_, input_weight_tensor_)
-
-
- def test_lerp_model_parallel_with_weight_float():
- """
- Feature: test Lerp model parallel
- Description: model parallel when 'weight' is float
- Expectation: compile success
- """
- context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
- strategy = ((2, 2, 2), (2,))
- net = Net(strategy)
- compile_net(net, input_start_, input_end_, input_weight_float_)
-
-
- def test_lerp_model_parallel_repeated_cal_with_weight_tensor():
- """
- Feature: test Lerp model parallel with repeated calculation
- Description: model parallel when 'weight' is tensor
- Expectation: compile success
- """
- context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
- strategy = ((1, 2, 2), (2,), (2, 2))
- net = Net(strategy)
- compile_net(net, input_start_, input_end_, input_weight_tensor_)
-
-
- def test_lerp_model_parallel_repeated_cal_with_weight_float():
- """
- Feature: test Lerp model parallel with repeated calculation
- Description: model parallel when 'weight' is float
- Expectation: compile success
- """
- context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
- strategy = ((1, 2, 2), (2,))
- net = Net(strategy)
- compile_net(net, input_start_, input_end_, input_weight_float_)
-
-
- def test_lerp_data_parallel_with_weight_tensor():
- """
- Feature: test Lerp data parallel
- Description: data parallel when 'weight' is tensor
- Expectation: compile success
- """
- context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
- strategy = ((8, 1, 1), (1,), (1, 1))
- net = Net(strategy)
- compile_net(net, input_start_, input_end_, input_weight_tensor_)
-
-
- def test_lerp_data_parallel_with_weight_float():
- """
- Feature: test Lerp data parallel
- Description: data parallel when 'weight' is float
- Expectation: compile success
- """
- context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
- strategy = ((8, 1, 1), (1,))
- net = Net(strategy)
- compile_net(net, input_start_, input_end_, input_weight_float_)
-
-
- def test_lerp_strategy_error_with_weight_tensor():
- """
- Feature: test invalid strategy for Lerp
- Description: illegal strategy when 'weight' is tensor
- Expectation: raise RuntimeError
- """
- context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
- strategy = ((4, 2, 1), (1,), (1, 2))
- net = Net(strategy)
- with pytest.raises(RuntimeError):
- compile_net(net, input_start_, input_end_, input_weight_tensor_)
-
-
- def test_lerp_strategy_error_with_weight_float():
- """
- Feature: test invalid strategy for Lerp
- Description: illegal strategy when 'weight' is float
- Expectation: raise RuntimeError
- """
- context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
- strategy = ((4, 1, 2), (1,))
- net = Net(strategy)
- with pytest.raises(RuntimeError):
- compile_net(net, input_start_, input_end_, input_weight_float_)
|