You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dsl_create.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """dsl create helping function"""
  15. import _akg
  16. from _akg.utils import format_transform as ft_util
  17. class TensorUtils:
  18. """Class for creating tensor."""
  19. CREATE_SCH_ONLY = 'create_sch_only'
  20. @classmethod
  21. def get_tensor_attrs(cls, tensor):
  22. """get tensor attrs."""
  23. tensor_attrs = dict()
  24. if "attrs" in dir(tensor.op):
  25. tensor_attrs = dict(tensor.op.attrs.items())
  26. return tensor_attrs
  27. @classmethod
  28. def update_tensor_attrs(cls, tensor, attrs):
  29. """update tensor attrs."""
  30. tensor_attrs = cls.get_tensor_attrs(tensor)
  31. tensor_attrs.update(attrs)
  32. tensor = _akg.tvm.compute(tensor.shape,
  33. lambda *indice: tensor[indice],
  34. name=tensor.op.name,
  35. tag=tensor.op.tag,
  36. attrs=tensor_attrs)
  37. return tensor
  38. @classmethod
  39. def is_create_sch_only(cls, tensor):
  40. tensor_attrs = cls.get_tensor_attrs(tensor)
  41. if cls.CREATE_SCH_ONLY in tensor_attrs.keys():
  42. return True
  43. return False
  44. @classmethod
  45. def is_output_value(cls, tensor):
  46. """check output value."""
  47. return not cls.is_create_sch_only(tensor)
  48. @classmethod
  49. def inplace_set(cls, input_tensor, output_tensor, buffer_name="data_buf"):
  50. """inplace set."""
  51. input_tensor_shape = ft_util.get_shape(input_tensor)
  52. output_tensor_shape = ft_util.get_shape(output_tensor)
  53. if not input_tensor_shape == output_tensor_shape:
  54. raise RuntimeError("Shape of the input_tensor and the output_tensor should be equal, "
  55. "but got %s and %s"%(input_tensor_shape, output_tensor_shape))
  56. output_tensor = cls.update_tensor_attrs(output_tensor, {cls.CREATE_SCH_ONLY: 1})
  57. data_buf = _akg.tvm.decl_buffer(input_tensor.shape, input_tensor.dtype, name=buffer_name)
  58. binds_info = {input_tensor: data_buf, output_tensor: data_buf}
  59. return output_tensor, binds_info
  60. @classmethod
  61. def inplace_set_tensors(cls, input_tensors, output_tensors, buffer_names=None):
  62. """
  63. inplace set for tensors
  64. Args:
  65. in_tensors (Union[list, tuple]): Origin input tensors.
  66. out_tensors (Union[list, tuple]): Origin output tensors.
  67. buffer_names (Union[list, tuple] or None): Buffer names used to bind.
  68. Return:
  69. inplace_tensors (list): Output tensors with the inplace info.
  70. binds_infos (dict): Dictionary that maps the input tensor and the output
  71. tensor to buffer.
  72. """
  73. if not buffer_names:
  74. buffer_names = ["data_buf_%s" % i for i in range(len(input_tensors))]
  75. for arg in (input_tensors, output_tensors, buffer_names):
  76. if not isinstance(arg, (tuple, list)):
  77. raise RuntimeError("arg must be tuple or list!")
  78. if len(input_tensors) != len(output_tensors) or len(input_tensors) != len(buffer_names):
  79. raise RuntimeError("length of the input_tensors, output_tensors and buffer_names must be equal!")
  80. inplace_tensors = []
  81. binds_infos = dict()
  82. for input_tensor, output_tensor, buffer_name in zip(input_tensors, output_tensors, buffer_names):
  83. inplace_tensor, binds_info = cls.inplace_set(input_tensor, output_tensor, buffer_name)
  84. inplace_tensors.append(inplace_tensor)
  85. binds_infos.update(binds_info)
  86. return inplace_tensors, binds_infos
  87. def produce_shapes(shape1, shape2):
  88. """two input shapes produce three output shape."""
  89. shape1 = list(shape1)
  90. shape2 = list(shape2)
  91. flag = 0
  92. if len(shape1) < len(shape2):
  93. shape1, shape2 = shape2, shape1
  94. flag = 1
  95. output_shape_len = len(shape1)
  96. dec = output_shape_len - len(shape2)
  97. for i in range(dec):
  98. shape2 = [1] + shape2
  99. out_shape = []
  100. for i in range(output_shape_len):
  101. if (shape1[i] != shape2[i]) and (shape1[i] != 1) and (shape2[i] != 1):
  102. raise RuntimeError("input shapes not match!")
  103. out_shape.append(shape1[i] if shape1[i] > shape2[i] else shape2[i])
  104. if flag == 1:
  105. shape1, shape2 = shape2, shape1
  106. return shape1, shape2, out_shape