From 7eced5d59ed697fa65392509f81aec46a535b15f Mon Sep 17 00:00:00 2001 From: TFbunny Date: Thu, 5 Nov 2020 13:52:10 -0500 Subject: [PATCH] add dynamic shape support to Zeroslike --- mindspore/core/abstract/prim_arrays.cc | 8 ++++++++ mindspore/core/abstract/prim_nn.cc | 7 ------- mindspore/ops/operations/array_ops.py | 8 ++------ 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 7d7a011b45..ccfa608d4e 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -377,5 +377,13 @@ AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const Primitive return tensor->ToAbstract(); } + +AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + CheckArgsSize(primitive->name(), args_spec_list, 1); + return args_spec_list[0]->Broaden(); +} + } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index 62ce8c69d4..0b2dc22710 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -296,13 +296,6 @@ AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &pri return args_spec_list[0]->Broaden(); } -AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - CheckArgsSize(primitive->name(), args_spec_list, 1); - return args_spec_list[0]->Broaden(); -} - AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: a tensor. diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index a6d670d1a8..7bba2b20c2 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1034,7 +1034,7 @@ class OnesLike(PrimitiveWithInfer): return x_dtype -class ZerosLike(PrimitiveWithInfer): +class ZerosLike(PrimitiveWithCheck): """ Creates a new tensor. All elements value are 0. @@ -1059,12 +1059,8 @@ class ZerosLike(PrimitiveWithInfer): """Initialize ZerosLike""" self.init_prim_io_names(inputs=['x'], outputs=['y']) - def infer_shape(self, x_shape): - return x_shape - - def infer_dtype(self, x_dtype): + def check_dtype(self, x_dtype): validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) - return x_dtype class TupleToArray(PrimitiveWithInfer):