|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ===========================================================================
- """generate json desc for BatchMatMul and MatMul"""
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
- from ._utils import Expander, ExpanderInfoValidator as VLD
-
-
- @VLD.check_attrs('transpose_a', 'transpose_b', 'left_format', 'right_format')
- class MatMul(Expander):
- """
- MatMul expander
- """
-
- def __init__(self, expand_info):
- super(MatMul, self).__init__(expand_info)
- self.transpose_a = self.attrs['transpose_a']
- self.transpose_b = self.attrs['transpose_b']
- self.left_format = self.attrs['left_format']
- self.right_format = self.attrs['right_format']
- self.shape_a = self.inputs[0]['shape']
- self.shape_b = self.inputs[1]['shape']
-
- def _optimize_to_mul(self):
- """check if matmul can be replace by mul"""
- if self.processor != 'aicore' or self.left_format != DF.DEFAULT or self.right_format != DF.DEFAULT:
- return False
- k_a = self.shape_a[-2] if self.transpose_a else self.shape_a[-1]
- k_b = self.shape_b[-1] if self.transpose_b else self.shape_b[-2]
- if k_a != 1 or k_b != 1:
- return False
- return True
-
- def _check(self):
- input_num = len(self.inputs)
- if input_num < 2:
- raise GKException("matul inputs number should bigger than 1, but got {}.".format(input_num))
-
- def _expand(self, graph_builder):
- def transpose(shape):
- trans_shape = list(shape)
- trans_shape[-2] = shape[-1]
- trans_shape[-1] = shape[-2]
- return trans_shape
- if not self._optimize_to_mul():
- raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul")
- # Matmul is replaced by Mul([b m k], [b k n]) when k==1
- input_a = self.inputs[0]
- input_b = self.inputs[1]
- if self.transpose_a:
- shape_a_trans = transpose(self.shape_a)
- input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans})
- if self.transpose_b:
- shape_b_trans = transpose(self.shape_b)
- input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans})
- result = graph_builder.emit('Mul', [input_a, input_b])
- if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']:
- result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']})
- return result
-
-
- class BatchMatMul(MatMul):
- """BatchMatMul expander"""
|