|
- from __future__ import absolute_import
- import numpy as np
- from .Node import Op
- from .ReduceSum import reduce_sum_op
- from ..gpu_links import broadcast_shape_simple
- from .. import ndarray
-
-
- class BroadcastShapeOp(Op):
- def __init__(self, node_A, shape, add_axes=(), ctx=None):
- super().__init__(BroadcastShapeOp, [node_A], ctx)
- self.target_shape = shape
- self.add_axes = add_axes
-
- def compute(self, input_vals, output_val, stream_handle=None):
- assert self.target_shape is not None and self.add_axes is not None
- if self.on_cpu:
- input_shape = list(input_vals[0].shape)
- for i in range(len(input_shape)):
- if self.add_axes and i in self.add_axes:
- input_shape[i] = 1
- output_val[:] = np.broadcast_to(
- input_vals[0].asnumpy().reshape(input_shape), self.target_shape)
- else:
- if self.inplace:
- input_vals[0].broadcast_to(
- self.target_shape, output_val, self.add_axes)
- else:
- # broadcast_shape(input_vals[0], output_val, self.add_axes, stream_handle)
- broadcast_shape_simple(
- input_vals[0], output_val, self.out_strides, self.in_dims, stream_handle)
-
- def gradient(self, output_grad):
- self.grad_node = reduce_sum_op(
- output_grad, None, None, ctx=self.raw_ctx)
- return [self.grad_node]
-
- def infer_shape(self, input_shapes):
- assert self.target_shape is not None and self.add_axes is not None
- assert len(input_shapes) == 1
- input_shape = list(input_shapes[0])
- output_shape = list(self.target_shape)
- output_ndim = len(output_shape)
- assert len(input_shape) <= output_ndim
- diff = output_ndim - len(input_shape)
- if self.add_axes:
- assert diff == len(self.add_axes) or input_shape == [1]
- assert all([axis < output_ndim for axis in self.add_axes])
- in_ind = 0
- for i in range(output_ndim):
- if i not in self.add_axes:
- assert input_shape[in_ind] == output_shape[i]
- in_ind += 1
- if hasattr(self, 'grad_node'):
- self.grad_node.axes = tuple(self.add_axes)
- self.grad_node.axes.keepdims = [False] * len(self.add_axes)
- else:
- axes = list(range(diff))
- keepdims = [False] * diff
- input_shape = [1] * diff + input_shape
- for i in range(output_ndim):
- if output_shape[i] == -1:
- output_shape[i] = input_shape[i]
- assert output_shape[i] > 0 and isinstance(output_shape[i], int)
- assert input_shape[i] == 1 or input_shape[i] == output_shape[i]
- if i >= diff and input_shape[i] == 1 and output_shape[i] > 1:
- axes.append(i)
- keepdims.append(True)
- if hasattr(self, 'grad_node'):
- self.grad_node.axes = axes
- self.grad_node.keepdims = keepdims
-
- # here we save the output strides and input dimensions for GPU computation
- if self.on_gpu and not self.inplace:
- input_shape = list(input_shapes[0])
- out_strides = [0 for _ in range(output_ndim)]
- temp_size = 1
- for i in range(output_ndim - 1, -1, -1):
- out_strides[i] = temp_size
- temp_size *= output_shape[i]
- if self.add_axes:
- in_dims = [0 for _ in range(output_ndim)]
- for i in range(diff):
- in_dims[self.add_axes[i]] = 1
- temp_ind = 0
- for dim in input_shape:
- while in_dims[temp_ind] == 1:
- temp_ind += 1
- in_dims[temp_ind] = dim
- else:
- in_dims = [1 for _ in range(diff)] + input_shape
-
- self.out_strides = ndarray.array(
- out_strides, self.ctx, data_type=np.int32)
- self.in_dims = ndarray.array(in_dims, self.ctx, data_type=np.int32)
- return tuple(output_shape)
-
- def backward_hook(self, config):
- self.inplace = config.enable_lazy and self not in config.eval_node_list
-
-
- def broadcast_shape_op(node_A, shape, add_axes=(), ctx=None):
- """Creates a node that represents np.broadcast_to(node_A, shape).
-
- Parameters:
- ----
- node_a : Node
- The Node to be broadcast.
- shape : tuple
- Target shape.
-
- Returns:
- ----
- A new Node instance created by Op.
-
- """
- return BroadcastShapeOp(node_A, shape, add_axes=add_axes, ctx=ctx)
|