GitOrigin-RevId: bf6136cc1a
tags/v1.9.0
| @@ -8,6 +8,7 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import abc | import abc | ||||
| import collections | import collections | ||||
| from functools import lru_cache | |||||
| from typing import Union | from typing import Union | ||||
| import numpy as np | import numpy as np | ||||
| @@ -24,8 +25,8 @@ from .utils import ( | |||||
| astype, | astype, | ||||
| cast_tensors, | cast_tensors, | ||||
| convert_inputs, | convert_inputs, | ||||
| isscalar, | |||||
| make_shape_tuple, | make_shape_tuple, | ||||
| subgraph, | |||||
| ) | ) | ||||
| _ElwMod = builtin.Elemwise.Mode | _ElwMod = builtin.Elemwise.Mode | ||||
| @@ -73,23 +74,292 @@ def _elwise(*args, mode): | |||||
| return _elwise_apply(args, mode) | return _elwise_apply(args, mode) | ||||
| def _matmul(inp1, inp2): | |||||
| @lru_cache(maxsize=None) | |||||
| def _get_extentedMatrixMulOp( | |||||
| device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||||
| ): | |||||
| @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||||
| def extentedMatrixMulOp(inputs, f, c): | |||||
| assert len(inputs) == 2 | |||||
| inp1, inp2 = inputs | |||||
| _dim1, _dim2 = dim1, dim2 | |||||
| def build_shape_head(shape, idx=-1): | |||||
| # shape[:idx] | |||||
| return f( | |||||
| builtin.Subtensor(items=[[0, False, True, False, False]]), | |||||
| shape, | |||||
| c(idx, "int32"), | |||||
| ) | |||||
| def build_shape_tail(shape, idx=-1): | |||||
| # shape[idx:] | |||||
| return f( | |||||
| builtin.Subtensor(items=[[0, True, False, False, False]]), | |||||
| shape, | |||||
| c(idx, "int32"), | |||||
| ) | |||||
| remove_row, remove_col = False, False | |||||
| if _dim1 == 1: | |||||
| _dim1 = 2 | |||||
| remove_row = True | |||||
| if _dim2 == 1: | |||||
| _dim2 = 2 | |||||
| remove_col = True | |||||
| if remove_row: | |||||
| inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||||
| if remove_col: | |||||
| inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||||
| shape1 = f(builtin.GetVarShape(), inp1) | |||||
| shape2 = f(builtin.GetVarShape(), inp2) | |||||
| if _dim1 > 2: | |||||
| inp1 = f( | |||||
| builtin.Reshape(), | |||||
| inp1, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), | |||||
| build_shape_tail(shape1), | |||||
| ), | |||||
| ) | |||||
| if _dim2 > 2: | |||||
| inp2 = f( | |||||
| builtin.Reshape(), | |||||
| inp2, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), | |||||
| build_shape_tail(shape2), | |||||
| ), | |||||
| ) | |||||
| op = builtin.MatrixMul( | |||||
| transposeA=transpose_a, | |||||
| transposeB=transpose_b, | |||||
| compute_mode=compute_mode, | |||||
| format=format, | |||||
| strategy=strategy.value, | |||||
| ) | |||||
| result = f(op, inp1, inp2) | |||||
| result_shape = f(builtin.GetVarShape(), result) | |||||
| if _dim1 > 2: | |||||
| result = f( | |||||
| builtin.Reshape(), | |||||
| result, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| build_shape_head(shape1), | |||||
| build_shape_tail(result_shape), | |||||
| ), | |||||
| ) | |||||
| if _dim2 > 2: | |||||
| result = f( | |||||
| builtin.Reshape(), | |||||
| result, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| build_shape_head(shape2), | |||||
| build_shape_tail(result_shape), | |||||
| ), | |||||
| ) | |||||
| maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||||
| if remove_row: | |||||
| result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||||
| if remove_col: | |||||
| result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||||
| return (result,), (True,) | |||||
| return extentedMatrixMulOp | |||||
| @lru_cache(maxsize=None) | |||||
| def _get_extentedBatchedMatrixMulOp( | |||||
| device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||||
| ): | |||||
| @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||||
| def extentedBatchedMatrixMulOp(inputs, f, c): | |||||
| assert len(inputs) == 2 | |||||
| inp1, inp2 = inputs | |||||
| _dim1, _dim2 = dim1, dim2 | |||||
| def build_shape_head(shape, idx=-2): | |||||
| # shape[:idx] | |||||
| return f( | |||||
| builtin.Subtensor(items=[[0, False, True, False, False]]), | |||||
| shape, | |||||
| c(idx, "int32"), | |||||
| ) | |||||
| def build_shape_tail(shape, idx=-2): | |||||
| # shape[idx:] | |||||
| return f( | |||||
| builtin.Subtensor(items=[[0, True, False, False, False]]), | |||||
| shape, | |||||
| c(idx, "int32"), | |||||
| ) | |||||
| remove_row, remove_col = False, False | |||||
| if _dim1 == 1: | |||||
| _dim1 = 2 | |||||
| remove_row = True | |||||
| if _dim2 == 1: | |||||
| _dim2 = 2 | |||||
| remove_col = True | |||||
| if remove_row: | |||||
| inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||||
| if remove_col: | |||||
| inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||||
| shape1 = f(builtin.GetVarShape(), inp1) | |||||
| shape2 = f(builtin.GetVarShape(), inp2) | |||||
| maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||||
| if _dim1 > _dim2: | |||||
| # broadcast | |||||
| shape2 = f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] | |||||
| shape2, | |||||
| ) | |||||
| inp2 = f(builtin.Broadcast(), inp2, shape2) | |||||
| batch_shape = build_shape_head(shape1) | |||||
| if _dim2 > _dim1: | |||||
| # broadcast | |||||
| shape1 = f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] | |||||
| shape1, | |||||
| ) | |||||
| inp1 = f(builtin.Broadcast(), inp1, shape1) | |||||
| batch_shape = build_shape_head(shape2) | |||||
| if _dim1 == _dim2: | |||||
| batch_shape = build_shape_head(shape1) | |||||
| # compress inputs to 3d | |||||
| if maxdim > 3: | |||||
| inp1 = f( | |||||
| builtin.Reshape(), | |||||
| inp1, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||||
| build_shape_tail(shape1), | |||||
| ), | |||||
| ) | |||||
| inp2 = f( | |||||
| builtin.Reshape(), | |||||
| inp2, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||||
| build_shape_tail(shape2), | |||||
| ), | |||||
| ) | |||||
| op = builtin.BatchedMatrixMul( | |||||
| transposeA=transpose_a, | |||||
| transposeB=transpose_b, | |||||
| compute_mode=compute_mode, | |||||
| format=format, | |||||
| strategy=strategy.value, | |||||
| ) | |||||
| result = f(op, inp1, inp2) | |||||
| if maxdim > 3: | |||||
| result = f( | |||||
| builtin.Reshape(), | |||||
| result, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| batch_shape, | |||||
| build_shape_tail(f(builtin.GetVarShape(), result)), | |||||
| ), | |||||
| ) | |||||
| if remove_row: | |||||
| result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||||
| if remove_col: | |||||
| result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||||
| return (result,), (True,) | |||||
| return extentedBatchedMatrixMulOp | |||||
| class _Hashable: | |||||
| def __init__(self, value) -> None: | |||||
| self.value = value | |||||
| def __hash__(self) -> int: | |||||
| return hash(str(self.value)) | |||||
| def __eq__(self, o: object) -> bool: | |||||
| if not isinstance(o, _Hashable): | |||||
| return False | |||||
| return self.value == o.value | |||||
| def _matmul( | |||||
| inp1, | |||||
| inp2, | |||||
| transpose_a=False, | |||||
| transpose_b=False, | |||||
| compute_mode="default", | |||||
| format="default", | |||||
| ): | |||||
| if amp._enabled: | if amp._enabled: | ||||
| compute_mode = "float32" | compute_mode = "float32" | ||||
| inp1, inp2 = cast_tensors(inp1, inp2) | inp1, inp2 = cast_tensors(inp1, inp2) | ||||
| else: | else: | ||||
| compute_mode = "default" | |||||
| dtype = dtype_promotion(inp1, inp2) | dtype = dtype_promotion(inp1, inp2) | ||||
| if inp1.dtype != dtype: | if inp1.dtype != dtype: | ||||
| inp1 = inp1.astype(dtype) | inp1 = inp1.astype(dtype) | ||||
| if inp2.dtype != dtype: | if inp2.dtype != dtype: | ||||
| inp2 = inp2.astype(dtype) | inp2 = inp2.astype(dtype) | ||||
| dim1, dim2 = inp1.ndim, inp2.ndim | |||||
| assert dim1 > 0 and dim2 > 0 | |||||
| maxdim = dim1 if dim1 > dim2 else dim2 | |||||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | ||||
| op = builtin.MatrixMul( | |||||
| transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" | |||||
| ) | |||||
| (result,) = apply(op, inp1, inp2) | |||||
| return result | |||||
| Strategy = builtin.ops.MatrixMul.Strategy | |||||
| strategy = Strategy(0) | |||||
| if _config._benchmark_kernel: | |||||
| strategy |= Strategy.PROFILE | |||||
| else: | |||||
| strategy |= Strategy.HEURISTIC | |||||
| if _config._deterministic_kernel: | |||||
| strategy |= Strategy.REPRODUCIBLE | |||||
| if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||||
| (result,) = apply(builtin.Dot(), inp1, inp2) | |||||
| return result | |||||
| elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul | |||||
| extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||||
| inp1.device, | |||||
| inp1.dtype, | |||||
| dim1, | |||||
| dim2, | |||||
| transpose_a, | |||||
| transpose_b, | |||||
| compute_mode, | |||||
| format, | |||||
| strategy=_Hashable(strategy), | |||||
| ) | |||||
| (result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||||
| return result | |||||
| else: # dispath to BatchedMatrixMul | |||||
| extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||||
| inp1.device, | |||||
| inp1.dtype, | |||||
| dim1, | |||||
| dim2, | |||||
| transpose_a, | |||||
| transpose_b, | |||||
| compute_mode, | |||||
| format, | |||||
| strategy=_Hashable(strategy), | |||||
| ) | |||||
| (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||||
| return result | |||||
| def _transpose(data, axes): | def _transpose(data, axes): | ||||
| @@ -8,24 +8,18 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import collections | import collections | ||||
| import math | import math | ||||
| from functools import lru_cache | |||||
| from typing import Iterable, Optional, Sequence, Tuple, Union | from typing import Iterable, Optional, Sequence, Tuple, Union | ||||
| from ..core import _config | |||||
| from ..core._imperative_rt.core2 import apply, dtype_promotion | from ..core._imperative_rt.core2 import apply, dtype_promotion | ||||
| from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | ||||
| from ..core._trace_option import use_symbolic_shape | |||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt | |||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor import amp | |||||
| from ..core.tensor.utils import _normalize_axis, cast_tensors, subgraph | |||||
| from ..jit import exclude_from_trace | |||||
| from ..core.tensor.array_method import _matmul | |||||
| from ..core.tensor.utils import _normalize_axis | |||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from ..utils.deprecation import deprecated_kwargs_default | from ..utils.deprecation import deprecated_kwargs_default | ||||
| from .debug_param import get_execution_strategy | |||||
| from .elemwise import clip, minimum | |||||
| from .tensor import broadcast_to, concat, expand_dims, squeeze | |||||
| from .elemwise import clip | |||||
| from .tensor import expand_dims, squeeze | |||||
| __all__ = [ | __all__ = [ | ||||
| "argmax", | "argmax", | ||||
| @@ -794,229 +788,6 @@ def matinv(inp: Tensor) -> Tensor: | |||||
| return result | return result | ||||
| class _Hashable: | |||||
| def __init__(self, value) -> None: | |||||
| self.value = value | |||||
| def __hash__(self) -> int: | |||||
| return hash(str(self.value)) | |||||
| def __eq__(self, o: object) -> bool: | |||||
| if not isinstance(o, _Hashable): | |||||
| return False | |||||
| return self.value == o.value | |||||
| @lru_cache(maxsize=None) | |||||
| def _get_extentedMatrixMulOp( | |||||
| device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||||
| ): | |||||
| @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||||
| def extentedMatrixMulOp(inputs, f, c): | |||||
| assert len(inputs) == 2 | |||||
| inp1, inp2 = inputs | |||||
| _dim1, _dim2 = dim1, dim2 | |||||
| def build_shape_head(shape, idx=-1): | |||||
| # shape[:idx] | |||||
| return f( | |||||
| builtin.Subtensor(items=[[0, False, True, False, False]]), | |||||
| shape, | |||||
| c(idx, "int32"), | |||||
| ) | |||||
| def build_shape_tail(shape, idx=-1): | |||||
| # shape[idx:] | |||||
| return f( | |||||
| builtin.Subtensor(items=[[0, True, False, False, False]]), | |||||
| shape, | |||||
| c(idx, "int32"), | |||||
| ) | |||||
| remove_row, remove_col = False, False | |||||
| if _dim1 == 1: | |||||
| _dim1 = 2 | |||||
| remove_row = True | |||||
| if _dim2 == 1: | |||||
| _dim2 = 2 | |||||
| remove_col = True | |||||
| if remove_row: | |||||
| inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||||
| if remove_col: | |||||
| inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||||
| shape1 = f(GetVarShape(), inp1) | |||||
| shape2 = f(GetVarShape(), inp2) | |||||
| if _dim1 > 2: | |||||
| inp1 = f( | |||||
| builtin.Reshape(), | |||||
| inp1, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), | |||||
| build_shape_tail(shape1), | |||||
| ), | |||||
| ) | |||||
| if _dim2 > 2: | |||||
| inp2 = f( | |||||
| builtin.Reshape(), | |||||
| inp2, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), | |||||
| build_shape_tail(shape2), | |||||
| ), | |||||
| ) | |||||
| op = builtin.MatrixMul( | |||||
| transposeA=transpose_a, | |||||
| transposeB=transpose_b, | |||||
| compute_mode=compute_mode, | |||||
| format=format, | |||||
| strategy=strategy.value, | |||||
| ) | |||||
| result = f(op, inp1, inp2) | |||||
| result_shape = f(GetVarShape(), result) | |||||
| if _dim1 > 2: | |||||
| result = f( | |||||
| builtin.Reshape(), | |||||
| result, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| build_shape_head(shape1), | |||||
| build_shape_tail(result_shape), | |||||
| ), | |||||
| ) | |||||
| if _dim2 > 2: | |||||
| result = f( | |||||
| builtin.Reshape(), | |||||
| result, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| build_shape_head(shape2), | |||||
| build_shape_tail(result_shape), | |||||
| ), | |||||
| ) | |||||
| maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||||
| if remove_row: | |||||
| result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||||
| if remove_col: | |||||
| result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||||
| return (result,), (True,) | |||||
| return extentedMatrixMulOp | |||||
| @lru_cache(maxsize=None) | |||||
| def _get_extentedBatchedMatrixMulOp( | |||||
| device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||||
| ): | |||||
| @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||||
| def extentedBatchedMatrixMulOp(inputs, f, c): | |||||
| assert len(inputs) == 2 | |||||
| inp1, inp2 = inputs | |||||
| _dim1, _dim2 = dim1, dim2 | |||||
| def build_shape_head(shape, idx=-2): | |||||
| # shape[:idx] | |||||
| return f( | |||||
| builtin.Subtensor(items=[[0, False, True, False, False]]), | |||||
| shape, | |||||
| c(idx, "int32"), | |||||
| ) | |||||
| def build_shape_tail(shape, idx=-2): | |||||
| # shape[idx:] | |||||
| return f( | |||||
| builtin.Subtensor(items=[[0, True, False, False, False]]), | |||||
| shape, | |||||
| c(idx, "int32"), | |||||
| ) | |||||
| remove_row, remove_col = False, False | |||||
| if _dim1 == 1: | |||||
| _dim1 = 2 | |||||
| remove_row = True | |||||
| if _dim2 == 1: | |||||
| _dim2 = 2 | |||||
| remove_col = True | |||||
| if remove_row: | |||||
| inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||||
| if remove_col: | |||||
| inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||||
| shape1 = f(GetVarShape(), inp1) | |||||
| shape2 = f(GetVarShape(), inp2) | |||||
| maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||||
| if _dim1 > _dim2: | |||||
| # broadcast | |||||
| shape2 = f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] | |||||
| shape2, | |||||
| ) | |||||
| inp2 = f(builtin.Broadcast(), inp2, shape2) | |||||
| batch_shape = build_shape_head(shape1) | |||||
| if _dim2 > _dim1: | |||||
| # broadcast | |||||
| shape1 = f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] | |||||
| shape1, | |||||
| ) | |||||
| inp1 = f(builtin.Broadcast(), inp1, shape1) | |||||
| batch_shape = build_shape_head(shape2) | |||||
| if _dim1 == _dim2: | |||||
| batch_shape = build_shape_head(shape1) | |||||
| # compress inputs to 3d | |||||
| if maxdim > 3: | |||||
| inp1 = f( | |||||
| builtin.Reshape(), | |||||
| inp1, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||||
| build_shape_tail(shape1), | |||||
| ), | |||||
| ) | |||||
| inp2 = f( | |||||
| builtin.Reshape(), | |||||
| inp2, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||||
| build_shape_tail(shape2), | |||||
| ), | |||||
| ) | |||||
| op = builtin.BatchedMatrixMul( | |||||
| transposeA=transpose_a, | |||||
| transposeB=transpose_b, | |||||
| compute_mode=compute_mode, | |||||
| format=format, | |||||
| strategy=strategy.value, | |||||
| ) | |||||
| result = f(op, inp1, inp2) | |||||
| if maxdim > 3: | |||||
| result = f( | |||||
| builtin.Reshape(), | |||||
| result, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| batch_shape, | |||||
| build_shape_tail(f(GetVarShape(), result)), | |||||
| ), | |||||
| ) | |||||
| if remove_row: | |||||
| result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||||
| if remove_col: | |||||
| result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||||
| return (result,), (True,) | |||||
| return extentedBatchedMatrixMulOp | |||||
| def matmul( | def matmul( | ||||
| inp1: Tensor, | inp1: Tensor, | ||||
| inp2: Tensor, | inp2: Tensor, | ||||
| @@ -1067,50 +838,7 @@ def matmul( | |||||
| [[10. 13.] | [[10. 13.] | ||||
| [28. 40.]] | [28. 40.]] | ||||
| """ | """ | ||||
| if amp._enabled: | |||||
| compute_mode = "float32" | |||||
| inp1, inp2 = cast_tensors(inp1, inp2) | |||||
| else: | |||||
| dtype = dtype_promotion(inp1, inp2) | |||||
| if inp1.dtype != dtype: | |||||
| inp1 = inp1.astype(dtype) | |||||
| if inp2.dtype != dtype: | |||||
| inp2 = inp2.astype(dtype) | |||||
| dim1, dim2 = inp1.ndim, inp2.ndim | |||||
| assert dim1 > 0 and dim2 > 0 | |||||
| maxdim = dim1 if dim1 > dim2 else dim2 | |||||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||||
| if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||||
| return dot(inp1, inp2) | |||||
| elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul | |||||
| extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||||
| inp1.device, | |||||
| inp1.dtype, | |||||
| dim1, | |||||
| dim2, | |||||
| transpose_a, | |||||
| transpose_b, | |||||
| compute_mode, | |||||
| format, | |||||
| strategy=_Hashable(get_execution_strategy()), | |||||
| ) | |||||
| (result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||||
| return result | |||||
| else: # dispath to BatchedMatrixMul | |||||
| extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||||
| inp1.device, | |||||
| inp1.dtype, | |||||
| dim1, | |||||
| dim2, | |||||
| transpose_a, | |||||
| transpose_b, | |||||
| compute_mode, | |||||
| format, | |||||
| strategy=_Hashable(get_execution_strategy()), | |||||
| ) | |||||
| (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||||
| return result | |||||
| return _matmul(inp1, inp2, transpose_a, transpose_b, compute_mode, format) | |||||
| def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | ||||
| @@ -46,14 +46,17 @@ def test_literal_arith(is_varnode): | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
| def test_matmul(is_varnode): | |||||
| @pytest.mark.parametrize( | |||||
| "shape_a, shape_b", [((4,), (4,)), ((10, 4), (4, 10)), ((3, 10, 4), (3, 4, 10)),], | |||||
| ) | |||||
| def test_matmul(is_varnode, shape_a, shape_b): | |||||
| if is_varnode: | if is_varnode: | ||||
| network = Network() | network = Network() | ||||
| else: | else: | ||||
| network = None | network = None | ||||
| A = make_tensor(np.random.rand(5, 7).astype("float32"), network) | |||||
| B = make_tensor(np.random.rand(7, 10).astype("float32"), network) | |||||
| A = make_tensor(np.random.rand(*shape_a).astype("float32"), network) | |||||
| B = make_tensor(np.random.rand(*shape_b).astype("float32"), network) | |||||
| C = A @ B | C = A @ B | ||||
| if is_varnode: | if is_varnode: | ||||
| np.testing.assert_almost_equal( | np.testing.assert_almost_equal( | ||||