|
|
@@ -273,6 +273,13 @@ def _check_invalid_input(x1_shape, x2_shape): |
|
|
+ f'while x1 is ({len(x1_shape)}) and x2 is ({len(x2_shape)}).') |
|
|
+ f'while x1 is ({len(x1_shape)}) and x2 is ({len(x2_shape)}).') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
|
|
def _get_transpose_shape(x2_shape): |
|
|
|
|
|
x2_shape_range = tuple(range(len(x2_shape))) |
|
|
|
|
|
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:] |
|
|
|
|
|
return x2_shape_transpose |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dot(x1, x2): |
|
|
def dot(x1, x2): |
|
|
""" |
|
|
""" |
|
|
Computation a dot product between samples in two tensors. |
|
|
Computation a dot product between samples in two tensors. |
|
|
@@ -304,8 +311,7 @@ def dot(x1, x2): |
|
|
_check_invalid_input(x1_shape, x2_shape) |
|
|
_check_invalid_input(x1_shape, x2_shape) |
|
|
|
|
|
|
|
|
if len(x1_shape) > 2 or len(x2_shape) > 2: |
|
|
if len(x1_shape) > 2 or len(x2_shape) > 2: |
|
|
x2_shape_range = range(len(x2_shape)) |
|
|
|
|
|
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:] |
|
|
|
|
|
|
|
|
x2_shape_transpose = _get_transpose_shape(x2_shape) |
|
|
x2_transpose = transpose_op(x2, x2_shape_transpose) |
|
|
x2_transpose = transpose_op(x2, x2_shape_transpose) |
|
|
x1_reshape = reshape_op(x1, (-1, x1_shape[-1])) |
|
|
x1_reshape = reshape_op(x1, (-1, x1_shape[-1])) |
|
|
x2_reshape = reshape_op(x2_transpose, (x2_shape[-2], -1)) |
|
|
x2_reshape = reshape_op(x2_transpose, (x2_shape[-2], -1)) |
|
|
|