| @@ -15,6 +15,9 @@ | |||||
| """inner_ops""" | """inner_ops""" | ||||
| import numbers | |||||
| from ..._checkparam import Validator as validator | |||||
| from ..._checkparam import Rel | |||||
| from ...common.dtype import tensor, dtype_to_pytype | from ...common.dtype import tensor, dtype_to_pytype | ||||
| from ..primitive import prim_attr_register, PrimitiveWithInfer | from ..primitive import prim_attr_register, PrimitiveWithInfer | ||||
| @@ -40,8 +43,10 @@ class ScalarCast(PrimitiveWithInfer): | |||||
| pass | pass | ||||
| def __infer__(self, x, t): | def __infer__(self, x, t): | ||||
| validator.check_integer('x shape', len(x['shape']), 0, Rel.EQ, self.name) | |||||
| value, to = x['value'], t['value'] | value, to = x['value'], t['value'] | ||||
| if value is not None: | if value is not None: | ||||
| validator.check_value_type("value", value, [numbers.Number, bool], self.name) | |||||
| if isinstance(to, type(tensor)): | if isinstance(to, type(tensor)): | ||||
| to = to.element_type() | to = to.element_type() | ||||
| np_type = dtype_to_pytype(to) | np_type = dtype_to_pytype(to) | ||||