Browse Source

fix parameter cast issue

tags/v0.5.0-beta
kingfo 5 years ago
parent
commit
396c18b92b
1 changed files with 6 additions and 1 deletions
  1. +6
    -1
      mindspore/ops/operations/array_ops.py

+ 6
- 1
mindspore/ops/operations/array_ops.py View File

@@ -28,6 +28,7 @@ from ..._checkparam import Validator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.tensor import Tensor from ...common.tensor import Tensor
from ...common.parameter import Parameter
from ..operations.math_ops import _infer_shape_reduce from ..operations.math_ops import _infer_shape_reduce
from .._utils import get_concat_offset from .._utils import get_concat_offset
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op
@@ -186,11 +187,15 @@ class Cast(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])


def check_elim(self, x, dtype): def check_elim(self, x, dtype):
if isinstance(x, (Tensor, numbers.Number)):
if isinstance(x, (Tensor, numbers.Number, Parameter)):
if isinstance(x, Tensor) and x.dtype == dtype: if isinstance(x, Tensor) and x.dtype == dtype:
return (True, x) return (True, x)
if isinstance(x, numbers.Number): if isinstance(x, numbers.Number):
return (True, Tensor(x, dtype=dtype)) return (True, Tensor(x, dtype=dtype))
if isinstance(x, Parameter):
data = x.default_input
if data.dtype == dtype:
return (True, x)
return (False, None) return (False, None)
raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})") raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})")




Loading…
Cancel
Save