You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ir_parser.py 2.1 kB

5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import akg.tvm
  15. from akg.tvm.hybrid import script
  16. from akg.backend import build_module
  17. @script
  18. def test_001(x):
  19. ''' TEST_CASE_01
  20. for (i, 0, 16) {
  21. for(k, 0, 1) {
  22. if(i > 0) {
  23. for(j, 0, 1) {
  24. out(i, k) = 0
  25. }
  26. }
  27. out(i, i) = in(i, 0) * in(0, i)
  28. }
  29. }
  30. '''
  31. y = output_tensor(x.shape, x.dtype)
  32. for i in range(x.shape[0]):
  33. for k in range(1):
  34. if k > 0:
  35. for j in range(1):
  36. y[i, k] = 0
  37. y[i, i] = x[i, 0] * x[0, i]
  38. return y
  39. ans_001 = '\
  40. realize test_001<float16>([0, 16], [0, 16]) {\n\
  41. produce test_001 {\n\
  42. for (i, 0, 16) {\n\
  43. for (k, 0, 1) {\n\
  44. if ((k > 0)) {\n\
  45. for (j, 0, 1) {\n\
  46. test_001(i, k) = 0\n\
  47. }\n\
  48. }\n\
  49. test_001(i, i) = (input(i, 0)*input(0, i))\n\
  50. }\n\
  51. }\n\
  52. }\n\
  53. }\n'
  54. def test(func, ans):
  55. shape = (16, 16)
  56. dtype = 'float16'
  57. x = akg.tvm.placeholder(shape, name='input', dtype=dtype)
  58. res = func(x)
  59. s = akg.tvm.create_schedule(res.op)
  60. bounds = akg.tvm.schedule.InferBound(s)
  61. stmt = akg.tvm.schedule.ScheduleOps(s, bounds)
  62. print('---------------BEFORE------------------')
  63. print(stmt)
  64. binds, _ = build_module.get_binds([x, res])
  65. stmt = akg.tvm.ParseHalideIRFromCode(str(stmt), binds)
  66. print('---------------AFTER-------------------')
  67. print(stmt)
  68. assert(str(stmt) == ans)
  69. if __name__ == "__main__":
  70. test(test_001, ans_001)