You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_inferbound.py 3.8 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import akg
  15. import akg.tvm
  16. def run(cons, e):
  17. print("=========TEST==========")
  18. print("constraints:")
  19. for c in cons:
  20. print(c)
  21. print("Res of infer bound of ", e)
  22. stmt = akg.tvm.ir_pass.TestInferBoundWithCond(e, cons)
  23. print(stmt)
  24. print("=========END==========")
  25. def run1(cons, e, var_set):
  26. print("=========TEST==========")
  27. print("constraints:")
  28. for c in cons:
  29. print(c)
  30. print("Res of infer bound of ", e)
  31. stmt = akg.tvm.ir_pass.TestInferBoundWithCond(e, cons, var_set)
  32. print(stmt)
  33. print("=========END==========")
  34. def test1():
  35. T1 = akg.tvm.var("T1")
  36. e = akg.tvm.expr.Sub(T1, 20)
  37. cons = list()
  38. cons.append(T1 > 0)
  39. cons.append(T1 <= 128)
  40. run(cons, e)
  41. # // (-19, 108)
  42. def test_floordiv1():
  43. T1 = akg.tvm.var("T1")
  44. e = akg.tvm.expr.Div(T1 + 15, 16) * 16
  45. cons = list()
  46. cons.append(T1 > 0)
  47. fd1 = akg.tvm.expr.FloorDiv(T1 + 15, 16) * 16
  48. fd2 = akg.tvm.expr.FloorDiv(T1 + 7, 8) * 8
  49. cons.append(T1 > 0)
  50. cons.append(fd1 + fd2 <= 3968)
  51. run(cons, e)
  52. # // (16, 1984)
  53. def test_maxpool():
  54. a = akg.tvm.var("a")
  55. b = akg.tvm.var("b")
  56. e = b * 9 + 144
  57. cons = list()
  58. cons.append(a > 0)
  59. cons.append(b > 0)
  60. cons.append(b <= 1819)
  61. cons.append(b*384 + (1*3+3)*(b+1)*32 <= 126592)
  62. run(cons, e)
  63. # // (153, 2119)
  64. def test_scale_inequality():
  65. a = akg.tvm.var("a")
  66. b = akg.tvm.var("b")
  67. e = akg.tvm.expr.FloorDiv((b - 1), 16384) + 1
  68. #e = akg.tvm.expr.FloorDiv((a - 1), 128) + 1
  69. cons = list()
  70. cons.append(akg.tvm.expr.FloorDiv(a+15, 16) * 96 + akg.tvm.expr.FloorDiv(a+7, 8) * 64 <= 126968)
  71. cons.append(b < a)
  72. run(cons, e)
  73. # (0, 1)
  74. def test_polynominial():
  75. a = akg.tvm.var("a")
  76. b = akg.tvm.var("b")
  77. cons = list()
  78. cons.append(a>0)
  79. cons.append(b>0)
  80. cons.append(b<=1819)
  81. cons.append(b*384+((a*3)+3)*(b+1)*32<=126592)
  82. e = (a*3+3)*(b+1)*16
  83. run(cons, e)
  84. #(192, 63072)
  85. def test_conv():
  86. a = akg.tvm.var("a")
  87. b = akg.tvm.var("b")
  88. c = akg.tvm.var("c")
  89. cons = list()
  90. cons.append(a>0)
  91. cons.append(b>0)
  92. cons.append(c>0)
  93. cons.append((b*16)*a <= 2047)
  94. cons.append((a*16)*c <= 1023)
  95. cons.append((b*16)*c <= 1023)
  96. e = 256*a
  97. run(cons, e)
  98. # (256, 16128)
  99. def test_min():
  100. a = akg.tvm.var("a")
  101. b = akg.tvm.var("b")
  102. CI1 =akg.tvm.var("CI1")
  103. H = akg.tvm.var("H")
  104. cc2 = akg.tvm.var("CC2")
  105. cc3 = akg.tvm.var("cc3")
  106. W = akg.tvm.var("W")
  107. e = (((CI1*((akg.tvm.expr.Min(2, (H - cc2)) + 1) - akg.tvm.expr.Max(0, (1 - cc2))))*((akg.tvm.expr.Min(29, (W - (cc3*28))) + 1) - akg.tvm.expr.Max(0, (1 - (cc3*28)))))*16)
  108. cond = [(cc3 >= 0), (cc3 < (akg.tvm.div((W - 1),28) + 1)), (cc2 >= 0), (cc2 < H), (H > 0), (W > 0)]
  109. run(cond, e)
  110. def test_poly2():
  111. a = akg.tvm.var("a")
  112. b = akg.tvm.var("b")
  113. conds = list()
  114. conds.append(a*b*480 + a*64 + 4192 <= 131072)
  115. conds.append(b < 40)
  116. conds.append(b > 0)
  117. run(conds, a)
  118. if __name__ == "__main__":
  119. test1()
  120. test_floordiv1()
  121. test_maxpool()
  122. test_scale_inequality()
  123. test_polynominial()
  124. test_conv()
  125. test_min()
  126. test_poly2()