Browse Source

fixed ScalarCast

tags/v0.5.0-beta
jiangjinsheng 5 years ago
parent
commit
0ac47f2f71
1 changed files with 5 additions and 0 deletions
  1. +5
    -0
      mindspore/ops/operations/inner_ops.py

+ 5
- 0
mindspore/ops/operations/inner_ops.py View File

@@ -15,6 +15,9 @@

"""inner_ops"""

import numbers
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common.dtype import tensor, dtype_to_pytype
from ..primitive import prim_attr_register, PrimitiveWithInfer

@@ -40,8 +43,10 @@ class ScalarCast(PrimitiveWithInfer):
pass

def __infer__(self, x, t):
validator.check_integer('x shape', len(x['shape']), 0, Rel.EQ, self.name)
value, to = x['value'], t['value']
if value is not None:
validator.check_value_type("value", value, [numbers.Number, bool], self.name)
if isinstance(to, type(tensor)):
to = to.element_type()
np_type = dtype_to_pytype(to)


Loading…
Cancel
Save