|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322 |
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Utitly functions to help distribution class."""
- import numpy as np
- from mindspore import context
- from mindspore._checkparam import Validator as validator
- from mindspore.common.tensor import Tensor
- from mindspore.common.parameter import Parameter
- from mindspore.common import dtype as mstype
- from mindspore.ops import composite as C
- from mindspore.ops import operations as P
- from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
- import mindspore.nn as nn
-
- def cast_to_tensor(t, hint_type=mstype.float32):
- """
- Cast an user input value into a Tensor of dtype.
- If the input t is of type Parameter, t is directly returned as a Parameter.
-
- Args:
- t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor.
- dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32.
-
- Raises:
- RuntimeError: if t cannot be cast to Tensor.
-
- Returns:
- Tensor.
- """
- if t is None:
- raise ValueError(f'Input cannot be None in cast_to_tensor')
- if isinstance(t, Parameter):
- return t
- if isinstance(t, bool):
- raise TypeError(f'Input cannot be Type Bool')
- if isinstance(t, (Tensor, np.ndarray, list, int, float)):
- return Tensor(t, dtype=hint_type)
- invalid_type = type(t)
- raise TypeError(
- f"Unable to convert input of type {invalid_type} to a Tensor of type {hint_type}")
-
- def cast_type_for_device(dtype):
- """
- use the alternative dtype supported by the device.
- Args:
- dtype (mindspore.dtype): input dtype.
- Returns:
- mindspore.dtype.
- """
- if context.get_context("device_target") == "GPU":
- if dtype in mstype.uint_type or dtype == mstype.int8:
- return mstype.int16
- if dtype == mstype.int64:
- return mstype.int32
- if dtype == mstype.float64:
- return mstype.float32
- return dtype
-
-
- def check_greater_equal_zero(value, name):
- """
- Check if the given Tensor is greater zero.
-
- Args:
- value (Tensor, Parameter): value to be checked.
- name (str) : name of the value.
-
- Raises:
- ValueError: if the input value is less than zero.
-
- """
- if isinstance(value, Parameter):
- if not isinstance(value.data, Tensor):
- return
- value = value.data
- comp = np.less(value.asnumpy(), np.zeros(value.shape))
- if comp.any():
- raise ValueError(f'{name} should be greater than ot equal to zero.')
-
-
- def check_greater_zero(value, name):
- """
- Check if the given Tensor is strictly greater than zero.
-
- Args:
- value (Tensor, Parameter): value to be checked.
- name (str) : name of the value.
-
- Raises:
- ValueError: if the input value is less than or equal to zero.
-
- """
- if value is None:
- raise ValueError(f'input value cannot be None in check_greater_zero')
- if isinstance(value, Parameter):
- if not isinstance(value.data, Tensor):
- return
- value = value.data
- comp = np.less(np.zeros(value.shape), value.asnumpy())
- if not comp.all():
- raise ValueError(f'{name} should be greater than zero.')
-
-
- def check_greater(a, b, name_a, name_b):
- """
- Check if Tensor b is strictly greater than Tensor a.
-
- Args:
- a (Tensor, Parameter): input tensor a.
- b (Tensor, Parameter): input tensor b.
- name_a (str): name of Tensor_a.
- name_b (str): name of Tensor_b.
-
- Raises:
- ValueError: if b is less than or equal to a
- """
- if a is None or b is None:
- raise ValueError(f'input value cannot be None in check_greater')
- if isinstance(a, Parameter) or isinstance(b, Parameter):
- return
- comp = np.less(a.asnumpy(), b.asnumpy())
- if not comp.all():
- raise ValueError(f'{name_a} should be less than {name_b}')
-
-
- def check_prob(p):
- """
- Check if p is a proper probability, i.e. 0 < p <1.
-
- Args:
- p (Tensor, Parameter): value to be checked.
-
- Raises:
- ValueError: if p is not a proper probability.
- """
- if p is None:
- raise ValueError(f'input value cannot be None in check_greater_zero')
- if isinstance(p, Parameter):
- if not isinstance(p.data, Tensor):
- return
- p = p.data
- comp = np.less(np.zeros(p.shape), p.asnumpy())
- if not comp.all():
- raise ValueError('Probabilities should be greater than zero')
- comp = np.greater(np.ones(p.shape), p.asnumpy())
- if not comp.all():
- raise ValueError('Probabilities should be less than one')
-
- def check_sum_equal_one(probs):
- prob_sum = np.sum(probs.asnumpy(), axis=-1)
- comp = np.equal(np.ones(prob_sum.shape), prob_sum)
- if not comp.all():
- raise ValueError('Probabilities for each category should sum to one for Categorical distribution.')
-
- def check_rank(probs):
- """
- Used in categorical distribution. check Rank >=1.
- """
- if probs.asnumpy().ndim == 0:
- raise ValueError('probs for Categorical distribution must have rank >= 1.')
-
- def logits_to_probs(logits, is_binary=False):
- """
- converts logits into probabilities.
- Args:
- logits (Tensor)
- is_binary (bool)
- """
- if is_binary:
- return nn.Sigmoid()(logits)
- return nn.Softmax(axis=-1)(logits)
-
-
- def clamp_probs(probs):
- """
- clamp probs boundary
- Args:
- probs (Tensor)
- """
- eps = P.Eps()(probs)
- return C.clip_by_value(probs, eps, 1-eps)
-
-
- def probs_to_logits(probs, is_binary=False):
- """
- converts probabilities into logits.
- Args:
- probs (Tensor)
- is_binary (bool)
- """
- ps_clamped = clamp_probs(probs)
- if is_binary:
- return P.Log()(ps_clamped) - P.Log()(1-ps_clamped)
- return P.Log()(ps_clamped)
-
-
- @constexpr
- def raise_none_error(name):
- raise TypeError(f"the type {name} should be subclass of Tensor."
- f" It should not be None since it is not specified during initialization.")
-
- @constexpr
- def raise_probs_logits_error():
- raise TypeError("Either 'probs' or 'logits' must be specified, but not both.")
-
- @constexpr
- def raise_broadcast_error(shape_a, shape_b):
- raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.")
-
- @constexpr
- def raise_not_impl_error(name):
- raise ValueError(
- f"{name} function should be implemented for non-linear transformation")
-
- @constexpr
- def raise_not_implemented_util(func_name, obj, *args, **kwargs):
- raise NotImplementedError(
- f"{func_name} is not implemented for {obj} distribution.")
-
-
- @constexpr
- def check_distribution_name(name, expected_name):
- if name is None:
- raise ValueError(
- f"Input dist should be a constant which is not None.")
- if name != expected_name:
- raise ValueError(
- f"Expected dist input is {expected_name}, but got {name}.")
-
-
- class CheckTuple(PrimitiveWithInfer):
- """
- Check if input is a tuple.
- """
- @prim_attr_register
- def __init__(self):
- super(CheckTuple, self).__init__("CheckTuple")
- self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
-
- def __infer__(self, x, name):
- if not isinstance(x['dtype'], tuple):
- raise TypeError(
- f"For {name['value']}, Input type should b a tuple.")
-
- out = {'shape': None,
- 'dtype': None,
- 'value': x["value"]}
- return out
-
- def __call__(self, x, name):
- if context.get_context("mode") == 0:
- return x["value"]
- # Pynative mode
- if isinstance(x, tuple):
- return x
- raise TypeError(f"For {name}, input type should be a tuple.")
-
-
- class CheckTensor(PrimitiveWithInfer):
- """
- Check if input is a Tensor.
- """
- @prim_attr_register
- def __init__(self):
- super(CheckTensor, self).__init__("CheckTensor")
- self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
-
- def __infer__(self, x, name):
- src_type = x['dtype']
- validator.check_subclass(
- "input", src_type, [mstype.tensor], name["value"])
-
- out = {'shape': None,
- 'dtype': None,
- 'value': None}
- return out
-
- def __call__(self, x, name):
- if isinstance(x, Tensor):
- return x
- raise TypeError(f"For {name}, input type should be a Tensor or Parameter.")
-
- def set_param_type(args, hint_type):
- """
- Find the common type among arguments.
-
- Args:
- args (dict): dictionary of arguments, {'name':value}.
- hint_type (mindspore.dtype): hint type to return.
-
- Raises:
- TypeError: if tensors in args are not the same dtype.
- """
- int_type = mstype.int_type + mstype.uint_type
- if hint_type in int_type:
- hint_type = mstype.float32
- common_dtype = None
- for name, arg in args.items():
- if hasattr(arg, 'dtype'):
- if isinstance(arg, np.ndarray):
- cur_dtype = mstype.pytype_to_dtype(arg.dtype)
- else:
- cur_dtype = arg.dtype
- if common_dtype is None:
- common_dtype = cur_dtype
- elif cur_dtype != common_dtype:
- raise TypeError(f"{name} should have the same dtype as other arguments.")
- if common_dtype in int_type or common_dtype == mstype.float64:
- return mstype.float32
- return hint_type if common_dtype is None else common_dtype
|