Browse Source

support parameter updata with implicit type conversion

tags/v0.7.0-beta
Wei Luning 5 years ago
parent
commit
77dcdd89ec
4 changed files with 32 additions and 4 deletions
  1. +3
    -0
      mindspore/common/dtype.py
  2. +7
    -3
      mindspore/common/parameter.py
  3. +1
    -1
      mindspore/common/tensor.py
  4. +21
    -0
      tests/ut/python/nn/test_parameter.py

+ 3
- 0
mindspore/common/dtype.py View File

@@ -119,6 +119,9 @@ int_type = (int8, int16, int32, int64,)
uint_type = (uint8, uint16, uint32, uint64)
float_type = (float16, float32, float64,)

implicit_conversion_seq = {t: idx for idx, t in enumerate((
bool_, int8, uint8, int16, int32, int64, float16, float32, float64))}

_simple_types = {
list: list_,
tuple: tuple_,


+ 7
- 3
mindspore/common/parameter.py View File

@@ -313,8 +313,9 @@ class Parameter(MetaTensor):
Parameter, the parameter after set data.
"""
def raise_type_error(incoming):
raise TypeError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}"
f", and incoming is {incoming}. Use .set_dtype(xxx) to change the dtype.")
raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. "
f"Current dtype is {self.dtype}, and incoming is {incoming}. "
f"Use .set_dtype(xxx) to change the dtype.")

if not isinstance(data, (MetaTensor, Initializer, int, float)):
raise TypeError(f"Parameter data must be [`Initializer`, `int`, `float`] or a kind of `MetaTensor` "
@@ -338,7 +339,10 @@ class Parameter(MetaTensor):
raise ValueError(f"Can not change the shape of Parameter which has been initialized."
f" Current shape is {self.shape}, and incoming is {data.shape}.")
if self.dtype != data.dtype:
raise_type_error(data.dtype)
if mstype.implicit_conversion_seq[self.dtype] < mstype.implicit_conversion_seq[data.dtype]:
raise_type_error(data.dtype)
else:
data = Tensor(data, self.dtype)
if isinstance(data, Initializer):
# The parameter has been initializered, directly update by the data
if is_current_tensor:


+ 1
- 1
mindspore/common/tensor.py View File

@@ -74,7 +74,7 @@ class Tensor(Tensor_):
self._virtual_flag = False

def __repr__(self):
return str(Tensor_.__str__(self))
return Tensor_.__repr__(self)

def __add__(self, other):
out = tensor_operator_registry.get('__add__')(self, other)


+ 21
- 0
tests/ut/python/nn/test_parameter.py View File

@@ -157,6 +157,7 @@ def test_parameter_compute():


def test_scalar_parameter_update():
# float
fp = Parameter(0.5, 'fp')
fp.default_input = 0.8
assert np.array_equal(fp.default_input.asnumpy(), np.array(0.8, np.float32))
@@ -167,6 +168,26 @@ def test_scalar_parameter_update():
assert np.array_equal(int_.default_input.asnumpy(), np.array(2, np.int32))
with pytest.raises(TypeError):
int_.default_input = 1.2
# Tensor
fp32 = Tensor(0.5, mstype.float32)
int32 = Tensor(2, mstype.int32)
fp16 = Tensor(0.6, mstype.float16)
int16 = Tensor(3, mstype.int16)
bool_ = Tensor(np.array(True, dtype=np.bool_))
# updata_by_tensor
fp32_p = Parameter(fp32, 'fp32')
fp32_p.default_input = 0.8
fp32_p.default_input = 1
fp32_p.default_input = int32
fp32_p.default_input = fp32
fp32_p.default_input = int16
fp32_p.default_input = fp16
fp32_p.default_input = bool_

# updata_by_tensor
fp16_p = Parameter(fp16, 'fp16')
with pytest.raises(TypeError):
fp16_p.default_input = fp32


def test_parameter_lazy_init():


Loading…
Cancel
Save