| @@ -1745,11 +1745,14 @@ class NMSWithMask(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) | self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) | ||||
| def infer_shape(self, bboxes_shape): | def infer_shape(self, bboxes_shape): | ||||
| validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ) | |||||
| validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT) | |||||
| validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ) | |||||
| num = bboxes_shape[0] | num = bboxes_shape[0] | ||||
| validator.check_integer("bboxes_shape[0]", num, 0, Rel.GT) | |||||
| return (bboxes_shape, (num,), (num,)) | return (bboxes_shape, (num,), (num,)) | ||||
| def infer_dtype(self, bboxes_dtype): | def infer_dtype(self, bboxes_dtype): | ||||
| validator.check_subclass("bboxes_dtype", bboxes_dtype, mstype.tensor) | |||||
| validator.check_typename("bboxes_dtype", bboxes_dtype, [mstype.float16, mstype.float32]) | validator.check_typename("bboxes_dtype", bboxes_dtype, [mstype.float16, mstype.float32]) | ||||
| return (bboxes_dtype, mstype.int32, mstype.bool_) | return (bboxes_dtype, mstype.int32, mstype.bool_) | ||||