You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_expr_simplify.py 1.9 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import akg
  15. import akg.tvm
  16. def test1():
  17. T1 = akg.tvm.var("T1")
  18. T2 = akg.tvm.var("T2")
  19. H = akg.tvm.var("H")
  20. cc2 = akg.tvm.var("cc2")
  21. e = T1 <= ((((T1*-1)*cc2) + 1) + akg.tvm.expr.Div((((1 + 1) - 3) + H), 2))
  22. print("=========TEST==========")
  23. print("constraints:")
  24. print("Res of infer bound of ", T1)
  25. stmt = akg.tvm.ir_pass.TestReduceInequality(e, cc2)
  26. print(stmt)
  27. print("=========END==========")
  28. def test2():
  29. T1 = akg.tvm.var("T1")
  30. T2 = akg.tvm.var("T2")
  31. H = akg.tvm.var("H")
  32. cc2 = akg.tvm.var("cc2")
  33. e = ((((T1 - 1)*2) + 3) - 1) <= ((H + (1 - 1)) - (cc2*(T1*2)))
  34. print("=========TEST==========")
  35. print("constraints:")
  36. print("Res of infer bound of ", T1)
  37. stmt = akg.tvm.ir_pass.TestReduceInequality(e, cc2)
  38. print(stmt)
  39. print("=========END==========")
  40. def test_simplify():
  41. T1 = akg.tvm.var("T1")
  42. e = akg.tvm.expr.Div(T1, T1)
  43. print("=========TEST==========")
  44. stmt = akg.tvm.ir_pass.TestSimplify(e)
  45. print(stmt)
  46. print("=========END==========")
  47. def test_gcd():
  48. T1 = akg.tvm.var("T1")
  49. print("=========TEST==========")
  50. stmt = akg.tvm.ir_pass.TestGcd(T1, 0)
  51. print(stmt)
  52. print("=========END==========")
  53. if __name__ == "__main__":
  54. test1()
  55. test_simplify()