You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

message.py 5.8 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright 2019-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """message"""
  15. import importlib.util
  16. import json
  17. import json.decoder as jd
  18. import logging
  19. import traceback
  20. import os
  21. from pathlib import Path
  22. import akg.tvm
  23. from akg.utils import kernel_exec as utils
  24. from akg.utils import validation_check as vc_util
  25. from akg import composite
  26. from . import cce
  27. from . import gpu
  28. from . import op_build
  29. from akg.global_configs import get_dump_ir_flag
  30. from akg.global_configs import get_dump_code_flag
  31. @vc_util.check_input_type(str)
  32. def _compilewithjson_to_module(json_str):
  33. """compile with json."""
  34. try:
  35. kernel_info = json.loads(json_str)
  36. except jd.JSONDecodeError:
  37. logging.error(traceback.format_exc())
  38. return False
  39. supported_processors = ['cuda', 'aicore']
  40. processor = 'cuda'
  41. if 'process' in kernel_info:
  42. processor = kernel_info['process']
  43. if processor not in supported_processors:
  44. logging.error("supported processors: {}, current processor: {}".format(supported_processors, processor))
  45. return False
  46. if 'composite' in kernel_info and kernel_info['composite'] is True:
  47. try:
  48. composite.build(json_str)
  49. return True
  50. except Exception:
  51. logging.error(traceback.format_exc())
  52. return False
  53. op_name = kernel_info['name']
  54. op_func = None
  55. # get custom ops implementation first.
  56. if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
  57. impl_path = os.path.realpath(kernel_info['impl_path'])
  58. if os.path.isfile(impl_path):
  59. custom_mod_name = Path(impl_path).resolve().stem
  60. mod_spec = importlib.util.spec_from_file_location(
  61. custom_mod_name, impl_path)
  62. custom_mod = importlib.util.module_from_spec(mod_spec)
  63. mod_spec.loader.exec_module(custom_mod)
  64. op_func = getattr(custom_mod, op_name, None)
  65. # get built-in ops.
  66. if op_func is None:
  67. if processor == 'cuda':
  68. op_func = getattr(gpu, op_name, None)
  69. if op_func is not None:
  70. input_shapes = []
  71. input_types = []
  72. for input_desc in kernel_info['input_desc']:
  73. input_shapes.append(input_desc[0]['shape'])
  74. input_types.append(input_desc[0]['data_type'])
  75. op_attrs = []
  76. if kernel_info['attr']:
  77. for ext_arg in kernel_info['attr']:
  78. op_attrs.append(ext_arg['value'])
  79. dump_ir = os.getenv(get_dump_ir_flag()) == "on"
  80. dump_code = os.getenv(get_dump_code_flag()) == "on"
  81. utils.op_build(op_func, input_shapes, input_types, op_attrs, kernel_info['op'], dump_ir=dump_ir,
  82. dump_code=dump_code)
  83. return True
  84. else:
  85. op_func = getattr(cce, op_name, None)
  86. if op_func is None:
  87. logging.error(
  88. "this op not support by akg, please check op name %s", str(op_name))
  89. return False
  90. args = {}
  91. tsr = []
  92. for input_desc in kernel_info['input_desc']:
  93. if len(input_desc) == 1:
  94. tensor_shape = input_desc[0]['shape']
  95. tensor_shape = (1,) if not tensor_shape else tensor_shape
  96. vc_util.shape_dtype_max_size_check(
  97. tensor_shape, input_desc[0]['data_type'])
  98. args[input_desc[0]['name']] = akg.tvm.placeholder(
  99. shape=tensor_shape, name=input_desc[0]['tensor_name'], dtype=input_desc[0]['data_type'])
  100. tsr.append(args[input_desc[0]['name']])
  101. else:
  102. tmp_input = []
  103. for tmp_desc in input_desc:
  104. tensor_shape = tmp_desc['shape']
  105. tensor_shape = (1,) if not tensor_shape else tensor_shape
  106. vc_util.shape_dtype_max_size_check(
  107. tensor_shape, tmp_desc['data_type'])
  108. tmp_input.append(akg.tvm.placeholder(
  109. shape=tensor_shape, name=tmp_desc['tensor_name'], dtype=tmp_desc['data_type']))
  110. args[input_desc[0]['name']] = tmp_input
  111. tsr = tsr + tmp_input
  112. if kernel_info['attr']:
  113. for ext_arg in kernel_info['attr']:
  114. args[ext_arg['name']] = ext_arg['value']
  115. output = op_func(**args)
  116. schedule_func = None
  117. attrs = {}
  118. if isinstance(output, (list, tuple)):
  119. from inspect import isfunction
  120. tmp_outputs = []
  121. for elem in output:
  122. if isfunction(elem):
  123. schedule_func = elem
  124. elif isinstance(elem, dict):
  125. for key, value in elem.items():
  126. if key not in attrs or not attrs[key]:
  127. attrs[key] = value
  128. else:
  129. tmp_outputs.append(elem)
  130. output = tmp_outputs
  131. else:
  132. output = [output]
  133. tsr = tsr + [i for i in output if utils.TensorUtils.is_output_value(i)]
  134. build_res = op_build([op_name], output, tsr, schedule_func, processor, kernel_info['op'], attrs)
  135. if not build_res:
  136. return False
  137. return True
  138. def compilewithjson(json_str):
  139. return _compilewithjson_to_module(json_str)