|
|
|
@@ -134,3 +134,19 @@ def test_check_str_by_regular(): |
|
|
|
_check_str_by_regular(str5) |
|
|
|
with pytest.raises(ValueError): |
|
|
|
_check_str_by_regular(str6) |
|
|
|
|
|
|
|
def test_parameter_lazy_init(): |
|
|
|
# Call init_data() without set default_input. |
|
|
|
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1') |
|
|
|
assert not isinstance(para.default_input, Tensor) |
|
|
|
para.init_data() |
|
|
|
assert isinstance(para.default_input, Tensor) |
|
|
|
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3))) |
|
|
|
|
|
|
|
# Call init_data() after default_input is set. |
|
|
|
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2') |
|
|
|
assert not isinstance(para.default_input, Tensor) |
|
|
|
para.default_input = Tensor(np.zeros((1, 2, 3))) |
|
|
|
assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3))) |
|
|
|
para.init_data() # expect no effect. |
|
|
|
assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3))) |