|
|
|
@@ -141,7 +141,18 @@ def test_init_abnormal(): |
|
|
|
with py.raises(TypeError): |
|
|
|
init.initializer([''], [5, 4], ms.float32) |
|
|
|
|
|
|
|
|
|
|
|
def test_initializer_reinit(): |
|
|
|
weights = init.initializer("XavierUniform", shape=(10, 1, 10, 10), dtype=ms.float16) |
|
|
|
assert weights.dtype == ms.float16 |
|
|
|
assert weights.shape == (10, 1, 10, 10) |
|
|
|
weights = init.initializer(weights) |
|
|
|
assert weights.dtype == ms.float16 |
|
|
|
assert weights.shape == (10, 1, 10, 10) |
|
|
|
weights.shape = None |
|
|
|
weights = init.initializer(weights, (10, 1)) |
|
|
|
assert weights.dtype == ms.float16 |
|
|
|
assert weights.shape == (10, 1) |
|
|
|
|
|
|
|
def test_init_xavier_uniform(): |
|
|
|
""" test_init_xavier_uniform """ |
|
|
|
gain = 1.2 |
|
|
|
|