You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Utitly functions to help distribution class."""
  16. import numpy as np
  17. from mindspore import context
  18. from mindspore._checkparam import Validator as validator
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.common.parameter import Parameter
  21. from mindspore.common import dtype as mstype
  22. from mindspore.ops import composite as C
  23. from mindspore.ops import operations as P
  24. from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
  25. import mindspore.nn as nn
  26. def cast_to_tensor(t, hint_type=mstype.float32):
  27. """
  28. Cast an user input value into a Tensor of dtype.
  29. If the input t is of type Parameter, t is directly returned as a Parameter.
  30. Args:
  31. t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor.
  32. dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32.
  33. Raises:
  34. RuntimeError: if t cannot be cast to Tensor.
  35. Returns:
  36. Tensor.
  37. """
  38. if t is None:
  39. raise ValueError(f'Input cannot be None in cast_to_tensor')
  40. if isinstance(t, Parameter):
  41. return t
  42. if isinstance(t, bool):
  43. raise TypeError(f'Input cannot be Type Bool')
  44. if isinstance(t, (Tensor, np.ndarray, list, int, float)):
  45. return Tensor(t, dtype=hint_type)
  46. invalid_type = type(t)
  47. raise TypeError(
  48. f"Unable to convert input of type {invalid_type} to a Tensor of type {hint_type}")
  49. def cast_type_for_device(dtype):
  50. """
  51. use the alternative dtype supported by the device.
  52. Args:
  53. dtype (mindspore.dtype): input dtype.
  54. Returns:
  55. mindspore.dtype.
  56. """
  57. if context.get_context("device_target") == "GPU":
  58. if dtype in mstype.uint_type or dtype == mstype.int8:
  59. return mstype.int16
  60. if dtype == mstype.int64:
  61. return mstype.int32
  62. if dtype == mstype.float64:
  63. return mstype.float32
  64. return dtype
  65. def check_greater_equal_zero(value, name):
  66. """
  67. Check if the given Tensor is greater zero.
  68. Args:
  69. value (Tensor, Parameter): value to be checked.
  70. name (str) : name of the value.
  71. Raises:
  72. ValueError: if the input value is less than zero.
  73. """
  74. if isinstance(value, Parameter):
  75. if not isinstance(value.data, Tensor):
  76. return
  77. value = value.data
  78. comp = np.less(value.asnumpy(), np.zeros(value.shape))
  79. if comp.any():
  80. raise ValueError(f'{name} should be greater than ot equal to zero.')
  81. def check_greater_zero(value, name):
  82. """
  83. Check if the given Tensor is strictly greater than zero.
  84. Args:
  85. value (Tensor, Parameter): value to be checked.
  86. name (str) : name of the value.
  87. Raises:
  88. ValueError: if the input value is less than or equal to zero.
  89. """
  90. if value is None:
  91. raise ValueError(f'input value cannot be None in check_greater_zero')
  92. if isinstance(value, Parameter):
  93. if not isinstance(value.data, Tensor):
  94. return
  95. value = value.data
  96. comp = np.less(np.zeros(value.shape), value.asnumpy())
  97. if not comp.all():
  98. raise ValueError(f'{name} should be greater than zero.')
  99. def check_greater(a, b, name_a, name_b):
  100. """
  101. Check if Tensor b is strictly greater than Tensor a.
  102. Args:
  103. a (Tensor, Parameter): input tensor a.
  104. b (Tensor, Parameter): input tensor b.
  105. name_a (str): name of Tensor_a.
  106. name_b (str): name of Tensor_b.
  107. Raises:
  108. ValueError: if b is less than or equal to a
  109. """
  110. if a is None or b is None:
  111. raise ValueError(f'input value cannot be None in check_greater')
  112. if isinstance(a, Parameter) or isinstance(b, Parameter):
  113. return
  114. comp = np.less(a.asnumpy(), b.asnumpy())
  115. if not comp.all():
  116. raise ValueError(f'{name_a} should be less than {name_b}')
  117. def check_prob(p):
  118. """
  119. Check if p is a proper probability, i.e. 0 < p <1.
  120. Args:
  121. p (Tensor, Parameter): value to be checked.
  122. Raises:
  123. ValueError: if p is not a proper probability.
  124. """
  125. if p is None:
  126. raise ValueError(f'input value cannot be None in check_greater_zero')
  127. if isinstance(p, Parameter):
  128. if not isinstance(p.data, Tensor):
  129. return
  130. p = p.data
  131. comp = np.less(np.zeros(p.shape), p.asnumpy())
  132. if not comp.all():
  133. raise ValueError('Probabilities should be greater than zero')
  134. comp = np.greater(np.ones(p.shape), p.asnumpy())
  135. if not comp.all():
  136. raise ValueError('Probabilities should be less than one')
  137. def check_sum_equal_one(probs):
  138. prob_sum = np.sum(probs.asnumpy(), axis=-1)
  139. comp = np.equal(np.ones(prob_sum.shape), prob_sum)
  140. if not comp.all():
  141. raise ValueError('Probabilities for each category should sum to one for Categorical distribution.')
  142. def check_rank(probs):
  143. """
  144. Used in categorical distribution. check Rank >=1.
  145. """
  146. if probs.asnumpy().ndim == 0:
  147. raise ValueError('probs for Categorical distribution must have rank >= 1.')
  148. def logits_to_probs(logits, is_binary=False):
  149. """
  150. converts logits into probabilities.
  151. Args:
  152. logits (Tensor)
  153. is_binary (bool)
  154. """
  155. if is_binary:
  156. return nn.Sigmoid()(logits)
  157. return nn.Softmax(axis=-1)(logits)
  158. def clamp_probs(probs):
  159. """
  160. clamp probs boundary
  161. Args:
  162. probs (Tensor)
  163. """
  164. eps = P.Eps()(probs)
  165. return C.clip_by_value(probs, eps, 1-eps)
  166. def probs_to_logits(probs, is_binary=False):
  167. """
  168. converts probabilities into logits.
  169. Args:
  170. probs (Tensor)
  171. is_binary (bool)
  172. """
  173. ps_clamped = clamp_probs(probs)
  174. if is_binary:
  175. return P.Log()(ps_clamped) - P.Log()(1-ps_clamped)
  176. return P.Log()(ps_clamped)
  177. @constexpr
  178. def raise_none_error(name):
  179. raise TypeError(f"the type {name} should be subclass of Tensor."
  180. f" It should not be None since it is not specified during initialization.")
  181. @constexpr
  182. def raise_probs_logits_error():
  183. raise TypeError("Either 'probs' or 'logits' must be specified, but not both.")
  184. @constexpr
  185. def raise_broadcast_error(shape_a, shape_b):
  186. raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.")
  187. @constexpr
  188. def raise_not_impl_error(name):
  189. raise ValueError(
  190. f"{name} function should be implemented for non-linear transformation")
  191. @constexpr
  192. def raise_not_implemented_util(func_name, obj, *args, **kwargs):
  193. raise NotImplementedError(
  194. f"{func_name} is not implemented for {obj} distribution.")
  195. @constexpr
  196. def check_distribution_name(name, expected_name):
  197. if name is None:
  198. raise ValueError(
  199. f"Input dist should be a constant which is not None.")
  200. if name != expected_name:
  201. raise ValueError(
  202. f"Expected dist input is {expected_name}, but got {name}.")
  203. class CheckTuple(PrimitiveWithInfer):
  204. """
  205. Check if input is a tuple.
  206. """
  207. @prim_attr_register
  208. def __init__(self):
  209. super(CheckTuple, self).__init__("CheckTuple")
  210. self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
  211. def __infer__(self, x, name):
  212. if not isinstance(x['dtype'], tuple):
  213. raise TypeError(
  214. f"For {name['value']}, Input type should b a tuple.")
  215. out = {'shape': None,
  216. 'dtype': None,
  217. 'value': x["value"]}
  218. return out
  219. def __call__(self, x, name):
  220. if context.get_context("mode") == 0:
  221. return x["value"]
  222. # Pynative mode
  223. if isinstance(x, tuple):
  224. return x
  225. raise TypeError(f"For {name}, input type should be a tuple.")
  226. class CheckTensor(PrimitiveWithInfer):
  227. """
  228. Check if input is a Tensor.
  229. """
  230. @prim_attr_register
  231. def __init__(self):
  232. super(CheckTensor, self).__init__("CheckTensor")
  233. self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
  234. def __infer__(self, x, name):
  235. src_type = x['dtype']
  236. validator.check_subclass(
  237. "input", src_type, [mstype.tensor], name["value"])
  238. out = {'shape': None,
  239. 'dtype': None,
  240. 'value': None}
  241. return out
  242. def __call__(self, x, name):
  243. if isinstance(x, Tensor):
  244. return x
  245. raise TypeError(f"For {name}, input type should be a Tensor or Parameter.")
  246. def set_param_type(args, hint_type):
  247. """
  248. Find the common type among arguments.
  249. Args:
  250. args (dict): dictionary of arguments, {'name':value}.
  251. hint_type (mindspore.dtype): hint type to return.
  252. Raises:
  253. TypeError: if tensors in args are not the same dtype.
  254. """
  255. int_type = mstype.int_type + mstype.uint_type
  256. if hint_type in int_type:
  257. hint_type = mstype.float32
  258. common_dtype = None
  259. for name, arg in args.items():
  260. if hasattr(arg, 'dtype'):
  261. if isinstance(arg, np.ndarray):
  262. cur_dtype = mstype.pytype_to_dtype(arg.dtype)
  263. else:
  264. cur_dtype = arg.dtype
  265. if common_dtype is None:
  266. common_dtype = cur_dtype
  267. elif cur_dtype != common_dtype:
  268. raise TypeError(f"{name} should have the same dtype as other arguments.")
  269. if common_dtype in int_type or common_dtype == mstype.float64:
  270. return mstype.float32
  271. return hint_type if common_dtype is None else common_dtype