|
|
|
@@ -1034,7 +1034,7 @@ class OnesLike(PrimitiveWithInfer): |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
class ZerosLike(PrimitiveWithInfer): |
|
|
|
class ZerosLike(PrimitiveWithCheck): |
|
|
|
""" |
|
|
|
Creates a new tensor. All elements value are 0. |
|
|
|
|
|
|
|
@@ -1059,12 +1059,8 @@ class ZerosLike(PrimitiveWithInfer): |
|
|
|
"""Initialize ZerosLike""" |
|
|
|
self.init_prim_io_names(inputs=['x'], outputs=['y']) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
def check_dtype(self, x_dtype): |
|
|
|
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
class TupleToArray(PrimitiveWithInfer): |
|
|
|
|