from __future__ import absolute_import from .Node import Op from .._base import DNNL_LIB from ..cpu_links import matrix_elementwise_multiply_by_const as cpu_matrix_elementwise_multiply_by_const from ..gpu_links import matrix_elementwise_multiply_by_const class MulByConstOp(Op): def __init__(self, node_A, const_val, ctx=None): super().__init__(MulByConstOp, [node_A], ctx) self.const_attr = const_val self.desc = self.name + '(%s, %s)' % (node_A.name, str(const_val)) def compute(self, input_vals, output_val, stream_handle=None): assert self.const_attr is not None if self.on_cpu: if DNNL_LIB['DnnlMatrixElementwiseMultiplyByConst']: cpu_matrix_elementwise_multiply_by_const( input_vals[0], self.const_attr, output_val) else: output_val[:] = input_vals[0].asnumpy() * self.const_attr else: matrix_elementwise_multiply_by_const( input_vals[0], self.const_attr, output_val, stream_handle) def gradient(self, output_grad): return [self.const_attr * output_grad] def infer_shape(self, input_shapes): assert len(input_shapes) == 1 return input_shapes[0] def mul_byconst_op(node_A, const_val, ctx=None): """Make a new instance of MulByConstOp and call the instance. Parameters: ---- node : Node The Node to be multiplied. const_val : scalar value The constant value to be mutiplied. Returns: ---- A new Node instance created by Op. """ return MulByConstOp(node_A, const_val, ctx=ctx)