| @@ -23,7 +23,7 @@ from mindspore.common.initializer import initializer | |||||
| from hccl_test.manage.api import Hccl | from hccl_test.manage.api import Hccl | ||||
| def test_initializer_weight_slice(): | |||||
| def check_initializer_weight_slice(init_name="Uniform"): | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self, strategy1, strategy2, weight): | def __init__(self, strategy1, strategy2, weight): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -49,7 +49,7 @@ def test_initializer_weight_slice(): | |||||
| exe = me._executor | exe = me._executor | ||||
| x = Tensor(np.ones([32, 32]), dtype=ms.float32) | x = Tensor(np.ones([32, 32]), dtype=ms.float32) | ||||
| weight = initializer("Uniform", [64, 32], ms.float32) | |||||
| weight = initializer(init_name, [64, 32], ms.float32) | |||||
| net = Net(strategy1, strategy2, weight) | net = Net(strategy1, strategy2, weight) | ||||
| net.set_auto_parallel() | net.set_auto_parallel() | ||||
| exe.compile(net, x, auto_parallel_mode=True, phase='train') | exe.compile(net, x, auto_parallel_mode=True, phase='train') | ||||
| @@ -68,8 +68,14 @@ def test_initializer_weight_slice(): | |||||
| assert expect_slice_shape == slice_shape | assert expect_slice_shape == slice_shape | ||||
| assert all(slice0 == slice4) | assert all(slice0 == slice4) | ||||
| assert any(slice0 != slice1) | |||||
| if init_name not in ["One", "Zero"]: | |||||
| assert any(slice0 != slice1) | |||||
| initializers = ["Uniform", "Normal", "TruncatedNormal", "HeUniform", "HeNormal", "XavierUniform", "One", "Zero"] | |||||
| def test_initializer_weight_slice(): | |||||
| for init_name in initializers: | |||||
| check_initializer_weight_slice(init_name) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_initializer_weight_slice() | test_initializer_weight_slice() | ||||