from __future__ import absolute_import from .Node import Op from .._base import DNNL_LIB from ..cpu_links import matrix_elementwise_multiply as\ cpu_matrix_elementwise_multiply from ..cpu_links import matrix_elementwise_multiply_by_const as\ cpu_matrix_elementwise_multiply_by_const from ..gpu_links import matrix_elementwise_multiply,\ matrix_elementwise_multiply_by_const class MulOp(Op): def __init__(self, node_A, node_B, ctx=None): super().__init__(MulOp, [node_A, node_B], ctx) def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: if DNNL_LIB['DnnlMatrixElementwiseMultiply'] and input_vals[0].shape == input_vals[1].shape: cpu_matrix_elementwise_multiply( input_vals[0], input_vals[1], output_val) elif DNNL_LIB['DnnlMatrixElementwiseMultiplyByConst'] and (input_vals[0].shape == (1,) or input_vals[1].shape == (1,)): if input_vals[1].shape == (1,): const_val = input_vals[1].asnumpy()[0] cpu_matrix_elementwise_multiply_by_const( input_vals[0], const_val, output_val) elif input_vals[0].shape == (1,): const_val = input_vals[0].asnumpy()[0] cpu_matrix_elementwise_multiply_by_const( input_vals[1], const_val, output_val) else: output_val[:] = input_vals[0].asnumpy() * \ input_vals[1].asnumpy() else: if input_vals[0].shape == input_vals[1].shape: matrix_elementwise_multiply( input_vals[0], input_vals[1], output_val, stream_handle) else: if input_vals[1].shape == (1,): const_val = input_vals[1].asnumpy()[0] matrix_elementwise_multiply_by_const( input_vals[0], const_val, output_val, stream_handle) elif input_vals[0].shape == (1,): const_val = input_vals[0].asnumpy()[0] matrix_elementwise_multiply_by_const( input_vals[1], const_val, output_val, stream_handle) def gradient(self, output_grad): return [mul_op(self.inputs[1], output_grad, ctx=self.raw_ctx), mul_op(self.inputs[0], output_grad, ctx=self.raw_ctx)] def infer_shape(self, input_shapes): """Need to handle input_vals[0].shape != input_vals[1].shape""" assert len(input_shapes) == 2 if input_shapes[0] == input_shapes[1]: output = input_shapes[0] else: if input_shapes[0] == (1,): output = input_shapes[1] elif input_shapes[1] == (1,): output = input_shapes[0] else: assert False, "can't do elementwise multiply between variables of different sizes." return output def mul_op(node_A, node_B, ctx=None): """Make a new instance of matrixs elementwise multiplication and call the instance. Parameters: ---- node_a : Node The Node to be multiplied. node_b : Node Another Node to be multiplied. Returns: ---- A new Node instance created by Op. """ return MulOp(node_A, node_B, ctx=ctx)