from __future__ import absolute_import import numpy as np from .Node import Op from .ReduceSum import reduce_sum_op from ..gpu_links import broadcast_shape_simple from .. import ndarray class BroadcastToOp(Op): def __init__(self, node_A, node_B, ctx=None): super().__init__(BroadcastToOp, [node_A, node_B], ctx) def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: input_shape = list(input_vals[1].shape) output_val[:] = np.broadcast_to( input_vals[0].asnumpy(), input_shape) else: if self.inplace: input_vals[0].broadcast_to(input_vals[1].shape, output_val) else: # broadcast_shape(input_vals[0], output_val, None, stream_handle) broadcast_shape_simple( input_vals[0], output_val, self.out_strides, self.in_dims, stream_handle) def gradient(self, output_grad): self.grad_node = reduce_sum_op( output_grad, None, None, ctx=self.raw_ctx) return [self.grad_node, None] def infer_shape(self, input_shapes): assert len(input_shapes) == 2 input_shape = list(input_shapes[0]) output_shape = list(input_shapes[1]) output_ndim = len(output_shape) assert len(input_shape) <= output_ndim diff = output_ndim - len(input_shape) axes = list(range(diff)) keepdims = [False] * diff input_shape = [1] * diff + input_shape for i in range(output_ndim): assert output_shape[i] > 0 and isinstance(output_shape[i], int) assert input_shape[i] == 1 or input_shape[i] == output_shape[i] if i >= diff and input_shape[i] == 1 and output_shape[i] > 1: axes.append(i) keepdims.append(True) if hasattr(self, 'grad_node'): self.grad_node.axes = axes self.grad_node.keepdims = keepdims # here we save the output strides and input dimensions for GPU computation if self.on_gpu and not self.inplace: input_shape = list(input_shapes[0]) out_strides = [0 for _ in range(output_ndim)] temp_size = 1 for i in range(output_ndim - 1, -1, -1): out_strides[i] = temp_size temp_size *= output_shape[i] in_dims = [1 for _ in range(diff)] + input_shape self.out_strides = ndarray.array( out_strides, self.ctx, data_type=np.int32) self.in_dims = ndarray.array(in_dims, self.ctx, data_type=np.int32) return input_shapes[1] def backward_hook(self, config): self.inplace = config.enable_lazy and self not in config.eval_node_list def broadcastto_op(node_A, node_B, ctx=None): """Creates a node that represents np.broadcast_to(node_A, node_B.shape). Parameters: ---- node_a : Node The Node to be broadcast. node_b : Node Another Node with the target shape. Returns: ---- A new Node instance created by Op. """ return BroadcastToOp(node_A, node_B, ctx=ctx)