Browse Source

!1853 Fix initializer

Merge pull request !1853 from amongo/FixInitializer
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
bd34c6ec8b
2 changed files with 21 additions and 7 deletions
  1. +9
    -6
      mindspore/common/initializer.py
  2. +12
    -1
      tests/ut/python/utils/test_initializer.py

+ 9
- 6
mindspore/common/initializer.py View File

@@ -338,12 +338,6 @@ def initializer(init, shape=None, dtype=mstype.float32):
"the variable shape {}.".format(list(init.shape()), shape))
return init

if isinstance(init, str):
init_obj = _INITIALIZER_ALIAS[init.lower()]()
if init_obj is None:
raise ValueError("The class corresponding to '{}' was not found.".format(init))
init = init_obj

if isinstance(shape, list):
shape = tuple(shape)
elif isinstance(shape, numbers.Number):
@@ -354,6 +348,15 @@ def initializer(init, shape=None, dtype=mstype.float32):
raise ValueError("Error shape={}".format(shape))

if isinstance(init, Initializer):
init.shape = init.shape if init.shape is not None else shape
init.dtype = init.dtype if init.dtype is not None else dtype
return init

if isinstance(init, str):
init_obj = _INITIALIZER_ALIAS[init.lower()]()
if init_obj is None:
raise ValueError("The class corresponding to '{}' was not found.".format(init))
init = init_obj
init.shape = shape
init.dtype = dtype
return init


+ 12
- 1
tests/ut/python/utils/test_initializer.py View File

@@ -141,7 +141,18 @@ def test_init_abnormal():
with py.raises(TypeError):
init.initializer([''], [5, 4], ms.float32)


def test_initializer_reinit():
weights = init.initializer("XavierUniform", shape=(10, 1, 10, 10), dtype=ms.float16)
assert weights.dtype == ms.float16
assert weights.shape == (10, 1, 10, 10)
weights = init.initializer(weights)
assert weights.dtype == ms.float16
assert weights.shape == (10, 1, 10, 10)
weights.shape = None
weights = init.initializer(weights, (10, 1))
assert weights.dtype == ms.float16
assert weights.shape == (10, 1)
def test_init_xavier_uniform():
""" test_init_xavier_uniform """
gain = 1.2


Loading…
Cancel
Save