| @@ -153,6 +153,16 @@ class Bernoulli(Distribution): | |||||
| """ | """ | ||||
| return self._probs | return self._probs | ||||
| def _get_dist_type(self): | |||||
| return "Bernoulli" | |||||
| def _get_dist_args(self, probs1=None): | |||||
| if probs1 is not None: | |||||
| self.checktensor(probs1, 'probs') | |||||
| else: | |||||
| probs1 = self.probs | |||||
| return (probs1,) | |||||
| def _mean(self, probs1=None): | def _mean(self, probs1=None): | ||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| @@ -169,6 +169,16 @@ class Categorical(Distribution): | |||||
| """ | """ | ||||
| return self._probs | return self._probs | ||||
| def _get_dist_type(self): | |||||
| return "Categorical" | |||||
| def _get_dist_args(self, probs=None): | |||||
| if probs is not None: | |||||
| self.checktensor(probs, 'probs') | |||||
| else: | |||||
| probs = self.probs | |||||
| return (probs,) | |||||
| def _mean(self, probs=None): | def _mean(self, probs=None): | ||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| @@ -344,6 +344,33 @@ class Distribution(Cell): | |||||
| else: | else: | ||||
| self._call_cross_entropy = self._raise_not_implemented_error('cross_entropy') | self._call_cross_entropy = self._raise_not_implemented_error('cross_entropy') | ||||
| def _get_dist_args(self, *args, **kwargs): | |||||
| return raise_not_implemented_util('get_dist_args', self.name, *args, **kwargs) | |||||
| def get_dist_args(self, *args, **kwargs): | |||||
| """ | |||||
| Check the availability and validity of default parameters and `dist_spec_args`. | |||||
| Args: | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | |||||
| `dist_spec_args` must be passed in through list or dictionary. The order of `dist_spec_args` | |||||
| should follow the initialization order of default parameters through `_add_parameter`. | |||||
| If some `dist_spec_args` is None, the corresponding default parameter is returned. | |||||
| """ | |||||
| return self._get_dist_args(*args, **kwargs) | |||||
| def _get_dist_type(self, *args, **kwargs): | |||||
| return raise_not_implemented_util('get_dist_type', self.name, *args, **kwargs) | |||||
| def get_dist_type(self, *args, **kwargs): | |||||
| """ | |||||
| Return the type of the distribution. | |||||
| """ | |||||
| return self._get_dist_type(*args, **kwargs) | |||||
| def _raise_not_implemented_error(self, func_name): | def _raise_not_implemented_error(self, func_name): | ||||
| name = self.name | name = self.name | ||||
| def raise_error(*args, **kwargs): | def raise_error(*args, **kwargs): | ||||
| @@ -721,4 +748,8 @@ class Distribution(Cell): | |||||
| return self._call_cross_entropy(*args, **kwargs) | return self._call_cross_entropy(*args, **kwargs) | ||||
| if name == 'sample': | if name == 'sample': | ||||
| return self._sample(*args, **kwargs) | return self._sample(*args, **kwargs) | ||||
| if name == 'get_dist_args': | |||||
| return self._get_dist_args(*args, **kwargs) | |||||
| if name == 'get_dist_type': | |||||
| return self._get_dist_type(*args, **kwargs) | |||||
| return raise_not_implemented_util(name, self.name, *args, **kwargs) | return raise_not_implemented_util(name, self.name, *args, **kwargs) | ||||
| @@ -157,6 +157,16 @@ class Exponential(Distribution): | |||||
| """ | """ | ||||
| return self._rate | return self._rate | ||||
| def _get_dist_type(self): | |||||
| return "Exponential" | |||||
| def _get_dist_args(self, rate=None): | |||||
| if rate is not None: | |||||
| self.checktensor(rate, 'rate') | |||||
| else: | |||||
| rate = self.rate | |||||
| return (rate,) | |||||
| def _mean(self, rate=None): | def _mean(self, rate=None): | ||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| @@ -162,6 +162,16 @@ class Geometric(Distribution): | |||||
| """ | """ | ||||
| return self._probs | return self._probs | ||||
| def _get_dist_type(self): | |||||
| return "Geometric" | |||||
| def _get_dist_args(self, probs1=None): | |||||
| if probs1 is not None: | |||||
| self.checktensor(probs1, 'probs') | |||||
| else: | |||||
| probs1 = self.probs | |||||
| return (probs1,) | |||||
| def _mean(self, probs1=None): | def _mean(self, probs1=None): | ||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| @@ -109,7 +109,7 @@ class Gumbel(TransformedDistribution): | |||||
| bijector=msb.Invert(gumbel_cdf), | bijector=msb.Invert(gumbel_cdf), | ||||
| seed=seed, name=name) | seed=seed, name=name) | ||||
| self._parameter_type = gumbel_cdf.parameter_type | |||||
| self.parameter_type = gumbel_cdf.parameter_type | |||||
| self._broadcast_shape = gumbel_cdf.event_shape | self._broadcast_shape = gumbel_cdf.event_shape | ||||
| if self._broadcast_shape != (): | if self._broadcast_shape != (): | ||||
| self._is_scalar_batch = False | self._is_scalar_batch = False | ||||
| @@ -146,6 +146,20 @@ class Gumbel(TransformedDistribution): | |||||
| str_info = f'batch_shape = {self._broadcast_shape}' | str_info = f'batch_shape = {self._broadcast_shape}' | ||||
| return str_info | return str_info | ||||
| def _get_dist_type(self): | |||||
| return "Gumbel" | |||||
| def _get_dist_args(self, loc=None, scale=None): | |||||
| if loc is not None: | |||||
| self.checktensor(loc, 'loc') | |||||
| else: | |||||
| loc = self.loc | |||||
| if scale is not None: | |||||
| self.checktensor(scale, 'scale') | |||||
| else: | |||||
| scale = self.scale | |||||
| return loc, scale | |||||
| def _mean(self): | def _mean(self): | ||||
| r""" | r""" | ||||
| The mean of the distribution. | The mean of the distribution. | ||||
| @@ -161,6 +161,20 @@ class LogNormal(msd.TransformedDistribution): | |||||
| """Distribution parameter for the pre-transformed standard deviation.""" | """Distribution parameter for the pre-transformed standard deviation.""" | ||||
| return self.distribution("sd") | return self.distribution("sd") | ||||
| def _get_dist_type(self): | |||||
| return "LogNormal" | |||||
| def _get_dist_args(self, loc=None, scale=None): | |||||
| if loc is not None: | |||||
| self.checktensor(loc, 'loc') | |||||
| else: | |||||
| loc = self.distribution("mean") | |||||
| if scale is not None: | |||||
| self.checktensor(scale, 'scale') | |||||
| else: | |||||
| scale = self.distribution("sd") | |||||
| return loc, scale | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| if self.is_scalar_batch: | if self.is_scalar_batch: | ||||
| s = f'loc = {self._mean_value}, scale = {self._sd_value}' | s = f'loc = {self._mean_value}, scale = {self._sd_value}' | ||||
| @@ -175,6 +175,20 @@ class Logistic(Distribution): | |||||
| """ | """ | ||||
| return self._scale | return self._scale | ||||
| def _get_dist_type(self): | |||||
| return "Logistic" | |||||
| def _get_dist_args(self, loc=None, scale=None): | |||||
| if loc is not None: | |||||
| self.checktensor(loc, 'loc') | |||||
| else: | |||||
| loc = self.loc | |||||
| if scale is not None: | |||||
| self.checktensor(scale, 'scale') | |||||
| else: | |||||
| scale = self.scale | |||||
| return loc, scale | |||||
| def _mean(self, loc=None, scale=None): | def _mean(self, loc=None, scale=None): | ||||
| """ | """ | ||||
| The mean of the distribution. | The mean of the distribution. | ||||
| @@ -154,6 +154,20 @@ class Normal(Distribution): | |||||
| s = f'batch_shape = {self._broadcast_shape}' | s = f'batch_shape = {self._broadcast_shape}' | ||||
| return s | return s | ||||
| def _get_dist_type(self): | |||||
| return "Normal" | |||||
| def _get_dist_args(self, mean=None, sd=None): | |||||
| if mean is not None: | |||||
| self.checktensor(mean, 'mean') | |||||
| else: | |||||
| mean = self._mean_value | |||||
| if sd is not None: | |||||
| self.checktensor(sd, 'sd') | |||||
| else: | |||||
| sd = self._sd_value | |||||
| return mean, sd | |||||
| def _mean(self, mean=None, sd=None): | def _mean(self, mean=None, sd=None): | ||||
| """ | """ | ||||
| The mean of the distribution. | The mean of the distribution. | ||||
| @@ -173,6 +173,20 @@ class Uniform(Distribution): | |||||
| """ | """ | ||||
| return self._high | return self._high | ||||
| def _get_dist_type(self): | |||||
| return "Uniform" | |||||
| def _get_dist_args(self, low=None, high=None): | |||||
| if low is not None: | |||||
| self.checktensor(low, 'low') | |||||
| else: | |||||
| low = self.low | |||||
| if high is not None: | |||||
| self.checktensor(high, 'high') | |||||
| else: | |||||
| high = self.high | |||||
| return high, low | |||||
| def _range(self, low=None, high=None): | def _range(self, low=None, high=None): | ||||
| r""" | r""" | ||||
| Return the range of the distribution. | Return the range of the distribution. | ||||
| @@ -0,0 +1,101 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """test cases for Normal distribution""" | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| import mindspore.nn.probability.distribution as msd | |||||
| from mindspore import Tensor | |||||
| from mindspore import dtype | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net1(nn.Cell): | |||||
| """ | |||||
| Test class: Normal distribution. `dist_spec_args` are `mean`, `sd`. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net1, self).__init__() | |||||
| self.normal = msd.Normal(dtype=dtype.float32) | |||||
| self.normal1 = msd.Normal(0.0, 1.0, dtype=dtype.float32) | |||||
| self.normal2 = msd.Normal(3.0, 4.0, dtype=dtype.float32) | |||||
| def construct(self, value, mean, sd, mean_a, sd_a): | |||||
| args_list = self.normal.get_dist_args(mean, sd) | |||||
| prob = self.normal1.prob(value, *args_list) | |||||
| args_list1 = self.normal.get_dist_args() | |||||
| prob1 = self.normal2.prob(value, *args_list1) | |||||
| args_list2 = self.normal1.get_dist_args() | |||||
| dist_type = self.normal1.get_dist_type() | |||||
| kl_loss = self.normal2.kl_loss(dist_type, *args_list2) | |||||
| args_list3 = self.normal.get_dist_args(mean_a, sd_a) | |||||
| dist_type = self.normal1.get_dist_type() | |||||
| kl_loss1 = self.normal2.kl_loss(dist_type, *args_list3) | |||||
| return prob, prob1, kl_loss, kl_loss1 | |||||
| def test1(): | |||||
| """ | |||||
| Test Normal with two `dist_spec_args`. | |||||
| """ | |||||
| net = Net1() | |||||
| mean = Tensor(3.0, dtype=dtype.float32) | |||||
| sd = Tensor(4.0, dtype=dtype.float32) | |||||
| mean_a = Tensor(0.0, dtype=dtype.float32) | |||||
| sd_a = Tensor(1.0, dtype=dtype.float32) | |||||
| value = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) | |||||
| ans, expected, ans1, expected1 = net(value, mean, sd, mean_a, sd_a) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(ans.asnumpy() - expected.asnumpy()) < tol).all() | |||||
| assert (np.abs(ans1.asnumpy() - expected1.asnumpy()) < tol).all() | |||||
| class Net2(nn.Cell): | |||||
| """ | |||||
| Test class: Exponential distribution. `dist_spec_args` is `rate`. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net2, self).__init__() | |||||
| self.expon = msd.Exponential(dtype=dtype.float32) | |||||
| self.expon1 = msd.Exponential(1.0, dtype=dtype.float32) | |||||
| self.expon2 = msd.Exponential(2.0, dtype=dtype.float32) | |||||
| def construct(self, value, rate, rate1): | |||||
| args_list = self.expon.get_dist_args(rate) | |||||
| prob = self.expon1.prob(value, *args_list) | |||||
| args_list1 = self.expon.get_dist_args() | |||||
| prob1 = self.expon2.prob(value, *args_list1) | |||||
| args_list2 = self.expon1.get_dist_args() | |||||
| dist_type = self.expon1.get_dist_type() | |||||
| kl_loss = self.expon2.kl_loss(dist_type, *args_list2) | |||||
| args_list3 = self.expon.get_dist_args(rate1) | |||||
| dist_type = self.expon.get_dist_type() | |||||
| kl_loss1 = self.expon2.kl_loss(dist_type, *args_list3) | |||||
| return prob, prob1, kl_loss, kl_loss1 | |||||
| def test2(): | |||||
| """ | |||||
| Test Expomential with single `dist_spec_args`. | |||||
| """ | |||||
| net = Net2() | |||||
| rate = Tensor(2.0, dtype=dtype.float32) | |||||
| rate1 = Tensor(1.0, dtype=dtype.float32) | |||||
| value = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) | |||||
| ans, expected, ans1, expected1 = net(value, rate, rate1) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(ans.asnumpy() - expected.asnumpy()) < tol).all() | |||||
| assert (np.abs(ans1.asnumpy() - expected1.asnumpy()) < tol).all() | |||||
| @@ -98,6 +98,8 @@ def test_kl_cross_entropy(): | |||||
| """ | """ | ||||
| Test kl_loss and cross_entropy. | Test kl_loss and cross_entropy. | ||||
| """ | """ | ||||
| from mindspore import context | |||||
| context.set_context(device_target="Ascend") | |||||
| net = KL() | net = KL() | ||||
| loc_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | loc_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | ||||
| scale_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | scale_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | ||||