|
- from __future__ import absolute_import
- from .Node import Op
- import numpy as np
- from .. import ndarray
- from ..gpu_links import instance_normalization2d
- from ..gpu_links import instance_normalization2d_gradient
-
-
- class Instance_Normalization2dOp(Op):
- def __init__(self, node_in, eps=0.0000001, ctx=None):
- super().__init__(Instance_Normalization2dOp, [node_in], ctx)
- self.eps = eps
- self.save_mean = None
- self.save_var = None
- self.data_shape = None
-
- def compute(self, input_vals, output_val, stream_handle=None, inference=False):
- local_shape = list(input_vals[0].shape)
- assert len(local_shape) == 4
- local_shape[-1] = 1
- local_shape[-2] = 1
- local_shape = tuple(local_shape)
- if self.on_cpu:
- raise NotImplementedError
- else:
- if self.data_shape is None:
- dev_id = input_vals[0].handle.contents.ctx.device_id
- self.save_mean = ndarray.empty(
- local_shape, ctx=ndarray.gpu(dev_id))
- self.save_var = ndarray.empty(
- local_shape, ctx=ndarray.gpu(dev_id))
- self.data_shape = local_shape
- elif self.data_shape != local_shape:
- del self.save_mean
- del self.save_var
- dev_id = input_vals[0].handle.contents.ctx.device_id
- self.save_mean = ndarray.empty(
- local_shape, ctx=ndarray.gpu(dev_id))
- self.save_var = ndarray.empty(
- local_shape, ctx=ndarray.gpu(dev_id))
- self.data_shape = local_shape
- instance_normalization2d(input_vals[0], self.save_mean, self.save_var,
- output_val, self.eps, stream_handle)
-
- def gradient(self, output_grad):
- return [instance_normalization2d_gradient_op(output_grad, self.inputs[0], self, ctx=self.ctx)]
-
- def infer_shape(self, input_shapes):
- assert len(input_shapes) == 1
- return input_shapes[0]
-
-
- class Instance_Normalization2d_GradientOp(Op):
- def __init__(self, out_gradient, in_node, forward_node, ctx=None):
- super().__init__(Instance_Normalization2d_GradientOp,
- [out_gradient, in_node], ctx)
- self.tmp_gradient_in_arr = None
- self.data_shape = None
- self.forward_node = forward_node
-
- def compute(self, input_vals, output_val, stream_handle=None):
- if self.on_cpu:
- raise NotImplementedError
- else:
- instance_normalization2d_gradient(input_vals[0], input_vals[1], output_val,
- self.forward_node.save_mean, self.forward_node.save_var,
- self.forward_node.eps, stream_handle)
-
- def gradient(self, output_grad):
- raise NotImplementedError
-
- def infer_shape(self, input_shapes):
- return input_shapes[0]
-
-
- def instance_normalization2d_op(node_in, eps=0.01, ctx=None):
- """Layer normalization node.
-
- Parameters:
- ----
- node_in : Node
- Input data.
- eps : float
- Epsilon value for numerical stability.
-
- Returns:
- ----
- A new Node instance created by Op.
-
- """
- return Instance_Normalization2dOp(node_in, eps, ctx=ctx)
-
-
- def instance_normalization2d_gradient_op(out_gradient, in_node, forward_node, ctx=None):
- """Gradient node of layer normalization.
-
- Parameters:
- ----
- out_gradient :
- The gradient array.
- in_node : Node
- Input node of ln layer.
-
- Returns:
- ----
- A new Node instance created by Op.
-
- """
- return Instance_Normalization2d_GradientOp(out_gradient, in_node, forward_node, ctx=ctx)
|