|
|
|
@@ -23,7 +23,7 @@ from mindspore.common.initializer import initializer |
|
|
|
from hccl_test.manage.api import Hccl |
|
|
|
|
|
|
|
|
|
|
|
def test_initializer_weight_slice(): |
|
|
|
def check_initializer_weight_slice(init_name="Uniform"): |
|
|
|
class Net(nn.Cell): |
|
|
|
def __init__(self, strategy1, strategy2, weight): |
|
|
|
super().__init__() |
|
|
|
@@ -49,7 +49,7 @@ def test_initializer_weight_slice(): |
|
|
|
exe = me._executor |
|
|
|
|
|
|
|
x = Tensor(np.ones([32, 32]), dtype=ms.float32) |
|
|
|
weight = initializer("Uniform", [64, 32], ms.float32) |
|
|
|
weight = initializer(init_name, [64, 32], ms.float32) |
|
|
|
net = Net(strategy1, strategy2, weight) |
|
|
|
net.set_auto_parallel() |
|
|
|
exe.compile(net, x, auto_parallel_mode=True, phase='train') |
|
|
|
@@ -68,8 +68,14 @@ def test_initializer_weight_slice(): |
|
|
|
|
|
|
|
assert expect_slice_shape == slice_shape |
|
|
|
assert all(slice0 == slice4) |
|
|
|
assert any(slice0 != slice1) |
|
|
|
if init_name not in ["One", "Zero"]: |
|
|
|
assert any(slice0 != slice1) |
|
|
|
|
|
|
|
initializers = ["Uniform", "Normal", "TruncatedNormal", "HeUniform", "HeNormal", "XavierUniform", "One", "Zero"] |
|
|
|
|
|
|
|
def test_initializer_weight_slice(): |
|
|
|
for init_name in initializers: |
|
|
|
check_initializer_weight_slice(init_name) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_initializer_weight_slice() |