You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Utility functions to help distribution class."""
  16. import numpy as np
  17. from mindspore import context
  18. from mindspore._checkparam import Validator as validator
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.common.parameter import Parameter
  21. from mindspore.common import dtype as mstype
  22. from mindspore.ops import composite as C
  23. from mindspore.ops import operations as P
  24. from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
  25. import mindspore.nn as nn
  26. def cast_to_tensor(t, hint_type=mstype.float32):
  27. """
  28. Cast an user input value into a Tensor of dtype.
  29. If the input t is of type Parameter, t is directly returned as a Parameter.
  30. Args:
  31. t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor.
  32. dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32.
  33. Raises:
  34. RuntimeError: if t cannot be cast to Tensor.
  35. Returns:
  36. Tensor.
  37. """
  38. if t is None:
  39. raise ValueError(f'Input cannot be None in cast_to_tensor')
  40. if isinstance(t, Parameter):
  41. return t
  42. if isinstance(t, bool):
  43. raise TypeError(f'Input cannot be Type Bool')
  44. if isinstance(t, (Tensor, np.ndarray, list, int, float)):
  45. return Tensor(t, dtype=hint_type)
  46. invalid_type = type(t)
  47. raise TypeError(
  48. f"Unable to convert input of type {invalid_type} to a Tensor of type {hint_type}")
  49. def cast_type_for_device(dtype):
  50. """
  51. use the alternative dtype supported by the device.
  52. Args:
  53. dtype (mindspore.dtype): input dtype.
  54. Returns:
  55. mindspore.dtype.
  56. """
  57. if context.get_context("device_target") == "GPU":
  58. if dtype in mstype.uint_type or dtype == mstype.int8:
  59. return mstype.int16
  60. if dtype == mstype.int64:
  61. return mstype.int32
  62. if dtype == mstype.float64:
  63. return mstype.float32
  64. return dtype
  65. def check_greater_equal_zero(value, name):
  66. """
  67. Check if the given Tensor is greater zero.
  68. Args:
  69. value (Tensor, Parameter): value to be checked.
  70. name (str) : name of the value.
  71. Raises:
  72. ValueError: if the input value is less than zero.
  73. """
  74. if isinstance(value, Parameter):
  75. if not isinstance(value.data, Tensor):
  76. return
  77. value = value.data
  78. comp = np.less(value.asnumpy(), np.zeros(value.shape))
  79. if comp.any():
  80. raise ValueError(f'{name} should be greater than ot equal to zero.')
  81. def check_greater_zero(value, name):
  82. """
  83. Check if the given Tensor is strictly greater than zero.
  84. Args:
  85. value (Tensor, Parameter): value to be checked.
  86. name (str) : name of the value.
  87. Raises:
  88. ValueError: if the input value is less than or equal to zero.
  89. """
  90. if value is None:
  91. raise ValueError(f'input value cannot be None in check_greater_zero')
  92. if isinstance(value, Parameter):
  93. if not isinstance(value.data, Tensor):
  94. return
  95. value = value.data
  96. comp = np.less(np.zeros(value.shape), value.asnumpy())
  97. if not comp.all():
  98. raise ValueError(f'{name} should be greater than zero.')
  99. def check_greater(a, b, name_a, name_b):
  100. """
  101. Check if Tensor b is strictly greater than Tensor a.
  102. Args:
  103. a (Tensor, Parameter): input tensor a.
  104. b (Tensor, Parameter): input tensor b.
  105. name_a (str): name of Tensor_a.
  106. name_b (str): name of Tensor_b.
  107. Raises:
  108. ValueError: if b is less than or equal to a
  109. """
  110. if a is None or b is None:
  111. raise ValueError(f'input value cannot be None in check_greater')
  112. if isinstance(a, Parameter) or isinstance(b, Parameter):
  113. return
  114. comp = np.less(a.asnumpy(), b.asnumpy())
  115. if not comp.all():
  116. raise ValueError(f'{name_a} should be less than {name_b}')
  117. def check_prob(p):
  118. """
  119. Check if p is a proper probability, i.e. 0 < p <1.
  120. Args:
  121. p (Tensor, Parameter): value to be checked.
  122. Raises:
  123. ValueError: if p is not a proper probability.
  124. """
  125. if p is None:
  126. raise ValueError(f'input value cannot be None in check_greater_zero')
  127. if isinstance(p, Parameter):
  128. if not isinstance(p.data, Tensor):
  129. return
  130. p = p.data
  131. comp = np.less(np.zeros(p.shape), p.asnumpy())
  132. if not comp.all():
  133. raise ValueError('Probabilities should be greater than zero')
  134. comp = np.greater(np.ones(p.shape), p.asnumpy())
  135. if not comp.all():
  136. raise ValueError('Probabilities should be less than one')
  137. def check_sum_equal_one(probs):
  138. """
  139. Used in categorical distribution. check if probabilities of each category sum to 1.
  140. """
  141. if probs is None:
  142. raise ValueError(f'input value cannot be None in check_sum_equal_one')
  143. if isinstance(probs, Parameter):
  144. if not isinstance(probs.data, Tensor):
  145. return
  146. probs = probs.data
  147. if isinstance(probs, Tensor):
  148. probs = probs.asnumpy()
  149. prob_sum = np.sum(probs, axis=-1)
  150. # add a small tolerance here to increase numerical stability
  151. comp = np.allclose(prob_sum, np.ones(prob_sum.shape), rtol=1e-14, atol=1e-14)
  152. if not comp:
  153. raise ValueError('Probabilities for each category should sum to one for Categorical distribution.')
  154. def check_rank(probs):
  155. """
  156. Used in categorical distribution. check Rank >=1.
  157. """
  158. if probs is None:
  159. raise ValueError(f'input value cannot be None in check_rank')
  160. if isinstance(probs, Parameter):
  161. if not isinstance(probs.data, Tensor):
  162. return
  163. probs = probs.data
  164. if probs.asnumpy().ndim == 0:
  165. raise ValueError('probs for Categorical distribution must have rank >= 1.')
  166. def logits_to_probs(logits, is_binary=False):
  167. """
  168. converts logits into probabilities.
  169. Args:
  170. logits (Tensor)
  171. is_binary (bool)
  172. """
  173. if is_binary:
  174. return nn.Sigmoid()(logits)
  175. return nn.Softmax(axis=-1)(logits)
  176. def clamp_probs(probs):
  177. """
  178. clamp probs boundary
  179. Args:
  180. probs (Tensor)
  181. """
  182. eps = P.Eps()(probs)
  183. return C.clip_by_value(probs, eps, 1-eps)
  184. def probs_to_logits(probs, is_binary=False):
  185. """
  186. converts probabilities into logits.
  187. Args:
  188. probs (Tensor)
  189. is_binary (bool)
  190. """
  191. ps_clamped = clamp_probs(probs)
  192. if is_binary:
  193. return P.Log()(ps_clamped) - P.Log()(1-ps_clamped)
  194. return P.Log()(ps_clamped)
  195. @constexpr
  196. def raise_none_error(name):
  197. raise TypeError(f"the type {name} should be subclass of Tensor."
  198. f" It should not be None since it is not specified during initialization.")
  199. @constexpr
  200. def raise_probs_logits_error():
  201. raise TypeError("Either 'probs' or 'logits' must be specified, but not both.")
  202. @constexpr
  203. def raise_broadcast_error(shape_a, shape_b):
  204. raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.")
  205. @constexpr
  206. def raise_not_impl_error(name):
  207. raise ValueError(
  208. f"{name} function should be implemented for non-linear transformation")
  209. @constexpr
  210. def raise_not_implemented_util(func_name, obj, *args, **kwargs):
  211. raise NotImplementedError(
  212. f"{func_name} is not implemented for {obj} distribution.")
  213. @constexpr
  214. def raise_type_error(name, cur_type, required_type):
  215. raise TypeError(
  216. f"For {name} , the type should be or be subclass of {required_type}, but got {cur_type}")
  217. @constexpr
  218. def raise_not_defined(func_name, obj, *args, **kwargs):
  219. raise ValueError(
  220. f"{func_name} is undefined for {obj} distribution.")
  221. @constexpr
  222. def check_distribution_name(name, expected_name):
  223. if name is None:
  224. raise ValueError(
  225. f"Input dist should be a constant which is not None.")
  226. if name != expected_name:
  227. raise ValueError(
  228. f"Expected dist input is {expected_name}, but got {name}.")
  229. class CheckTuple(PrimitiveWithInfer):
  230. """
  231. Check if input is a tuple.
  232. """
  233. @prim_attr_register
  234. def __init__(self):
  235. super(CheckTuple, self).__init__("CheckTuple")
  236. self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
  237. def __infer__(self, x, name):
  238. if not isinstance(x['dtype'], tuple):
  239. raise TypeError(
  240. f"For {name['value']}, Input type should b a tuple.")
  241. out = {'shape': None,
  242. 'dtype': None,
  243. 'value': x["value"]}
  244. return out
  245. def __call__(self, x, name):
  246. # The op is not used in a cell
  247. if isinstance(x, tuple):
  248. return x
  249. if context.get_context("mode") == 0:
  250. return x["value"]
  251. raise TypeError(f"For {name}, input type should be a tuple.")
  252. class CheckTensor(PrimitiveWithInfer):
  253. """
  254. Check if input is a Tensor.
  255. """
  256. @prim_attr_register
  257. def __init__(self):
  258. super(CheckTensor, self).__init__("CheckTensor")
  259. self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
  260. def __infer__(self, x, name):
  261. src_type = x['dtype']
  262. validator.check_subclass(
  263. "input", src_type, [mstype.tensor], name["value"])
  264. out = {'shape': None,
  265. 'dtype': None,
  266. 'value': None}
  267. return out
  268. def __call__(self, x, name):
  269. if isinstance(x, Tensor):
  270. return x
  271. raise TypeError(f"For {name}, input type should be a Tensor or Parameter.")
  272. def set_param_type(args, hint_type):
  273. """
  274. Find the common type among arguments.
  275. Args:
  276. args (dict): dictionary of arguments, {'name':value}.
  277. hint_type (mindspore.dtype): hint type to return.
  278. Raises:
  279. TypeError: if tensors in args are not the same dtype.
  280. """
  281. int_type = mstype.int_type + mstype.uint_type
  282. if hint_type in int_type or hint_type is None:
  283. hint_type = mstype.float32
  284. common_dtype = None
  285. for name, arg in args.items():
  286. if hasattr(arg, 'dtype'):
  287. if isinstance(arg, np.ndarray):
  288. cur_dtype = mstype.pytype_to_dtype(arg.dtype)
  289. else:
  290. cur_dtype = arg.dtype
  291. if common_dtype is None:
  292. common_dtype = cur_dtype
  293. elif cur_dtype != common_dtype:
  294. raise TypeError(f"{name} should have the same dtype as other arguments.")
  295. if common_dtype in int_type or common_dtype == mstype.float64:
  296. return mstype.float32
  297. return hint_type if common_dtype is None else common_dtype