From 2f89b75b2d5bc0485910ffdd19ef7bc613172b88 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 15:02:24 +0800 Subject: [PATCH] remove batchnorm grad 6 output to 5 output --- mindspore/ops/operations/_grad_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 07857ca27b..782784ca00 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -68,11 +68,11 @@ class BatchNormGrad(PrimitiveWithInfer): self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.add_prim_attr('data_format', "NCHW") - def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape): + def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape) - def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type, reserve_3_type): + def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type): return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type)