Browse Source

fix cast check elim

tags/v0.5.0-beta
BowenK 5 years ago
parent
commit
35a57e076d
1 changed files with 5 additions and 3 deletions
  1. +5
    -3
      mindspore/ops/operations/array_ops.py

+ 5
- 3
mindspore/ops/operations/array_ops.py View File

@@ -186,11 +186,13 @@ class Cast(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])

def check_elim(self, x, dtype):
if isinstance(x, Tensor):
if x.dtype == dtype:
if isinstance(x, (Tensor, numbers.Number)):
if isinstance(x, Tensor) and x.dtype == dtype:
return (True, x)
if isinstance(x, numbers.Number):
return (True, Tensor(x, dtype=dtype))
return (False, None)
raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs))
raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})")

def __infer__(self, x, t):
src_type = x['dtype']


Loading…
Cancel
Save