Browse Source

Fixed some doc issues and removed context checking in CheckTensor

tags/v1.1.0
Xun Deng 5 years ago
parent
commit
b527a058f1
7 changed files with 15 additions and 19 deletions
  1. +2
    -4
      mindspore/nn/probability/bijector/bijector.py
  2. +3
    -2
      mindspore/nn/probability/bijector/scalar_affine.py
  3. +1
    -1
      mindspore/nn/probability/distribution/bernoulli.py
  4. +3
    -3
      mindspore/nn/probability/distribution/categorical.py
  5. +3
    -6
      mindspore/nn/probability/distribution/distribution.py
  6. +2
    -2
      mindspore/nn/probability/distribution/logistic.py
  7. +1
    -1
      mindspore/nn/probability/distribution/transformed_distribution.py

+ 2
- 4
mindspore/nn/probability/bijector/bijector.py View File

@@ -91,10 +91,8 @@ class Bijector(Cell):
"""
Check availability of `value` as a Tensor.
"""
if self.context_mode == 0:
self.checktensor(value, name)
return value
return self.checktensor(value, name)
self.checktensor(value, name)
return value

def cast_param_by_value(self, value, para):
local = self.cast_base(para, self.dtype_base(value))


+ 3
- 2
mindspore/nn/probability/bijector/scalar_affine.py View File

@@ -65,8 +65,6 @@ class ScalarAffine(Bijector):
'scale', scale, [int, float], type(self).__name__)
validator.check_value_type(
'shift', shift, [int, float], type(self).__name__)
self._scale = cast_to_tensor(scale)
self._shift = cast_to_tensor(shift)
super(ScalarAffine, self).__init__(
is_constant_jacobian=True,
is_injective=True,
@@ -74,6 +72,9 @@ class ScalarAffine(Bijector):
dtype=None,
param=param)

self._scale = cast_to_tensor(scale)
self._shift = cast_to_tensor(shift)

self.abs = P.Abs()
self.oneslike = P.OnesLike()
self.dtypeop = P.DType()


+ 1
- 1
mindspore/nn/probability/distribution/bernoulli.py View File

@@ -273,7 +273,7 @@ class Bernoulli(Distribution):

Args:
shape (tuple): The shape of the sample. Default: ().
probs1 (Tensor, Number): `probs1` of the samples. Default: self.probs.
probs1 (Tensor): `probs1` of the samples. Default: self.probs.

Returns:
Tensor, shape is shape + batch_shape.


+ 3
- 3
mindspore/nn/probability/distribution/categorical.py View File

@@ -47,7 +47,7 @@ class Categorical(Distribution):
>>> def __init__(self, probs):
>>> super(net, self).__init__():
>>> self.ca = msd.Categorical(probs=[0.2, 0.8], dtype=mstype.int32)
>>> self.ca1 = msd.Categorical(probs=[0.2, 0.8], dtype=mstype.int32)
>>> self.ca1 = msd.Categorical(dtype=mstype.int32)
>>>
>>> # All the following calls in construct are valid
>>> def construct(self, value):
@@ -95,8 +95,8 @@ class Categorical(Distribution):
>>> # probs (Tensor): event probabilities. Default: self.probs.
>>> ans = self.ca.sample()
>>> ans = self.ca.sample((2,3))
>>> ans = self.b1.sample((2,3), probs_b)
>>> ans = self.b2.sample((2,3), probs_a)
>>> ans = self.ca.sample((2,3), probs_b)
>>> ans = self.ca1.sample((2,3), probs_a)
"""
def __init__(self,


+ 3
- 6
mindspore/nn/probability/distribution/distribution.py View File

@@ -165,10 +165,7 @@ class Distribution(Cell):
for arg, name, default in zip(args, self.parameter_names, self.default_parameters):
# check if the argument is a Tensor
if arg is not None:
if self.context_mode == 0:
self.checktensor(arg, name)
else:
arg = self.checktensor(arg, name)
self.checktensor(arg, name)
else:
arg = default if default is not None else raise_none_error(
name)
@@ -687,8 +684,8 @@ class Distribution(Cell):

Note:
Names of supported functions include:
'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival'
'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', and 'sample'.
'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival',
'var', 'sd', 'mode', 'mean', 'entropy', 'kl_loss', 'cross_entropy', and 'sample'.

Args:
name (str): The name of the function.


+ 2
- 2
mindspore/nn/probability/distribution/logistic.py View File

@@ -94,8 +94,8 @@ class Logistic(Distribution):
>>> # scale (Tensor): the scale of the distribution. Default: self.scale.
>>> ans = self.l1.sample()
>>> ans = self.l1.sample((2,3))
>>> ans = self.l1.sample((2,3), scale_b, scale_b)
>>> ans = self.l2.sample((2,3), scale_a, scale_a)
>>> ans = self.l1.sample((2,3), loc_b, scale_b)
>>> ans = self.l2.sample((2,3), loc_a, scale_a)
"""

def __init__(self,


+ 1
- 1
mindspore/nn/probability/distribution/transformed_distribution.py View File

@@ -31,7 +31,7 @@ class TransformedDistribution(Distribution):
Args:
bijector (Bijector): The transformation to perform.
distribution (Distribution): The original distribution.
seed (int): The seed is used in sampling. The global seed is used if it is None.
seed (int): The seed is used in sampling. The global seed is used if it is None. Default:None.
If this seed is given when a TransformedDistribution object is initialised, the object's sampling function
will use this seed; elsewise, the underlying distribution's seed will be used.
name (str): The name of the transformed distribution. Default: 'transformed_distribution'.


Loading…
Cancel
Save