You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

InstanceNorm2d.py 3.7 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from __future__ import absolute_import
  2. from .Node import Op
  3. import numpy as np
  4. from .. import ndarray
  5. from ..gpu_links import instance_normalization2d
  6. from ..gpu_links import instance_normalization2d_gradient
  7. class Instance_Normalization2dOp(Op):
  8. def __init__(self, node_in, eps=0.0000001, ctx=None):
  9. super().__init__(Instance_Normalization2dOp, [node_in], ctx)
  10. self.eps = eps
  11. self.save_mean = None
  12. self.save_var = None
  13. self.data_shape = None
  14. def compute(self, input_vals, output_val, stream_handle=None, inference=False):
  15. local_shape = list(input_vals[0].shape)
  16. assert len(local_shape) == 4
  17. local_shape[-1] = 1
  18. local_shape[-2] = 1
  19. local_shape = tuple(local_shape)
  20. if self.on_cpu:
  21. raise NotImplementedError
  22. else:
  23. if self.data_shape is None:
  24. dev_id = input_vals[0].handle.contents.ctx.device_id
  25. self.save_mean = ndarray.empty(
  26. local_shape, ctx=ndarray.gpu(dev_id))
  27. self.save_var = ndarray.empty(
  28. local_shape, ctx=ndarray.gpu(dev_id))
  29. self.data_shape = local_shape
  30. elif self.data_shape != local_shape:
  31. del self.save_mean
  32. del self.save_var
  33. dev_id = input_vals[0].handle.contents.ctx.device_id
  34. self.save_mean = ndarray.empty(
  35. local_shape, ctx=ndarray.gpu(dev_id))
  36. self.save_var = ndarray.empty(
  37. local_shape, ctx=ndarray.gpu(dev_id))
  38. self.data_shape = local_shape
  39. instance_normalization2d(input_vals[0], self.save_mean, self.save_var,
  40. output_val, self.eps, stream_handle)
  41. def gradient(self, output_grad):
  42. return [instance_normalization2d_gradient_op(output_grad, self.inputs[0], self, ctx=self.ctx)]
  43. def infer_shape(self, input_shapes):
  44. assert len(input_shapes) == 1
  45. return input_shapes[0]
  46. class Instance_Normalization2d_GradientOp(Op):
  47. def __init__(self, out_gradient, in_node, forward_node, ctx=None):
  48. super().__init__(Instance_Normalization2d_GradientOp,
  49. [out_gradient, in_node], ctx)
  50. self.tmp_gradient_in_arr = None
  51. self.data_shape = None
  52. self.forward_node = forward_node
  53. def compute(self, input_vals, output_val, stream_handle=None):
  54. if self.on_cpu:
  55. raise NotImplementedError
  56. else:
  57. instance_normalization2d_gradient(input_vals[0], input_vals[1], output_val,
  58. self.forward_node.save_mean, self.forward_node.save_var,
  59. self.forward_node.eps, stream_handle)
  60. def gradient(self, output_grad):
  61. raise NotImplementedError
  62. def infer_shape(self, input_shapes):
  63. return input_shapes[0]
  64. def instance_normalization2d_op(node_in, eps=0.01, ctx=None):
  65. """Layer normalization node.
  66. Parameters:
  67. ----
  68. node_in : Node
  69. Input data.
  70. eps : float
  71. Epsilon value for numerical stability.
  72. Returns:
  73. ----
  74. A new Node instance created by Op.
  75. """
  76. return Instance_Normalization2dOp(node_in, eps, ctx=ctx)
  77. def instance_normalization2d_gradient_op(out_gradient, in_node, forward_node, ctx=None):
  78. """Gradient node of layer normalization.
  79. Parameters:
  80. ----
  81. out_gradient :
  82. The gradient array.
  83. in_node : Node
  84. Input node of ln layer.
  85. Returns:
  86. ----
  87. A new Node instance created by Op.
  88. """
  89. return Instance_Normalization2d_GradientOp(out_gradient, in_node, forward_node, ctx=ctx)