from __future__ import absolute_import import numpy as np from .Node import Op from ..gpu_links import matrix_slice_simple from ..gpu_links import matrix_slice_gradient_simple from .. import ndarray class SplitOp(Op): def __init__(self, node_A, axes, indices, splits, ctx=None): super().__init__(SplitOp, [node_A], ctx) self.axes = axes self.indices = indices self.splits = splits assert len(self.axes) == len(self.splits) assert all([x >= 0 for x in axes]) assert all([x >= 1 for x in splits]) assert all([x >= 0 and x < splits[i] for i, x in enumerate(indices)]) def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: index = tuple([slice(i, i+j) for i, j in zip(self.begin_pos, self.output_shape)]) output_val[:] = input_vals[0].asnumpy()[index] else: # matrix_slice(input_vals[0], output_val, self.begin_pos, stream_handle) matrix_slice_simple( input_vals[0], output_val, self.gpu_buffer, stream_handle) def gradient(self, output_grad): self.grad_node = split_gradient_op( output_grad, self.axes, self.indices, self.splits, ctx=self.raw_ctx) return [self.grad_node] def infer_shape(self, input_shapes): assert len(input_shapes) == 1 ori_shape = list(input_shapes[0]) self.begin_pos = [0 for _ in ori_shape] self.output_shape = [x for x in ori_shape] for axe, ind, spl in zip(self.axes, self.indices, self.splits): part_size = ori_shape[axe] // spl self.begin_pos[axe] = ind * part_size self.output_shape[axe] = part_size if ind != spl - \ 1 else ori_shape[axe] - self.begin_pos[axe] if hasattr(self, 'grad_node'): self.grad_node.begin_pos = self.begin_pos self.grad_node.output_shape = ori_shape # here we save the information on device for GPU computation if self.on_gpu: ndim = len(ori_shape) gpu_buf = [0 for _ in range(3 * ndim)] for i in range(ndim): gpu_buf[i] = self.begin_pos[i] gpu_buf[ndim + i] = ori_shape[i] gpu_buf[2 * ndim + i] = self.output_shape[i] self.gpu_buffer = ndarray.array( gpu_buf, self.ctx, data_type=np.uintc) return self.output_shape class SplitGradientOp(Op): def __init__(self, node_A, axes, indices, splits, ctx=None): super().__init__(SplitGradientOp, [node_A], ctx) self.axes = axes self.indices = indices self.splits = splits self.begin_pos = None self.output_shape = None assert len(self.axes) == len(self.splits) assert all([x >= 0 for x in axes]) assert all([x >= 1 for x in splits]) assert all([x >= 0 and x < splits[i] for i, x in enumerate(indices)]) def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: output_val[:] = np.zeros(self.output_shape, dtype=np.float32) index = tuple([slice(i, i+j) for i, j in zip(self.begin_pos, self.ori_shape)]) output_val[index] = input_vals[0] else: # matrix_slice_gradient(input_vals[0], output_val, self.begin_pos, stream_handle) matrix_slice_gradient_simple( input_vals[0], output_val, self.gpu_buffer, stream_handle) def gradient(self, output_grad): raise NotImplementedError def infer_shape(self, input_shapes): assert self.output_shape != None and self.begin_pos != None assert len(input_shapes) == 1 ori_shape = list(input_shapes[0]) for i in range(len(ori_shape)): assert self.begin_pos[i] + ori_shape[i] <= self.output_shape[i] self.ori_shape = tuple(ori_shape) # here we save the information on device for GPU computation if self.on_gpu: ndim = len(ori_shape) gpu_buf = [0 for _ in range(3 * ndim)] for i in range(ndim): gpu_buf[i] = self.begin_pos[i] gpu_buf[ndim + i] = ori_shape[i] gpu_buf[2 * ndim + i] = self.output_shape[i] self.gpu_buffer = ndarray.array( gpu_buf, self.ctx, data_type=np.uintc) return self.output_shape def split_op(node, axes, indices, splits, ctx=None): return SplitOp(node, axes, indices, splits, ctx=ctx) def split_gradient_op(node, axes, indices, splits, ctx=None): return SplitGradientOp(node, axes, indices, splits, ctx=ctx)