Browse Source

!4643 add test cases for parameter slice init using all kind of initializers

Merge pull request !4643 from yihuaijie/dev
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
025e6f23f7
1 changed files with 9 additions and 3 deletions
  1. +9
    -3
      tests/ut/python/parallel/test_initializer_weight_slice.py

+ 9
- 3
tests/ut/python/parallel/test_initializer_weight_slice.py View File

@@ -23,7 +23,7 @@ from mindspore.common.initializer import initializer
from hccl_test.manage.api import Hccl


def test_initializer_weight_slice():
def check_initializer_weight_slice(init_name="Uniform"):
class Net(nn.Cell):
def __init__(self, strategy1, strategy2, weight):
super().__init__()
@@ -49,7 +49,7 @@ def test_initializer_weight_slice():
exe = me._executor

x = Tensor(np.ones([32, 32]), dtype=ms.float32)
weight = initializer("Uniform", [64, 32], ms.float32)
weight = initializer(init_name, [64, 32], ms.float32)
net = Net(strategy1, strategy2, weight)
net.set_auto_parallel()
exe.compile(net, x, auto_parallel_mode=True, phase='train')
@@ -68,8 +68,14 @@ def test_initializer_weight_slice():

assert expect_slice_shape == slice_shape
assert all(slice0 == slice4)
assert any(slice0 != slice1)
if init_name not in ["One", "Zero"]:
assert any(slice0 != slice1)

initializers = ["Uniform", "Normal", "TruncatedNormal", "HeUniform", "HeNormal", "XavierUniform", "One", "Zero"]

def test_initializer_weight_slice():
for init_name in initializers:
check_initializer_weight_slice(init_name)

if __name__ == '__main__':
test_initializer_weight_slice()

Loading…
Cancel
Save