You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bias_add.py 3.5 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """operator dsl function: bias_add"""
  17. import akg
  18. from akg.ops.array.reshape import reshape
  19. from akg.utils import validation_check as vc_util
  20. from akg.utils.format_transform import get_shape
  21. @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, str)
  22. def bias_add(data1, data2, data_format):
  23. """
  24. Adds bias data2 to input tensor data1.
  25. Args:
  26. data1 (tvm.tensor.Tensor): Tensor of type float16, float32.
  27. data2 (tvm.tensor.Tensor): The bias tensor, should be of same type as data1.
  28. If shape(data2) != shape(data1), broadcast will happen.
  29. data_format (str): Data format of input tensors, could be NC1HWC0, NHWC or DefaultFormat.
  30. Returns:
  31. tvm.tensor.Tensor of same shape and type as data1.
  32. """
  33. vc_util.check_shape(data1.shape)
  34. vc_util.check_shape(data2.shape)
  35. shape1 = get_shape(data1)
  36. shape2 = get_shape(data2)
  37. vc_util.davinci_format_check(shape1, data_format)
  38. vc_util.ops_dtype_check([data1.dtype, data2.dtype], vc_util.DtypeForDavinci.ALL_FLOAT)
  39. if data_format == 'NC1HWC0':
  40. data2_new = akg.lang.cce.broadcast(data2, shape1)
  41. res = akg.lang.cce.vadd(data1, data2_new)
  42. else:
  43. if len(shape2) != 1:
  44. raise RuntimeError("data2 should be a 1D Tensor!")
  45. if data_format == "NHWC":
  46. if len(shape1) != 4:
  47. raise RuntimeError("bias_add only support 4D shape when data format is NHWC!")
  48. c_dim_len = shape1[3]
  49. if c_dim_len != shape2[0]:
  50. raise ValueError("The size of bias should be equal to the channel dimension, "
  51. " while the size of bias is {0} and the channel dimension is "
  52. "{1}".format(shape2[0], c_dim_len))
  53. data2_reshaped, _ = reshape(data2, [1, 1, 1, shape2[0]])
  54. elif data_format == "DefaultFormat":
  55. if len(shape1) != 2 and len(shape1) != 4:
  56. raise RuntimeError("bias_add only support 2D and 4D shape when data format is DefaultFormat!")
  57. c_dim_len = shape1[1]
  58. if c_dim_len != shape2[0]:
  59. raise ValueError("The size of bias should be equal to the channel dimension, "
  60. " while the size of bias is {0} and the channel dimension is "
  61. "{1}".format(shape2[0], c_dim_len))
  62. if len(shape1) == 2:
  63. data2_reshaped, _ = reshape(data2, [1, shape2[0]])
  64. else:
  65. # NCHW
  66. data2_reshaped, _ = reshape(data2, [1, shape2[0], 1, 1])
  67. data2_new = akg.lang.cce.broadcast(data2_reshaped, shape1)
  68. res = akg.lang.cce.vadd(data1, data2_new)
  69. akg.register_variables("reshape_diff", [data2], data2_reshaped)
  70. return res