|
|
|
@@ -31,7 +31,7 @@ class TransformedDistribution(Distribution): |
|
|
|
|
|
|
|
Args: |
|
|
|
bijector (Bijector): The transformation to perform. |
|
|
|
distribution (Distribution): The original distribution. Must has dtype of mindspore.float_type. |
|
|
|
distribution (Distribution): The original distribution. Must has a float dtype. |
|
|
|
seed (int): The seed is used in sampling. The global seed is used if it is None. Default:None. |
|
|
|
If this seed is given when a TransformedDistribution object is initialised, the object's sampling function |
|
|
|
will use this seed; elsewise, the underlying distribution's seed will be used. |
|
|
|
@@ -42,7 +42,7 @@ class TransformedDistribution(Distribution): |
|
|
|
|
|
|
|
Note: |
|
|
|
The arguments used to initialize the original distribution cannot be None. |
|
|
|
For example, mynormal = nn.Normal(dtype=dtyple.float32) cannot be used to initialized a |
|
|
|
For example, mynormal = msd.Normal(dtype=mindspore.float32) cannot be used to initialized a |
|
|
|
TransformedDistribution since `mean` and `sd` are not specified. |
|
|
|
`batch_shape` is the batch_shape of the original distribution. |
|
|
|
`broadcast_shape` is the broadcast shape between the original distribution and bijector. |
|
|
|
|