Browse Source

!10463 SequenceMask doc fix and input check

From: @peilin-wang
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
45da2d7aba
2 changed files with 5 additions and 1 deletions
  1. +4
    -0
      mindspore/core/abstract/prim_arrays.cc
  2. +1
    -1
      mindspore/ops/composite/array_ops.py

+ 4
- 0
mindspore/core/abstract/prim_arrays.cc View File

@@ -886,6 +886,10 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive
maxlen_value = *static_cast<int64_t *>(maxlen_tensor->data_c());
}

if (maxlen_value <= 0) {
MS_LOG(EXCEPTION) << "maxlen must be positive, but got: " << maxlen_value;
}

ShapeVector lengths_shape = lengths->shape()->shape();
ShapeVector lengths_shape_min = lengths->shape()->min_shape();
if (lengths_shape_min.empty()) {


+ 1
- 1
mindspore/ops/composite/array_ops.py View File

@@ -114,7 +114,7 @@ def sequence_mask(lengths, maxlen):

Args:
length (Tensor): Tensor to calculate the mask for. All values in this tensor must be
less than `maxlen`. Must be type int32 or int64.
less than or equal to `maxlen`. Must be type int32 or int64.

maxlen (int): size of the last dimension of returned tensor. Must be positive and same
type as elements in `lengths`.


Loading…
Cancel
Save