|
|
@@ -108,8 +108,13 @@ def repeat_elements(x, rep, axis=0): |
|
|
@constexpr |
|
|
@constexpr |
|
|
def _check_sequence_mask_input_len(input_shape): |
|
|
def _check_sequence_mask_input_len(input_shape): |
|
|
if not input_shape: |
|
|
if not input_shape: |
|
|
raise ValueError(f"sequence_mask input lengths_shape should be > 0. " |
|
|
|
|
|
f"current lengths_shape is {input_shape}.") |
|
|
|
|
|
|
|
|
raise ValueError(f"Sequence_mask lengths_shape should be > 0. " |
|
|
|
|
|
f"Current lengths_shape is {input_shape}.") |
|
|
|
|
|
# broadcast only supports 7d shape |
|
|
|
|
|
shape_size = len(input_shape) |
|
|
|
|
|
if shape_size >= 7: |
|
|
|
|
|
raise ValueError(f"Sequence_mask lengths_shape's size only support a value less than 7. " |
|
|
|
|
|
f"Current lengths_shape is {shape_size}d.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sequence_mask(lengths, maxlen=None): |
|
|
def sequence_mask(lengths, maxlen=None): |
|
|
|