Browse Source

fix bug in checkpoint when save scaler

tags/v0.2.0-alpha
wangnan39@huawei.com 5 years ago
parent
commit
f38d18c665
3 changed files with 62 additions and 4 deletions
  1. +1
    -1
      mindspore/nn/wrap/cell_wrapper.py
  2. +4
    -3
      mindspore/train/serialization.py
  3. +57
    -0
      tests/ut/python/nn/test_parameter.py

+ 1
- 1
mindspore/nn/wrap/cell_wrapper.py View File

@@ -344,5 +344,5 @@ class ParameterUpdate(Cell):
self._param = param self._param = param


def construct(self, x): def construct(self, x):
self._param = x
F.assign(self._param, x)
return x return x

+ 4
- 3
mindspore/train/serialization.py View File

@@ -408,10 +408,11 @@ def _fill_param_into_net(net, parameter_list):
for each_param in parameter_list: for each_param in parameter_list:
param_name = each_param["name"] param_name = each_param["name"]
np_val = each_param["data"].asnumpy() np_val = each_param["data"].asnumpy()
if np_val.shape == (1,): # to scalar
parameter_dict[param_name] = Parameter(np_val[0], name=param_name)
if np_val.shape == (1,):
parameter_dict[param_name] = Parameter(np_val, name=param_name)
elif np_val.shape == (): elif np_val.shape == ():
parameter_dict[param_name] = Parameter(np_val.tolist(), name=param_name)
parameter_dict[param_name] = Parameter(Tensor(np_val.tolist(), mstype.pytype_to_dtype(np_val.dtype)),
name=param_name)
else: else:
parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name) parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name)




+ 57
- 0
tests/ut/python/nn/test_parameter.py View File

@@ -52,12 +52,69 @@ def test_parameter_tuple_illegal():




def test_parameter_init_illegal(): def test_parameter_init_illegal():
import numpy as np
dat = np.array([[1, 2, 3], [2, 3, 4]])
tensor = Tensor(dat)
data_none = None
data_bool = True data_bool = True
data_str = "nicai" data_str = "nicai"
data_int = 3
data_list = [1, "2", True]
data_tuple = (1, 2, 3)

# test data
Parameter(tensor, name=data_str)
Parameter(data_int, name=data_str)
Parameter(dat, name=data_str)
with pytest.raises(ValueError): with pytest.raises(ValueError):
Parameter(data_bool, name=data_str) Parameter(data_bool, name=data_str)


# test name
Parameter(tensor, name=data_none)
with pytest.raises(ValueError):
Parameter(tensor, name=dat)
with pytest.raises(ValueError):
Parameter(tensor, name=tensor)
with pytest.raises(ValueError):
Parameter(tensor, name=data_bool)
with pytest.raises(ValueError):
Parameter(tensor, name=data_int)
with pytest.raises(ValueError):
Parameter(tensor, name=data_list)
with pytest.raises(ValueError):
Parameter(tensor, name=data_tuple)

Parameter(tensor, name=data_str, requires_grad=data_bool)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_none)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=dat)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=tensor)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_str)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_int)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_list)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_tuple)


Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_bool)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=dat)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=tensor)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_none)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_str)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_int)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_list)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_tuple)




def test_check_str_by_regular(): def test_check_str_by_regular():


Loading…
Cancel
Save