You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 4.4 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """test_utils"""
  15. import math
  16. import random
  17. import numpy as np
  18. import akg.tvm
  19. from akg.utils.validation_check import MAX_DATA_SIZE
  20. from akg.utils.format_transform import get_bytes
  21. def compute_blockdim(shape):
  22. size = 1
  23. if isinstance(shape, (list, tuple)):
  24. for i in shape:
  25. size = size * i
  26. elif isinstance(shape, int):
  27. size = shape
  28. else:
  29. size = 2
  30. return min(32, math.ceil(size / 8192 + 1))
  31. def process_dynamic_shape(shapes, attrs, keep_axis=None):
  32. dynamic_shape_args = []
  33. if len(shapes) == 0 or not attrs.get("dynamic"):
  34. return shapes, dynamic_shape_args
  35. new_shapes = []
  36. prefix = "I"
  37. keep_axis_local = keep_axis
  38. if isinstance(keep_axis, int):
  39. keep_axis_local = [keep_axis]
  40. for shape in shapes:
  41. dynamic_shape = []
  42. for i in range(len(shape)):
  43. if (i in keep_axis_local) or ((i - len(shape)) in keep_axis_local):
  44. dynamic_shape.append(shape[i])
  45. else:
  46. dynamic_shape.append(akg.tvm.var(prefix + str(i)))
  47. dynamic_shape_args.append(shape[i])
  48. new_shapes.append(dynamic_shape)
  49. prefix += "I"
  50. return new_shapes, dynamic_shape_args
  51. def gen_random_shape(shape_dim, slope=0, min_value=None, max_value=None):
  52. """
  53. Generate a list of random integer with length equals shape_dim within range [min_value, max_value];
  54. Args:
  55. shape_dim : length of output random shape
  56. slope : only represents the tendency of random shape's value, not mathematical slope of random shape;
  57. slope = -1 tend to generate random shape list with largest value at the beginning and smallest value at the end
  58. slope = 0 tend to generate random shape list with nearly average value among list
  59. slope = 1 tend to generate random shape list with smallest value at the beginning and largest value at the end
  60. """
  61. if shape_dim <= 0:
  62. raise ValueError("Shape dim should be positive.")
  63. def _build_limit(limit, default):
  64. if limit is None:
  65. limit = default
  66. res = list()
  67. nonlocal shape_dim
  68. if isinstance(limit, (tuple, list)):
  69. if len(limit) != shape_dim:
  70. raise ValueError(
  71. "Min/Max value should have same length with shape_dim")
  72. res = limit
  73. elif isinstance(limit, int):
  74. res = [limit] * shape_dim
  75. else:
  76. raise TypeError(
  77. "Min/Max value should be int or list of int with same length of shape_dim")
  78. return res
  79. device_limit = MAX_DATA_SIZE // get_bytes("float32")
  80. if max_value is None and shape_dim > 1:
  81. limit_avg = int(math.pow(device_limit, 1 / shape_dim))
  82. if slope == 0:
  83. max_value = [limit_avg] * shape_dim
  84. else:
  85. ratio = np.arange(-1/2, 1/2 + 1/shape_dim, 1/shape_dim)
  86. if len(ratio) > shape_dim:
  87. new_ratio = list()
  88. for i, r in enumerate(ratio):
  89. if i == len(ratio)//2 - 1:
  90. new_ratio.append(0)
  91. elif i == len(ratio)//2:
  92. continue
  93. else:
  94. new_ratio.append(r)
  95. ratio = new_ratio
  96. if slope == -1:
  97. ratio.reverse()
  98. max_value = list()
  99. for i, r in enumerate(ratio):
  100. max_value.append(int((1 + ratio[i]) * limit_avg))
  101. shape_min = _build_limit(min_value, 1)
  102. shape_extent = _build_limit(max_value, device_limit)
  103. random_shape = list()
  104. for mn, mx in zip(shape_min, shape_extent):
  105. random_shape.append(random.randint(mn, mx))
  106. return random_shape