| @@ -25,20 +25,23 @@ class RandomChoiceWithMask(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| Generates a random samply as index tensor with a mask tensor from a given tensor. | Generates a random samply as index tensor with a mask tensor from a given tensor. | ||||
| The input must be a tensor of rank >= 2, the first dimension specify the number of sample. | |||||
| The index tensor and the mask tensor have the same and fixed shape. The index tensor denotes the index | |||||
| of the nonzero sample, while the mask tensor denotes which element in the index tensor are valid. | |||||
| The input must be a tensor of rank >= 1. If its rank >= 2, the first dimension specify the number of sample. | |||||
| The index tensor and the mask tensor have the fixed shapes. The index tensor denotes the index of the nonzero | |||||
| sample, while the mask tensor denotes which elements in the index tensor are valid. | |||||
| Args: | Args: | ||||
| count (int): Number of items expected to get. Default: 256. | |||||
| seed (int): Random seed. | |||||
| seed2 (int): Random seed2. | |||||
| count (int): Number of items expected to get and the number should be greater than 0. Default: 256. | |||||
| seed (int): Random seed. Default: 0. | |||||
| seed2 (int): Random seed2. Default: 0. | |||||
| Inputs: | Inputs: | ||||
| - **input_x** (Tensor) - The input tensor. | |||||
| - **input_x** (Tensor[bool]) - The input tensor. | |||||
| Outputs: | Outputs: | ||||
| Tuple, two tensors, the first one is the index tensor and the other one is the mask tensor. | |||||
| Two tensors, the first one is the index tensor and the other one is the mask tensor. | |||||
| - **index** (Tensor) - The output has shape between 2-D and 5-D. | |||||
| - **mask** (Tensor) - The output has shape 1-D. | |||||
| Examples: | Examples: | ||||
| >>> rnd_choice_mask = RandomChoiceWithMask() | >>> rnd_choice_mask = RandomChoiceWithMask() | ||||