|
|
|
@@ -1146,7 +1146,7 @@ class Ones(PrimitiveWithInfer): |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self): |
|
|
|
"""Initialize Fill""" |
|
|
|
"""Initialize Ones""" |
|
|
|
|
|
|
|
def __infer__(self, dims, dtype): |
|
|
|
if isinstance(dims['value'], int): |
|
|
|
@@ -1197,7 +1197,7 @@ class Zeros(PrimitiveWithInfer): |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self): |
|
|
|
"""Initialize Fill""" |
|
|
|
"""Initialize Zeros""" |
|
|
|
|
|
|
|
def __infer__(self, dims, dtype): |
|
|
|
if isinstance(dims['value'], int): |
|
|
|
@@ -1221,6 +1221,65 @@ class Zeros(PrimitiveWithInfer): |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class SequenceMask(PrimitiveWithInfer): |
|
|
|
r""" |
|
|
|
Generates sequence mask according to input lengths. |
|
|
|
|
|
|
|
Creates a mask tensor which retains the first N elements in tensor by setting the values |
|
|
|
to be True or one. The rest values in mask are set to False or zero. |
|
|
|
|
|
|
|
Args: |
|
|
|
max_length (int): Nonnegative integer, size of the last dimension in mask. Default: None. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **lengths** (Union[tuple[int], list[int]]) - Defines the first N elements that are retained. |
|
|
|
Only constant value is allowed. |
|
|
|
- **dtype** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor. |
|
|
|
If max_length is set, the shape of the output is (lengths.shape, max_length). |
|
|
|
If max_length is not set and the biggest value in lengths is x. Then, the shape of |
|
|
|
the output is (lengths.shape, x). |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> from mindspore.ops import operations as P |
|
|
|
>>> sequence_mask = P.SequenceMask() |
|
|
|
>>> mask = sequence_mask([2, 2, 4], mindspore.int32) |
|
|
|
>>> print(mask) |
|
|
|
[[1, 1, 0, 0], |
|
|
|
[1, 1, 0, 0], |
|
|
|
[1, 1, 1, 1]] |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self): |
|
|
|
"""Initialize SequenceMask""" |
|
|
|
|
|
|
|
def __infer__(self, lengths, dtype, max_length=None): |
|
|
|
validator.check_value_type("shape", lengths['value'], [tuple, list], self.name) |
|
|
|
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, |
|
|
|
mstype.uint8, mstype.uint32, mstype.uint64, |
|
|
|
mstype.float16, mstype.float32, mstype.float64] |
|
|
|
validator.check_subclass("dtype", dtype['value'], valid_types, self.name) |
|
|
|
nptype = mstype.dtype_to_nptype(dtype['value']) |
|
|
|
if max_length is None: |
|
|
|
max_length = np.max(lengths['value']) |
|
|
|
else: |
|
|
|
validator.check_non_negative_int(max_length['value']) |
|
|
|
max_length = max_length['value'] |
|
|
|
row_vector = np.arange(0, max_length) |
|
|
|
col_matrix = np.expand_dims(lengths['value'], -1) |
|
|
|
result = (row_vector < col_matrix).astype(nptype) |
|
|
|
out = { |
|
|
|
'value': Tensor(result), |
|
|
|
'shape': result.shape, |
|
|
|
'dtype': dtype['value'] |
|
|
|
} |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class OnesLike(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
Creates a new tensor. The values of all elements are 1. |
|
|
|
|