|
|
|
@@ -972,7 +972,7 @@ class CumSum(PrimitiveWithInfer): |
|
|
|
if axis['value'] is None: |
|
|
|
raise ValueError(f"For {self.name}, axis must be const.") |
|
|
|
validator.check_value_type('axis', axis['value'], [int], cls_name) |
|
|
|
valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] |
|
|
|
valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32, mstype.float64] |
|
|
|
validator.check_tensor_dtype_valid('x', x['dtype'], valid_dtypes, cls_name) |
|
|
|
return {'shape': x_shp, |
|
|
|
'dtype': x['dtype'], |
|
|
|
|