You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parameter.py 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Parameter for cell."""
  16. from copy import copy, deepcopy
  17. from .initializer import initializer
  18. from .tensor import Tensor
  19. from .._checkparam import _check_str_by_regular
  20. from ..parallel._utils import _set_clone_info, _CloneInfo
  21. __all__ = ['Parameter', 'ParameterTuple']
  22. PARAMETER_NAME_DEFAULT = "Parameter"
  23. PARAMETER_NAME_PREFIX_MAX_LEN = 1024
  24. def _check_type(x):
  25. """Check input data type"""
  26. if not isinstance(x, Parameter):
  27. raise ValueError("Should be `Parameter` collection.")
  28. return True
  29. class Parameter:
  30. """
  31. Parameter types of cell models.
  32. Note:
  33. Each parameter of Cell is represented by Parameter class.
  34. Args:
  35. default_input (Tensor): A parameter tensor.
  36. name (str): Name of the child parameter.
  37. requires_grad (bool): True if the parameter requires gradient. Default: True.
  38. layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
  39. broadcast and gradients communication would not be applied on parameters. Default: False.
  40. """
  41. def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False):
  42. self.set_parameter_data(default_input)
  43. self.name = name
  44. self.requires_grad = requires_grad
  45. self.layerwise_parallel = layerwise_parallel
  46. self._is_init = False
  47. self.clone_info = _CloneInfo()
  48. def __repr__(self):
  49. format_str = 'Parameter (name={name})'
  50. return format_str.format(name=self._name)
  51. def __parameter__(self):
  52. """For parse check."""
  53. @property
  54. def name(self):
  55. """Get the name of the parameter."""
  56. return self._name
  57. @name.setter
  58. def name(self, name_):
  59. """
  60. Define a name for the parameter.
  61. Args:
  62. name_ (`str` or `None`): The name of the parameter. When the parameter is None or an empty string,
  63. the default value `PARAMETER_NAME_DEFAULT` is used.
  64. """
  65. if name_ is None:
  66. name_ = PARAMETER_NAME_DEFAULT
  67. elif isinstance(name_, str):
  68. name_ = name_.strip()
  69. if name_ == '':
  70. name_ = PARAMETER_NAME_DEFAULT
  71. if len(name_) > PARAMETER_NAME_PREFIX_MAX_LEN:
  72. raise ValueError("The length of the '{}' name should be less than {}.".
  73. format(name_, PARAMETER_NAME_PREFIX_MAX_LEN))
  74. else:
  75. raise ValueError("The type of the name should be `str` or `None`.")
  76. self._name = name_
  77. @property
  78. def is_init(self):
  79. """Get init status of the parameter."""
  80. return self._is_init
  81. @is_init.setter
  82. def is_init(self, is_init_):
  83. """
  84. Set init status of the parameter.
  85. Args:
  86. is_init_ (bool): The init status of the parameter.
  87. """
  88. self._is_init = is_init_
  89. def clone(self, prefix, init='same'):
  90. """
  91. Clone the parameter.
  92. Args:
  93. prefix (str): Namespace of parameter.
  94. init (Union[Tensor, str, Initializer, numbers.Number]): Initialize the shape of the parameter.
  95. Default: 'same'.
  96. Returns:
  97. Parameter, a new parameter.
  98. """
  99. _check_str_by_regular(prefix)
  100. x = copy(self)
  101. x.name = prefix + '.' + x.name
  102. x.is_init = False
  103. if init != 'same':
  104. shape = self.default_input.shape()
  105. dtype = self.default_input.dtype()
  106. x.default_input = initializer(init, shape=shape, dtype=dtype)
  107. x.clone_info = copy(self.clone_info)
  108. _set_clone_info(self.clone_info, x.clone_info)
  109. return x
  110. @property
  111. def layerwise_parallel(self):
  112. return self._layerwise_parallel
  113. @layerwise_parallel.setter
  114. def layerwise_parallel(self, value=True):
  115. if not isinstance(value, bool):
  116. raise TypeError("`layerwise_parallel` parameter must be bool type")
  117. self._layerwise_parallel = value
  118. @property
  119. def requires_grad(self):
  120. """Return whether the parameter requires gradient."""
  121. return self._requires_grad
  122. @requires_grad.setter
  123. def requires_grad(self, value=True):
  124. if not isinstance(value, bool):
  125. raise TypeError("`requires_grad` parameter must be bool type")
  126. self._requires_grad = value
  127. @property
  128. def data(self):
  129. return self.default_input
  130. def __add__(self, other):
  131. res = deepcopy(self)
  132. res.default_input = res.default_input + other
  133. return res
  134. def __sub__(self, other):
  135. res = deepcopy(self)
  136. res.default_input = res.default_input - other
  137. return res
  138. def __mul__(self, other):
  139. res = deepcopy(self)
  140. res.default_input = res.default_input * other
  141. return res
  142. def __truediv__(self, other):
  143. res = deepcopy(self)
  144. res.default_input = res.default_input / other
  145. return res
  146. def set_parameter_data(self, data):
  147. """Set `default_input` of current `Parameter`."""
  148. if isinstance(data, bool):
  149. raise ValueError('Parameter data can not be `bool`')
  150. if isinstance(data, Tensor):
  151. # make a copy of Tensor to init the parameter
  152. data = Tensor(data.asnumpy().copy())
  153. else:
  154. data = Tensor(data)
  155. self.default_input = data
  156. class ParameterTuple(tuple):
  157. """
  158. Class for storing tuple of parameters.
  159. Note:
  160. Used to store the parameters of the network into the parameter tuple collection.
  161. """
  162. def __new__(cls, iterable):
  163. """Create instance object of ParameterTuple."""
  164. g = (x for x in iterable if _check_type(x))
  165. return tuple.__new__(ParameterTuple, g)
  166. def clone(self, prefix, init='same'):
  167. """
  168. Clone the parameter.
  169. Args:
  170. prefix (str): Namespace of parameter.
  171. init (str): Initialize the shape of the parameter. Default: 'same'.
  172. Returns:
  173. Tuple, the new Parameter tuple.
  174. """
  175. _check_str_by_regular(prefix)
  176. new = []
  177. for x in self:
  178. x1 = x.clone(prefix, init)
  179. new.append(x1)
  180. return ParameterTuple(new)
  181. def __parameter_tuple__(self):
  182. """For parse check."""