# Copyright (c) Alibaba, Inc. and its affiliates. import ast import contextlib import hashlib import importlib import os import os.path as osp import time import traceback from functools import reduce from pathlib import Path from typing import Generator, Union import gast import json from modelscope import __version__ from modelscope.fileio.file import LocalStorage from modelscope.metainfo import (Datasets, Heads, Hooks, LR_Schedulers, Metrics, Models, Optimizers, Pipelines, Preprocessors, TaskModels, Trainers) from modelscope.utils.constant import Fields, Tasks from modelscope.utils.file_utils import get_default_cache_dir from modelscope.utils.logger import get_logger from modelscope.utils.registry import default_group logger = get_logger() storage = LocalStorage() p = Path(__file__) # get the path of package 'modelscope' MODELSCOPE_PATH = p.resolve().parents[1] REGISTER_MODULE = 'register_module' IGNORED_PACKAGES = ['modelscope', '.'] SCAN_SUB_FOLDERS = [ 'models', 'metrics', 'pipelines', 'preprocessors', 'trainers', 'msdatasets' ] INDEXER_FILE = 'ast_indexer' DECORATOR_KEY = 'decorators' EXPRESS_KEY = 'express' FROM_IMPORT_KEY = 'from_imports' IMPORT_KEY = 'imports' FILE_NAME_KEY = 'filepath' VERSION_KEY = 'version' MD5_KEY = 'md5' INDEX_KEY = 'index' REQUIREMENT_KEY = 'requirements' MODULE_KEY = 'module' CLASS_NAME = 'class_name' GROUP_KEY = 'group_key' MODULE_NAME = 'module_name' MODULE_CLS = 'module_cls' class AstScaning(object): def __init__(self) -> None: self.result_import = dict() self.result_from_import = dict() self.result_decorator = [] self.express = [] def _is_sub_node(self, node: object) -> bool: return isinstance(node, ast.AST) and not isinstance(node, ast.expr_context) def _is_leaf(self, node: ast.AST) -> bool: for field in node._fields: attr = getattr(node, field) if self._is_sub_node(attr): return False elif isinstance(attr, (list, tuple)): for val in attr: if self._is_sub_node(val): return False else: return True def _fields(self, n: ast.AST, show_offsets: bool = True) -> tuple: if show_offsets: return n._attributes + n._fields else: return n._fields def _leaf(self, node: ast.AST, show_offsets: bool = True) -> str: output = dict() local_print = list() if isinstance(node, ast.AST): local_dict = dict() for field in self._fields(node, show_offsets=show_offsets): field_output, field_prints = self._leaf( getattr(node, field), show_offsets=show_offsets) local_dict[field] = field_output local_print.append('{}={}'.format(field, field_prints)) prints = '{}({})'.format( type(node).__name__, ', '.join(local_print), ) output[type(node).__name__] = local_dict return output, prints elif isinstance(node, list): if '_fields' not in node: return node, repr(node) for item in node: item_output, item_prints = self._leaf( getattr(node, item), show_offsets=show_offsets) local_print.append(item_prints) return node, '[{}]'.format(', '.join(local_print), ) else: return node, repr(node) def _refresh(self): self.result_import = dict() self.result_from_import = dict() self.result_decorator = [] self.result_express = [] def scan_ast(self, node: Union[ast.AST, None, str]): self._setup_global() self.scan_import(node, indent=' ', show_offsets=False) def scan_import( self, node: Union[ast.AST, None, str], indent: Union[str, int] = ' ', show_offsets: bool = True, _indent: int = 0, parent_node_name: str = '', ) -> tuple: if node is None: return node, repr(node) elif self._is_leaf(node): return self._leaf(node, show_offsets=show_offsets) else: if isinstance(indent, int): indent_s = indent * ' ' else: indent_s = indent class state: indent = _indent @contextlib.contextmanager def indented() -> Generator[None, None, None]: state.indent += 1 yield state.indent -= 1 def indentstr() -> str: return state.indent * indent_s def _scan_import(el: Union[ast.AST, None, str], _indent: int = 0, parent_node_name: str = '') -> str: return self.scan_import( el, indent=indent, show_offsets=show_offsets, _indent=_indent, parent_node_name=parent_node_name) out = type(node).__name__ + '(\n' outputs = dict() # add relative path expression if type(node).__name__ == 'ImportFrom': level = getattr(node, 'level') if level >= 1: path_level = ''.join(['.'] * level) setattr(node, 'level', 0) module_name = getattr(node, 'module') if module_name is None: setattr(node, 'module', path_level) else: setattr(node, 'module', path_level + module_name) with indented(): for field in self._fields(node, show_offsets=show_offsets): attr = getattr(node, field) if attr == []: representation = '[]' outputs[field] = [] elif (isinstance(attr, list) and len(attr) == 1 and isinstance(attr[0], ast.AST) and self._is_leaf(attr[0])): local_out, local_print = _scan_import(attr[0]) representation = f'[{local_print}]' outputs[field] = local_out elif isinstance(attr, list): representation = '[\n' el_dict = dict() with indented(): for el in attr: local_out, local_print = _scan_import( el, state.indent, type(el).__name__) representation += '{}{},\n'.format( indentstr(), local_print, ) name = type(el).__name__ if (name == 'Import' or name == 'ImportFrom' or parent_node_name == 'ImportFrom' or parent_node_name == 'Import'): if name not in el_dict: el_dict[name] = [] el_dict[name].append(local_out) representation += indentstr() + ']' outputs[field] = el_dict elif isinstance(attr, ast.AST): output, representation = _scan_import( attr, state.indent) outputs[field] = output else: representation = repr(attr) outputs[field] = attr if (type(node).__name__ == 'Import' or type(node).__name__ == 'ImportFrom'): if type(node).__name__ == 'ImportFrom': if field == 'module': self.result_from_import[ outputs[field]] = dict() if field == 'names': if isinstance(outputs[field]['alias'], list): item_name = [] for item in outputs[field]['alias']: local_name = item['alias']['name'] item_name.append(local_name) self.result_from_import[ outputs['module']] = item_name else: local_name = outputs[field]['alias'][ 'name'] self.result_from_import[ outputs['module']] = [local_name] if type(node).__name__ == 'Import': final_dict = outputs[field]['alias'] if isinstance(final_dict, list): for item in final_dict: self.result_import[ item['alias']['name']] = item['alias'] else: self.result_import[outputs[field]['alias'] ['name']] = final_dict if 'decorator_list' == field and attr != []: for item in attr: setattr(item, CLASS_NAME, node.name) self.result_decorator.extend(attr) if attr != [] and type( attr ).__name__ == 'Call' and parent_node_name == 'Expr': self.result_express.append(attr) out += f'{indentstr()}{field}={representation},\n' out += indentstr() + ')' return { IMPORT_KEY: self.result_import, FROM_IMPORT_KEY: self.result_from_import, DECORATOR_KEY: self.result_decorator, EXPRESS_KEY: self.result_express }, out def _parse_decorator(self, node: ast.AST) -> tuple: def _get_attribute_item(node: ast.AST) -> tuple: value, id, attr = None, None, None if type(node).__name__ == 'Attribute': value = getattr(node, 'value') id = getattr(value, 'id') attr = getattr(node, 'attr') if type(node).__name__ == 'Name': id = getattr(node, 'id') return id, attr def _get_args_name(nodes: list) -> list: result = [] for node in nodes: if type(node).__name__ == 'Str': result.append((node.s, None)) else: result.append(_get_attribute_item(node)) return result def _get_keyword_name(nodes: ast.AST) -> list: result = [] for node in nodes: if type(node).__name__ == 'keyword': attribute_node = getattr(node, 'value') if type(attribute_node).__name__ == 'Str': result.append((getattr(node, 'arg'), attribute_node.s, None)) elif type(attribute_node).__name__ == 'Constant': result.append( (getattr(node, 'arg'), attribute_node.value, None)) else: result.append((getattr(node, 'arg'), ) + _get_attribute_item(attribute_node)) return result functions = _get_attribute_item(node.func) args_list = _get_args_name(node.args) keyword_list = _get_keyword_name(node.keywords) return functions, args_list, keyword_list def _get_registry_value(self, key_item): if key_item is None: return None if key_item == 'default_group': return default_group split_list = key_item.split('.') # in the case, the key_item is raw data, not registred if len(split_list) == 1: return key_item else: return getattr(eval(split_list[0]), split_list[1]) def _registry_indexer(self, parsed_input: tuple, class_name: str) -> tuple: """format registry information to a tuple indexer Return: tuple: (MODELS, Tasks.text-classification, Models.structbert) """ functions, args_list, keyword_list = parsed_input # ignore decocators other than register_module if REGISTER_MODULE != functions[1]: return None output = [functions[0]] if len(args_list) == 0 and len(keyword_list) == 0: args_list.append(default_group) if len(keyword_list) == 0 and len(args_list) == 1: args_list.append(class_name) if len(keyword_list) > 0 and len(args_list) == 0: remove_group_item = None for item in keyword_list: key, name, attr = item if key == GROUP_KEY: args_list.append((name, attr)) remove_group_item = item if remove_group_item is not None: keyword_list.remove(remove_group_item) if len(args_list) == 0: args_list.append(default_group) for item in keyword_list: key, name, attr = item if key == MODULE_CLS: class_name = name else: args_list.append((name, attr)) for item in args_list: # the case empty input if item is None: output.append(None) # the case (default_group) elif item[1] is None: output.append(item[0]) elif isinstance(item, str): output.append(item) else: output.append('.'.join(item)) return (output[0], self._get_registry_value(output[1]), self._get_registry_value(output[2])) def parse_decorators(self, nodes: list) -> list: """parse the AST nodes of decorators object to registry indexer Args: nodes (list): list of AST decorator nodes Returns: list: list of registry indexer """ results = [] for node in nodes: if type(node).__name__ != 'Call': continue class_name = getattr(node, CLASS_NAME, None) func = getattr(node, 'func') if getattr(func, 'attr', None) != REGISTER_MODULE: continue parse_output = self._parse_decorator(node) index = self._registry_indexer(parse_output, class_name) if None is not index: results.append(index) return results def generate_ast(self, file): self._refresh() with open(file, 'r', encoding='utf8') as code: data = code.readlines() data = ''.join(data) node = gast.parse(data) output, _ = self.scan_import(node, indent=' ', show_offsets=False) output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY]) output[EXPRESS_KEY] = self.parse_decorators(output[EXPRESS_KEY]) output[DECORATOR_KEY].extend(output[EXPRESS_KEY]) return output class FilesAstScaning(object): def __init__(self) -> None: self.astScaner = AstScaning() self.file_dirs = [] def _parse_import_path(self, import_package: str, current_path: str = None) -> str: """ Args: import_package (str): relative import or abs import current_path (str): path/to/current/file """ if import_package.startswith(IGNORED_PACKAGES[0]): return MODELSCOPE_PATH + '/' + '/'.join( import_package.split('.')[1:]) + '.py' elif import_package.startswith(IGNORED_PACKAGES[1]): current_path_list = current_path.split('/') import_package_list = import_package.split('.') level = 0 for index, item in enumerate(import_package_list): if item != '': level = index break abs_path_list = current_path_list[0:-level] abs_path_list.extend(import_package_list[index:]) return '/' + '/'.join(abs_path_list) + '.py' else: return current_path def _traversal_import( self, import_abs_path, ): pass def parse_import(self, scan_result: dict) -> list: """parse import and from import dicts to a third party package list Args: scan_result (dict): including the import and from import result Returns: list: a list of package ignored 'modelscope' and relative path import """ output = [] output.extend(list(scan_result[IMPORT_KEY].keys())) output.extend(list(scan_result[FROM_IMPORT_KEY].keys())) # get the package name for index, item in enumerate(output): if '' == item.split('.')[0]: output[index] = '.' else: output[index] = item.split('.')[0] ignored = set() for item in output: for ignored_package in IGNORED_PACKAGES: if item.startswith(ignored_package): ignored.add(item) return list(set(output) - set(ignored)) def traversal_files(self, path, check_sub_dir): self.file_dirs = [] if check_sub_dir is None or len(check_sub_dir) == 0: self._traversal_files(path) for item in check_sub_dir: sub_dir = os.path.join(path, item) if os.path.isdir(sub_dir): self._traversal_files(sub_dir) def _traversal_files(self, path): dir_list = os.scandir(path) for item in dir_list: if item.name.startswith('__'): continue if item.is_dir(): self._traversal_files(item.path) elif item.is_file() and item.name.endswith('.py'): self.file_dirs.append(item.path) def _get_single_file_scan_result(self, file): try: output = self.astScaner.generate_ast(file) except Exception as e: detail = traceback.extract_tb(e.__traceback__) raise Exception( f'During ast indexing, error is in the file {detail[-1].filename}' f' line: {detail[-1].lineno}: "{detail[-1].line}" with error msg: ' f'"{type(e).__name__}: {e}"') import_list = self.parse_import(output) return output[DECORATOR_KEY], import_list def _inverted_index(self, forward_index): inverted_index = dict() for index in forward_index: for item in forward_index[index][DECORATOR_KEY]: inverted_index[item] = { FILE_NAME_KEY: index, IMPORT_KEY: forward_index[index][IMPORT_KEY], MODULE_KEY: forward_index[index][MODULE_KEY], } return inverted_index def _module_import(self, forward_index): module_import = dict() for index, value_dict in forward_index.items(): module_import[value_dict[MODULE_KEY]] = value_dict[IMPORT_KEY] return module_import def _ignore_useless_keys(self, inverted_index): if ('OPTIMIZERS', 'default', 'name') in inverted_index: del inverted_index[('OPTIMIZERS', 'default', 'name')] if ('LR_SCHEDULER', 'default', 'name') in inverted_index: del inverted_index[('LR_SCHEDULER', 'default', 'name')] return inverted_index def get_files_scan_results(self, target_dir=MODELSCOPE_PATH, target_folders=SCAN_SUB_FOLDERS): """the entry method of the ast scan method Args: target_dir (str, optional): the absolute path of the target directory to be scaned. Defaults to None. target_folder (list, optional): the list of sub-folders to be scaned in the target folder. Defaults to SCAN_SUB_FOLDERS. Returns: dict: indexer of registry """ self.traversal_files(target_dir, target_folders) start = time.time() logger.info( f'AST-Scaning the path "{target_dir}" with the following sub folders {target_folders}' ) result = dict() for file in self.file_dirs: filepath = file[file.rfind('modelscope'):] module_name = filepath.replace(osp.sep, '.').replace('.py', '') decorator_list, import_list = self._get_single_file_scan_result( file) result[file] = { DECORATOR_KEY: decorator_list, IMPORT_KEY: import_list, MODULE_KEY: module_name } inverted_index_with_results = self._inverted_index(result) inverted_index_with_results = self._ignore_useless_keys( inverted_index_with_results) module_import = self._module_import(result) index = { INDEX_KEY: inverted_index_with_results, REQUIREMENT_KEY: module_import } logger.info( f'Scaning done! A number of {len(inverted_index_with_results)}' f' files indexed! Time consumed {time.time()-start}s') return index def files_mtime_md5(self, target_path=MODELSCOPE_PATH, target_subfolder=SCAN_SUB_FOLDERS): self.file_dirs = [] self.traversal_files(target_path, target_subfolder) files_mtime = [] for item in self.file_dirs: files_mtime.append(os.path.getmtime(item)) result_str = reduce(lambda x, y: str(x) + str(y), files_mtime, '') md5 = hashlib.md5(result_str.encode()) return md5.hexdigest() file_scanner = FilesAstScaning() def _save_index(index, file_path): # convert tuple key to str key index[INDEX_KEY] = {str(k): v for k, v in index[INDEX_KEY].items()} index[VERSION_KEY] = __version__ index[MD5_KEY] = file_scanner.files_mtime_md5() json_index = json.dumps(index) storage.write(json_index.encode(), file_path) index[INDEX_KEY] = { ast.literal_eval(k): v for k, v in index[INDEX_KEY].items() } def _load_index(file_path): bytes_index = storage.read(file_path) wrapped_index = json.loads(bytes_index) # convert str key to tuple key wrapped_index[INDEX_KEY] = { ast.literal_eval(k): v for k, v in wrapped_index[INDEX_KEY].items() } return wrapped_index def load_index(force_rebuild=False): """get the index from scan results or cache Args: force_rebuild: If set true, rebuild and load index Returns: dict: the index information for all registred modules, including key: index, requirments, version and md5, the detail is shown below example: { 'index': { ('MODELS', 'nlp', 'bert'):{ 'filepath' : 'path/to/the/registered/model', 'imports': ['os', 'torch', 'typeing'] 'module': 'modelscope.models.nlp.bert' }, ... }, 'requirments': { 'modelscope.models.nlp.bert': ['os', 'torch', 'typeing'], 'modelscope.models.nlp.structbert': ['os', 'torch', 'typeing'], ... }, 'version': '0.2.3', 'md5': '8616924970fe6bc119d1562832625612', } """ cache_dir = os.getenv('MODELSCOPE_CACHE', get_default_cache_dir()) file_path = os.path.join(cache_dir, INDEXER_FILE) logger.info(f'Loading ast index from {file_path}') index = None if not force_rebuild and os.path.exists(file_path): wrapped_index = _load_index(file_path) md5 = file_scanner.files_mtime_md5() if (wrapped_index[VERSION_KEY] == __version__ and wrapped_index[MD5_KEY] == md5): index = wrapped_index if index is None: if force_rebuild: logger.info('Force rebuilding ast index') else: logger.info( f'No valid ast index found from {file_path}, rebuilding ast index!' ) index = file_scanner.get_files_scan_results() _save_index(index, file_path) logger.info( f'Loading done! Current index file version is {index[VERSION_KEY]}, ' f'with md5 {index[MD5_KEY]}') return index def check_import_module_avaliable(module_dicts: dict) -> list: missed_module = [] for module in module_dicts.keys(): loader = importlib.find_loader(module) if loader is None: missed_module.append(module) return missed_module if __name__ == '__main__': index = load_index() print(index)