|
|
|
@@ -68,11 +68,11 @@ class BatchNormGrad(PrimitiveWithInfer): |
|
|
|
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) |
|
|
|
self.add_prim_attr('data_format', "NCHW") |
|
|
|
|
|
|
|
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape): |
|
|
|
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): |
|
|
|
validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) |
|
|
|
return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape) |
|
|
|
|
|
|
|
def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type, reserve_3_type): |
|
|
|
def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type): |
|
|
|
return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type) |
|
|
|
|
|
|
|
|
|
|
|
|