You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dynamic_shape.py 4.5 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """dynamic shape function"""
  17. import akg
  18. import akg.tvm
  19. from akg.utils.format_transform import get_shape
  20. NODE_TYPE = "DynamicShapeNode"
  21. def to_expanded_list(data):
  22. data_list = []
  23. if isinstance(data, (list, tuple)):
  24. for i in data:
  25. tmp_list = to_expanded_list(i)
  26. for ii in tmp_list:
  27. data_list.append(ii)
  28. else:
  29. data_list.append(data)
  30. return data_list
  31. def shape_is_dynamic(data):
  32. data_list = to_expanded_list(data)
  33. for i in data_list:
  34. shape = get_shape(i)
  35. if False in [isinstance(s, (int, akg.tvm.expr.IntImm)) for s in shape]:
  36. return True
  37. return False
  38. def preprocess_position(position):
  39. """check position's value is valid and turn integer position into list"""
  40. if isinstance(position, (list, tuple)):
  41. for p in position:
  42. if not isinstance(p, int):
  43. raise TypeError("Position of tensor should be a integer")
  44. elif isinstance(position, int):
  45. position = [position]
  46. else:
  47. raise TypeError(
  48. "Position of tensor should be a integer, list or a tuple")
  49. return position
  50. def preprocess_value_with_position(values, position):
  51. """check value is valid and compatible with position, and turn integer into list"""
  52. if isinstance(values, (list, tuple)):
  53. if len(values) != len(position):
  54. raise ValueError(
  55. "Length of values is not compatible with position.")
  56. for l in values:
  57. if not isinstance(l, int):
  58. raise TypeError(
  59. "Dynamic shape values of tensor should be a integer or a list/tuple of integer")
  60. elif isinstance(values, int):
  61. values = [values]
  62. else:
  63. raise TypeError(
  64. "Dynamic shape values of tensor should be a integer or a list/tuple of integer")
  65. return values
  66. def set_poly_upper_bound_for_tensor(tensor, upper_bound, position=None):
  67. """api for dsl to set poly upper bound for certain tensor."""
  68. if not isinstance(tensor, akg.tvm.tensor.Tensor):
  69. raise TypeError("Tensor should be tvm.tensor.Tensor")
  70. if position is None:
  71. position = [i for i, _ in enumerate(tensor.shape)]
  72. position = preprocess_position(position)
  73. upper_bound = preprocess_value_with_position(upper_bound, position)
  74. tensor_shape = get_shape(tensor)
  75. ret = list()
  76. for i, p in enumerate(position):
  77. # create limit for var will help poly to determine the upper bound
  78. if isinstance(tensor_shape[p], akg.tvm.expr.Var):
  79. ret.append(create_dynamic_shape_node(
  80. tensor_name=tensor_shape[p].name, pos=p, poly_upper_bound=upper_bound[i]))
  81. return ret
  82. def set_dynamic_shape_limit_for_tensor(tensor, limit, position=None):
  83. """api for dsl to set dynamic shape limit for certain tensor."""
  84. if not isinstance(tensor, akg.tvm.tensor.Tensor):
  85. raise TypeError("Tensor should be tvm.tensor.Tensor")
  86. if position is None:
  87. position = [i for i, _ in enumerate(tensor.shape)]
  88. position = preprocess_position(position)
  89. limit = preprocess_value_with_position(limit, position)
  90. tensor_name = tensor.op.name
  91. ret = list()
  92. for i, p in enumerate(position):
  93. # create limit for tensor in position p will help inferbound to determine the max bound
  94. ret.append(create_dynamic_shape_node(
  95. tensor_name=tensor_name, pos=p, dyn_shape_limit=limit[i]))
  96. return ret
  97. def create_dynamic_shape_node(tensor_name, pos, dyn_shape_limit=-1, poly_upper_bound=-1):
  98. return akg.tvm.make.node(NODE_TYPE,
  99. tensor_name=tensor_name,
  100. pos=pos,
  101. dyn_shape_limit=dyn_shape_limit,
  102. poly_upper_bound=poly_upper_bound)