You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

broadcast_util.py 3.7 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """broadcast_util"""
  17. from functools import reduce
  18. import akg.tvm
  19. from akg.utils.format_transform import get_shape
  20. from akg.ms.utils import NC1HWC0
  21. def need_broadcast(main_data_shape, main_logical_shape, with_shape):
  22. """return False if main_data needn't to do broadcast"""
  23. if not with_shape:
  24. return False
  25. if not main_logical_shape:
  26. return False
  27. with_data_num = reduce(lambda x, y: x * y, with_shape)
  28. if with_data_num == 1:
  29. return False
  30. if main_logical_shape == with_shape:
  31. return False
  32. main_logical_shape_new = main_logical_shape if main_logical_shape else (1,)
  33. # No special broadcast is needed if there is no pad in data
  34. if main_logical_shape_new == main_data_shape:
  35. return False
  36. if len(main_logical_shape) >= len(with_shape):
  37. for i in range(0 - len(with_shape), 0):
  38. if main_logical_shape[i] < with_shape[i]:
  39. return True
  40. return False
  41. return True
  42. def broadcast_by_format(ori_data, logical_shape, format_in, with_shape):
  43. """
  44. Do special broadcast for special formats when padding axis needs to broadcast, such as C in NCHW(NC1HWC0).
  45. Rewrite padding value to broadcast value in special case, for example: op1 * op2, where op1 and op2 are both
  46. NC1HWC0, and their logical 4D shapes are (4, 1, 3, 3) and (4, 4, 3, 3). op1's shape become (4, 1, 3, 3, 16) after
  47. transformation from 4D to NC1HWC0. we need to fill the data of axis C0 with broadcast value but not padding value.
  48. Note:
  49. There is no need to do broadcast for scalar and DefaultFormat(or NHWC) here.
  50. """
  51. ori_data_shape = tuple(get_shape(ori_data))
  52. if not need_broadcast(ori_data_shape, tuple(logical_shape),
  53. tuple(with_shape)):
  54. return ori_data
  55. nchw_shape_len = fracz_shape_len = 4
  56. nc1hwc0_shape_len = 5
  57. logical_shape_new = tuple(logical_shape) if logical_shape else (1,)
  58. data_num = reduce(lambda x, y: x * y, logical_shape_new)
  59. if data_num == 1:
  60. # this is a scalar
  61. if len(ori_data_shape) == fracz_shape_len:
  62. new_data = akg.tvm.compute((1,), lambda i: ori_data[0, 0, 0, i])
  63. elif len(ori_data_shape) == nc1hwc0_shape_len:
  64. new_data = akg.tvm.compute((1,), lambda i: ori_data[0, 0, 0, 0, i])
  65. else:
  66. raise RuntimeError("Unsupported shape {}".format(ori_data_shape))
  67. return new_data
  68. # NC1HWC0
  69. if format_in == NC1HWC0:
  70. if len(with_shape) != nchw_shape_len:
  71. raise ValueError("with_shape must be 4D, while it is {}".format(with_shape))
  72. # rewrite padding value to broadcast value only if C(NCHW, NHWC is not considered) is the broadcast axis
  73. if logical_shape[1] == 1:
  74. new_data = akg.tvm.compute(ori_data_shape, lambda n, c1, h, w, c0: ori_data[n, c1, h, w, 0])
  75. return new_data
  76. return ori_data
  77. raise RuntimeError("Broadcast is unsupported when logical_shape is {}, and format is {}".
  78. format(logical_shape, format_in))