| @@ -16,8 +16,10 @@ | |||
| from mindspore import context | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore._checkparam import Validator as validator | |||
| from ..distribution._utils.utils import CheckTensor, cast_to_tensor | |||
| from ..distribution._utils.utils import CheckTensor, cast_to_tensor, raise_type_error | |||
| from ..distribution import Distribution | |||
| from ..distribution import TransformedDistribution | |||
| @@ -32,6 +34,17 @@ class Bijector(Cell): | |||
| name (str): The name of the Bijector. Default: None. | |||
| dtype (mindspore.dtype): The type of the distributions that the Bijector can operate on. Default: None. | |||
| param (dict): The parameters used to initialize the Bijector. Default: None. | |||
| Note: | |||
| `dtype` of bijector represents the type of the distributions that the bijector could operate on. | |||
| When `dtype` is None, there is no enforcement on the type of input value except that the input value | |||
| has to be float type. During initilization, when `dtype` is None, there is no enforcement on the dtype | |||
| of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised. | |||
| Specifically, the parameter type will follow the dtype of the input value, i.e. parameters of the bijector | |||
| will be casted into the same type as input value when `dtype`is None. | |||
| When `dtype` is specified, it is forcing the parameters and input value to be the same dtype as `dtype`. | |||
| When the type of parameters or the type of the input value is not the same as `dtype`, a TypeError will be | |||
| raised. Only subtype of mindspore.float_type can be used to specify bijector's `dtype`. | |||
| """ | |||
| def __init__(self, | |||
| @@ -48,6 +61,8 @@ class Bijector(Cell): | |||
| validator.check_value_type( | |||
| 'is_constant_jacobian', is_constant_jacobian, [bool], name) | |||
| validator.check_value_type('is_injective', is_injective, [bool], name) | |||
| if dtype is not None: | |||
| validator.check_type_name("dtype", dtype, mstype.float_type, type(self).__name__) | |||
| self._name = name | |||
| self._dtype = dtype | |||
| self._parameters = {} | |||
| @@ -57,6 +72,12 @@ class Bijector(Cell): | |||
| continue | |||
| if not(k == 'self' or k.startswith('_')): | |||
| self._parameters[k] = param[k] | |||
| # if no bijector is used as an argument during initilization | |||
| if 'bijector' not in param.keys(): | |||
| self._batch_shape = self._calc_batch_shape() | |||
| self._is_scalar_batch = self._check_is_scalar_batch() | |||
| self._is_constant_jacobian = is_constant_jacobian | |||
| self._is_injective = is_injective | |||
| @@ -68,6 +89,8 @@ class Bijector(Cell): | |||
| self.dtype_base = P.DType() | |||
| self.shape_base = P.Shape() | |||
| self.fill_base = P.Fill() | |||
| self.sametypeshape_base = P.SameTypeShape() | |||
| self.issubclass_base = P.IsSubClass() | |||
| @property | |||
| def name(self): | |||
| @@ -89,6 +112,38 @@ class Bijector(Cell): | |||
| def is_injective(self): | |||
| return self._is_injective | |||
| @property | |||
| def batch_shape(self): | |||
| return self._batch_shape | |||
| @property | |||
| def is_scalar_batch(self): | |||
| return self._is_scalar_batch | |||
| def _check_value_dtype(self, value): | |||
| """ | |||
| Firstly check if the input value is Tensor. Then, if `self.dtype` is None, check | |||
| if the input tensor is or can be directly cast into a float tensor. | |||
| If `self.dtype` is not None, check if the input tensor's dtype is `self.dtype`. | |||
| """ | |||
| self.checktensor(value, 'input value of bijector') | |||
| value_type = self.dtype_base(value) | |||
| if self.dtype is None: | |||
| if self.issubclass_base(value_type, mstype.float_): | |||
| return value | |||
| return raise_type_error('input value of bijector', value_type, mstype.float_) | |||
| dtype_tensor = self.fill_base(self.dtype, self.shape_base(value), 0.0) | |||
| self.sametypeshape_base(value, dtype_tensor) | |||
| return value | |||
| def _shape_mapping(self, shape): | |||
| shape_tensor = self.fill_base(self.parameter_type, shape, 0.0) | |||
| dist_shape_tensor = self.fill_base(self.parameter_type, self.batch_shape, 0.0) | |||
| return (shape_tensor + dist_shape_tensor).shape | |||
| def shape_mapping(self, shape): | |||
| return self._shape_mapping(shape) | |||
| def _add_parameter(self, value, name): | |||
| """ | |||
| Cast `value` to a tensor and add it to `self.default_parameters`. | |||
| @@ -98,26 +153,51 @@ class Bijector(Cell): | |||
| if not hasattr(self, 'default_parameters'): | |||
| self.default_parameters = [] | |||
| self.parameter_names = [] | |||
| self.common_dtype = None | |||
| # cast value to a tensor if it is not None | |||
| value_t = None if value is None else cast_to_tensor(value, self.parameter_type) | |||
| self.default_parameters += [value_t,] | |||
| if isinstance(value, bool) or value is None: | |||
| raise TypeError(f"{name} cannot be type {type(value)}") | |||
| value_t = Tensor(value) | |||
| # if the bijector's dtype is not specified | |||
| if self.dtype is None: | |||
| if self.common_dtype is None: | |||
| self.common_dtype = value_t.dtype | |||
| elif value_t.dtype != self.common_dtype: | |||
| raise TypeError(f"{name} should have the same dtype as other arguments.") | |||
| # check if the dtype of the input_parameter agrees with the bijector's dtype | |||
| elif value_t.dtype != self.dtype: | |||
| raise TypeError(f"{name} should have the same dtype as the bijector's dtype.") | |||
| self.default_parameters += [value,] | |||
| self.parameter_names += [name,] | |||
| return value_t | |||
| def _calc_event_shape(self): | |||
| def _calc_batch_shape(self): | |||
| """ | |||
| Calculate event_shape based on parameters. | |||
| Calculate batch_shape based on parameters. | |||
| """ | |||
| broadcast_shape = None | |||
| for param in self.default_parameters: | |||
| if broadcast_shape is None: | |||
| broadcast_shape = self.shape_base(param) | |||
| broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0) | |||
| param_dict = self.parameters['param_dict'] | |||
| broadcast_shape_tensor = None | |||
| for value in param_dict.values(): | |||
| if value is None: | |||
| return None | |||
| if broadcast_shape_tensor is None: | |||
| broadcast_shape_tensor = cast_to_tensor(value) | |||
| else: | |||
| broadcast_shape = self.shape_base(param + broadcast_shape_tensor) | |||
| broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0) | |||
| return broadcast_shape | |||
| value = cast_to_tensor(value) | |||
| broadcast_shape_tensor = (value + broadcast_shape_tensor) | |||
| return broadcast_shape_tensor.shape | |||
| def _check_is_scalar_batch(self): | |||
| """ | |||
| Check if the parameters used during initialization are scalars. | |||
| """ | |||
| param_dict = self.parameters['param_dict'] | |||
| for value in param_dict.values(): | |||
| if value is None: | |||
| continue | |||
| if not isinstance(value, (int, float)): | |||
| return False | |||
| return True | |||
| def _check_value(self, value, name): | |||
| """ | |||
| @@ -127,32 +207,35 @@ class Bijector(Cell): | |||
| return value | |||
| def cast_param_by_value(self, value, para): | |||
| """ | |||
| Cast the parameter(s) of the bijector to be the same type of input_value. | |||
| """ | |||
| local = self.cast_base(para, self.dtype_base(value)) | |||
| return local | |||
| def forward(self, *args, **kwargs): | |||
| def forward(self, value, *args, **kwargs): | |||
| """ | |||
| Forward transformation: transform the input value to another distribution. | |||
| """ | |||
| return self._forward(*args, **kwargs) | |||
| return self._forward(value, *args, **kwargs) | |||
| def inverse(self, *args, **kwargs): | |||
| def inverse(self, value, *args, **kwargs): | |||
| """ | |||
| Inverse transformation: transform the input value back to the original distribution. | |||
| """ | |||
| return self._inverse(*args, **kwargs) | |||
| return self._inverse(value, *args, **kwargs) | |||
| def forward_log_jacobian(self, *args, **kwargs): | |||
| def forward_log_jacobian(self, value, *args, **kwargs): | |||
| """ | |||
| Logarithm of the derivative of the forward transformation. | |||
| """ | |||
| return self._forward_log_jacobian(*args, **kwargs) | |||
| return self._forward_log_jacobian(value, *args, **kwargs) | |||
| def inverse_log_jacobian(self, *args, **kwargs): | |||
| def inverse_log_jacobian(self, value, *args, **kwargs): | |||
| """ | |||
| Logarithm of the derivative of the inverse transformation. | |||
| """ | |||
| return self._inverse_log_jacobian(*args, **kwargs) | |||
| return self._inverse_log_jacobian(value, *args, **kwargs) | |||
| def __call__(self, *args, **kwargs): | |||
| """ | |||
| @@ -167,7 +250,7 @@ class Bijector(Cell): | |||
| *args: args[0] shall be either a distribution or the name of a Bijector function. | |||
| """ | |||
| if isinstance(args[0], Distribution): | |||
| return TransformedDistribution(self, args[0], self.distribution.dtype) | |||
| return TransformedDistribution(self, args[0]) | |||
| return super(Bijector, self).__call__(*args, **kwargs) | |||
| def construct(self, name, *args, **kwargs): | |||
| @@ -13,10 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """GumbelCDF Bijector""" | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore._checkparam import Validator | |||
| from mindspore.ops import operations as P | |||
| from ..distribution._utils.utils import check_greater_zero, set_param_type | |||
| from ..distribution._utils.utils import check_greater_zero | |||
| from ..distribution._utils.custom_ops import exp_generic, log_generic | |||
| from .bijector import Bijector | |||
| @@ -30,12 +28,11 @@ class GumbelCDF(Bijector): | |||
| Y = \exp(-\exp(\frac{-(X - loc)}{scale})) | |||
| Note: | |||
| For `reverse` and `reverse_log_jacobian`, input should be in range of (0, 1). | |||
| For `inverse` and `inverse_log_jacobian`, input should be in range of (0, 1). | |||
| Args: | |||
| loc (int, float, list, numpy.ndarray, Tensor): The location. Default: 0.. | |||
| scale (int, float, list, numpy.ndarray, Tensor): The scale. Default: 1.0. | |||
| dtype (mindspore.dtype): Type of the distribution which the bijector operates on. Default: float32. | |||
| loc (float, list, numpy.ndarray, Tensor): The location. Default: 0.. | |||
| scale (float, list, numpy.ndarray, Tensor): The scale. Default: 1.0. | |||
| name (str): The name of the Bijector. Default: 'Gumbel_CDF'. | |||
| Examples: | |||
| @@ -61,22 +58,18 @@ class GumbelCDF(Bijector): | |||
| def __init__(self, | |||
| loc=0.0, | |||
| scale=1.0, | |||
| dtype=mstype.float32, | |||
| name='GumbelCDF'): | |||
| """ | |||
| Constructor of GumbelCDF Bijector. | |||
| """ | |||
| param = dict(locals()) | |||
| valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type | |||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||
| parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype) | |||
| super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param) | |||
| param['param_dict'] = {'loc': loc, 'scale': scale} | |||
| super(GumbelCDF, self).__init__(name=name, param=param) | |||
| self._parameter_type = parameter_type | |||
| self._loc = self._add_parameter(loc, 'loc') | |||
| self._scale = self._add_parameter(scale, 'scale') | |||
| check_greater_zero(self._scale, "scale") | |||
| self._event_shape = self._calc_event_shape() | |||
| self.cast = P.Cast() | |||
| self.exp = exp_generic | |||
| @@ -91,38 +84,34 @@ class GumbelCDF(Bijector): | |||
| def scale(self): | |||
| return self._scale | |||
| @property | |||
| def event_shape(self): | |||
| return self._event_shape | |||
| @property | |||
| def parameter_type(self): | |||
| return self._parameter_type | |||
| def extend_repr(self): | |||
| return f'loc = {self.loc}, scale = {self.scale}' | |||
| def shape_mapping(self, shape): | |||
| return shape | |||
| if self.is_scalar_batch: | |||
| str_info = f'loc = {self.loc}, scale = {self.scale}' | |||
| else: | |||
| str_info = f'batch_shape = {self.batch_shape}' | |||
| return str_info | |||
| def _forward(self, x): | |||
| x = self._check_value(x, 'value') | |||
| x = self.cast(x, self.parameter_type) | |||
| z = (x - self.loc) / self.scale | |||
| x = self._check_value_dtype(x) | |||
| loc_local = self.cast_param_by_value(x, self.loc) | |||
| scale_local = self.cast_param_by_value(x, self.scale) | |||
| z = (x - loc_local) / scale_local | |||
| return self.exp(-self.exp(-z)) | |||
| def _inverse(self, y): | |||
| y = self._check_value(y, 'value') | |||
| y = self.cast(y, self.parameter_type) | |||
| return self.loc - self.scale * self.log(-self.log(y)) | |||
| y = self._check_value_dtype(y) | |||
| loc_local = self.cast_param_by_value(y, self.loc) | |||
| scale_local = self.cast_param_by_value(y, self.scale) | |||
| return loc_local - scale_local * self.log(-self.log(y)) | |||
| def _forward_log_jacobian(self, x): | |||
| x = self._check_value(x, 'value') | |||
| x = self.cast(x, self.parameter_type) | |||
| z = (x - self.loc) / self.scale | |||
| return -z - self.exp(-z) - self.log(self.scale) | |||
| x = self._check_value_dtype(x) | |||
| loc_local = self.cast_param_by_value(x, self.loc) | |||
| scale_local = self.cast_param_by_value(x, self.scale) | |||
| z = (x - loc_local) / scale_local | |||
| return -z - self.exp(-z) - self.log(scale_local) | |||
| def _inverse_log_jacobian(self, y): | |||
| y = self._check_value(y, 'value') | |||
| y = self.cast(y, self.parameter_type) | |||
| return self.log(self.scale / (-1. * y * self.log(y))) | |||
| y = self._check_value_dtype(y) | |||
| scale_local = self.cast_param_by_value(y, self.scale) | |||
| return self.log(scale_local / (-1. * y * self.log(y))) | |||
| @@ -53,23 +53,17 @@ class Invert(Bijector): | |||
| name = (name + bijector.name) if name == 'Invert' else name | |||
| super(Invert, self).__init__(is_constant_jacobian=bijector.is_constant_jacobian, | |||
| is_injective=bijector.is_injective, | |||
| dtype=bijector.dtype, | |||
| name=name, | |||
| dtype=bijector.dtype, | |||
| param=param) | |||
| self._bijector = bijector | |||
| if hasattr(self._bijector, 'event_shape'): | |||
| self._event_shape = self.bijector.event_shape | |||
| else: | |||
| self._event_shape = () | |||
| self._batch_shape = self.bijector.batch_shape | |||
| self._is_scalar_batch = self.bijector.is_scalar_batch | |||
| @property | |||
| def bijector(self): | |||
| return self._bijector | |||
| @property | |||
| def event_shape(self): | |||
| return self._event_shape | |||
| def inverse(self, y): | |||
| return self.bijector("forward", y) | |||
| @@ -14,8 +14,7 @@ | |||
| # ============================================================================ | |||
| """Power Bijector""" | |||
| from mindspore.ops import operations as P | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from ..distribution._utils.utils import check_greater_equal_zero | |||
| from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic | |||
| from .bijector import Bijector | |||
| @@ -37,7 +36,7 @@ class PowerTransform(Bijector): | |||
| ValueError: When the power is less than 0 or is not known statically. | |||
| Args: | |||
| power (int or float): The scale factor. Default: 0. | |||
| power (float, list, numpy.ndarray, Tensor): The scale factor. Default: 0. | |||
| name (str): The name of the bijector. Default: 'PowerTransform'. | |||
| Examples: | |||
| @@ -64,10 +63,11 @@ class PowerTransform(Bijector): | |||
| power=0, | |||
| name='PowerTransform'): | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'power': power} | |||
| super(PowerTransform, self).__init__(name=name, param=param) | |||
| validator.check_value_type('power', power, [int, float], self.name) | |||
| validator.check_number("power", power, 0, Rel.GE, self.name) | |||
| self._power = power | |||
| self._power = self._add_parameter(power, 'power') | |||
| check_greater_equal_zero(self._power, 'Power') | |||
| self.pow = P.Pow() | |||
| self.dtypeop = P.DType() | |||
| self.cast = P.Cast() | |||
| @@ -81,13 +81,15 @@ class PowerTransform(Bijector): | |||
| return self._power | |||
| def extend_repr(self): | |||
| return f'power = {self.power}' | |||
| if self.is_scalar_batch: | |||
| str_info = f'power = {self.power}' | |||
| else: | |||
| str_info = f'batch_shape = {self.batch_shape}' | |||
| return str_info | |||
| def shape_mapping(self, shape): | |||
| return shape | |||
| def _forward(self, x): | |||
| x = self._check_value(x, 'value') | |||
| x = self._check_value_dtype(x) | |||
| power_local = self.cast_param_by_value(x, self.power) | |||
| if power_local == 0: | |||
| forward_v = self.exp(x) | |||
| @@ -96,7 +98,7 @@ class PowerTransform(Bijector): | |||
| return forward_v | |||
| def _inverse(self, y): | |||
| y = self._check_value(y, 'value') | |||
| y = self._check_value_dtype(y) | |||
| power_local = self.cast_param_by_value(y, self.power) | |||
| if power_local == 0: | |||
| inverse_v = self.log(y) | |||
| @@ -116,7 +118,7 @@ class PowerTransform(Bijector): | |||
| f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} | |||
| \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) | |||
| """ | |||
| x = self._check_value(x, 'value') | |||
| x = self._check_value_dtype(x) | |||
| power_local = self.cast_param_by_value(x, self.power) | |||
| if power_local == 0: | |||
| forward_log_j = x | |||
| @@ -136,7 +138,7 @@ class PowerTransform(Bijector): | |||
| f'(x) = \frac{e^c\log(y)}{y} | |||
| \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) | |||
| """ | |||
| y = self._check_value(y, 'value') | |||
| y = self._check_value_dtype(y) | |||
| power_local = self.cast_param_by_value(y, self.power) | |||
| inverse_log_j = (power_local - 1) * self.log(y) | |||
| return inverse_log_j | |||
| @@ -14,8 +14,6 @@ | |||
| # ============================================================================ | |||
| """Scalar Affine Bijector""" | |||
| from mindspore.ops import operations as P | |||
| from mindspore._checkparam import Validator as validator | |||
| from ..distribution._utils.utils import cast_to_tensor | |||
| from ..distribution._utils.custom_ops import log_generic | |||
| from .bijector import Bijector | |||
| @@ -30,10 +28,14 @@ class ScalarAffine(Bijector): | |||
| where a is the scale factor and b is the shift factor. | |||
| Args: | |||
| scale (float): The scale factor. Default: 1.0. | |||
| shift (float): The shift factor. Default: 0.0. | |||
| scale (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0. | |||
| shift (float, list, numpy.ndarray, Tensor): The shift factor. Default: 0.0. | |||
| name (str): The name of the bijector. Default: 'ScalarAffine'. | |||
| Note: | |||
| If `shift`, `scale` are passed in as numpy.ndarray or tensor, they have to have | |||
| the same dtype otherwise an error will be raised. | |||
| Examples: | |||
| >>> # To initialize a ScalarAffine bijector of scale 1 and shift 2. | |||
| >>> scalaraffine = nn.probability.bijector.ScalarAffine(1, 2) | |||
| @@ -61,10 +63,7 @@ class ScalarAffine(Bijector): | |||
| Constructor of ScalarAffine Bijector. | |||
| """ | |||
| param = dict(locals()) | |||
| validator.check_value_type( | |||
| 'scale', scale, [int, float], type(self).__name__) | |||
| validator.check_value_type( | |||
| 'shift', shift, [int, float], type(self).__name__) | |||
| param['param_dict'] = {'scale': scale, 'shift': shift} | |||
| super(ScalarAffine, self).__init__( | |||
| is_constant_jacobian=True, | |||
| is_injective=True, | |||
| @@ -72,8 +71,8 @@ class ScalarAffine(Bijector): | |||
| dtype=None, | |||
| param=param) | |||
| self._scale = cast_to_tensor(scale) | |||
| self._shift = cast_to_tensor(shift) | |||
| self._scale = self._add_parameter(scale, 'scale') | |||
| self._shift = self._add_parameter(shift, 'shift') | |||
| self.abs = P.Abs() | |||
| self.oneslike = P.OnesLike() | |||
| @@ -90,17 +89,18 @@ class ScalarAffine(Bijector): | |||
| return self._shift | |||
| def extend_repr(self): | |||
| return f'scale = {self.scale}, shift = {self.shift}' | |||
| def shape_mapping(self, shape): | |||
| return shape | |||
| if self.is_scalar_batch: | |||
| str_info = f'scale = {self.scale}, shift = {self.shift}' | |||
| else: | |||
| str_info = f'batch_shape = {self.batch_shape}' | |||
| return str_info | |||
| def _forward(self, x): | |||
| r""" | |||
| .. math:: | |||
| f(x) = a * x + b | |||
| """ | |||
| x = self._check_value(x, 'value') | |||
| x = self._check_value_dtype(x) | |||
| scale_local = self.cast_param_by_value(x, self.scale) | |||
| shift_local = self.cast_param_by_value(x, self.shift) | |||
| forward_v = scale_local * x + shift_local * self.oneslike(x) | |||
| @@ -111,7 +111,7 @@ class ScalarAffine(Bijector): | |||
| .. math:: | |||
| f(y) = \frac{y - b}{a} | |||
| """ | |||
| y = self._check_value(y, 'value') | |||
| y = self._check_value_dtype(y) | |||
| scale_local = self.cast_param_by_value(y, self.scale) | |||
| shift_local = self.cast_param_by_value(y, self.shift) | |||
| inverse_v = (y - shift_local) / scale_local | |||
| @@ -124,7 +124,7 @@ class ScalarAffine(Bijector): | |||
| f'(x) = a | |||
| \log(f'(x)) = \log(a) | |||
| """ | |||
| x = self._check_value(x, 'value') | |||
| x = self._check_value_dtype(x) | |||
| scale_local = self.cast_param_by_value(x, self.scale) | |||
| forward_log_j = self.log(self.abs(scale_local)) | |||
| return forward_log_j | |||
| @@ -136,7 +136,7 @@ class ScalarAffine(Bijector): | |||
| f'(x) = \frac{1.0}{a} | |||
| \log(f'(x)) = - \log(a) | |||
| """ | |||
| y = self._check_value(y, 'value') | |||
| y = self._check_value_dtype(y) | |||
| scale_local = self.cast_param_by_value(y, self.scale) | |||
| inverse_log_j = -1. * self.log(self.abs(scale_local)) | |||
| return inverse_log_j | |||
| @@ -16,8 +16,6 @@ | |||
| import numpy as np | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn.layer.activation import LogSigmoid | |||
| from mindspore._checkparam import Validator as validator | |||
| from ..distribution._utils.utils import cast_to_tensor | |||
| from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic | |||
| from .bijector import Bijector | |||
| @@ -32,7 +30,7 @@ class Softplus(Bijector): | |||
| where k is the sharpness factor. | |||
| Args: | |||
| sharpness (float): The scale factor. Default: 1.0. | |||
| sharpness (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0. | |||
| name (str): The name of the Bijector. Default: 'Softplus'. | |||
| Examples: | |||
| @@ -61,10 +59,9 @@ class Softplus(Bijector): | |||
| Constructor of Softplus Bijector. | |||
| """ | |||
| param = dict(locals()) | |||
| validator.check_value_type('sharpness', sharpness, | |||
| [int, float], type(self).__name__) | |||
| super(Softplus, self).__init__(name=name, param=param) | |||
| self._sharpness = cast_to_tensor(sharpness) | |||
| param['param_dict'] = {'sharpness': sharpness} | |||
| super(Softplus, self).__init__(name=name, dtype=None, param=param) | |||
| self._sharpness = self._add_parameter(sharpness, 'sharpness') | |||
| self.exp = exp_generic | |||
| self.log = log_generic | |||
| @@ -118,13 +115,14 @@ class Softplus(Bijector): | |||
| return self._sharpness | |||
| def extend_repr(self): | |||
| return f'sharpness = {self.sharpness}' | |||
| def shape_mapping(self, shape): | |||
| return shape | |||
| if self.is_scalar_batch: | |||
| str_info = f'sharpness = {self.sharpness}' | |||
| else: | |||
| str_info = f'batch_shape = {self.batch_shape}' | |||
| return str_info | |||
| def _forward(self, x): | |||
| x = self._check_value(x, 'value') | |||
| x = self._check_value_dtype(x) | |||
| sharpness_local = self.cast_param_by_value(x, self.sharpness) | |||
| scaled_value = sharpness_local * x | |||
| forward_v = self.softplus(scaled_value) / sharpness_local | |||
| @@ -136,7 +134,7 @@ class Softplus(Bijector): | |||
| f(x) = \frac{\log(1 + e^{kx}))}{k} | |||
| f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} | |||
| """ | |||
| y = self._check_value(y, 'value') | |||
| y = self._check_value_dtype(y) | |||
| sharpness_local = self.cast_param_by_value(y, self.sharpness) | |||
| scaled_value = sharpness_local * y | |||
| inverse_v = self.inverse_softplus(scaled_value) / sharpness_local | |||
| @@ -149,7 +147,7 @@ class Softplus(Bijector): | |||
| f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} | |||
| \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) | |||
| """ | |||
| x = self._check_value(x, 'value') | |||
| x = self._check_value_dtype(x) | |||
| sharpness_local = self.cast_param_by_value(x, self.sharpness) | |||
| scaled_value = sharpness_local * x | |||
| forward_log_j = self.log_sigmoid(scaled_value) | |||
| @@ -162,7 +160,7 @@ class Softplus(Bijector): | |||
| f'(y) = \frac{e^{ky}}{e^{ky} - 1} | |||
| \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) | |||
| """ | |||
| y = self._check_value(y, 'value') | |||
| y = self._check_value_dtype(y) | |||
| sharpness_local = self.cast_param_by_value(y, self.sharpness) | |||
| scaled_value = sharpness_local * y | |||
| inverse_log_j = scaled_value - self.inverse_softplus(scaled_value) | |||
| @@ -229,6 +229,11 @@ def raise_not_implemented_util(func_name, obj, *args, **kwargs): | |||
| raise NotImplementedError( | |||
| f"{func_name} is not implemented for {obj} distribution.") | |||
| @constexpr | |||
| def raise_type_error(name, cur_type, required_type): | |||
| raise TypeError( | |||
| f"For {name} , the type should be or be subclass of {required_type}, but got {cur_type}") | |||
| @constexpr | |||
| def check_distribution_name(name, expected_name): | |||
| @@ -304,7 +309,7 @@ def set_param_type(args, hint_type): | |||
| TypeError: if tensors in args are not the same dtype. | |||
| """ | |||
| int_type = mstype.int_type + mstype.uint_type | |||
| if hint_type in int_type: | |||
| if hint_type in int_type or hint_type is None: | |||
| hint_type = mstype.float32 | |||
| common_dtype = None | |||
| for name, arg in args.items(): | |||
| @@ -72,13 +72,12 @@ class Distribution(Cell): | |||
| if not(k == 'self' or k.startswith('_')): | |||
| self._parameters[k] = param[k] | |||
| # some attributes | |||
| if 'distribution' in self.parameters.keys(): | |||
| self.parameter_type = self.parameters['distribution'].parameter_type | |||
| else: | |||
| # if not a transformed distribution, set the following attribute | |||
| if 'distribution' not in self.parameters.keys(): | |||
| self.parameter_type = set_param_type(self.parameters['param_dict'], dtype) | |||
| self._broadcast_shape = self._calc_broadcast_shape() | |||
| self._is_scalar_batch = self._check_is_scalar_batch() | |||
| self._batch_shape = self._calc_batch_shape() | |||
| self._is_scalar_batch = self._check_is_scalar_batch() | |||
| self._broadcast_shape = self._batch_shape | |||
| # set the function to call according to the derived class's attributes | |||
| self._set_prob() | |||
| @@ -128,6 +127,10 @@ class Distribution(Cell): | |||
| def is_scalar_batch(self): | |||
| return self._is_scalar_batch | |||
| @property | |||
| def batch_shape(self): | |||
| return self._batch_shape | |||
| @property | |||
| def broadcast_shape(self): | |||
| return self._broadcast_shape | |||
| @@ -208,8 +211,6 @@ class Distribution(Cell): | |||
| """ | |||
| Check if the parameters used during initialization are scalars. | |||
| """ | |||
| if 'distribution' in self.parameters.keys(): | |||
| return self.parameters['distribution'].is_scalar_batch | |||
| param_dict = self.parameters['param_dict'] | |||
| for value in param_dict.values(): | |||
| if value is None: | |||
| @@ -218,12 +219,10 @@ class Distribution(Cell): | |||
| return False | |||
| return True | |||
| def _calc_broadcast_shape(self): | |||
| def _calc_batch_shape(self): | |||
| """ | |||
| Calculate the broadcast shape of the parameters used during initialization. | |||
| """ | |||
| if 'distribution' in self.parameters.keys(): | |||
| return self.parameters['distribution'].broadcast_shape | |||
| param_dict = self.parameters['param_dict'] | |||
| broadcast_shape_tensor = None | |||
| for value in param_dict.values(): | |||
| @@ -362,14 +361,14 @@ class Distribution(Cell): | |||
| """ | |||
| return self._get_dist_args(*args, **kwargs) | |||
| def _get_dist_type(self, *args, **kwargs): | |||
| return raise_not_implemented_util('get_dist_type', self.name, *args, **kwargs) | |||
| def _get_dist_type(self): | |||
| return raise_not_implemented_util('get_dist_type', self.name) | |||
| def get_dist_type(self, *args, **kwargs): | |||
| def get_dist_type(self): | |||
| """ | |||
| Return the type of the distribution. | |||
| """ | |||
| return self._get_dist_type(*args, **kwargs) | |||
| return self._get_dist_type() | |||
| def _raise_not_implemented_error(self, func_name): | |||
| name = self.name | |||
| @@ -751,5 +750,5 @@ class Distribution(Cell): | |||
| if name == 'get_dist_args': | |||
| return self._get_dist_args(*args, **kwargs) | |||
| if name == 'get_dist_type': | |||
| return self._get_dist_type(*args, **kwargs) | |||
| return self._get_dist_type() | |||
| return raise_not_implemented_util(name, self.name, *args, **kwargs) | |||
| @@ -103,17 +103,12 @@ class Gumbel(TransformedDistribution): | |||
| """ | |||
| valid_dtype = mstype.float_type | |||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||
| gumbel_cdf = msb.GumbelCDF(loc, scale, dtype) | |||
| gumbel_cdf = msb.GumbelCDF(loc, scale) | |||
| super(Gumbel, self).__init__( | |||
| distribution=msd.Uniform(0.0, 1.0, dtype=dtype), | |||
| bijector=msb.Invert(gumbel_cdf), | |||
| seed=seed, name=name) | |||
| self.parameter_type = gumbel_cdf.parameter_type | |||
| self._broadcast_shape = gumbel_cdf.event_shape | |||
| if self._broadcast_shape != (): | |||
| self._is_scalar_batch = False | |||
| # overwrite default_parameters and parameter_names | |||
| self._reset_parameters() | |||
| self._loc = self._add_parameter(loc, 'loc') | |||
| @@ -202,6 +197,7 @@ class Gumbel(TransformedDistribution): | |||
| where z = \frac{x - loc}{scale} | |||
| """ | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, self.dtype) | |||
| z = (value - self.loc) / self.scale | |||
| return -(z + self.exp(-z)) - self.log(self.scale) | |||
| @@ -210,6 +206,8 @@ class Gumbel(TransformedDistribution): | |||
| .. math:: | |||
| cdf_pdf(X) = \exp(-\exp(-\frac{x - loc}{scale}) | |||
| """ | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, self.dtype) | |||
| return self._gumbel_bijector("forward", value) | |||
| def _cross_entropy(self, dist, loc_b, scale_b): | |||
| @@ -251,12 +249,14 @@ class Gumbel(TransformedDistribution): | |||
| self.expm1((loc_b - self.loc) / scale_b + self.lgamma(self.scale / scale_b + 1.)) | |||
| def _sample(self, shape=()): | |||
| shape = self.checktuple(shape, 'shape') | |||
| origin_shape = shape + self._broadcast_shape | |||
| if origin_shape == (): | |||
| sample_shape = (1,) | |||
| else: | |||
| sample_shape = origin_shape | |||
| org_sample = self.distribution("sample", sample_shape) | |||
| org_sample = self.cast(org_sample, self.dtype) | |||
| value = self.bijector("forward", org_sample) | |||
| if origin_shape == (): | |||
| value = self.squeeze(value) | |||
| @@ -137,6 +137,11 @@ class LogNormal(msd.TransformedDistribution): | |||
| bijector=msb.Exp(), | |||
| seed=seed, name=name) | |||
| # overwrite default_parameters and parameter_names | |||
| self._reset_parameters() | |||
| self._loc = self._add_parameter(loc, 'loc') | |||
| self._scale = self._add_parameter(scale, 'scale') | |||
| self.log_2pi = np.log(2 * np.pi) | |||
| #ops needed for the class | |||
| @@ -154,12 +159,12 @@ class LogNormal(msd.TransformedDistribution): | |||
| @property | |||
| def loc(self): | |||
| """Distribution parameter for the pre-transformed mean.""" | |||
| return self.distribution("mean") | |||
| return self._loc | |||
| @property | |||
| def scale(self): | |||
| """Distribution parameter for the pre-transformed standard deviation.""" | |||
| return self.distribution("sd") | |||
| return self._scale | |||
| def _get_dist_type(self): | |||
| return "LogNormal" | |||
| @@ -168,18 +173,18 @@ class LogNormal(msd.TransformedDistribution): | |||
| if loc is not None: | |||
| self.checktensor(loc, 'loc') | |||
| else: | |||
| loc = self.distribution("mean") | |||
| loc = self.loc | |||
| if scale is not None: | |||
| self.checktensor(scale, 'scale') | |||
| else: | |||
| scale = self.distribution("sd") | |||
| scale = self.scale | |||
| return loc, scale | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| s = f'loc = {self._mean_value}, scale = {self._sd_value}' | |||
| s = f'loc = {self.loc}, scale = {self.scale}' | |||
| else: | |||
| s = f'batch_shape = {self._broadcast_shape}' | |||
| s = f'batch_shape = {self.broadcast_shape}' | |||
| return s | |||
| def _mean(self, loc=None, scale=None): | |||
| @@ -16,6 +16,7 @@ | |||
| import numpy as np | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common import dtype as mstype | |||
| import mindspore.nn as nn | |||
| from .distribution import Distribution | |||
| from ._utils.utils import raise_not_impl_error | |||
| @@ -30,7 +31,7 @@ class TransformedDistribution(Distribution): | |||
| Args: | |||
| bijector (Bijector): The transformation to perform. | |||
| distribution (Distribution): The original distribution. | |||
| distribution (Distribution): The original distribution. Must has dtype of mindspore.float_type. | |||
| seed (int): The seed is used in sampling. The global seed is used if it is None. Default:None. | |||
| If this seed is given when a TransformedDistribution object is initialised, the object's sampling function | |||
| will use this seed; elsewise, the underlying distribution's seed will be used. | |||
| @@ -40,6 +41,12 @@ class TransformedDistribution(Distribution): | |||
| The arguments used to initialize the original distribution cannot be None. | |||
| For example, mynormal = nn.Normal(dtype=dtyple.float32) cannot be used to initialized a | |||
| TransformedDistribution since `mean` and `sd` are not specified. | |||
| `batch_shape` is the batch_shape of the original distribution. | |||
| `broadcast_shape` is the broadcast shape between the original distribution and bijector. | |||
| `is_scalar_batch` is only true if both the original distribution and the bijector are scalar batches. | |||
| `default_parameters`, `parameter_names` and `parameter_type` are set to be consistent with the original | |||
| distribution. Derived class can overwrite `default_parameters` and `parameter_names` by calling | |||
| `reset_parameters` followed by `add_parameter`. | |||
| Examples: | |||
| >>> # To initialize a transformed distribution, e.g. a lognormal distribution, | |||
| @@ -75,28 +82,34 @@ class TransformedDistribution(Distribution): | |||
| [nn.probability.bijector.Bijector], type(self).__name__) | |||
| validator.check_value_type('distribution', distribution, | |||
| [Distribution], type(self).__name__) | |||
| validator.check_type_name("dtype", distribution.dtype, mstype.float_type, type(self).__name__) | |||
| super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param) | |||
| self._bijector = bijector | |||
| self._distribution = distribution | |||
| self._is_linear_transformation = bijector.is_constant_jacobian | |||
| self.default_parameters = distribution.default_parameters | |||
| self.parameter_names = distribution.parameter_names | |||
| # set attributes | |||
| self._is_linear_transformation = self.bijector.is_constant_jacobian | |||
| self._dtype = self.distribution.dtype | |||
| self._is_scalar_batch = self.distribution.is_scalar_batch and self.bijector.is_scalar_batch | |||
| self._batch_shape = self.distribution.batch_shape | |||
| self.default_parameters = self.distribution.default_parameters | |||
| self.parameter_names = self.distribution.parameter_names | |||
| # by default, set the parameter_type to be the distribution's parameter_type | |||
| self.parameter_type = self.distribution.parameter_type | |||
| self.exp = exp_generic | |||
| self.log = log_generic | |||
| self.isnan = P.IsNan() | |||
| self.cast_base = P.Cast() | |||
| self.equal_base = P.Equal() | |||
| self.select_base = P.Select() | |||
| self.fill = P.Fill() | |||
| self.fill_base = P.Fill() | |||
| # broadcast bijector batch_shape and distribution batch_shape | |||
| self._broadcast_shape = self._broadcast_bijector_dist() | |||
| # check if batch shape of the distribution and event shape is broadcastable | |||
| if hasattr(self.bijector, 'event_shape'): | |||
| event_shape_tensor = self.fill(self.dtype, self.bijector.event_shape, 0.0) | |||
| broadcast_shape_tensor = self.fill(self.dtype, self.broadcast_shape, 0.0) | |||
| self._batch_event = (event_shape_tensor + broadcast_shape_tensor).shape | |||
| else: | |||
| self._batch_event = self.broadcast_shape | |||
| @property | |||
| def bijector(self): | |||
| @@ -108,12 +121,22 @@ class TransformedDistribution(Distribution): | |||
| @property | |||
| def dtype(self): | |||
| return self.distribution.dtype | |||
| return self._dtype | |||
| @property | |||
| def is_linear_transformation(self): | |||
| return self._is_linear_transformation | |||
| def _broadcast_bijector_dist(self): | |||
| """ | |||
| check if the batch shape of base distribution and the bijector is broadcastable. | |||
| """ | |||
| if self.batch_shape is None or self.bijector.batch_shape is None: | |||
| return None | |||
| bijector_shape_tensor = self.fill_base(self.dtype, self.bijector.batch_shape, 0.0) | |||
| dist_shape_tensor = self.fill_base(self.dtype, self.batch_shape, 0.0) | |||
| return (bijector_shape_tensor + dist_shape_tensor).shape | |||
| def _cdf(self, value, *args, **kwargs): | |||
| r""" | |||
| .. math:: | |||
| @@ -0,0 +1,191 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test cases for exp""" | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability.bijector as msb | |||
| from mindspore import Tensor | |||
| from mindspore import dtype | |||
| class MyBijector(msb.Bijector): | |||
| """ | |||
| Customized bijector class with dtype not specified. | |||
| """ | |||
| def __init__(self, param1, param2): | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'param1': param1, 'param2': param2} | |||
| super(MyBijector, self).__init__(name='MyBijector', dtype=None, param=param) | |||
| self._param1 = self._add_parameter(param1, 'param1') | |||
| self._param2 = self._add_parameter(param2, 'param2') | |||
| @property | |||
| def param1(self): | |||
| return self._param1 | |||
| @property | |||
| def param2(self): | |||
| return self._param2 | |||
| def _forward(self, value): | |||
| value = self._check_value_dtype(value) | |||
| param1_local = self.cast_param_by_value(value, self.param1) | |||
| param2_local = self.cast_param_by_value(value, self.param2) | |||
| return value * param1_local + param2_local | |||
| class MySecondBijector(msb.Bijector): | |||
| """ | |||
| Customized bijector class with dtype specified. | |||
| """ | |||
| def __init__(self, param1, param2): | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'param1': param1, 'param2': param2} | |||
| super(MySecondBijector, self).__init__(name='MySecondBijector', dtype=dtype.float32, param=param) | |||
| self._param1 = self._add_parameter(param1, 'param1') | |||
| self._param2 = self._add_parameter(param2, 'param2') | |||
| @property | |||
| def param1(self): | |||
| return self._param1 | |||
| @property | |||
| def param2(self): | |||
| return self._param2 | |||
| def _forward(self, value): | |||
| value = self._check_value_dtype(value) | |||
| param1_local = self.cast_param_by_value(value, self.param1) | |||
| param2_local = self.cast_param_by_value(value, self.param2) | |||
| return value * param1_local + param2_local | |||
| def test_arguments_same_type(): | |||
| """ | |||
| Test bijector initializations. | |||
| """ | |||
| param1_1 = np.array(1.0).astype(np.float16) | |||
| param2_1 = np.array(2.0).astype(np.float32) | |||
| with pytest.raises(TypeError): | |||
| MyBijector(param1_1, param2_1) | |||
| param1_2 = Tensor(1.0, dtype=dtype.float16) | |||
| param2_2 = Tensor(2.0, dtype=dtype.float32) | |||
| with pytest.raises(TypeError): | |||
| MyBijector(param1_2, param2_2) | |||
| with pytest.raises(TypeError): | |||
| MyBijector(True, param2_2) | |||
| with pytest.raises(TypeError): | |||
| MyBijector(None, param2_2) | |||
| param1_3 = Tensor(1.0, dtype=dtype.float32) | |||
| param2_3 = Tensor(2.0, dtype=dtype.float32) | |||
| bijector = MyBijector(param1_3, param2_3) | |||
| assert isinstance(bijector, msb.Bijector) | |||
| param1_4 = np.array([1.0, 2.0]).astype(np.float16) | |||
| param2_4 = np.array([1.0, 2.0]).astype(np.float16) | |||
| bijector = MyBijector(param1_4, param2_4) | |||
| assert isinstance(bijector, msb.Bijector) | |||
| bijector = MyBijector(1.0, 2.0) | |||
| assert isinstance(bijector, msb.Bijector) | |||
| def test_arguments_with_dtype_specified(): | |||
| """ | |||
| Customized bijector class with dtype not specified. | |||
| """ | |||
| param1_1 = np.array(1.0).astype(np.float16) | |||
| param2_1 = np.array(2.0).astype(np.float16) | |||
| with pytest.raises(TypeError): | |||
| MySecondBijector(param1_1, param2_1) | |||
| param1_2 = Tensor(1.0, dtype=dtype.float16) | |||
| param2_2 = Tensor(2.0, dtype=dtype.float32) | |||
| with pytest.raises(TypeError): | |||
| MySecondBijector(param1_2, param2_2) | |||
| with pytest.raises(TypeError): | |||
| MySecondBijector(True, param2_2) | |||
| with pytest.raises(TypeError): | |||
| MySecondBijector(None, param2_2) | |||
| param1_3 = Tensor(1.0, dtype=dtype.float32) | |||
| param2_3 = Tensor(2.0, dtype=dtype.float32) | |||
| bijector = MyBijector(param1_3, param2_3) | |||
| assert isinstance(bijector, msb.Bijector) | |||
| param1_4 = np.array(2.0).astype(np.float32) | |||
| param2_4 = np.array(1.0).astype(np.float32) | |||
| bijector = MyBijector(param1_4, param2_4) | |||
| assert isinstance(bijector, msb.Bijector) | |||
| class Net1(nn.Cell): | |||
| """ | |||
| Test input value when bijector's dtype is not specified. | |||
| """ | |||
| def __init__(self): | |||
| super(Net1, self).__init__() | |||
| self.bijector = MyBijector(np.array(1.0).astype(np.float32), np.array(2.0).astype(np.float32)) | |||
| def construct(self, value): | |||
| return self.bijector.forward(value) | |||
| class Net2(nn.Cell): | |||
| """ | |||
| Test input value when bijector's dtype is specified. | |||
| """ | |||
| def __init__(self): | |||
| super(Net2, self).__init__() | |||
| self.bijector = MySecondBijector(np.array(1.0).astype(np.float32), np.array(2.0).astype(np.float32)) | |||
| def construct(self, value): | |||
| return self.bijector.forward(value) | |||
| def test_input_value(): | |||
| """ | |||
| Test validity of input value. | |||
| """ | |||
| net = Net1() | |||
| value = None | |||
| with pytest.raises(TypeError): | |||
| ans = net(value) | |||
| value = 1.0 | |||
| with pytest.raises(TypeError): | |||
| ans = net(value) | |||
| value = Tensor(1.0, dtype=dtype.int32) | |||
| with pytest.raises(TypeError): | |||
| ans = net(value) | |||
| value = Tensor(1.0, dtype=dtype.float32) | |||
| ans = net(value) | |||
| assert ans.dtype == dtype.float32 | |||
| value = Tensor(1.0, dtype=dtype.float16) | |||
| ans = net(value) | |||
| assert ans.dtype == dtype.float16 | |||
| def test_input_value2(): | |||
| """ | |||
| Test validity of input value. | |||
| """ | |||
| net = Net2() | |||
| value = None | |||
| with pytest.raises(TypeError): | |||
| ans = net(value) | |||
| value = 1.0 | |||
| with pytest.raises(TypeError): | |||
| ans = net(value) | |||
| value = Tensor(1.0, dtype=dtype.int32) | |||
| with pytest.raises(TypeError): | |||
| ans = net(value) | |||
| value = Tensor(1.0, dtype=dtype.float16) | |||
| with pytest.raises(TypeError): | |||
| ans = net(value) | |||
| value = Tensor(1.0, dtype=dtype.float32) | |||
| ans = net(value) | |||
| assert ans.dtype == dtype.float32 | |||
| @@ -1,3 +1,7 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||