from __future__ import absolute_import import numpy as np from .Node import Op from .._base import DNNL_LIB from ..cpu_links import sqrt as cpu_sqrt from ..cpu_links import rsqrt as cpu_rsqrt from ..gpu_links import matrix_sqrt from ..gpu_links import matrix_rsqrt class SqrtOp(Op): def __init__(self, node_A, ctx=None): super().__init__(SqrtOp, [node_A], ctx) def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: if DNNL_LIB['DnnlSqrt']: cpu_sqrt(input_vals[0], output_val) else: output_val[:] = np.sqrt(input_vals[0].asnumpy()) else: matrix_sqrt(input_vals[0], output_val, stream_handle) def gradient(self, output_grad): return [0.5 * rsqrt_op(self.inputs[0], ctx=self.raw_ctx) * output_grad] def infer_shape(self, input_shapes): assert len(input_shapes) == 1 return input_shapes[0] class ReciprocalSqrtOp(Op): def __init__(self, node_A, ctx=None): super().__init__(ReciprocalSqrtOp, [node_A], ctx) def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: if DNNL_LIB['DnnlReciprocalSqrt']: cpu_rsqrt(input_vals[0], output_val) else: output_val[:] = 1 / np.sqrt(input_vals[0].asnumpy()) else: matrix_rsqrt(input_vals[0], output_val, stream_handle) def gradient(self, output_grad): from .Division import div_op return [-0.5 * div_op(rsqrt_op(self.inputs[0], ctx=self.raw_ctx), self.inputs[0], ctx=self.raw_ctx) * output_grad] def infer_shape(self, input_shapes): assert len(input_shapes) == 1 return input_shapes[0] def sqrt_op(node, ctx=None): """Calculate square root of a matrix elementwisely. Parameters: ---- node : Node Input variable. Returns: ---- A new Node instance created by Op. """ return SqrtOp(node, ctx=ctx) def rsqrt_op(node, ctx=None): """Calculate the reciprocal of square root of a matrix elementwisely. Parameters: ---- node : Node Input variable. Returns: ---- A new Node instance created by Op. """ return ReciprocalSqrtOp(node, ctx=ctx)