from __future__ import absolute_import from .Node import Op import numpy as np from .. import ndarray from .._base import DNNL_LIB from ..cpu_links import batch_norm as cpu_batch_norm from ..cpu_links import batch_norm_inference as cpu_batch_norm_inference from ..cpu_links import batch_norm_gradient as cpu_batch_norm_gradient from ..gpu_links import CuDNN_Batch_Normalization from ..gpu_links import CuDNN_Batch_Normalization_gradient from ..gpu_links import CuDNN_Batch_Normalization_inference import numpy as np class Batch_NormalizationOp(Op): def __init__(self, node_in, bn_scale, bn_bias, momentum=0.99, eps=0.01, ctx=None): super().__init__(Batch_NormalizationOp, [node_in, bn_scale, bn_bias], ctx) self.momentum = momentum self.eps = eps self.save_mean = None self.save_var = None self.running_mean = None self.running_var = None def compute(self, input_vals, output_val, stream_handle=None, inference=False): if inference: if self.on_cpu: if DNNL_LIB['DnnlBatchNorm_Inference']: cpu_batch_norm_inference(input_vals[0], input_vals[1], input_vals[2], output_val, self.save_mean, self.save_var, self.momentum, self.eps) else: output_val[:] = batchnorm_inference(input_vals[0].asnumpy(), input_vals[1].asnumpy(), input_vals[2].asnumpy( ), self.save_mean, self.save_var, self.eps) else: CuDNN_Batch_Normalization_inference( input_vals[0], input_vals[1], input_vals[2], output_val, self.save_mean, self.save_var, self.eps, stream_handle) else: if self.on_cpu: if DNNL_LIB['DnnlBatchNorm']: if self.save_mean is None: dev_id = input_vals[0].handle.contents.ctx.device_id C = input_vals[0].shape[1] self.save_mean = ndarray.array( np.zeros([C], dtype=np.float32), ctx=ndarray.cpu(dev_id)) self.save_var = ndarray.array( np.zeros([C], dtype=np.float32), ctx=ndarray.cpu(dev_id)) cpu_batch_norm(input_vals[0], input_vals[1], input_vals[2], output_val, self.save_mean, self.save_var, self.momentum, self.eps) else: output_val[:], self.save_mean, self.save_var = batchnorm_forward(input_vals[0].asnumpy(), input_vals[1].asnumpy( ), input_vals[2].asnumpy( ), self.save_mean, self.save_var, self.momentum, self.eps) else: if self.save_mean == None: dev_id = input_vals[0].handle.contents.ctx.device_id C = input_vals[0].shape[1] self.save_mean = ndarray.array( np.zeros([1, C, 1, 1]), ctx=ndarray.gpu(dev_id)) self.save_var = ndarray.array( np.zeros([1, C, 1, 1]), ctx=ndarray.gpu(dev_id)) self.running_mean = ndarray.array( np.zeros([1, C, 1, 1]), ctx=ndarray.gpu(dev_id)) self.running_var = ndarray.array( np.zeros([1, C, 1, 1]), ctx=ndarray.gpu(dev_id)) CuDNN_Batch_Normalization( input_vals[0], input_vals[1], input_vals[2], output_val, self.save_mean, self.save_var, self.running_mean, self.running_var, self.momentum, self.eps, stream_handle) def gradient(self, output_grad): bn_gradient_node = batch_normalization_gradient_op( output_grad, self.inputs[0], self.inputs[1], self, self.eps, ctx=self.raw_ctx) data_gradient = batch_normalization_gradient_of_data_op( bn_gradient_node, self.inputs[0], ctx=self.raw_ctx) scale_gradient = batch_normalization_gradient_of_scale_op( bn_gradient_node, self.inputs[1], ctx=self.raw_ctx) bias_gradient = batch_normalization_gradient_of_bias_op( bn_gradient_node, self.inputs[2], ctx=self.raw_ctx) return [data_gradient, scale_gradient, bias_gradient] def infer_shape(self, input_shapes): return input_shapes[0] class Batch_Normalization_GradientOp(Op): def __init__(self, out_gradient, in_node, bn_scale, forward_node, eps, ctx=None): super().__init__(Batch_Normalization_GradientOp, [out_gradient, in_node, bn_scale], ctx) self.tmp_gradient_in_arr = None self.tmp_gradient_bn_bias = None self.tmp_gradient_bn_scale = None self.forward_node = forward_node self.eps = eps def update_mean_and_var(self, saved_mean, saved_var): self.saved_mean = saved_mean self.saved_var = saved_var def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: if DNNL_LIB['DnnlBatchNorm_Gradient']: if self.tmp_gradient_bn_bias is None: shapebn = input_vals[2].shape self.tmp_gradient_bn_bias = np.zeros( shape=shapebn, dtype=np.float32) self.tmp_gradient_bn_scale = np.zeros( shape=shapebn, dtype=np.float32) self.tmp_gradient_in_arr = np.zeros( shape=input_vals[1].shape, dtype=np.float32) cpu_batch_norm_gradient(input_vals[0], input_vals[1], input_vals[2], bn_bias, self.tmp_gradient_in_arr, self.tmp_gradient_bn_scale, self.tmp_gradient_bn_bias, self.forward_node.running_mean, self.forward_node.running_var, self.eps) else: if self.tmp_gradient_bn_bias is None: typebn = input_vals[2].asnumpy().dtype shapebn = input_vals[2].asnumpy().shape self.tmp_gradient_bn_bias = np.zeros( shape=shapebn, dtype=typebn) self.tmp_gradient_bn_scale = np.zeros( shape=shapebn, dtype=typebn) self.tmp_gradient_in_arr, self.tmp_gradient_bn_scale, self.tmp_gradient_bn_bias = batchnorm_backward( input_vals[0].asnumpy(), input_vals[1].asnumpy( ), input_vals[2].asnumpy(), self.tmp_gradient_bn_scale, self.tmp_gradient_bn_bias, self.eps, self.forward_node.save_mean, self.forward_node.save_var) else: if self.tmp_gradient_bn_bias == None: shapebn = input_vals[2].shape self.tmp_gradient_bn_scale = ndarray.empty( shape=shapebn, ctx=input_vals[0].ctx) self.tmp_gradient_bn_bias = ndarray.empty( shape=shapebn, ctx=input_vals[0].ctx) self.tmp_gradient_in_arr = ndarray.empty( shape=input_vals[1].shape, ctx=input_vals[0].ctx) CuDNN_Batch_Normalization_gradient(input_vals[0], input_vals[1], input_vals[2], self.tmp_gradient_in_arr, self.tmp_gradient_bn_scale, self.tmp_gradient_bn_bias, self.forward_node.save_mean, self.forward_node.save_var, self.eps, stream_handle) def gradient(self, output_grad): raise NotImplementedError def infer_shape(self, input_shapes): return (1,) class Batch_Normalization_Gradient_of_DataOp(Op): def __init__(self, bn_gradient, in_arr, ctx=None): super().__init__(Batch_Normalization_Gradient_of_DataOp, [bn_gradient, in_arr], ctx) def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: output_val[:] = self.inputs[0].tmp_gradient_in_arr else: self.inputs[0].tmp_gradient_in_arr.copyto(output_val) def gradient(self, output_grad): raise NotImplementedError def infer_shape(self, input_shapes): return input_shapes[1] class Batch_Normalization_Gradient_of_ScaleOp(Op): def __init__(self, bn_gradient, in_scale, ctx=None): super().__init__(Batch_Normalization_Gradient_of_ScaleOp, [bn_gradient, in_scale], ctx) def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: output_val[:] = self.inputs[0].tmp_gradient_bn_scale else: self.inputs[0].tmp_gradient_bn_scale.copyto(output_val) def gradient(self, output_grad): raise NotImplementedError def infer_shape(self, input_shapes): return input_shapes[1] class Batch_Normalization_Gradient_of_BiasOp(Op): def __init__(self, bn_gradient, in_bias, ctx=None): super().__init__(Batch_Normalization_Gradient_of_BiasOp, [bn_gradient, in_bias], ctx) def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: output_val[:] = self.inputs[0].tmp_gradient_bn_bias else: self.inputs[0].tmp_gradient_bn_bias.copyto(output_val) def gradient(self, output_grad): raise NotImplementedError def infer_shape(self, input_shapes): return input_shapes[1] def batch_normalization_op(node_in, bn_scale, bn_bias, momentum=0.99, eps=0.01, ctx=None): """Batch normalization layer node. Parameters: ---- node_in : Node Input data. bn_scale : float scaling parameter bn_bias : learnable bias parameter momentum : float Acting on the calculation of mean and variance, the mean and variance values in historical batch are retained. eps : float Epsilon value for numerical stability. Returns: ---- A new Node instance created by Op. """ return Batch_NormalizationOp(node_in, bn_scale, bn_bias, momentum, eps, ctx=ctx) def batch_normalization_gradient_op(out_gradient, in_node, bn_scale, forward_node, eps, ctx=None): """Gradient node of batch normalization. Parameters: ---- out_gradient : The gradient array. in_node : Node Input node of bn layer. bn_scale : Scaling parameter. Returns: ---- A new Node instance created by Op. """ return Batch_Normalization_GradientOp(out_gradient, in_node, bn_scale, forward_node, eps, ctx=ctx) def batch_normalization_gradient_of_data_op(bn_gradient, in_arr, ctx=None): """Gradient node of data of batch normalization. Parameters: ---- bn_gradient : The gradient array. in_arr : Node Input array of bn layer. Returns: ---- A new Node instance created by Op. """ return Batch_Normalization_Gradient_of_DataOp(bn_gradient, in_arr, ctx=ctx) def batch_normalization_gradient_of_scale_op(bn_gradient, in_scale, ctx=None): """Gradient node of scale parameter of batch normalization. Parameters: ---- bn_gradient : The gradient array. in_scale : Scaling parameter of bn layer. Returns: ---- A new Node instance created by Op. """ return Batch_Normalization_Gradient_of_ScaleOp(bn_gradient, in_scale, ctx=ctx) def batch_normalization_gradient_of_bias_op(bn_gradient, in_bias, ctx=None): """Gradient node of bias parameter of batch normalization. Parameters: ---- bn_gradient : The gradient array. in_bias : Bias parameter of bn layer. Returns: ---- A new Node instance created by Op. """ return Batch_Normalization_Gradient_of_BiasOp(bn_gradient, in_bias, ctx=ctx) def batchnorm_forward(x, bn_scale, bn_bias, save_mean, save_var, momentum=0.99, eps=0.01): D = x.shape[1] if save_mean is None: save_mean = np.zeros(D, dtype=x.dtype) if save_var is None: save_var = np.ones(D, dtype=x.dtype) sample_mean = x.mean(axis=(0, 2, 3), dtype=x.dtype) sample_var = x.var(axis=(0, 2, 3), dtype=x.dtype) save_mean = momentum * sample_mean + (1.0 - momentum) * save_mean save_var = momentum * sample_var + (1.0 - momentum) * save_var std = np.sqrt(sample_var.reshape(1, D, 1, 1) + eps, dtype=x.dtype) x_centered = x - sample_mean.reshape(1, D, 1, 1) x_norm = x_centered / std out = bn_scale.reshape(1, D, 1, 1) * x_norm + bn_bias.reshape(1, D, 1, 1) return out, save_mean, save_mean def batchnorm_inference(x, bn_scale, bn_bias, save_mean, save_var, eps=0.01): D = x.shape[1] std = np.sqrt(save_var.reshape(1, D, 1, 1) + eps, dtype=x.dtype) x_centered = x - save_mean.reshape(1, D, 1, 1) x_norm = x_centered / std out = bn_scale.reshape(1, D, 1, 1) * x_norm + bn_bias.reshape(1, D, 1, 1) return out def batchnorm_backward(gradient_Y, x, bn_scale, dbn_scale, dbn_bias, eps, save_mean, save_var): D = gradient_Y.shape[1] sample_mean = save_mean sample_var = save_var std = np.sqrt(sample_var.reshape(1, D, 1, 1) + eps) x_centered = x - sample_mean.reshape(1, D, 1, 1) x_norm = x_centered / std dbn_scale = (gradient_Y * x_norm).sum(axis=(0, 2, 3)) dbn_bias = gradient_Y.sum(axis=(0, 2, 3)) dx_norm = gradient_Y * bn_scale.reshape(1, D, 1, 1) dx_centered = dx_norm / std dmean = -(dx_centered.sum(axis=(0, 2, 3)) + 2 / D * x_centered.sum(axis=(0, 2, 3))).reshape(1, D, 1, 1) dstd = (dx_norm * x_centered * -std ** (-2) ).sum(axis=(0, 2, 3)).reshape(1, D, 1, 1) dvar = dstd / 2 / std dx = dx_centered + (dmean + dvar * 2 * x_centered) / D return dx, dbn_scale, dbn_bias