You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

random_ops.py 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Operators for random."""
  16. from ..._checkparam import Validator as validator
  17. from ..._checkparam import Rel
  18. from ...common import dtype as mstype
  19. from ..primitive import PrimitiveWithInfer, prim_attr_register
  20. class RandomChoiceWithMask(PrimitiveWithInfer):
  21. """
  22. Generates a random samply as index tensor with a mask tensor from a given tensor.
  23. The input must be a tensor of rank >= 1. If its rank >= 2, the first dimension specify the number of sample.
  24. The index tensor and the mask tensor have the fixed shapes. The index tensor denotes the index of the nonzero
  25. sample, while the mask tensor denotes which elements in the index tensor are valid.
  26. Args:
  27. count (int): Number of items expected to get and the number should be greater than 0. Default: 256.
  28. seed (int): Random seed. Default: 0.
  29. seed2 (int): Random seed2. Default: 0.
  30. Inputs:
  31. - **input_x** (Tensor[bool]) - The input tensor.
  32. Outputs:
  33. Two tensors, the first one is the index tensor and the other one is the mask tensor.
  34. - **index** (Tensor) - The output has shape between 2-D and 5-D.
  35. - **mask** (Tensor) - The output has shape 1-D.
  36. Examples:
  37. >>> rnd_choice_mask = P.RandomChoiceWithMask()
  38. >>> input_x = Tensor(np.ones(shape=[240000, 4]).astype(np.bool))
  39. >>> output_y, output_mask = rnd_choice_mask(input_x)
  40. """
  41. @prim_attr_register
  42. def __init__(self, count=256, seed=0, seed2=0):
  43. """Init RandomChoiceWithMask"""
  44. validator.check_value_type("count", count, [int], self.name)
  45. validator.check_integer("count", count, 0, Rel.GT, self.name)
  46. validator.check_value_type('seed', seed, [int], self.name)
  47. validator.check_value_type('seed2', seed2, [int], self.name)
  48. def infer_shape(self, x_shape):
  49. validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name)
  50. return ([self.count, len(x_shape)], [self.count])
  51. def infer_dtype(self, x_dtype):
  52. validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name)
  53. return (mstype.int32, mstype.bool_)