Browse Source

!12604 add shape check for sequence_mask

From: @TFbunny
Reviewed-by: @tom__chen,@robingrosman
Signed-off-by: @robingrosman
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
c8bca27340
1 changed files with 7 additions and 2 deletions
  1. +7
    -2
      mindspore/ops/composite/array_ops.py

+ 7
- 2
mindspore/ops/composite/array_ops.py View File

@@ -108,8 +108,13 @@ def repeat_elements(x, rep, axis=0):
@constexpr @constexpr
def _check_sequence_mask_input_len(input_shape): def _check_sequence_mask_input_len(input_shape):
if not input_shape: if not input_shape:
raise ValueError(f"sequence_mask input lengths_shape should be > 0. "
f"current lengths_shape is {input_shape}.")
raise ValueError(f"Sequence_mask lengths_shape should be > 0. "
f"Current lengths_shape is {input_shape}.")
# broadcast only supports 7d shape
shape_size = len(input_shape)
if shape_size >= 7:
raise ValueError(f"Sequence_mask lengths_shape's size only support a value less than 7. "
f"Current lengths_shape is {shape_size}d.")




def sequence_mask(lengths, maxlen=None): def sequence_mask(lengths, maxlen=None):


Loading…
Cancel
Save