You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 3.2 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019-2021 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """
  17. Additional IR Pass for CCE
  18. """
  19. from __future__ import absolute_import as _abs
  20. import sys
  21. import os
  22. import logging
  23. def AKGAddPath():
  24. """akg add path."""
  25. pwd = os.path.dirname(os.path.realpath(__file__))
  26. tvm_path = os.path.realpath(pwd)
  27. if tvm_path not in sys.path:
  28. sys.path.insert(0, tvm_path)
  29. else:
  30. sys.path.remove(tvm_path)
  31. sys.path.insert(0, tvm_path)
  32. class AKGMetaPathFinder:
  33. """class AKGMetaPath finder."""
  34. def find_module(self, fullname, path=None):
  35. """method akg find module."""
  36. if fullname.startswith("akg.tvm"):
  37. rname = fullname[4:]
  38. return AKGMetaPathLoader(rname)
  39. if fullname.startswith("akg.topi"):
  40. rname = fullname[4:]
  41. return AKGMetaPathLoader(rname)
  42. if fullname == "akg.topi.cce.cce_extended_op_build":
  43. logging.warning("akg error: 'akg.topi.cce.cce_extended_op_build' has been deprecated, please using "
  44. "'akg.topi.cce.te_op_build' instead ")
  45. return None
  46. class AKGMetaPathLoader:
  47. """class AKGMetaPathLoader loader."""
  48. def __init__(self, rname):
  49. self.__rname = rname
  50. def load_module(self, fullname):
  51. if self.__rname in sys.modules:
  52. sys.modules.pop(self.__rname)
  53. AKGAddPath()
  54. __import__(self.__rname, globals(), locals())
  55. self.__target_module = sys.modules[self.__rname]
  56. sys.modules[fullname] = self.__target_module
  57. return self.__target_module
  58. def schedule(sch, target = 'cuda'):
  59. def decorator(func):
  60. def wrapper(*args, **kwargs):
  61. binds = None
  62. output = func(*args, **kwargs)
  63. if isinstance(output, tuple):
  64. attrs = [t for t in output if isinstance(t, dict)]
  65. for attr in attrs:
  66. if "binds" in attr.keys():
  67. binds = attr['binds']
  68. output = tuple([t for t in output if not isinstance(t, dict)])
  69. return {'schedule' : sch, 'target' : target, 'output' : output, 'binds': binds, 'op_name' : func.__name__}
  70. return wrapper
  71. return decorator
  72. sys.meta_path.insert(0, AKGMetaPathFinder())
  73. from . import autodiff
  74. from .build_module import build, build_to_func, lower, build_config
  75. from .autodiff import differentiate
  76. from .autodiff import get_variables
  77. from .autodiff import register_variables
  78. from . import lang
  79. from .utils.dump_cuda_meta import dump_cuda_meta
  80. from .utils.dump_ascend_meta import tvm_callback_cce_postproc
  81. __all__ = ["differentiate"]