Browse Source

gpu floatstatus add type check

tags/v0.3.0-alpha
VectorSL 5 years ago
parent
commit
2b51199054
1 changed files with 1 additions and 0 deletions
  1. +1
    -0
      mindspore/ops/operations/math_ops.py

+ 1
- 0
mindspore/ops/operations/math_ops.py View File

@@ -1712,6 +1712,7 @@ class FloatStatus(PrimitiveWithInfer):
return [1]

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32, mstype.float16], self.name)
return x_dtype

class NPUAllocFloatStatus(PrimitiveWithInfer):


Loading…
Cancel
Save