GitOrigin-RevId: bf6136cc1a
tags/v1.9.0
| @@ -8,6 +8,7 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import abc | |||
| import collections | |||
| from functools import lru_cache | |||
| from typing import Union | |||
| import numpy as np | |||
| @@ -24,8 +25,8 @@ from .utils import ( | |||
| astype, | |||
| cast_tensors, | |||
| convert_inputs, | |||
| isscalar, | |||
| make_shape_tuple, | |||
| subgraph, | |||
| ) | |||
| _ElwMod = builtin.Elemwise.Mode | |||
| @@ -73,23 +74,292 @@ def _elwise(*args, mode): | |||
| return _elwise_apply(args, mode) | |||
| def _matmul(inp1, inp2): | |||
| @lru_cache(maxsize=None) | |||
| def _get_extentedMatrixMulOp( | |||
| device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||
| ): | |||
| @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||
| def extentedMatrixMulOp(inputs, f, c): | |||
| assert len(inputs) == 2 | |||
| inp1, inp2 = inputs | |||
| _dim1, _dim2 = dim1, dim2 | |||
| def build_shape_head(shape, idx=-1): | |||
| # shape[:idx] | |||
| return f( | |||
| builtin.Subtensor(items=[[0, False, True, False, False]]), | |||
| shape, | |||
| c(idx, "int32"), | |||
| ) | |||
| def build_shape_tail(shape, idx=-1): | |||
| # shape[idx:] | |||
| return f( | |||
| builtin.Subtensor(items=[[0, True, False, False, False]]), | |||
| shape, | |||
| c(idx, "int32"), | |||
| ) | |||
| remove_row, remove_col = False, False | |||
| if _dim1 == 1: | |||
| _dim1 = 2 | |||
| remove_row = True | |||
| if _dim2 == 1: | |||
| _dim2 = 2 | |||
| remove_col = True | |||
| if remove_row: | |||
| inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||
| if remove_col: | |||
| inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||
| shape1 = f(builtin.GetVarShape(), inp1) | |||
| shape2 = f(builtin.GetVarShape(), inp2) | |||
| if _dim1 > 2: | |||
| inp1 = f( | |||
| builtin.Reshape(), | |||
| inp1, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), | |||
| build_shape_tail(shape1), | |||
| ), | |||
| ) | |||
| if _dim2 > 2: | |||
| inp2 = f( | |||
| builtin.Reshape(), | |||
| inp2, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), | |||
| build_shape_tail(shape2), | |||
| ), | |||
| ) | |||
| op = builtin.MatrixMul( | |||
| transposeA=transpose_a, | |||
| transposeB=transpose_b, | |||
| compute_mode=compute_mode, | |||
| format=format, | |||
| strategy=strategy.value, | |||
| ) | |||
| result = f(op, inp1, inp2) | |||
| result_shape = f(builtin.GetVarShape(), result) | |||
| if _dim1 > 2: | |||
| result = f( | |||
| builtin.Reshape(), | |||
| result, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| build_shape_head(shape1), | |||
| build_shape_tail(result_shape), | |||
| ), | |||
| ) | |||
| if _dim2 > 2: | |||
| result = f( | |||
| builtin.Reshape(), | |||
| result, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| build_shape_head(shape2), | |||
| build_shape_tail(result_shape), | |||
| ), | |||
| ) | |||
| maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||
| if remove_row: | |||
| result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||
| if remove_col: | |||
| result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||
| return (result,), (True,) | |||
| return extentedMatrixMulOp | |||
| @lru_cache(maxsize=None) | |||
| def _get_extentedBatchedMatrixMulOp( | |||
| device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||
| ): | |||
| @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||
| def extentedBatchedMatrixMulOp(inputs, f, c): | |||
| assert len(inputs) == 2 | |||
| inp1, inp2 = inputs | |||
| _dim1, _dim2 = dim1, dim2 | |||
| def build_shape_head(shape, idx=-2): | |||
| # shape[:idx] | |||
| return f( | |||
| builtin.Subtensor(items=[[0, False, True, False, False]]), | |||
| shape, | |||
| c(idx, "int32"), | |||
| ) | |||
| def build_shape_tail(shape, idx=-2): | |||
| # shape[idx:] | |||
| return f( | |||
| builtin.Subtensor(items=[[0, True, False, False, False]]), | |||
| shape, | |||
| c(idx, "int32"), | |||
| ) | |||
| remove_row, remove_col = False, False | |||
| if _dim1 == 1: | |||
| _dim1 = 2 | |||
| remove_row = True | |||
| if _dim2 == 1: | |||
| _dim2 = 2 | |||
| remove_col = True | |||
| if remove_row: | |||
| inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||
| if remove_col: | |||
| inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||
| shape1 = f(builtin.GetVarShape(), inp1) | |||
| shape2 = f(builtin.GetVarShape(), inp2) | |||
| maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||
| if _dim1 > _dim2: | |||
| # broadcast | |||
| shape2 = f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] | |||
| shape2, | |||
| ) | |||
| inp2 = f(builtin.Broadcast(), inp2, shape2) | |||
| batch_shape = build_shape_head(shape1) | |||
| if _dim2 > _dim1: | |||
| # broadcast | |||
| shape1 = f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] | |||
| shape1, | |||
| ) | |||
| inp1 = f(builtin.Broadcast(), inp1, shape1) | |||
| batch_shape = build_shape_head(shape2) | |||
| if _dim1 == _dim2: | |||
| batch_shape = build_shape_head(shape1) | |||
| # compress inputs to 3d | |||
| if maxdim > 3: | |||
| inp1 = f( | |||
| builtin.Reshape(), | |||
| inp1, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||
| build_shape_tail(shape1), | |||
| ), | |||
| ) | |||
| inp2 = f( | |||
| builtin.Reshape(), | |||
| inp2, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||
| build_shape_tail(shape2), | |||
| ), | |||
| ) | |||
| op = builtin.BatchedMatrixMul( | |||
| transposeA=transpose_a, | |||
| transposeB=transpose_b, | |||
| compute_mode=compute_mode, | |||
| format=format, | |||
| strategy=strategy.value, | |||
| ) | |||
| result = f(op, inp1, inp2) | |||
| if maxdim > 3: | |||
| result = f( | |||
| builtin.Reshape(), | |||
| result, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| batch_shape, | |||
| build_shape_tail(f(builtin.GetVarShape(), result)), | |||
| ), | |||
| ) | |||
| if remove_row: | |||
| result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||
| if remove_col: | |||
| result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||
| return (result,), (True,) | |||
| return extentedBatchedMatrixMulOp | |||
| class _Hashable: | |||
| def __init__(self, value) -> None: | |||
| self.value = value | |||
| def __hash__(self) -> int: | |||
| return hash(str(self.value)) | |||
| def __eq__(self, o: object) -> bool: | |||
| if not isinstance(o, _Hashable): | |||
| return False | |||
| return self.value == o.value | |||
| def _matmul( | |||
| inp1, | |||
| inp2, | |||
| transpose_a=False, | |||
| transpose_b=False, | |||
| compute_mode="default", | |||
| format="default", | |||
| ): | |||
| if amp._enabled: | |||
| compute_mode = "float32" | |||
| inp1, inp2 = cast_tensors(inp1, inp2) | |||
| else: | |||
| compute_mode = "default" | |||
| dtype = dtype_promotion(inp1, inp2) | |||
| if inp1.dtype != dtype: | |||
| inp1 = inp1.astype(dtype) | |||
| if inp2.dtype != dtype: | |||
| inp2 = inp2.astype(dtype) | |||
| dim1, dim2 = inp1.ndim, inp2.ndim | |||
| assert dim1 > 0 and dim2 > 0 | |||
| maxdim = dim1 if dim1 > dim2 else dim2 | |||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
| op = builtin.MatrixMul( | |||
| transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" | |||
| ) | |||
| (result,) = apply(op, inp1, inp2) | |||
| return result | |||
| Strategy = builtin.ops.MatrixMul.Strategy | |||
| strategy = Strategy(0) | |||
| if _config._benchmark_kernel: | |||
| strategy |= Strategy.PROFILE | |||
| else: | |||
| strategy |= Strategy.HEURISTIC | |||
| if _config._deterministic_kernel: | |||
| strategy |= Strategy.REPRODUCIBLE | |||
| if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||
| (result,) = apply(builtin.Dot(), inp1, inp2) | |||
| return result | |||
| elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul | |||
| extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||
| inp1.device, | |||
| inp1.dtype, | |||
| dim1, | |||
| dim2, | |||
| transpose_a, | |||
| transpose_b, | |||
| compute_mode, | |||
| format, | |||
| strategy=_Hashable(strategy), | |||
| ) | |||
| (result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||
| return result | |||
| else: # dispath to BatchedMatrixMul | |||
| extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||
| inp1.device, | |||
| inp1.dtype, | |||
| dim1, | |||
| dim2, | |||
| transpose_a, | |||
| transpose_b, | |||
| compute_mode, | |||
| format, | |||
| strategy=_Hashable(strategy), | |||
| ) | |||
| (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||
| return result | |||
| def _transpose(data, axes): | |||
| @@ -8,24 +8,18 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import math | |||
| from functools import lru_cache | |||
| from typing import Iterable, Optional, Sequence, Tuple, Union | |||
| from ..core import _config | |||
| from ..core._imperative_rt.core2 import apply, dtype_promotion | |||
| from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | |||
| from ..core._trace_option import use_symbolic_shape | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import amp | |||
| from ..core.tensor.utils import _normalize_axis, cast_tensors, subgraph | |||
| from ..jit import exclude_from_trace | |||
| from ..core.tensor.array_method import _matmul | |||
| from ..core.tensor.utils import _normalize_axis | |||
| from ..tensor import Tensor | |||
| from ..utils.deprecation import deprecated_kwargs_default | |||
| from .debug_param import get_execution_strategy | |||
| from .elemwise import clip, minimum | |||
| from .tensor import broadcast_to, concat, expand_dims, squeeze | |||
| from .elemwise import clip | |||
| from .tensor import expand_dims, squeeze | |||
| __all__ = [ | |||
| "argmax", | |||
| @@ -794,229 +788,6 @@ def matinv(inp: Tensor) -> Tensor: | |||
| return result | |||
| class _Hashable: | |||
| def __init__(self, value) -> None: | |||
| self.value = value | |||
| def __hash__(self) -> int: | |||
| return hash(str(self.value)) | |||
| def __eq__(self, o: object) -> bool: | |||
| if not isinstance(o, _Hashable): | |||
| return False | |||
| return self.value == o.value | |||
| @lru_cache(maxsize=None) | |||
| def _get_extentedMatrixMulOp( | |||
| device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||
| ): | |||
| @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||
| def extentedMatrixMulOp(inputs, f, c): | |||
| assert len(inputs) == 2 | |||
| inp1, inp2 = inputs | |||
| _dim1, _dim2 = dim1, dim2 | |||
| def build_shape_head(shape, idx=-1): | |||
| # shape[:idx] | |||
| return f( | |||
| builtin.Subtensor(items=[[0, False, True, False, False]]), | |||
| shape, | |||
| c(idx, "int32"), | |||
| ) | |||
| def build_shape_tail(shape, idx=-1): | |||
| # shape[idx:] | |||
| return f( | |||
| builtin.Subtensor(items=[[0, True, False, False, False]]), | |||
| shape, | |||
| c(idx, "int32"), | |||
| ) | |||
| remove_row, remove_col = False, False | |||
| if _dim1 == 1: | |||
| _dim1 = 2 | |||
| remove_row = True | |||
| if _dim2 == 1: | |||
| _dim2 = 2 | |||
| remove_col = True | |||
| if remove_row: | |||
| inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||
| if remove_col: | |||
| inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||
| shape1 = f(GetVarShape(), inp1) | |||
| shape2 = f(GetVarShape(), inp2) | |||
| if _dim1 > 2: | |||
| inp1 = f( | |||
| builtin.Reshape(), | |||
| inp1, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), | |||
| build_shape_tail(shape1), | |||
| ), | |||
| ) | |||
| if _dim2 > 2: | |||
| inp2 = f( | |||
| builtin.Reshape(), | |||
| inp2, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), | |||
| build_shape_tail(shape2), | |||
| ), | |||
| ) | |||
| op = builtin.MatrixMul( | |||
| transposeA=transpose_a, | |||
| transposeB=transpose_b, | |||
| compute_mode=compute_mode, | |||
| format=format, | |||
| strategy=strategy.value, | |||
| ) | |||
| result = f(op, inp1, inp2) | |||
| result_shape = f(GetVarShape(), result) | |||
| if _dim1 > 2: | |||
| result = f( | |||
| builtin.Reshape(), | |||
| result, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| build_shape_head(shape1), | |||
| build_shape_tail(result_shape), | |||
| ), | |||
| ) | |||
| if _dim2 > 2: | |||
| result = f( | |||
| builtin.Reshape(), | |||
| result, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| build_shape_head(shape2), | |||
| build_shape_tail(result_shape), | |||
| ), | |||
| ) | |||
| maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||
| if remove_row: | |||
| result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||
| if remove_col: | |||
| result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||
| return (result,), (True,) | |||
| return extentedMatrixMulOp | |||
| @lru_cache(maxsize=None) | |||
| def _get_extentedBatchedMatrixMulOp( | |||
| device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||
| ): | |||
| @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||
| def extentedBatchedMatrixMulOp(inputs, f, c): | |||
| assert len(inputs) == 2 | |||
| inp1, inp2 = inputs | |||
| _dim1, _dim2 = dim1, dim2 | |||
| def build_shape_head(shape, idx=-2): | |||
| # shape[:idx] | |||
| return f( | |||
| builtin.Subtensor(items=[[0, False, True, False, False]]), | |||
| shape, | |||
| c(idx, "int32"), | |||
| ) | |||
| def build_shape_tail(shape, idx=-2): | |||
| # shape[idx:] | |||
| return f( | |||
| builtin.Subtensor(items=[[0, True, False, False, False]]), | |||
| shape, | |||
| c(idx, "int32"), | |||
| ) | |||
| remove_row, remove_col = False, False | |||
| if _dim1 == 1: | |||
| _dim1 = 2 | |||
| remove_row = True | |||
| if _dim2 == 1: | |||
| _dim2 = 2 | |||
| remove_col = True | |||
| if remove_row: | |||
| inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||
| if remove_col: | |||
| inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||
| shape1 = f(GetVarShape(), inp1) | |||
| shape2 = f(GetVarShape(), inp2) | |||
| maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||
| if _dim1 > _dim2: | |||
| # broadcast | |||
| shape2 = f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] | |||
| shape2, | |||
| ) | |||
| inp2 = f(builtin.Broadcast(), inp2, shape2) | |||
| batch_shape = build_shape_head(shape1) | |||
| if _dim2 > _dim1: | |||
| # broadcast | |||
| shape1 = f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] | |||
| shape1, | |||
| ) | |||
| inp1 = f(builtin.Broadcast(), inp1, shape1) | |||
| batch_shape = build_shape_head(shape2) | |||
| if _dim1 == _dim2: | |||
| batch_shape = build_shape_head(shape1) | |||
| # compress inputs to 3d | |||
| if maxdim > 3: | |||
| inp1 = f( | |||
| builtin.Reshape(), | |||
| inp1, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||
| build_shape_tail(shape1), | |||
| ), | |||
| ) | |||
| inp2 = f( | |||
| builtin.Reshape(), | |||
| inp2, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||
| build_shape_tail(shape2), | |||
| ), | |||
| ) | |||
| op = builtin.BatchedMatrixMul( | |||
| transposeA=transpose_a, | |||
| transposeB=transpose_b, | |||
| compute_mode=compute_mode, | |||
| format=format, | |||
| strategy=strategy.value, | |||
| ) | |||
| result = f(op, inp1, inp2) | |||
| if maxdim > 3: | |||
| result = f( | |||
| builtin.Reshape(), | |||
| result, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| batch_shape, | |||
| build_shape_tail(f(GetVarShape(), result)), | |||
| ), | |||
| ) | |||
| if remove_row: | |||
| result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||
| if remove_col: | |||
| result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||
| return (result,), (True,) | |||
| return extentedBatchedMatrixMulOp | |||
| def matmul( | |||
| inp1: Tensor, | |||
| inp2: Tensor, | |||
| @@ -1067,50 +838,7 @@ def matmul( | |||
| [[10. 13.] | |||
| [28. 40.]] | |||
| """ | |||
| if amp._enabled: | |||
| compute_mode = "float32" | |||
| inp1, inp2 = cast_tensors(inp1, inp2) | |||
| else: | |||
| dtype = dtype_promotion(inp1, inp2) | |||
| if inp1.dtype != dtype: | |||
| inp1 = inp1.astype(dtype) | |||
| if inp2.dtype != dtype: | |||
| inp2 = inp2.astype(dtype) | |||
| dim1, dim2 = inp1.ndim, inp2.ndim | |||
| assert dim1 > 0 and dim2 > 0 | |||
| maxdim = dim1 if dim1 > dim2 else dim2 | |||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
| if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||
| return dot(inp1, inp2) | |||
| elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul | |||
| extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||
| inp1.device, | |||
| inp1.dtype, | |||
| dim1, | |||
| dim2, | |||
| transpose_a, | |||
| transpose_b, | |||
| compute_mode, | |||
| format, | |||
| strategy=_Hashable(get_execution_strategy()), | |||
| ) | |||
| (result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||
| return result | |||
| else: # dispath to BatchedMatrixMul | |||
| extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||
| inp1.device, | |||
| inp1.dtype, | |||
| dim1, | |||
| dim2, | |||
| transpose_a, | |||
| transpose_b, | |||
| compute_mode, | |||
| format, | |||
| strategy=_Hashable(get_execution_strategy()), | |||
| ) | |||
| (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||
| return result | |||
| return _matmul(inp1, inp2, transpose_a, transpose_b, compute_mode, format) | |||
| def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||
| @@ -46,14 +46,17 @@ def test_literal_arith(is_varnode): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_matmul(is_varnode): | |||
| @pytest.mark.parametrize( | |||
| "shape_a, shape_b", [((4,), (4,)), ((10, 4), (4, 10)), ((3, 10, 4), (3, 4, 10)),], | |||
| ) | |||
| def test_matmul(is_varnode, shape_a, shape_b): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| A = make_tensor(np.random.rand(5, 7).astype("float32"), network) | |||
| B = make_tensor(np.random.rand(7, 10).astype("float32"), network) | |||
| A = make_tensor(np.random.rand(*shape_a).astype("float32"), network) | |||
| B = make_tensor(np.random.rand(*shape_b).astype("float32"), network) | |||
| C = A @ B | |||
| if is_varnode: | |||
| np.testing.assert_almost_equal( | |||