You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Split.py 4.7 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .Node import Op
  4. from ..gpu_links import matrix_slice_simple
  5. from ..gpu_links import matrix_slice_gradient_simple
  6. from .. import ndarray
  7. class SplitOp(Op):
  8. def __init__(self, node_A, axes, indices, splits, ctx=None):
  9. super().__init__(SplitOp, [node_A], ctx)
  10. self.axes = axes
  11. self.indices = indices
  12. self.splits = splits
  13. assert len(self.axes) == len(self.splits)
  14. assert all([x >= 0 for x in axes])
  15. assert all([x >= 1 for x in splits])
  16. assert all([x >= 0 and x < splits[i] for i, x in enumerate(indices)])
  17. def compute(self, input_vals, output_val, stream_handle=None):
  18. if self.on_cpu:
  19. index = tuple([slice(i, i+j)
  20. for i, j in zip(self.begin_pos, self.output_shape)])
  21. output_val[:] = input_vals[0].asnumpy()[index]
  22. else:
  23. # matrix_slice(input_vals[0], output_val, self.begin_pos, stream_handle)
  24. matrix_slice_simple(
  25. input_vals[0], output_val, self.gpu_buffer, stream_handle)
  26. def gradient(self, output_grad):
  27. self.grad_node = split_gradient_op(
  28. output_grad, self.axes, self.indices, self.splits, ctx=self.raw_ctx)
  29. return [self.grad_node]
  30. def infer_shape(self, input_shapes):
  31. assert len(input_shapes) == 1
  32. ori_shape = list(input_shapes[0])
  33. self.begin_pos = [0 for _ in ori_shape]
  34. self.output_shape = [x for x in ori_shape]
  35. for axe, ind, spl in zip(self.axes, self.indices, self.splits):
  36. part_size = ori_shape[axe] // spl
  37. self.begin_pos[axe] = ind * part_size
  38. self.output_shape[axe] = part_size if ind != spl - \
  39. 1 else ori_shape[axe] - self.begin_pos[axe]
  40. if hasattr(self, 'grad_node'):
  41. self.grad_node.begin_pos = self.begin_pos
  42. self.grad_node.output_shape = ori_shape
  43. # here we save the information on device for GPU computation
  44. if self.on_gpu:
  45. ndim = len(ori_shape)
  46. gpu_buf = [0 for _ in range(3 * ndim)]
  47. for i in range(ndim):
  48. gpu_buf[i] = self.begin_pos[i]
  49. gpu_buf[ndim + i] = ori_shape[i]
  50. gpu_buf[2 * ndim + i] = self.output_shape[i]
  51. self.gpu_buffer = ndarray.array(
  52. gpu_buf, self.ctx, data_type=np.uintc)
  53. return self.output_shape
  54. class SplitGradientOp(Op):
  55. def __init__(self, node_A, axes, indices, splits, ctx=None):
  56. super().__init__(SplitGradientOp, [node_A], ctx)
  57. self.axes = axes
  58. self.indices = indices
  59. self.splits = splits
  60. self.begin_pos = None
  61. self.output_shape = None
  62. assert len(self.axes) == len(self.splits)
  63. assert all([x >= 0 for x in axes])
  64. assert all([x >= 1 for x in splits])
  65. assert all([x >= 0 and x < splits[i] for i, x in enumerate(indices)])
  66. def compute(self, input_vals, output_val, stream_handle=None):
  67. if self.on_cpu:
  68. output_val[:] = np.zeros(self.output_shape, dtype=np.float32)
  69. index = tuple([slice(i, i+j)
  70. for i, j in zip(self.begin_pos, self.ori_shape)])
  71. output_val[index] = input_vals[0]
  72. else:
  73. # matrix_slice_gradient(input_vals[0], output_val, self.begin_pos, stream_handle)
  74. matrix_slice_gradient_simple(
  75. input_vals[0], output_val, self.gpu_buffer, stream_handle)
  76. def gradient(self, output_grad):
  77. raise NotImplementedError
  78. def infer_shape(self, input_shapes):
  79. assert self.output_shape != None and self.begin_pos != None
  80. assert len(input_shapes) == 1
  81. ori_shape = list(input_shapes[0])
  82. for i in range(len(ori_shape)):
  83. assert self.begin_pos[i] + ori_shape[i] <= self.output_shape[i]
  84. self.ori_shape = tuple(ori_shape)
  85. # here we save the information on device for GPU computation
  86. if self.on_gpu:
  87. ndim = len(ori_shape)
  88. gpu_buf = [0 for _ in range(3 * ndim)]
  89. for i in range(ndim):
  90. gpu_buf[i] = self.begin_pos[i]
  91. gpu_buf[ndim + i] = ori_shape[i]
  92. gpu_buf[2 * ndim + i] = self.output_shape[i]
  93. self.gpu_buffer = ndarray.array(
  94. gpu_buf, self.ctx, data_type=np.uintc)
  95. return self.output_shape
  96. def split_op(node, axes, indices, splits, ctx=None):
  97. return SplitOp(node, axes, indices, splits, ctx=ctx)
  98. def split_gradient_op(node, axes, indices, splits, ctx=None):
  99. return SplitGradientOp(node, axes, indices, splits, ctx=ctx)