from __future__ import absolute_import from .Node import Op import numpy as np from .. import ndarray from ..gpu_links import layer_normalization from ..gpu_links import layer_normalization_gradient from ..gpu_links import layer_normalization_inference class Layer_NormalizationOp(Op): def __init__(self, node_in, ln_scale, ln_bias, eps=0.01, ctx=None): super().__init__(Layer_NormalizationOp, [node_in, ln_scale, ln_bias], ctx) self.eps = eps self.save_mean = None self.save_var = None self.data_shape = None def compute(self, input_vals, output_val, stream_handle=None, inference=False): if inference: if self.on_cpu: input_vals = [n.asnumpy() for n in input_vals] data_type = input_vals[0].dtype std = np.sqrt(self.save_var + self.eps, dtype=data_type) centered_input = input_vals[0] - self.save_mean normed_input = centered_input / std bc_shape = [1] * len(input_vals[0].shape) bc_shape[-1] = input_vals[0].shape[-1] output_val[:] = input_vals[1].reshape(bc_shape) * normed_input + \ input_vals[2].reshape(bc_shape) else: layer_normalization_inference(input_vals[0], input_vals[1], input_vals[2], self.save_mean, self.save_var, output_val, self.eps, stream_handle) else: local_shape = list(input_vals[0].shape) local_shape[-1] = 1 local_shape = tuple(local_shape) if self.on_cpu: input_vals = [n.asnumpy() for n in input_vals] data_type = input_vals[0].dtype if self.data_shape is None: self.save_mean = np.empty(local_shape, dtype=np.float32) self.save_var = np.empty(local_shape, dtype=np.float32) self.data_shape = local_shape elif self.data_shape != local_shape: del self.save_mean del self.save_var self.save_mean = np.empty(local_shape, dtype=np.float32) self.save_var = np.empty(local_shape, dtype=np.float32) self.data_shape = local_shape self.save_mean[:] = input_vals[0].mean( axis=-1, dtype=data_type, keepdims=True) self.save_var[:] = input_vals[0].var( axis=-1, dtype=data_type, keepdims=True) std = np.sqrt(self.save_var + self.eps, dtype=data_type) centered_input = input_vals[0] - self.save_mean normed_input = centered_input / std bc_shape = [1] * len(input_vals[0].shape) bc_shape[-1] = input_vals[0].shape[-1] output_val[:] = input_vals[1].reshape(bc_shape) * normed_input + \ input_vals[2].reshape(bc_shape) else: if self.data_shape is None: dev_id = input_vals[0].handle.contents.ctx.device_id self.save_mean = ndarray.empty( local_shape, ctx=ndarray.gpu(dev_id)) self.save_var = ndarray.empty( local_shape, ctx=ndarray.gpu(dev_id)) self.data_shape = local_shape elif self.data_shape != local_shape: del self.save_mean del self.save_var dev_id = input_vals[0].handle.contents.ctx.device_id self.save_mean = ndarray.empty( local_shape, ctx=ndarray.gpu(dev_id)) self.save_var = ndarray.empty( local_shape, ctx=ndarray.gpu(dev_id)) self.data_shape = local_shape layer_normalization(input_vals[0], input_vals[1], input_vals[2], self.save_mean, self.save_var, output_val, self.eps, stream_handle) def gradient(self, output_grad): ln_gradient_node = layer_normalization_gradient_op( output_grad, self.inputs[0], self.inputs[1], self, self.eps, ctx=self.raw_ctx) data_gradient = layer_normalization_gradient_of_data_op( ln_gradient_node, self.inputs[0], ctx=self.raw_ctx) scale_gradient = layer_normalization_gradient_of_scale_op( ln_gradient_node, self.inputs[1], ctx=self.raw_ctx) bias_gradient = layer_normalization_gradient_of_bias_op( ln_gradient_node, self.inputs[2], ctx=self.raw_ctx) return [data_gradient, scale_gradient, bias_gradient] def infer_shape(self, input_shapes): assert len(input_shapes) == 3 assert len(input_shapes[1]) == len(input_shapes[2]) == 1 assert input_shapes[0][-1] == input_shapes[1][0] == input_shapes[2][0] return input_shapes[0] class Layer_Normalization_GradientOp(Op): def __init__(self, out_gradient, in_node, ln_scale, forward_node, eps, ctx=None): super().__init__(Layer_Normalization_GradientOp, [out_gradient, in_node, ln_scale], ctx) self.tmp_gradient_in_arr = None self.tmp_gradient_ln_bias = None self.tmp_gradient_ln_scale = None self.data_shape = None self.forward_node = forward_node self.eps = eps def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: if self.tmp_gradient_ln_bias is None: shapeln = input_vals[2].shape self.data_shape = tuple(input_vals[0].shape) self.tmp_gradient_ln_scale = np.empty( shape=shapeln, dtype=np.float32) self.tmp_gradient_ln_bias = np.empty( shape=shapeln, dtype=np.float32) self.tmp_gradient_in_arr = np.empty( shape=self.data_shape, dtype=np.float32) elif self.data_shape != tuple(input_vals[0].shape): self.data_shape = tuple(input_vals[0].shape) del self.tmp_gradient_in_arr self.tmp_gradient_in_arr = np.empty( shape=self.data_shape, dtype=np.float32) red_axis = tuple(range(input_vals[0].ndim - 1)) self.tmp_gradient_ln_bias[:] = input_vals[0].sum(red_axis) # (X,) std = np.sqrt(self.forward_node.save_var + self.eps) # (N, 1) x_centered = input_vals[1] - self.forward_node.save_mean # (N, X) x_norm = x_centered / std # (N, X) self.tmp_gradient_ln_scale[:] = ( input_vals[0] * x_norm).sum(red_axis) # (X,) last_dim = input_vals[1].shape[-1] dx_norm = input_vals[0] * input_vals[2].reshape( [1] * (input_vals[0].ndim - 1) + [-1]) # (N, X) dvar = (dx_norm * x_centered).sum(axis=-1, keepdims=True) * -0.5 / ( self.forward_node.save_var + self.eps) / std # (N, 1) dx_mu_1 = dx_norm / std # (N, X) dx_mu_2 = dvar * 2 * x_centered / last_dim # (N, X) dx_1 = dx_mu_1 + dx_mu_2 # (N, X) dx_2 = -1 * dx_1.sum(axis=-1, keepdims=True) / last_dim # (N, 1) self.tmp_gradient_in_arr[:] = dx_1 + dx_2 # (N, X) else: if self.tmp_gradient_ln_bias is None: shapeln = input_vals[2].shape self.data_shape = tuple(input_vals[0].shape) self.tmp_gradient_ln_bias = ndarray.empty( shape=shapeln, ctx=input_vals[0].ctx) self.tmp_gradient_ln_scale = ndarray.empty( shape=shapeln, ctx=input_vals[0].ctx) self.tmp_gradient_in_arr = ndarray.empty( shape=self.data_shape, ctx=input_vals[0].ctx) elif self.data_shape != tuple(input_vals[0].shape): self.data_shape = tuple(input_vals[0].shape) del self.tmp_gradient_in_arr self.tmp_gradient_in_arr = ndarray.empty( shape=self.data_shape, ctx=input_vals[0].ctx) layer_normalization_gradient(input_vals[0], input_vals[1], input_vals[2], self.tmp_gradient_in_arr, self.tmp_gradient_ln_scale, self.tmp_gradient_ln_bias, self.forward_node.save_mean, self.forward_node.save_var, self.eps, stream_handle) def gradient(self, output_grad): raise NotImplementedError def infer_shape(self, input_shapes): return (1,) class Layer_Normalization_Gradient_of_DataOp(Op): def __init__(self, ln_gradient, in_arr, ctx=None): super().__init__(Layer_Normalization_Gradient_of_DataOp, [ln_gradient, in_arr], ctx) def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: output_val[:] = self.inputs[0].tmp_gradient_in_arr else: self.inputs[0].tmp_gradient_in_arr.copyto(output_val) def gradient(self, output_grad): raise NotImplementedError def infer_shape(self, input_shapes): return input_shapes[1] class Layer_Normalization_Gradient_of_ScaleOp(Op): def __init__(self, ln_gradient, in_scale, ctx=None): super().__init__(Layer_Normalization_Gradient_of_ScaleOp, [ln_gradient, in_scale], ctx) def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: output_val[:] = self.inputs[0].tmp_gradient_ln_scale else: self.inputs[0].tmp_gradient_ln_scale.copyto(output_val) def gradient(self, output_grad): raise NotImplementedError def infer_shape(self, input_shapes): return input_shapes[1] class Layer_Normalization_Gradient_of_BiasOp(Op): def __init__(self, ln_gradient, in_bias, ctx=None): super().__init__(Layer_Normalization_Gradient_of_BiasOp, [ln_gradient, in_bias], ctx) def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: output_val[:] = self.inputs[0].tmp_gradient_ln_bias else: self.inputs[0].tmp_gradient_ln_bias.copyto(output_val) def gradient(self, output_grad): raise NotImplementedError def infer_shape(self, input_shapes): return input_shapes[1] def layer_normalization_op(node_in, ln_scale, ln_bias, eps=0.01, ctx=None): """Layer normalization node. Parameters: ---- node_in : Node Input data. ln_scale : float scaling parameter ln_bias : learnable bias parameter eps : float Epsilon value for numerical stability. Returns: ---- A new Node instance created by Op. """ return Layer_NormalizationOp(node_in, ln_scale, ln_bias, eps, ctx=ctx) def layer_normalization_gradient_op(out_gradient, in_node, ln_scale, forward_node, eps, ctx=None): """Gradient node of layer normalization. Parameters: ---- out_gradient : The gradient array. in_node : Node Input node of ln layer. ln_scale : Scaling parameter. Returns: ---- A new Node instance created by Op. """ return Layer_Normalization_GradientOp(out_gradient, in_node, ln_scale, forward_node, eps, ctx=ctx) def layer_normalization_gradient_of_data_op(ln_gradient, in_arr, ctx=None): """Gradient node of data of layer normalization. Parameters: ---- ln_gradient : The gradient array. in_arr : Node Input array of ln layer. Returns: ---- A new Node instance created by Op. """ return Layer_Normalization_Gradient_of_DataOp(ln_gradient, in_arr, ctx=ctx) def layer_normalization_gradient_of_scale_op(ln_gradient, in_scale, ctx=None): """Gradient node of scale parameter of layer normalization. Parameters: ---- ln_gradient : The gradient array. in_scale : Scaling parameter of ln layer. Returns: ---- A new Node instance created by Op. """ return Layer_Normalization_Gradient_of_ScaleOp(ln_gradient, in_scale, ctx=ctx) def layer_normalization_gradient_of_bias_op(ln_gradient, in_bias, ctx=None): """Gradient node of bias parameter of layer normalization. Parameters: ---- ln_gradient : The gradient array. in_bias : Bias parameter of ln layer. Returns: ---- A new Node instance created by Op. """ return Layer_Normalization_Gradient_of_BiasOp(ln_gradient, in_bias, ctx=ctx)