You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dump_cuda_meta.py 3.2 kB

5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """save gpu param"""
  17. import os
  18. import fcntl
  19. import hashlib
  20. import akg.tvm
  21. from akg.global_configs import get_cuda_meta_path
  22. @akg.tvm.register_func
  23. def dump_cuda_meta(code, ptx, thread_info):
  24. """
  25. Function for dumping cuda meta.
  26. Args:
  27. code: gpu code.
  28. ptx: ptx code.
  29. thread_info: thread info, written to json file.
  30. """
  31. # kernel name
  32. kernel_name = code.split("_kernel")[0].split(" ")[-1]
  33. # sha256 of ptx
  34. sha256 = hashlib.sha256()
  35. sha256.update(ptx.encode("utf-8"))
  36. hash_str = sha256.hexdigest()
  37. # thread info
  38. thread_info_dict = {
  39. "blockIdx.x": "1",
  40. "blockIdx.y": "1",
  41. "blockIdx.z": "1",
  42. "threadIdx.x": "1",
  43. "threadIdx.y": "1",
  44. "threadIdx.z": "1"
  45. }
  46. for thread_tag in thread_info_dict.keys():
  47. if thread_tag in thread_info:
  48. if isinstance(thread_info[thread_tag], int):
  49. thread_info_dict[thread_tag] = str(thread_info[thread_tag])
  50. elif isinstance(thread_info[thread_tag], akg.tvm.expr.IntImm):
  51. thread_info_dict[thread_tag] = str(thread_info[thread_tag].value)
  52. meta_path = get_cuda_meta_path()
  53. cuda_path = os.path.realpath(meta_path)
  54. if not os.path.isdir(cuda_path):
  55. os.makedirs(cuda_path)
  56. # save ptx file to cuda meta
  57. ptx_file = os.path.realpath(meta_path + kernel_name + ".ptx")
  58. if os.path.exists(ptx_file):
  59. os.remove(ptx_file)
  60. with open(ptx_file, "at") as f:
  61. fcntl.flock(f.fileno(), fcntl.LOCK_EX)
  62. f.seek(0, 2)
  63. if f.tell() == 0:
  64. f.write(ptx)
  65. # modify the file permisson to 400
  66. os.chmod(ptx_file, 0o400)
  67. # save json file to cuda meta
  68. json_file = os.path.realpath(meta_path + kernel_name + ".json")
  69. if os.path.exists(json_file):
  70. os.remove(json_file)
  71. with os.fdopen(os.open(json_file, os.O_WRONLY | os.O_CREAT, 0o400), 'w') as fo:
  72. fo.write("{\n")
  73. fo.write('"kernelName" : ' + '"' + kernel_name + "_kernel0" + '",\n')
  74. fo.write('"blockIdx.x" : ' + thread_info_dict["blockIdx.x"] + ',\n')
  75. fo.write('"blockIdx.y" : ' + thread_info_dict["blockIdx.y"] + ',\n')
  76. fo.write('"blockIdx.z" : ' + thread_info_dict["blockIdx.z"] + ',\n')
  77. fo.write('"threadIdx.x" : ' + thread_info_dict["threadIdx.x"] + ',\n')
  78. fo.write('"threadIdx.y" : ' + thread_info_dict["threadIdx.y"] + ',\n')
  79. fo.write('"threadIdx.z" : ' + thread_info_dict["threadIdx.z"] + ',\n')
  80. fo.write('"sha256" : ' + '"' + hash_str + '"\n')
  81. fo.write("}\n")