|
- #!/usr/bin/env python3
- # coding: utf-8
- # Copyright 2019-2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """
- Additional IR Pass for CCE
- """
- from __future__ import absolute_import as _abs
- import sys
- import os
- import logging
-
- def AKGAddPath():
- """akg add path."""
- pwd = os.path.dirname(os.path.realpath(__file__))
- tvm_path = os.path.realpath(pwd)
- if tvm_path not in sys.path:
- sys.path.insert(0, tvm_path)
- else:
- sys.path.remove(tvm_path)
- sys.path.insert(0, tvm_path)
-
-
- class AKGMetaPathFinder:
- """class AKGMetaPath finder."""
-
- def find_module(self, fullname, path=None):
- """method akg find module."""
- if fullname.startswith("akg.tvm"):
- rname = fullname[4:]
- return AKGMetaPathLoader(rname)
- if fullname.startswith("akg.topi"):
- rname = fullname[4:]
- return AKGMetaPathLoader(rname)
- if fullname == "akg.topi.cce.cce_extended_op_build":
- logging.warning("akg error: 'akg.topi.cce.cce_extended_op_build' has been deprecated, please using "
- "'akg.topi.cce.te_op_build' instead ")
- return None
-
-
- class AKGMetaPathLoader:
- """class AKGMetaPathLoader loader."""
- def __init__(self, rname):
- self.__rname = rname
-
- def load_module(self, fullname):
- if self.__rname in sys.modules:
- sys.modules.pop(self.__rname)
- AKGAddPath()
- __import__(self.__rname, globals(), locals())
- self.__target_module = sys.modules[self.__rname]
- sys.modules[fullname] = self.__target_module
- return self.__target_module
-
- def schedule(sch, target = 'cuda'):
- def decorator(func):
- def wrapper(*args, **kwargs):
- binds = None
- output = func(*args, **kwargs)
- if isinstance(output, tuple):
- attrs = [t for t in output if isinstance(t, dict)]
- for attr in attrs:
- if "binds" in attr.keys():
- binds = attr['binds']
- output = tuple([t for t in output if not isinstance(t, dict)])
- return {'schedule' : sch, 'target' : target, 'output' : output, 'binds': binds, 'op_name' : func.__name__}
- return wrapper
- return decorator
-
- sys.meta_path.insert(0, AKGMetaPathFinder())
-
- from . import autodiff
- from .build_module import build, build_to_func, lower, build_config
- from .autodiff import differentiate
- from .autodiff import get_variables
- from .autodiff import register_variables
- from . import lang
- from .utils.dump_cuda_meta import dump_cuda_meta
- from .utils.dump_ascend_meta import tvm_callback_cce_postproc
-
- __all__ = ["differentiate"]
|