Browse Source

change cumsum op python to allow for float64 input

pull/13648/head
Peilin Wang 4 years ago
parent
commit
c0ebc36c78
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      mindspore/ops/operations/math_ops.py

+ 1
- 1
mindspore/ops/operations/math_ops.py View File

@@ -972,7 +972,7 @@ class CumSum(PrimitiveWithInfer):
if axis['value'] is None:
raise ValueError(f"For {self.name}, axis must be const.")
validator.check_value_type('axis', axis['value'], [int], cls_name)
valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32, mstype.float64]
validator.check_tensor_dtype_valid('x', x['dtype'], valid_dtypes, cls_name)
return {'shape': x_shp,
'dtype': x['dtype'],


Loading…
Cancel
Save