Browse Source

remove batchnorm grad 6 output to 5 output

tags/v0.3.0-alpha
zhaozhenlong 6 years ago
parent
commit
2f89b75b2d
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      mindspore/ops/operations/_grad_ops.py

+ 2
- 2
mindspore/ops/operations/_grad_ops.py View File

@@ -68,11 +68,11 @@ class BatchNormGrad(PrimitiveWithInfer):
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
self.add_prim_attr('data_format', "NCHW")

def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape):
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape):
validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape)
return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape)

def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type, reserve_3_type):
def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type):
return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type)




Loading…
Cancel
Save