| @@ -153,6 +153,16 @@ class Bernoulli(Distribution): | |||
| """ | |||
| return self._probs | |||
| def _get_dist_type(self): | |||
| return "Bernoulli" | |||
| def _get_dist_args(self, probs1=None): | |||
| if probs1 is not None: | |||
| self.checktensor(probs1, 'probs') | |||
| else: | |||
| probs1 = self.probs | |||
| return (probs1,) | |||
| def _mean(self, probs1=None): | |||
| r""" | |||
| .. math:: | |||
| @@ -169,6 +169,16 @@ class Categorical(Distribution): | |||
| """ | |||
| return self._probs | |||
| def _get_dist_type(self): | |||
| return "Categorical" | |||
| def _get_dist_args(self, probs=None): | |||
| if probs is not None: | |||
| self.checktensor(probs, 'probs') | |||
| else: | |||
| probs = self.probs | |||
| return (probs,) | |||
| def _mean(self, probs=None): | |||
| r""" | |||
| .. math:: | |||
| @@ -344,6 +344,33 @@ class Distribution(Cell): | |||
| else: | |||
| self._call_cross_entropy = self._raise_not_implemented_error('cross_entropy') | |||
| def _get_dist_args(self, *args, **kwargs): | |||
| return raise_not_implemented_util('get_dist_args', self.name, *args, **kwargs) | |||
| def get_dist_args(self, *args, **kwargs): | |||
| """ | |||
| Check the availability and validity of default parameters and `dist_spec_args`. | |||
| Args: | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| `dist_spec_args` must be passed in through list or dictionary. The order of `dist_spec_args` | |||
| should follow the initialization order of default parameters through `_add_parameter`. | |||
| If some `dist_spec_args` is None, the corresponding default parameter is returned. | |||
| """ | |||
| return self._get_dist_args(*args, **kwargs) | |||
| def _get_dist_type(self, *args, **kwargs): | |||
| return raise_not_implemented_util('get_dist_type', self.name, *args, **kwargs) | |||
| def get_dist_type(self, *args, **kwargs): | |||
| """ | |||
| Return the type of the distribution. | |||
| """ | |||
| return self._get_dist_type(*args, **kwargs) | |||
| def _raise_not_implemented_error(self, func_name): | |||
| name = self.name | |||
| def raise_error(*args, **kwargs): | |||
| @@ -721,4 +748,8 @@ class Distribution(Cell): | |||
| return self._call_cross_entropy(*args, **kwargs) | |||
| if name == 'sample': | |||
| return self._sample(*args, **kwargs) | |||
| if name == 'get_dist_args': | |||
| return self._get_dist_args(*args, **kwargs) | |||
| if name == 'get_dist_type': | |||
| return self._get_dist_type(*args, **kwargs) | |||
| return raise_not_implemented_util(name, self.name, *args, **kwargs) | |||
| @@ -157,6 +157,16 @@ class Exponential(Distribution): | |||
| """ | |||
| return self._rate | |||
| def _get_dist_type(self): | |||
| return "Exponential" | |||
| def _get_dist_args(self, rate=None): | |||
| if rate is not None: | |||
| self.checktensor(rate, 'rate') | |||
| else: | |||
| rate = self.rate | |||
| return (rate,) | |||
| def _mean(self, rate=None): | |||
| r""" | |||
| .. math:: | |||
| @@ -162,6 +162,16 @@ class Geometric(Distribution): | |||
| """ | |||
| return self._probs | |||
| def _get_dist_type(self): | |||
| return "Geometric" | |||
| def _get_dist_args(self, probs1=None): | |||
| if probs1 is not None: | |||
| self.checktensor(probs1, 'probs') | |||
| else: | |||
| probs1 = self.probs | |||
| return (probs1,) | |||
| def _mean(self, probs1=None): | |||
| r""" | |||
| .. math:: | |||
| @@ -109,7 +109,7 @@ class Gumbel(TransformedDistribution): | |||
| bijector=msb.Invert(gumbel_cdf), | |||
| seed=seed, name=name) | |||
| self._parameter_type = gumbel_cdf.parameter_type | |||
| self.parameter_type = gumbel_cdf.parameter_type | |||
| self._broadcast_shape = gumbel_cdf.event_shape | |||
| if self._broadcast_shape != (): | |||
| self._is_scalar_batch = False | |||
| @@ -146,6 +146,20 @@ class Gumbel(TransformedDistribution): | |||
| str_info = f'batch_shape = {self._broadcast_shape}' | |||
| return str_info | |||
| def _get_dist_type(self): | |||
| return "Gumbel" | |||
| def _get_dist_args(self, loc=None, scale=None): | |||
| if loc is not None: | |||
| self.checktensor(loc, 'loc') | |||
| else: | |||
| loc = self.loc | |||
| if scale is not None: | |||
| self.checktensor(scale, 'scale') | |||
| else: | |||
| scale = self.scale | |||
| return loc, scale | |||
| def _mean(self): | |||
| r""" | |||
| The mean of the distribution. | |||
| @@ -161,6 +161,20 @@ class LogNormal(msd.TransformedDistribution): | |||
| """Distribution parameter for the pre-transformed standard deviation.""" | |||
| return self.distribution("sd") | |||
| def _get_dist_type(self): | |||
| return "LogNormal" | |||
| def _get_dist_args(self, loc=None, scale=None): | |||
| if loc is not None: | |||
| self.checktensor(loc, 'loc') | |||
| else: | |||
| loc = self.distribution("mean") | |||
| if scale is not None: | |||
| self.checktensor(scale, 'scale') | |||
| else: | |||
| scale = self.distribution("sd") | |||
| return loc, scale | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| s = f'loc = {self._mean_value}, scale = {self._sd_value}' | |||
| @@ -175,6 +175,20 @@ class Logistic(Distribution): | |||
| """ | |||
| return self._scale | |||
| def _get_dist_type(self): | |||
| return "Logistic" | |||
| def _get_dist_args(self, loc=None, scale=None): | |||
| if loc is not None: | |||
| self.checktensor(loc, 'loc') | |||
| else: | |||
| loc = self.loc | |||
| if scale is not None: | |||
| self.checktensor(scale, 'scale') | |||
| else: | |||
| scale = self.scale | |||
| return loc, scale | |||
| def _mean(self, loc=None, scale=None): | |||
| """ | |||
| The mean of the distribution. | |||
| @@ -154,6 +154,20 @@ class Normal(Distribution): | |||
| s = f'batch_shape = {self._broadcast_shape}' | |||
| return s | |||
| def _get_dist_type(self): | |||
| return "Normal" | |||
| def _get_dist_args(self, mean=None, sd=None): | |||
| if mean is not None: | |||
| self.checktensor(mean, 'mean') | |||
| else: | |||
| mean = self._mean_value | |||
| if sd is not None: | |||
| self.checktensor(sd, 'sd') | |||
| else: | |||
| sd = self._sd_value | |||
| return mean, sd | |||
| def _mean(self, mean=None, sd=None): | |||
| """ | |||
| The mean of the distribution. | |||
| @@ -173,6 +173,20 @@ class Uniform(Distribution): | |||
| """ | |||
| return self._high | |||
| def _get_dist_type(self): | |||
| return "Uniform" | |||
| def _get_dist_args(self, low=None, high=None): | |||
| if low is not None: | |||
| self.checktensor(low, 'low') | |||
| else: | |||
| low = self.low | |||
| if high is not None: | |||
| self.checktensor(high, 'high') | |||
| else: | |||
| high = self.high | |||
| return high, low | |||
| def _range(self, low=None, high=None): | |||
| r""" | |||
| Return the range of the distribution. | |||
| @@ -0,0 +1,101 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test cases for Normal distribution""" | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability.distribution as msd | |||
| from mindspore import Tensor | |||
| from mindspore import dtype | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Net1(nn.Cell): | |||
| """ | |||
| Test class: Normal distribution. `dist_spec_args` are `mean`, `sd`. | |||
| """ | |||
| def __init__(self): | |||
| super(Net1, self).__init__() | |||
| self.normal = msd.Normal(dtype=dtype.float32) | |||
| self.normal1 = msd.Normal(0.0, 1.0, dtype=dtype.float32) | |||
| self.normal2 = msd.Normal(3.0, 4.0, dtype=dtype.float32) | |||
| def construct(self, value, mean, sd, mean_a, sd_a): | |||
| args_list = self.normal.get_dist_args(mean, sd) | |||
| prob = self.normal1.prob(value, *args_list) | |||
| args_list1 = self.normal.get_dist_args() | |||
| prob1 = self.normal2.prob(value, *args_list1) | |||
| args_list2 = self.normal1.get_dist_args() | |||
| dist_type = self.normal1.get_dist_type() | |||
| kl_loss = self.normal2.kl_loss(dist_type, *args_list2) | |||
| args_list3 = self.normal.get_dist_args(mean_a, sd_a) | |||
| dist_type = self.normal1.get_dist_type() | |||
| kl_loss1 = self.normal2.kl_loss(dist_type, *args_list3) | |||
| return prob, prob1, kl_loss, kl_loss1 | |||
| def test1(): | |||
| """ | |||
| Test Normal with two `dist_spec_args`. | |||
| """ | |||
| net = Net1() | |||
| mean = Tensor(3.0, dtype=dtype.float32) | |||
| sd = Tensor(4.0, dtype=dtype.float32) | |||
| mean_a = Tensor(0.0, dtype=dtype.float32) | |||
| sd_a = Tensor(1.0, dtype=dtype.float32) | |||
| value = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) | |||
| ans, expected, ans1, expected1 = net(value, mean, sd, mean_a, sd_a) | |||
| tol = 1e-6 | |||
| assert (np.abs(ans.asnumpy() - expected.asnumpy()) < tol).all() | |||
| assert (np.abs(ans1.asnumpy() - expected1.asnumpy()) < tol).all() | |||
| class Net2(nn.Cell): | |||
| """ | |||
| Test class: Exponential distribution. `dist_spec_args` is `rate`. | |||
| """ | |||
| def __init__(self): | |||
| super(Net2, self).__init__() | |||
| self.expon = msd.Exponential(dtype=dtype.float32) | |||
| self.expon1 = msd.Exponential(1.0, dtype=dtype.float32) | |||
| self.expon2 = msd.Exponential(2.0, dtype=dtype.float32) | |||
| def construct(self, value, rate, rate1): | |||
| args_list = self.expon.get_dist_args(rate) | |||
| prob = self.expon1.prob(value, *args_list) | |||
| args_list1 = self.expon.get_dist_args() | |||
| prob1 = self.expon2.prob(value, *args_list1) | |||
| args_list2 = self.expon1.get_dist_args() | |||
| dist_type = self.expon1.get_dist_type() | |||
| kl_loss = self.expon2.kl_loss(dist_type, *args_list2) | |||
| args_list3 = self.expon.get_dist_args(rate1) | |||
| dist_type = self.expon.get_dist_type() | |||
| kl_loss1 = self.expon2.kl_loss(dist_type, *args_list3) | |||
| return prob, prob1, kl_loss, kl_loss1 | |||
| def test2(): | |||
| """ | |||
| Test Expomential with single `dist_spec_args`. | |||
| """ | |||
| net = Net2() | |||
| rate = Tensor(2.0, dtype=dtype.float32) | |||
| rate1 = Tensor(1.0, dtype=dtype.float32) | |||
| value = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) | |||
| ans, expected, ans1, expected1 = net(value, rate, rate1) | |||
| tol = 1e-6 | |||
| assert (np.abs(ans.asnumpy() - expected.asnumpy()) < tol).all() | |||
| assert (np.abs(ans1.asnumpy() - expected1.asnumpy()) < tol).all() | |||
| @@ -98,6 +98,8 @@ def test_kl_cross_entropy(): | |||
| """ | |||
| Test kl_loss and cross_entropy. | |||
| """ | |||
| from mindspore import context | |||
| context.set_context(device_target="Ascend") | |||
| net = KL() | |||
| loc_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | |||
| scale_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | |||