Browse Source

!13432 fix format error of GradOperation, gamma, poisson, etc.

From: @mind-lh
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
a4edcfa268
7 changed files with 111 additions and 86 deletions
  1. +37
    -37
      mindspore/ops/composite/base.py
  2. +16
    -0
      mindspore/ops/composite/random_ops.py
  3. +17
    -16
      mindspore/ops/operations/array_ops.py
  4. +13
    -13
      mindspore/ops/operations/math_ops.py
  5. +19
    -20
      mindspore/ops/operations/nn_ops.py
  6. +7
    -0
      mindspore/ops/operations/other_ops.py
  7. +2
    -0
      mindspore/ops/operations/random_ops.py

+ 37
- 37
mindspore/ops/composite/base.py View File

@@ -114,72 +114,72 @@ class GradOperation(GradOperation_):
To generate a gradient function that returns gradients with respect to the first input
(see `GradNetWrtX` in Examples).

1. Construct a `GradOperation` higher-order function with default arguments:
`grad_op = GradOperation()`.
1. Construct a `GradOperation` higher-order function with default arguments:
`grad_op = GradOperation()`.

2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.

3. Call the gradient function with input function's inputs to get the gradients with respect to the first input:
`grad_op(net)(x, y)`.
3. Call the gradient function with input function's inputs to get the gradients with respect to the first input:
`grad_op(net)(x, y)`.

To generate a gradient function that returns gradients with respect to all inputs (see `GradNetWrtXY` in Examples).

1. Construct a `GradOperation` higher-order function with `get_all=True` which
indicates getting gradients with respect to all inputs, they are `x` and `y` in example function `Net()`:
`grad_op = GradOperation(get_all=True)`.
1. Construct a `GradOperation` higher-order function with `get_all=True` which
indicates getting gradients with respect to all inputs, they are `x` and `y` in example function `Net()`:
`grad_op = GradOperation(get_all=True)`.

2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.

3. Call the gradient function with input function's inputs to get the gradients with respect to all inputs:
`gradient_function(x, y)`.
3. Call the gradient function with input function's inputs to get the gradients with respect to all inputs:
`gradient_function(x, y)`.

To generate a gradient function that returns gradients with respect to given parameters
(see `GradNetWithWrtParams` in Examples).

1. Construct a `GradOperation` higher-order function with `get_by_list=True`:
`grad_op = GradOperation(get_by_list=True)`.
1. Construct a `GradOperation` higher-order function with `get_by_list=True`:
`grad_op = GradOperation(get_by_list=True)`.

2. Construct a `ParameterTuple` that will be passed to the input function when constructing
`GradOperation` higher-order function, it will be used as a parameter filter that determine
which gradient to return: `params = ParameterTuple(net.trainable_params())`.
2. Construct a `ParameterTuple` that will be passed to the input function when constructing
`GradOperation` higher-order function, it will be used as a parameter filter that determine
which gradient to return: `params = ParameterTuple(net.trainable_params())`.

3. Call it with input function and `params` as arguments to get the gradient function:
`gradient_function = grad_op(net, params)`.
3. Call it with input function and `params` as arguments to get the gradient function:
`gradient_function = grad_op(net, params)`.

4. Call the gradient function with input function's inputs to get the gradients with
respect to given parameters: `gradient_function(x, y)`.
4. Call the gradient function with input function's inputs to get the gradients with
respect to given parameters: `gradient_function(x, y)`.

To generate a gradient function that returns gradients with respect to all inputs and given parameters
in the format of ((dx, dy), (dz))(see `GradNetWrtInputsAndParams` in Examples).

1. Construct a `GradOperation` higher-order function with `get_all=True` and `get_by_list=True`:
`grad_op = GradOperation(get_all=True, get_by_list=True)`.
1. Construct a `GradOperation` higher-order function with `get_all=True` and `get_by_list=True`:
`grad_op = GradOperation(get_all=True, get_by_list=True)`.

2. Construct a `ParameterTuple` that will be passed along input function when constructing
`GradOperation` higher-order function: `params = ParameterTuple(net.trainable_params())`.
2. Construct a `ParameterTuple` that will be passed along input function when constructing
`GradOperation` higher-order function: `params = ParameterTuple(net.trainable_params())`.

3. Call it with input function and `params` as arguments to get the gradient function:
`gradient_function = grad_op(net, params)`.
3. Call it with input function and `params` as arguments to get the gradient function:
`gradient_function = grad_op(net, params)`.

4. Call the gradient function with input function's inputs
to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`.
4. Call the gradient function with input function's inputs
to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`.

We can configure the sensitivity(gradient with respect to output) by setting `sens_param` as True and
passing an extra sensitivity input to the gradient function, the sensitivity input should has the
same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples).

1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`:
`grad_op = GradOperation(get_all=True, sens_param=True)`.
1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`:
`grad_op = GradOperation(get_all=True, sens_param=True)`.

2. Define `grad_wrt_output` as `sens_param` which works as the gradient with respect to output:
`grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`.
2. Define `grad_wrt_output` as `sens_param` which works as the gradient with respect to output:
`grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`.

3. Call it with input function as argument to get the gradient function:
`gradient_function = grad_op(net)`.
3. Call it with input function as argument to get the gradient function:
`gradient_function = grad_op(net)`.

4. Call the gradient function with input function's inputs and `sens_param` to
get the gradients with respect to all inputs:
`gradient_function(x, y, grad_wrt_output)`.
4. Call the gradient function with input function's inputs and `sens_param` to
get the gradients with respect to all inputs:
`gradient_function(x, y, grad_wrt_output)`.

Args:
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.


+ 16
- 0
mindspore/ops/composite/random_ops.py View File

@@ -192,6 +192,12 @@ def gamma(shape, alpha, beta, seed=None):
of `alpha` and `beta`.
The dtype is float32.

Raises:
TypeError: If `shape` is not a tuple.
TypeError: If neither `alpha` nor `beta` is a Tensor.
TypeError: If `seed` is not an int.
TypeError: If dtype of `alpha` and `beta` is not float32.

Supported Platforms:
``Ascend`` ``CPU``

@@ -228,6 +234,11 @@ def poisson(shape, mean, seed=None):
Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes of `mean`.
The dtype is float32.

Raises:
TypeError: If `shape` is not a tuple.
TypeError: If `mean` is not a Tensor whose dtype is not float32.
TypeError: If `seed` is not an int.

Supported Platforms:
``Ascend`` ``CPU``

@@ -266,6 +277,11 @@ def multinomial(inputs, num_sample, replacement=True, seed=None):
Tensor, has the same rows with input. The number of sampled indices of each row is `num_samples`.
The dtype is float32.

Raises:
TypeError: If `inputs` is not a Tensor whose dtype is not float32.
TypeError: If `num_sample` is not an int.
TypeError: If `seed` is neither an int nor a optional.

Supported Platforms:
``GPU``



+ 17
- 16
mindspore/ops/operations/array_ops.py View File

@@ -4165,13 +4165,13 @@ class SpaceToBatch(PrimitiveWithInfer):
:math:`block\_size` and :math:`paddings`. The shape of the output tensor will be :math:`(n', c', h', w')`,
where

:math:`n' = n*(block\_size*block\_size)`
:math:`n' = n*(block\_size*block\_size)`

:math:`c' = c`
:math:`c' = c`

:math:`h' = (h+paddings[0][0]+paddings[0][1])//block\_size`
:math:`h' = (h+paddings[0][0]+paddings[0][1])//block\_size`

:math:`w' = (w+paddings[1][0]+paddings[1][1])//block\_size`
:math:`w' = (w+paddings[1][0]+paddings[1][1])//block\_size`

Raises:
TypeError: If `block_size` is not an int.
@@ -4248,13 +4248,13 @@ class BatchToSpace(PrimitiveWithInfer):
Tensor, the output tensor with the same type as input. Assume input shape is (n, c, h, w) with block_size
and crops. The output shape will be (n', c', h', w'), where

:math:`n' = n//(block\_size*block\_size)`
:math:`n' = n//(block\_size*block\_size)`

:math:`c' = c`
:math:`c' = c`

:math:`h' = h*block\_size-crops[0][0]-crops[0][1]`
:math:`h' = h*block\_size-crops[0][0]-crops[0][1]`

:math:`w' = w*block\_size-crops[1][0]-crops[1][1]`
:math:`w' = w*block\_size-crops[1][0]-crops[1][1]`

Raises:
TypeError: If `block_size` or element of `crops` is not an int.
@@ -4333,18 +4333,19 @@ class SpaceToBatchND(PrimitiveWithInfer):

Inputs:
- **input_x** (Tensor) - The input tensor. It must be a 4-D tensor.

Outputs:
Tensor, the output tensor with the same data type as input. Assume input shape is :math:`(n, c, h, w)` with
:math:`block\_shape` and :math:`padddings`. The shape of the output tensor will be :math:`(n', c', h', w')`,
where

:math:`n' = n*(block\_shape[0]*block\_shape[1])`
:math:`n' = n*(block\_shape[0]*block\_shape[1])`

:math:`c' = c`
:math:`c' = c`

:math:`h' = (h+paddings[0][0]+paddings[0][1])//block\_shape[0]`
:math:`h' = (h+paddings[0][0]+paddings[0][1])//block\_shape[0]`

:math:`w' = (w+paddings[1][0]+paddings[1][1])//block\_shape[1]`
:math:`w' = (w+paddings[1][0]+paddings[1][1])//block\_shape[1]`

Raises:
TypeError: If `block_shape` is not one of list, tuple, int.
@@ -4440,13 +4441,13 @@ class BatchToSpaceND(PrimitiveWithInfer):
Tensor, the output tensor with the same type as input. Assume input shape is (n, c, h, w) with block_shape
and crops. The output shape will be (n', c', h', w'), where

:math:`n' = n//(block\_shape[0]*block\_shape[1])`
:math:`n' = n//(block\_shape[0]*block\_shape[1])`

:math:`c' = c`
:math:`c' = c`

:math:`h' = h*block\_shape[0]-crops[0][0]-crops[0][1]`
:math:`h' = h*block\_shape[0]-crops[0][0]-crops[0][1]`

:math:`w' = w*block\_shape[1]-crops[1][0]-crops[1][1]`
:math:`w' = w*block\_shape[1]-crops[1][0]-crops[1][1]`

Raises:
TypeError: If `block_shape` is not one of list, tuple, int.


+ 13
- 13
mindspore/ops/operations/math_ops.py View File

@@ -141,7 +141,7 @@ class Add(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs.

Raises:
TypeError: If neither `input_x` nor `input_y` is one of the following: Tensor, Number, bool.
TypeError: If `input_x` and `input_y` is not one of the following: Tensor, Number, bool.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@@ -1337,7 +1337,7 @@ class Sub(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs.

Raises:
TypeError: If neither `input_x` nor `input_y` is a Number or a bool or a Tensor.
TypeError: If `input_x` and `input_y` is not a Number or a bool or a Tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@@ -1387,7 +1387,7 @@ class Mul(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs.

Raises:
TypeError: If neither `input_x` nor `input_y` is one of the following: Tensor, Number, bool.
TypeError: If `input_x` and `input_y` is not one of the following: Tensor, Number, bool.
ValueError: If `input_x` and `input_y` are not the same shape.

Supported Platforms:
@@ -1434,7 +1434,7 @@ class SquaredDifference(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs.

Raises:
TypeError: if neither `input_x` nor `input_y` is a Number or a bool or a Tensor.
TypeError: if `input_x` and `input_y` is not a Number or a bool or a Tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@@ -1653,7 +1653,7 @@ class Pow(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs.

Raises:
TypeError: If neither `input_x` nor `input_y` is one of the following: Tensor, Number, bool.
TypeError: If `input_x` and `input_y` is not one of the following: Tensor, Number, bool.
ValueError: If `input_x` and `input_y` are not the same shape.

Supported Platforms:
@@ -2013,7 +2013,7 @@ class Minimum(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs.

Raises:
TypeError: If neither `input_x` nor `input_y` is one of the following: Tensor, Number, bool.
TypeError: If `input_x` and `input_y` is not one of the following: Tensor, Number, bool.
ValueError: If `input_x` and `input_y` are not the same shape.

Supported Platforms:
@@ -2060,7 +2060,7 @@ class Maximum(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs.

Raises:
TypeError: If neither `input_x` nor `input_y` is one of the following: Tensor, Number, bool.
TypeError: If `input_x` and `input_y` is not one of the following: Tensor, Number, bool.
ValueError: If `input_x` and `input_y` are not the same shape.

Supported Platforms:
@@ -2107,7 +2107,7 @@ class RealDiv(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs.

Raises:
TypeError: If neither `input_x` nor `input_y` is one of the following: Tensor, Number, bool.
TypeError: If `input_x` and `input_y` is not one of the following: Tensor, Number, bool.
ValueError: If `input_x` and `input_y` are not the same shape.

Supported Platforms:
@@ -2346,7 +2346,7 @@ class TruncateDiv(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs.

Raises:
TypeError: If neither `input_x` nor `input_y` is one of the following: Tensor, Number, bool.
TypeError: If `input_x` and `input_y` is not one of the following: Tensor, Number, bool.

Supported Platforms:
``Ascend``
@@ -2579,7 +2579,7 @@ class Xdivy(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs.

Raises:
TypeError: If neither `input_x` nor `input_y` is one of the following: Tensor, Number, bool.
TypeError: If `input_x` and `input_y` is not one of the following: Tensor, Number, bool.

Supported Platforms:
``Ascend``
@@ -2621,7 +2621,7 @@ class Xlogy(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs.

Raises:
TypeError: If neither `input_x` nor `input_y` is one of the following: Tensor, Number, bool.
TypeError: If `input_x` and `input_y` is not one of the following: Tensor, Number, bool.

Supported Platforms:
``Ascend``
@@ -2975,7 +2975,7 @@ class NotEqual(_LogicBinaryOp):
Tensor, the shape is the same as the one after broadcasting,and the data type is bool.

Raises:
TypeError: If neither `input_x` nor `input_y` is one of the following: Tensor, Number, bool.
TypeError: If `input_x` and `input_y` is not one of the following: Tensor, Number, bool.
TypeError: If neither `input_x` nor `input_y` is a Tensor.

Supported Platforms:
@@ -3109,7 +3109,7 @@ class Less(_LogicBinaryOp):
Tensor, the shape is the same as the one after broadcasting,and the data type is bool.

Raises:
TypeError: If neither `input_x` nor `input_y` is one of the following: Tensor, Number, bool.
TypeError: If `input_x` and `input_y` is not one of the following: Tensor, Number, bool.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``


+ 19
- 20
mindspore/ops/operations/nn_ops.py View File

@@ -521,7 +521,7 @@ class ReLUV2(PrimitiveWithInfer):

Raises:
TypeError: If `input_x`, `output` or `mask` is not a Tensor.
TypeError: If dtype of `output` is not same as `input_x.
TypeError: If dtype of `output` is not same as `input_x`.
TypeError: If dtype of `mask` is not unit8.

Supported Platforms:
@@ -856,8 +856,8 @@ class FusedBatchNorm(Primitive):
Raises:
TypeError: If `mode` is not an int.
TypeError: If `epsilon` or `momentum` is not a float.
TypeError: If `output_x`, `updated_scale`, `updated_bias`, `updated_moving_mean` or
`updated_moving_variance` is a Tensor.
TypeError: If `output_x`, `updated_scale`, `updated_bias`, `updated_moving_mean` or `updated_moving_variance` is
a Tensor.

Supported Platforms:
``CPU``
@@ -2894,7 +2894,7 @@ class SGD(PrimitiveWithCheck):
TypeError: If `nesterov` is not a bool.
TypeError: If `parameters`, `gradient`, `learning_rate`, `accum`, `momentum` or `stat` is not a Tensor.
TypeError: If dtype of `parameters`, `gradient`, `learning_rate`, `accum`, `momentum` or `stat` is neither
float16 nor float32.
float16 nor float32.

Supported Platforms:
``Ascend`` ``GPU``
@@ -4413,7 +4413,7 @@ class Adam(PrimitiveWithInfer):
Raises:
TypeError: If neither `use_locking` nor `use_nesterov` is a bool.
TypeError: If `var`, `m` or `v` is not a Tensor.
TypeError: If `beta1_power`, `beta2_power1, `lr`, `beta1`, `beta2`, `epsilon` or `gradient` is not a Tensor.
TypeError: If `beta1_power`, `beta2_power1`, `lr`, `beta1`, `beta2`, `epsilon` or `gradient` is not a Tensor.

Supported Platforms:
``Ascend`` ``GPU``
@@ -4523,8 +4523,8 @@ class AdamNoUpdateParam(PrimitiveWithInfer):

Raises:
TypeError: If neither `use_locking` nor `use_nesterov` is a bool.
TypeError: If `m`, `v`, `beta1_power`, `beta2_power1, `lr`,
`beta1`, `beta2`, `epsilon` or `gradient` is a Tensor.
TypeError: If `m`, `v`, `beta1_power`, `beta2_power1`, `lr`, `beta1`, `beta2`, `epsilon` or `gradient` is not a
Tensor.

Supported Platforms:
``CPU``
@@ -4644,7 +4644,7 @@ class FusedSparseAdam(PrimitiveWithInfer):
Raises:
TypeError: If neither `use_locking` nor `use_neserov` is a bool.
TypeError: If dtype of `var`, `m`, `v`, `beta1_power`, `beta2_power`, `lr`, `beta1`, `beta2`, `epsilon`,
`gradient` or `indices` is not float32.
`gradient` or `indices` is not float32.

Supported Platforms:
``CPU``
@@ -4788,7 +4788,7 @@ class FusedSparseLazyAdam(PrimitiveWithInfer):
Raises:
TypeError: If neither `use_locking` nor `use_nestrov` is a bool.
TypeError: If dtype of `var`, `m`, `v`, `beta1_power`, `beta2_power`, `lr`, `beta1`, `beta2`, `epsilon` or
gradient is not float32.
gradient is not float32.
TypeError: If dtype of `indices` is not int32.

Supported Platforms:
@@ -5316,8 +5316,8 @@ class ApplyAdaMax(PrimitiveWithInfer):
- **v** (Tensor) - The same shape and data type as `v`.

Raises:
TypeError: If dtype of `var`, `m`, `v`, `beta_power`, `lr`, `beta1`,
`beta2`, `epsilon` or `grad` is neither float16 nor float32.
TypeError: If dtype of `var`, `m`, `v`, `beta_power`, `lr`, `beta1`, `beta2`, `epsilon` or `grad` is neither
float16 nor float32.
TypeError: If `beta_power`, `lr`, `beta1`, `beta2` or `epsilon` is neither a Number nor a Tensor.
TypeError: If `grad` is not a Tensor.

@@ -5456,8 +5456,8 @@ class ApplyAdadelta(PrimitiveWithInfer):
- **accum_update** (Tensor) - The same shape and data type as `accum_update`.

Raises:
TypeError: If dtype of `var`, `accum`, `accum_update`,
`lr`, `rho`, `epsilon` or `grad` is neither float16 nor float32.
TypeError: If dtype of `var`, `accum`, `accum_update`, `lr`, `rho`, `epsilon` or `grad` is neither float16 nor
float32.
TypeError: If `accum_update`, `lr`, `rho` or `epsilon` is neither a Number nor a Tensor.

Supported Platforms:
@@ -6109,8 +6109,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck):

Raises:
TypeError: If `use_locking` is not a bool.
TypeError: If dtype of `var`, `accum`, `lr`, `l1`, `l2`, `scalar` or `grad` is neither float16
nor float32.
TypeError: If dtype of `var`, `accum`, `lr`, `l1`, `l2`, `scalar` or `grad` is neither float16 nor float32.
TypeError: If dtype of `indices` is neither int32 nor int64.

Supported Platforms:
@@ -6940,7 +6939,7 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
- **linear** (Tensor) - Tensor, has the same shape and data type as `linear`.

Raises:
TypeError: If `lr`, `l1`, `l2`, `lr_power` or`use_locking` is not a float.
TypeError: If `lr`, `l1`, `l2`, `lr_power` or `use_locking` is not a float.
TypeError: If `use_locking` is not a bool.
TypeError: If dtype of `var`, `accum`, `linear` or `grad` is neither float16 nor float32.
TypeError: If dtype of `indices` is not int32.
@@ -7215,8 +7214,8 @@ class CTCLoss(PrimitiveWithInfer):
- **gradient** (Tensor) - The gradient of `loss`, has the same type and shape with `inputs`.

Raises:
TypeError: If `preprocess_collapse_repeated`, `ctc_merge_repeated` or `ignore_longer_outputs_than_inputs` is
not a bool.
TypeError: If `preprocess_collapse_repeated`, `ctc_merge_repeated` or `ignore_longer_outputs_than_inputs` is not
a bool.
TypeError: If `inputs`, `labels_indices`, `labels_values` or `sequence_length` is not a Tensor.
TypeError: If dtype of `inputs` is not one of the following: float16, float32 or float64.
TypeError: If dtype of `labels_indices` is not int64.
@@ -7708,8 +7707,8 @@ class DynamicGRUV2(PrimitiveWithInfer):
TypeError: If `cell_depth` or `num_proj` is not an int.
TypeError: If `keep_prob` or `cell_clip` is not a float.
TypeError: If `time_major`, `reset_after` or `is_training` is not a bool.
TypeError: If `x`, `weight_input`, `weight_hidden`, `bias_input`, `bias_hidden`, `seq_length` or `ini_h` is
not a Tensor.
TypeError: If `x`, `weight_input`, `weight_hidden`, `bias_input`, `bias_hidden`, `seq_length` or `ini_h` is not
a Tensor.
TypeError: If dtype of `x`, `weight_input` or `weight_hidden` is not float16.
TypeError: If dtype of `init_h` is neither float16 nor float32.



+ 7
- 0
mindspore/ops/operations/other_ops.py View File

@@ -38,6 +38,10 @@ class Assign(PrimitiveWithCheck):
Outputs:
Tensor, has the same type as original `variable`.

Raises:
TypeError: If `variable` is not a Parameter.
TypeError: If `value` is not a Tensor.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -630,6 +634,9 @@ class PopulationCount(PrimitiveWithInfer):
Outputs:
Tensor, with the same shape as the input.

Raises:
TypeError: If `input` is not a Tensor.

Supported Platforms:
``Ascend``



+ 2
- 0
mindspore/ops/operations/random_ops.py View File

@@ -529,9 +529,11 @@ class Multinomial(PrimitiveWithInfer):
Note:
The rows of input do not need to sum to one (in which case we use the values as weights),
but must be non-negative, finite and have a non-zero sum.

Args:
seed (int): Random seed, must be non-negative. Default: 0.
seed2 (int): Random seed2, must be non-negative. Default: 0.

Inputs:
- **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2
dimensions.


Loading…
Cancel
Save