You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch_matmul.py 4.4 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """operator dsl function: batch_matmul"""
  15. import numpy as np
  16. import akg.topi as topi
  17. import akg.tvm as tvm
  18. from akg.utils import validation_check as vc_util
  19. def batch_matmul(data1, data2, bias=None, layout1="NHDT", layout2="NHDT", layout_out="NHDT"):
  20. if len(data1.shape) == 4:
  21. res = batch_matmul_4D(data1, data2, bias, layout1, layout2, layout_out)
  22. elif len(data1.shape) == 2:
  23. res = batch_matmul_2D(data1, data2, bias, layout1, layout2, layout_out)
  24. else:
  25. res = batch_matmul_3D(data1, data2, bias, layout1, layout2, layout_out)
  26. return res
  27. def auto_in_transpose(data, layout="NHDT"):
  28. layout_int = layout.replace('N', '0').replace('H', '1').replace('D', '2').replace('T', '3')
  29. layout_list = list(layout_int)
  30. layout_axis = np.argsort(layout_list)
  31. data = topi.transpose(data, axes=tuple(layout_axis))
  32. return data
  33. def auto_out_transpose(expect, layout_out="NHDT"):
  34. if len(expect.shape) == 3:
  35. layout_out = layout_out[1:]
  36. if len(expect.shape) == 2:
  37. layout_out = layout_out[2:]
  38. layout_out_int = layout_out.replace('N', '0').replace('H', '1').replace('D', '2').replace('T', '3')
  39. layout_out_list = list(layout_out_int)
  40. layout_out_axis = np.argsort(layout_out_list)
  41. expect = topi.transpose(expect, axes=tuple(layout_out_axis))
  42. return expect
  43. def batch_matmul_3D(data1, data2, bias, layout1="NHDT", layout2="NHDT", layout_out="NHDT"):
  44. if layout1 != "NHDT":
  45. layout1 = layout1[1:]
  46. data1 = auto_in_transpose(data1, layout1)
  47. if layout2 != "NHDT":
  48. layout2 = layout2[1:]
  49. data2 = auto_in_transpose(data2, layout2)
  50. res = topi.nn.batch_matmul(data1, data2)
  51. if bias is not None:
  52. res = topi.add(res, bias)
  53. if layout_out != "NHDT":
  54. res = auto_out_transpose(res, layout_out)
  55. return res
  56. def batch_matmul_4D(data1, data2, bias, layout1="NHDT", layout2="NHDT", layout_out="NHDT"):
  57. if layout1 != "NHDT":
  58. data1 = auto_in_transpose(data1, layout1)
  59. if layout2 != "NHDT":
  60. data2 = auto_in_transpose(data2, layout2)
  61. b1, b2, m, k = data1.shape
  62. b1, b2, n, k = data2.shape
  63. reduce_axis = tvm.reduce_axis((0, k), name='reduce_axis')
  64. res = tvm.compute((b1, b2, m, n), lambda i_b1, i_b2, i_m, i_n: tvm.sum(data1[i_b1, i_b2, i_m, reduce_axis] *
  65. data2[i_b1, i_b2, i_n, reduce_axis],
  66. axis=reduce_axis), name='matmul_compute')
  67. if bias is not None:
  68. res = topi.add(res, bias)
  69. if layout_out != "NHDT":
  70. res = auto_out_transpose(res, layout_out)
  71. return res
  72. def batch_matmul_2D(data1, data2, bias, layout1="NHDT", layout2="NHDT", layout_out="NHDT"):
  73. if layout1 != "NHDT":
  74. layout1 = layout1[2:]
  75. data1 = auto_in_transpose(data1, layout1)
  76. if layout2 != "NHDT":
  77. layout2 = layout2[2:]
  78. data2 = auto_in_transpose(data2, layout2)
  79. m, k = data1.shape
  80. n, k = data2.shape
  81. reduce_axis = tvm.reduce_axis((0, k), name='reduce_axis')
  82. res = tvm.compute((m, n), lambda i_m, i_n: tvm.sum(data1[i_m, reduce_axis] *
  83. data2[i_n, reduce_axis],
  84. axis=reduce_axis), name='matmul_compute')
  85. if bias is not None:
  86. res = topi.add(res, bias)
  87. if layout_out != "NHDT":
  88. res = auto_out_transpose(res, layout_out)
  89. return res