| @@ -36,6 +36,7 @@ SCAN_SUB_FOLDERS = [ | |||
| ] | |||
| INDEXER_FILE = 'ast_indexer' | |||
| DECORATOR_KEY = 'decorators' | |||
| EXPRESS_KEY = 'express' | |||
| FROM_IMPORT_KEY = 'from_imports' | |||
| IMPORT_KEY = 'imports' | |||
| FILE_NAME_KEY = 'filepath' | |||
| @@ -45,6 +46,9 @@ INDEX_KEY = 'index' | |||
| REQUIREMENT_KEY = 'requirements' | |||
| MODULE_KEY = 'module' | |||
| CLASS_NAME = 'class_name' | |||
| GROUP_KEY = 'group_key' | |||
| MODULE_NAME = 'module_name' | |||
| MODULE_CLS = 'module_cls' | |||
| class AstScaning(object): | |||
| @@ -53,6 +57,7 @@ class AstScaning(object): | |||
| self.result_import = dict() | |||
| self.result_from_import = dict() | |||
| self.result_decorator = [] | |||
| self.express = [] | |||
| def _is_sub_node(self, node: object) -> bool: | |||
| return isinstance(node, | |||
| @@ -108,6 +113,7 @@ class AstScaning(object): | |||
| self.result_import = dict() | |||
| self.result_from_import = dict() | |||
| self.result_decorator = [] | |||
| self.result_express = [] | |||
| def scan_ast(self, node: Union[ast.AST, None, str]): | |||
| self._setup_global() | |||
| @@ -243,13 +249,19 @@ class AstScaning(object): | |||
| setattr(item, CLASS_NAME, node.name) | |||
| self.result_decorator.extend(attr) | |||
| if attr != [] and type( | |||
| attr | |||
| ).__name__ == 'Call' and parent_node_name == 'Expr': | |||
| self.result_express.append(attr) | |||
| out += f'{indentstr()}{field}={representation},\n' | |||
| out += indentstr() + ')' | |||
| return { | |||
| IMPORT_KEY: self.result_import, | |||
| FROM_IMPORT_KEY: self.result_from_import, | |||
| DECORATOR_KEY: self.result_decorator | |||
| DECORATOR_KEY: self.result_decorator, | |||
| EXPRESS_KEY: self.result_express | |||
| }, out | |||
| def _parse_decorator(self, node: ast.AST) -> tuple: | |||
| @@ -267,7 +279,10 @@ class AstScaning(object): | |||
| def _get_args_name(nodes: list) -> list: | |||
| result = [] | |||
| for node in nodes: | |||
| result.append(_get_attribute_item(node)) | |||
| if type(node).__name__ == 'Str': | |||
| result.append((node.s, None)) | |||
| else: | |||
| result.append(_get_attribute_item(node)) | |||
| return result | |||
| def _get_keyword_name(nodes: ast.AST) -> list: | |||
| @@ -276,9 +291,11 @@ class AstScaning(object): | |||
| if type(node).__name__ == 'keyword': | |||
| attribute_node = getattr(node, 'value') | |||
| if type(attribute_node).__name__ == 'Str': | |||
| result.append((attribute_node.s, None)) | |||
| result.append((getattr(node, | |||
| 'arg'), attribute_node.s, None)) | |||
| else: | |||
| result.append(_get_attribute_item(attribute_node)) | |||
| result.append((getattr(node, 'arg'), ) | |||
| + _get_attribute_item(attribute_node)) | |||
| return result | |||
| functions = _get_attribute_item(node.func) | |||
| @@ -315,10 +332,26 @@ class AstScaning(object): | |||
| args_list.append(default_group) | |||
| if len(keyword_list) == 0 and len(args_list) == 1: | |||
| args_list.append(class_name) | |||
| if len(keyword_list) == 1 and len(args_list) == 0: | |||
| if len(keyword_list) > 0 and len(args_list) == 0: | |||
| remove_group_item = None | |||
| for item in keyword_list: | |||
| key, name, attr = item | |||
| if key == GROUP_KEY: | |||
| args_list.append((name, attr)) | |||
| remove_group_item = item | |||
| if remove_group_item is not None: | |||
| keyword_list.remove(remove_group_item) | |||
| if len(args_list) == 0: | |||
| args_list.append(default_group) | |||
| args_list.extend(keyword_list) | |||
| for item in keyword_list: | |||
| key, name, attr = item | |||
| if key == MODULE_CLS: | |||
| class_name = name | |||
| else: | |||
| args_list.append((name, attr)) | |||
| for item in args_list: | |||
| # the case empty input | |||
| @@ -347,9 +380,14 @@ class AstScaning(object): | |||
| for node in nodes: | |||
| if type(node).__name__ != 'Call': | |||
| continue | |||
| class_name = getattr(node, CLASS_NAME, None) | |||
| func = getattr(node, 'func') | |||
| if getattr(func, 'attr', None) != REGISTER_MODULE: | |||
| continue | |||
| parse_output = self._parse_decorator(node) | |||
| index = self._registry_indexer(parse_output, | |||
| getattr(node, CLASS_NAME)) | |||
| index = self._registry_indexer(parse_output, class_name) | |||
| if None is not index: | |||
| results.append(index) | |||
| return results | |||
| @@ -363,6 +401,8 @@ class AstScaning(object): | |||
| node = gast.parse(data) | |||
| output, _ = self.scan_import(node, indent=' ', show_offsets=False) | |||
| output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY]) | |||
| output[EXPRESS_KEY] = self.parse_decorators(output[EXPRESS_KEY]) | |||
| output[DECORATOR_KEY].extend(output[EXPRESS_KEY]) | |||
| return output | |||
| @@ -481,6 +521,13 @@ class FilesAstScaning(object): | |||
| module_import[value_dict[MODULE_KEY]] = value_dict[IMPORT_KEY] | |||
| return module_import | |||
| def _ignore_useless_keys(self, inverted_index): | |||
| if ('OPTIMIZERS', 'default', 'name') in inverted_index: | |||
| del inverted_index[('OPTIMIZERS', 'default', 'name')] | |||
| if ('LR_SCHEDULER', 'default', 'name') in inverted_index: | |||
| del inverted_index[('LR_SCHEDULER', 'default', 'name')] | |||
| return inverted_index | |||
| def get_files_scan_results(self, | |||
| target_dir=MODELSCOPE_PATH, | |||
| target_folders=SCAN_SUB_FOLDERS): | |||
| @@ -514,6 +561,8 @@ class FilesAstScaning(object): | |||
| MODULE_KEY: module_name | |||
| } | |||
| inverted_index_with_results = self._inverted_index(result) | |||
| inverted_index_with_results = self._ignore_useless_keys( | |||
| inverted_index_with_results) | |||
| module_import = self._module_import(result) | |||
| index = { | |||
| INDEX_KEY: inverted_index_with_results, | |||