Browse Source

!3525 Fix a bug for Parameter

Merge pull request !3525 from hewei/fix_parameter_bug
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
5f7d2ba396
2 changed files with 18 additions and 2 deletions
  1. +2
    -2
      mindspore/common/parameter.py
  2. +16
    -0
      tests/ut/python/nn/test_parameter.py

+ 2
- 2
mindspore/common/parameter.py View File

@@ -210,7 +210,6 @@ class Parameter:

def set_parameter_data(self, data):
"""Set `default_input` of current `Parameter`."""
self.init_mode = None
if isinstance(data, bool):
raise ValueError('Parameter data can not be `bool`')
if isinstance(data, Tensor):
@@ -243,7 +242,8 @@ class Parameter:
set_sliced (bool): True if should set parameter sliced after init the data of initializer.
Default: False.
"""
if self.init_mode is None:
if isinstance(self.default_input, Tensor):
# skip if data already initialized.
return
if layout is not None:
if not isinstance(layout, list):


+ 16
- 0
tests/ut/python/nn/test_parameter.py View File

@@ -134,3 +134,19 @@ def test_check_str_by_regular():
_check_str_by_regular(str5)
with pytest.raises(ValueError):
_check_str_by_regular(str6)

def test_parameter_lazy_init():
# Call init_data() without set default_input.
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1')
assert not isinstance(para.default_input, Tensor)
para.init_data()
assert isinstance(para.default_input, Tensor)
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))

# Call init_data() after default_input is set.
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2')
assert not isinstance(para.default_input, Tensor)
para.default_input = Tensor(np.zeros((1, 2, 3)))
assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))
para.init_data() # expect no effect.
assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))

Loading…
Cancel
Save