|
|
|
@@ -15,6 +15,9 @@ |
|
|
|
|
|
|
|
"""inner_ops""" |
|
|
|
|
|
|
|
import numbers |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
from ..._checkparam import Rel |
|
|
|
from ...common.dtype import tensor, dtype_to_pytype |
|
|
|
from ..primitive import prim_attr_register, PrimitiveWithInfer |
|
|
|
|
|
|
|
@@ -40,8 +43,10 @@ class ScalarCast(PrimitiveWithInfer): |
|
|
|
pass |
|
|
|
|
|
|
|
def __infer__(self, x, t): |
|
|
|
validator.check_integer('x shape', len(x['shape']), 0, Rel.EQ, self.name) |
|
|
|
value, to = x['value'], t['value'] |
|
|
|
if value is not None: |
|
|
|
validator.check_value_type("value", value, [numbers.Number, bool], self.name) |
|
|
|
if isinstance(to, type(tensor)): |
|
|
|
to = to.element_type() |
|
|
|
np_type = dtype_to_pytype(to) |
|
|
|
|