| @@ -377,5 +377,13 @@ AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const Primitive | |||||
| return tensor->ToAbstract(); | return tensor->ToAbstract(); | ||||
| } | } | ||||
| AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: a tensor. | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 1); | |||||
| return args_spec_list[0]->Broaden(); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -296,13 +296,6 @@ AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||||
| return args_spec_list[0]->Broaden(); | return args_spec_list[0]->Broaden(); | ||||
| } | } | ||||
| AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: a tensor. | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 1); | |||||
| return args_spec_list[0]->Broaden(); | |||||
| } | |||||
| AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: a tensor. | // Inputs: a tensor. | ||||
| @@ -1034,7 +1034,7 @@ class OnesLike(PrimitiveWithInfer): | |||||
| return x_dtype | return x_dtype | ||||
| class ZerosLike(PrimitiveWithInfer): | |||||
| class ZerosLike(PrimitiveWithCheck): | |||||
| """ | """ | ||||
| Creates a new tensor. All elements value are 0. | Creates a new tensor. All elements value are 0. | ||||
| @@ -1059,12 +1059,8 @@ class ZerosLike(PrimitiveWithInfer): | |||||
| """Initialize ZerosLike""" | """Initialize ZerosLike""" | ||||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | self.init_prim_io_names(inputs=['x'], outputs=['y']) | ||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype): | |||||
| def check_dtype(self, x_dtype): | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) | validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) | ||||
| return x_dtype | |||||
| class TupleToArray(PrimitiveWithInfer): | class TupleToArray(PrimitiveWithInfer): | ||||