|
- from __future__ import absolute_import
- import numpy as np
- from .Node import Op
- from ..gpu_links import matrix_slice_simple
- from ..gpu_links import matrix_slice_gradient_simple
- from .. import ndarray
-
-
- class SliceOp(Op):
- def __init__(self, node_A, begin_pos, output_shape, ctx=None):
- super().__init__(SliceOp, [node_A], ctx)
- self.begin_pos = tuple(begin_pos)
- self.output_shape = list(output_shape)
- self.ori_output_shape = list(output_shape)
- assert len(self.begin_pos) == len(self.output_shape)
- for i in range(len(self.begin_pos)):
- assert self.begin_pos[i] >= 0
-
- def compute(self, input_vals, output_val, stream_handle=None):
- if self.on_cpu:
- index = tuple([slice(i, i+j)
- for i, j in zip(self.begin_pos, self.output_shape)])
- output_val[:] = input_vals[0].asnumpy()[index]
- else:
- # matrix_slice(input_vals[0], output_val, self.begin_pos, stream_handle)
- matrix_slice_simple(
- input_vals[0], output_val, self.gpu_buffer, stream_handle)
-
- def gradient(self, output_grad):
- self.grad_node = slice_gradient_op(
- output_grad, self.begin_pos, None, ctx=self.raw_ctx)
- return [self.grad_node]
-
- def infer_shape(self, input_shapes):
- assert len(input_shapes) == 1
- ori_shape = list(input_shapes[0])
- assert len(ori_shape) == len(self.begin_pos)
- for i in range(len(ori_shape)):
- if self.ori_output_shape[i] == -1:
- self.output_shape[i] = ori_shape[i] - self.begin_pos[i]
- assert self.output_shape[i] > 0
- assert self.begin_pos[i] + self.output_shape[i] <= ori_shape[i]
- self.ori_shape = tuple(ori_shape)
- if hasattr(self, 'grad_node'):
- self.grad_node.output_shape = self.ori_shape
- assert len(self.ori_shape) == len(self.grad_node.begin_pos)
-
- # here we save the information on device for GPU computation
- if self.on_gpu:
- ndim = len(ori_shape)
- gpu_buf = [0 for _ in range(3 * ndim)]
- for i in range(ndim):
- gpu_buf[i] = self.begin_pos[i]
- gpu_buf[ndim + i] = ori_shape[i]
- gpu_buf[2 * ndim + i] = self.output_shape[i]
- self.gpu_buffer = ndarray.array(
- gpu_buf, self.ctx, data_type=np.uintc)
- return self.output_shape
-
-
- class SliceGradientOp(Op):
- def __init__(self, node_A, begin_pos, output_shape, ctx=None):
- super().__init__(SliceGradientOp, [node_A], ctx)
- self.begin_pos = tuple(begin_pos)
- self.output_shape = None
- if output_shape != None:
- self.output_shape = tuple(output_shape)
- assert len(self.begin_pos) == len(self.output_shape)
- for i in range(len(self.begin_pos)):
- assert self.begin_pos[i] >= 0
-
- def compute(self, input_vals, output_val, stream_handle=None):
- if self.on_cpu:
- output_val[:] = np.zeros(self.output_shape, dtype=np.float32)
- index = tuple([slice(i, i+j)
- for i, j in zip(self.begin_pos, self.ori_shape)])
- output_val[index] = input_vals[0]
- else:
- # matrix_slice_gradient(input_vals[0], output_val, self.begin_pos, stream_handle)
- matrix_slice_gradient_simple(
- input_vals[0], output_val, self.gpu_buffer, stream_handle)
-
- def gradient(self, output_grad):
- raise NotImplementedError
-
- def infer_shape(self, input_shapes):
- assert self.output_shape != None
- assert len(input_shapes) == 1
- ori_shape = list(input_shapes[0])
- assert len(ori_shape) == len(self.begin_pos)
- for i in range(len(ori_shape)):
- assert self.begin_pos[i] + ori_shape[i] <= self.output_shape[i]
- self.ori_shape = tuple(ori_shape)
-
- # here we save the information on device for GPU computation
- if self.on_gpu:
- ndim = len(ori_shape)
- gpu_buf = [0 for _ in range(3 * ndim)]
- for i in range(ndim):
- gpu_buf[i] = self.begin_pos[i]
- gpu_buf[ndim + i] = ori_shape[i]
- gpu_buf[2 * ndim + i] = self.output_shape[i]
- self.gpu_buffer = ndarray.array(
- gpu_buf, self.ctx, data_type=np.uintc)
- return self.output_shape
-
-
- def slice_op(node, begin, size, ctx=None):
- """Creates a node that represents tf.slice(node, begin, size).
-
- Parameters:
- ----
- node : Node
- The Node needed to be summed.
- begin: tuple
- The beginning position of slice operation.
- size: tuple
- The shape(size) of output tensor.
-
- Returns:
- ----
- A new Node instance created by Op.
-
- """
- return SliceOp(node, begin, size, ctx=ctx)
-
-
- def slice_gradient_op(node, begin, size=None, ctx=None):
- """Creates a node that represents the gradient of tf.slice.
-
- Parameters:
- ----
- node : Node
- The Node needed to be summed.
- begin: tuple
- The beginning position of slice operation.
- size: tuple
- The shape(size) of output tensor.
-
- Returns:
- ----
- A new Node instance created by Op.
-
- """
- return SliceGradientOp(node, begin, size, ctx=ctx)
|