From 0f7ead5f1433f4a71266c1ccadb689a6c6e23347 Mon Sep 17 00:00:00 2001 From: Yi Huaijie Date: Tue, 18 Aug 2020 09:46:42 +0800 Subject: [PATCH] parameter slice init test all initializers --- .../python/parallel/test_initializer_weight_slice.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/ut/python/parallel/test_initializer_weight_slice.py b/tests/ut/python/parallel/test_initializer_weight_slice.py index 7f63c5a650..2d78e1d4a0 100644 --- a/tests/ut/python/parallel/test_initializer_weight_slice.py +++ b/tests/ut/python/parallel/test_initializer_weight_slice.py @@ -23,7 +23,7 @@ from mindspore.common.initializer import initializer from hccl_test.manage.api import Hccl -def test_initializer_weight_slice(): +def check_initializer_weight_slice(init_name="Uniform"): class Net(nn.Cell): def __init__(self, strategy1, strategy2, weight): super().__init__() @@ -49,7 +49,7 @@ def test_initializer_weight_slice(): exe = me._executor x = Tensor(np.ones([32, 32]), dtype=ms.float32) - weight = initializer("Uniform", [64, 32], ms.float32) + weight = initializer(init_name, [64, 32], ms.float32) net = Net(strategy1, strategy2, weight) net.set_auto_parallel() exe.compile(net, x, auto_parallel_mode=True, phase='train') @@ -68,8 +68,14 @@ def test_initializer_weight_slice(): assert expect_slice_shape == slice_shape assert all(slice0 == slice4) - assert any(slice0 != slice1) + if init_name not in ["One", "Zero"]: + assert any(slice0 != slice1) + +initializers = ["Uniform", "Normal", "TruncatedNormal", "HeUniform", "HeNormal", "XavierUniform", "One", "Zero"] +def test_initializer_weight_slice(): + for init_name in initializers: + check_initializer_weight_slice(init_name) if __name__ == '__main__': test_initializer_weight_slice()