You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matmul.py 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """generate json desc for BatchMatMul and MatMul"""
  16. from mindspore._extends.graph_kernel.model.model import DataFormat as DF
  17. from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
  18. from ._utils import Expander, ExpanderInfoValidator as VLD
  19. @VLD.check_attrs('transpose_a', 'transpose_b', 'left_format', 'right_format')
  20. class MatMul(Expander):
  21. """
  22. MatMul expander
  23. """
  24. def __init__(self, expand_info):
  25. super(MatMul, self).__init__(expand_info)
  26. self.transpose_a = self.attrs['transpose_a']
  27. self.transpose_b = self.attrs['transpose_b']
  28. self.left_format = self.attrs['left_format']
  29. self.right_format = self.attrs['right_format']
  30. self.shape_a = self.inputs[0]['shape']
  31. self.shape_b = self.inputs[1]['shape']
  32. def _optimize_to_mul(self):
  33. """check if matmul can be replace by mul"""
  34. if self.processor != 'aicore' or self.left_format != DF.DEFAULT or self.right_format != DF.DEFAULT:
  35. return False
  36. k_a = self.shape_a[-2] if self.transpose_a else self.shape_a[-1]
  37. k_b = self.shape_b[-1] if self.transpose_b else self.shape_b[-2]
  38. if k_a != 1 or k_b != 1:
  39. return False
  40. return True
  41. def _check(self):
  42. input_num = len(self.inputs)
  43. if input_num < 2:
  44. raise GKException("matul inputs number should bigger than 1, but got {}.".format(input_num))
  45. def _expand(self, graph_builder):
  46. def transpose(shape):
  47. trans_shape = list(shape)
  48. trans_shape[-2] = shape[-1]
  49. trans_shape[-1] = shape[-2]
  50. return trans_shape
  51. if not self._optimize_to_mul():
  52. raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul")
  53. # Matmul is replaced by Mul([b m k], [b k n]) when k==1
  54. input_a = self.inputs[0]
  55. input_b = self.inputs[1]
  56. if self.transpose_a:
  57. shape_a_trans = transpose(self.shape_a)
  58. input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans})
  59. if self.transpose_b:
  60. shape_b_trans = transpose(self.shape_b)
  61. input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans})
  62. result = graph_builder.emit('Mul', [input_a, input_b])
  63. if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']:
  64. result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']})
  65. return result
  66. class BatchMatMul(MatMul):
  67. """BatchMatMul expander"""