Browse Source

!54 Interface change :`_Constant` to `Constant`

Merge pull request !54 from ghzl/initializer-change-_Constant-to-Constant
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
9e529f9dd0
2 changed files with 22 additions and 8 deletions
  1. +15
    -6
      mindspore/common/initializer.py
  2. +7
    -2
      tests/ut/python/utils/test_initializer.py

+ 15
- 6
mindspore/common/initializer.py View File

@@ -180,18 +180,18 @@ class HeUniform(Initializer):
_assignment(arr, data) _assignment(arr, data)




class _Constant(Initializer):
class Constant(Initializer):
""" """
Initialize a constant. Initialize a constant.


Args: Args:
value (int or numpy.ndarray): The value to initialize.
value (Union[int, numpy.ndarray]): The value to initialize.


Returns: Returns:
Array, initialize array. Array, initialize array.
""" """
def __init__(self, value): def __init__(self, value):
super(_Constant, self).__init__(value=value)
super(Constant, self).__init__(value=value)
self.value = value self.value = value


def _initialize(self, arr): def _initialize(self, arr):
@@ -266,8 +266,16 @@ def initializer(init, shape=None, dtype=mstype.float32):


Args: Args:
init (Union[Tensor, str, Initializer, numbers.Number]): Initialize value. init (Union[Tensor, str, Initializer, numbers.Number]): Initialize value.

- `str`: The `init` should be the alias of the class inheriting from `Initializer` and the corresponding
class will be called.

- `Initializer`: The `init` should be the class inheriting from `Initializer` to initialize tensor.

- `numbers.Number`: The `Constant` will be called to initialize tensor.

shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
output. Default: None.
output. Default: None.
dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mstype.float32. dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mstype.float32.


Returns: Returns:
@@ -295,7 +303,7 @@ def initializer(init, shape=None, dtype=mstype.float32):
raise ValueError(msg) raise ValueError(msg)


if isinstance(init, numbers.Number): if isinstance(init, numbers.Number):
init_obj = _Constant(init)
init_obj = Constant(init)
elif isinstance(init, str): elif isinstance(init, str):
init_obj = _INITIALIZER_ALIAS[init.lower()]() init_obj = _INITIALIZER_ALIAS[init.lower()]()
else: else:
@@ -314,4 +322,5 @@ __all__ = [
'HeUniform', 'HeUniform',
'XavierUniform', 'XavierUniform',
'One', 'One',
'Zero']
'Zero',
'Constant']

+ 7
- 2
tests/ut/python/utils/test_initializer.py View File

@@ -37,8 +37,8 @@ def _check_value(tensor, value_min, value_max):
for ele in nd.flatten(): for ele in nd.flatten():
if value_min <= ele <= value_max: if value_min <= ele <= value_max:
continue continue
raise TypeError('value_min = %d, ele = %d, value_max = %d'
% (value_min, ele, value_max))
raise ValueError('value_min = %d, ele = %d, value_max = %d'
% (value_min, ele, value_max))




def _check_uniform(tensor, boundary_a, boundary_b): def _check_uniform(tensor, boundary_a, boundary_b):
@@ -92,6 +92,11 @@ def test_init_one_alias():
_check_value(tensor, 1, 1) _check_value(tensor, 1, 1)




def test_init_constant():
tensor = init.initializer(init.Constant(1), [2, 2], ms.float32)
_check_value(tensor, 1, 1)


def test_init_uniform(): def test_init_uniform():
scale = 10 scale = 10
tensor = init.initializer(init.Uniform(scale=scale), [5, 4], ms.float32) tensor = init.initializer(init.Uniform(scale=scale), [5, 4], ms.float32)


Loading…
Cancel
Save