You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Slice.py 5.3 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .Node import Op
  4. from ..gpu_links import matrix_slice_simple
  5. from ..gpu_links import matrix_slice_gradient_simple
  6. from .. import ndarray
  7. class SliceOp(Op):
  8. def __init__(self, node_A, begin_pos, output_shape, ctx=None):
  9. super().__init__(SliceOp, [node_A], ctx)
  10. self.begin_pos = tuple(begin_pos)
  11. self.output_shape = list(output_shape)
  12. self.ori_output_shape = list(output_shape)
  13. assert len(self.begin_pos) == len(self.output_shape)
  14. for i in range(len(self.begin_pos)):
  15. assert self.begin_pos[i] >= 0
  16. def compute(self, input_vals, output_val, stream_handle=None):
  17. if self.on_cpu:
  18. index = tuple([slice(i, i+j)
  19. for i, j in zip(self.begin_pos, self.output_shape)])
  20. output_val[:] = input_vals[0].asnumpy()[index]
  21. else:
  22. # matrix_slice(input_vals[0], output_val, self.begin_pos, stream_handle)
  23. matrix_slice_simple(
  24. input_vals[0], output_val, self.gpu_buffer, stream_handle)
  25. def gradient(self, output_grad):
  26. self.grad_node = slice_gradient_op(
  27. output_grad, self.begin_pos, None, ctx=self.raw_ctx)
  28. return [self.grad_node]
  29. def infer_shape(self, input_shapes):
  30. assert len(input_shapes) == 1
  31. ori_shape = list(input_shapes[0])
  32. assert len(ori_shape) == len(self.begin_pos)
  33. for i in range(len(ori_shape)):
  34. if self.ori_output_shape[i] == -1:
  35. self.output_shape[i] = ori_shape[i] - self.begin_pos[i]
  36. assert self.output_shape[i] > 0
  37. assert self.begin_pos[i] + self.output_shape[i] <= ori_shape[i]
  38. self.ori_shape = tuple(ori_shape)
  39. if hasattr(self, 'grad_node'):
  40. self.grad_node.output_shape = self.ori_shape
  41. assert len(self.ori_shape) == len(self.grad_node.begin_pos)
  42. # here we save the information on device for GPU computation
  43. if self.on_gpu:
  44. ndim = len(ori_shape)
  45. gpu_buf = [0 for _ in range(3 * ndim)]
  46. for i in range(ndim):
  47. gpu_buf[i] = self.begin_pos[i]
  48. gpu_buf[ndim + i] = ori_shape[i]
  49. gpu_buf[2 * ndim + i] = self.output_shape[i]
  50. self.gpu_buffer = ndarray.array(
  51. gpu_buf, self.ctx, data_type=np.uintc)
  52. return self.output_shape
  53. class SliceGradientOp(Op):
  54. def __init__(self, node_A, begin_pos, output_shape, ctx=None):
  55. super().__init__(SliceGradientOp, [node_A], ctx)
  56. self.begin_pos = tuple(begin_pos)
  57. self.output_shape = None
  58. if output_shape != None:
  59. self.output_shape = tuple(output_shape)
  60. assert len(self.begin_pos) == len(self.output_shape)
  61. for i in range(len(self.begin_pos)):
  62. assert self.begin_pos[i] >= 0
  63. def compute(self, input_vals, output_val, stream_handle=None):
  64. if self.on_cpu:
  65. output_val[:] = np.zeros(self.output_shape, dtype=np.float32)
  66. index = tuple([slice(i, i+j)
  67. for i, j in zip(self.begin_pos, self.ori_shape)])
  68. output_val[index] = input_vals[0]
  69. else:
  70. # matrix_slice_gradient(input_vals[0], output_val, self.begin_pos, stream_handle)
  71. matrix_slice_gradient_simple(
  72. input_vals[0], output_val, self.gpu_buffer, stream_handle)
  73. def gradient(self, output_grad):
  74. raise NotImplementedError
  75. def infer_shape(self, input_shapes):
  76. assert self.output_shape != None
  77. assert len(input_shapes) == 1
  78. ori_shape = list(input_shapes[0])
  79. assert len(ori_shape) == len(self.begin_pos)
  80. for i in range(len(ori_shape)):
  81. assert self.begin_pos[i] + ori_shape[i] <= self.output_shape[i]
  82. self.ori_shape = tuple(ori_shape)
  83. # here we save the information on device for GPU computation
  84. if self.on_gpu:
  85. ndim = len(ori_shape)
  86. gpu_buf = [0 for _ in range(3 * ndim)]
  87. for i in range(ndim):
  88. gpu_buf[i] = self.begin_pos[i]
  89. gpu_buf[ndim + i] = ori_shape[i]
  90. gpu_buf[2 * ndim + i] = self.output_shape[i]
  91. self.gpu_buffer = ndarray.array(
  92. gpu_buf, self.ctx, data_type=np.uintc)
  93. return self.output_shape
  94. def slice_op(node, begin, size, ctx=None):
  95. """Creates a node that represents tf.slice(node, begin, size).
  96. Parameters:
  97. ----
  98. node : Node
  99. The Node needed to be summed.
  100. begin: tuple
  101. The beginning position of slice operation.
  102. size: tuple
  103. The shape(size) of output tensor.
  104. Returns:
  105. ----
  106. A new Node instance created by Op.
  107. """
  108. return SliceOp(node, begin, size, ctx=ctx)
  109. def slice_gradient_op(node, begin, size=None, ctx=None):
  110. """Creates a node that represents the gradient of tf.slice.
  111. Parameters:
  112. ----
  113. node : Node
  114. The Node needed to be summed.
  115. begin: tuple
  116. The beginning position of slice operation.
  117. size: tuple
  118. The shape(size) of output tensor.
  119. Returns:
  120. ----
  121. A new Node instance created by Op.
  122. """
  123. return SliceGradientOp(node, begin, size, ctx=ctx)