|
|
|
@@ -1091,7 +1091,7 @@ class Fill(PrimitiveWithInfer): |
|
|
|
for i, item in enumerate(dims['value']): |
|
|
|
validator.check_positive_int(item, f'dims[{i}]', self.name) |
|
|
|
valid_dtypes = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, |
|
|
|
mstype.uint8, mstype.uint32, mstype.uint64, |
|
|
|
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, |
|
|
|
mstype.float16, mstype.float32, mstype.float64] |
|
|
|
validator.check_types_same_and_valid({"value": dtype['value']}, valid_dtypes, self.name) |
|
|
|
x_nptype = mstype.dtype_to_nptype(dtype['value']) |
|
|
|
@@ -1144,7 +1144,7 @@ class Ones(PrimitiveWithInfer): |
|
|
|
for i, item in enumerate(shape): |
|
|
|
validator.check_non_negative_int(item, shape[i], self.name) |
|
|
|
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, |
|
|
|
mstype.uint8, mstype.uint32, mstype.uint64, |
|
|
|
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, |
|
|
|
mstype.float16, mstype.float32, mstype.float64] |
|
|
|
validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name) |
|
|
|
x_nptype = mstype.dtype_to_nptype(dtype['value']) |
|
|
|
@@ -1198,7 +1198,7 @@ class Zeros(PrimitiveWithInfer): |
|
|
|
for i, item in enumerate(shape): |
|
|
|
validator.check_non_negative_int(item, shape[i], self.name) |
|
|
|
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, |
|
|
|
mstype.uint8, mstype.uint32, mstype.uint64, |
|
|
|
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, |
|
|
|
mstype.float16, mstype.float32, mstype.float64] |
|
|
|
validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name) |
|
|
|
x_nptype = mstype.dtype_to_nptype(dtype['value']) |
|
|
|
@@ -1253,7 +1253,7 @@ class SequenceMask(PrimitiveWithInfer): |
|
|
|
def __infer__(self, lengths, dtype, max_length=None): |
|
|
|
validator.check_value_type("shape", lengths['value'], [tuple, list], self.name) |
|
|
|
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, |
|
|
|
mstype.uint8, mstype.uint32, mstype.uint64, |
|
|
|
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, |
|
|
|
mstype.float16, mstype.float32, mstype.float64] |
|
|
|
validator.check_subclass("dtype", dtype['value'], valid_types, self.name) |
|
|
|
nptype = mstype.dtype_to_nptype(dtype['value']) |
|
|
|
|