| @@ -36,6 +36,7 @@ SCAN_SUB_FOLDERS = [ | |||||
| ] | ] | ||||
| INDEXER_FILE = 'ast_indexer' | INDEXER_FILE = 'ast_indexer' | ||||
| DECORATOR_KEY = 'decorators' | DECORATOR_KEY = 'decorators' | ||||
| EXPRESS_KEY = 'express' | |||||
| FROM_IMPORT_KEY = 'from_imports' | FROM_IMPORT_KEY = 'from_imports' | ||||
| IMPORT_KEY = 'imports' | IMPORT_KEY = 'imports' | ||||
| FILE_NAME_KEY = 'filepath' | FILE_NAME_KEY = 'filepath' | ||||
| @@ -45,6 +46,9 @@ INDEX_KEY = 'index' | |||||
| REQUIREMENT_KEY = 'requirements' | REQUIREMENT_KEY = 'requirements' | ||||
| MODULE_KEY = 'module' | MODULE_KEY = 'module' | ||||
| CLASS_NAME = 'class_name' | CLASS_NAME = 'class_name' | ||||
| GROUP_KEY = 'group_key' | |||||
| MODULE_NAME = 'module_name' | |||||
| MODULE_CLS = 'module_cls' | |||||
| class AstScaning(object): | class AstScaning(object): | ||||
| @@ -53,6 +57,7 @@ class AstScaning(object): | |||||
| self.result_import = dict() | self.result_import = dict() | ||||
| self.result_from_import = dict() | self.result_from_import = dict() | ||||
| self.result_decorator = [] | self.result_decorator = [] | ||||
| self.express = [] | |||||
| def _is_sub_node(self, node: object) -> bool: | def _is_sub_node(self, node: object) -> bool: | ||||
| return isinstance(node, | return isinstance(node, | ||||
| @@ -108,6 +113,7 @@ class AstScaning(object): | |||||
| self.result_import = dict() | self.result_import = dict() | ||||
| self.result_from_import = dict() | self.result_from_import = dict() | ||||
| self.result_decorator = [] | self.result_decorator = [] | ||||
| self.result_express = [] | |||||
| def scan_ast(self, node: Union[ast.AST, None, str]): | def scan_ast(self, node: Union[ast.AST, None, str]): | ||||
| self._setup_global() | self._setup_global() | ||||
| @@ -243,13 +249,19 @@ class AstScaning(object): | |||||
| setattr(item, CLASS_NAME, node.name) | setattr(item, CLASS_NAME, node.name) | ||||
| self.result_decorator.extend(attr) | self.result_decorator.extend(attr) | ||||
| if attr != [] and type( | |||||
| attr | |||||
| ).__name__ == 'Call' and parent_node_name == 'Expr': | |||||
| self.result_express.append(attr) | |||||
| out += f'{indentstr()}{field}={representation},\n' | out += f'{indentstr()}{field}={representation},\n' | ||||
| out += indentstr() + ')' | out += indentstr() + ')' | ||||
| return { | return { | ||||
| IMPORT_KEY: self.result_import, | IMPORT_KEY: self.result_import, | ||||
| FROM_IMPORT_KEY: self.result_from_import, | FROM_IMPORT_KEY: self.result_from_import, | ||||
| DECORATOR_KEY: self.result_decorator | |||||
| DECORATOR_KEY: self.result_decorator, | |||||
| EXPRESS_KEY: self.result_express | |||||
| }, out | }, out | ||||
| def _parse_decorator(self, node: ast.AST) -> tuple: | def _parse_decorator(self, node: ast.AST) -> tuple: | ||||
| @@ -267,7 +279,10 @@ class AstScaning(object): | |||||
| def _get_args_name(nodes: list) -> list: | def _get_args_name(nodes: list) -> list: | ||||
| result = [] | result = [] | ||||
| for node in nodes: | for node in nodes: | ||||
| result.append(_get_attribute_item(node)) | |||||
| if type(node).__name__ == 'Str': | |||||
| result.append((node.s, None)) | |||||
| else: | |||||
| result.append(_get_attribute_item(node)) | |||||
| return result | return result | ||||
| def _get_keyword_name(nodes: ast.AST) -> list: | def _get_keyword_name(nodes: ast.AST) -> list: | ||||
| @@ -276,9 +291,11 @@ class AstScaning(object): | |||||
| if type(node).__name__ == 'keyword': | if type(node).__name__ == 'keyword': | ||||
| attribute_node = getattr(node, 'value') | attribute_node = getattr(node, 'value') | ||||
| if type(attribute_node).__name__ == 'Str': | if type(attribute_node).__name__ == 'Str': | ||||
| result.append((attribute_node.s, None)) | |||||
| result.append((getattr(node, | |||||
| 'arg'), attribute_node.s, None)) | |||||
| else: | else: | ||||
| result.append(_get_attribute_item(attribute_node)) | |||||
| result.append((getattr(node, 'arg'), ) | |||||
| + _get_attribute_item(attribute_node)) | |||||
| return result | return result | ||||
| functions = _get_attribute_item(node.func) | functions = _get_attribute_item(node.func) | ||||
| @@ -315,10 +332,26 @@ class AstScaning(object): | |||||
| args_list.append(default_group) | args_list.append(default_group) | ||||
| if len(keyword_list) == 0 and len(args_list) == 1: | if len(keyword_list) == 0 and len(args_list) == 1: | ||||
| args_list.append(class_name) | args_list.append(class_name) | ||||
| if len(keyword_list) == 1 and len(args_list) == 0: | |||||
| if len(keyword_list) > 0 and len(args_list) == 0: | |||||
| remove_group_item = None | |||||
| for item in keyword_list: | |||||
| key, name, attr = item | |||||
| if key == GROUP_KEY: | |||||
| args_list.append((name, attr)) | |||||
| remove_group_item = item | |||||
| if remove_group_item is not None: | |||||
| keyword_list.remove(remove_group_item) | |||||
| if len(args_list) == 0: | |||||
| args_list.append(default_group) | args_list.append(default_group) | ||||
| args_list.extend(keyword_list) | |||||
| for item in keyword_list: | |||||
| key, name, attr = item | |||||
| if key == MODULE_CLS: | |||||
| class_name = name | |||||
| else: | |||||
| args_list.append((name, attr)) | |||||
| for item in args_list: | for item in args_list: | ||||
| # the case empty input | # the case empty input | ||||
| @@ -347,9 +380,14 @@ class AstScaning(object): | |||||
| for node in nodes: | for node in nodes: | ||||
| if type(node).__name__ != 'Call': | if type(node).__name__ != 'Call': | ||||
| continue | continue | ||||
| class_name = getattr(node, CLASS_NAME, None) | |||||
| func = getattr(node, 'func') | |||||
| if getattr(func, 'attr', None) != REGISTER_MODULE: | |||||
| continue | |||||
| parse_output = self._parse_decorator(node) | parse_output = self._parse_decorator(node) | ||||
| index = self._registry_indexer(parse_output, | |||||
| getattr(node, CLASS_NAME)) | |||||
| index = self._registry_indexer(parse_output, class_name) | |||||
| if None is not index: | if None is not index: | ||||
| results.append(index) | results.append(index) | ||||
| return results | return results | ||||
| @@ -363,6 +401,8 @@ class AstScaning(object): | |||||
| node = gast.parse(data) | node = gast.parse(data) | ||||
| output, _ = self.scan_import(node, indent=' ', show_offsets=False) | output, _ = self.scan_import(node, indent=' ', show_offsets=False) | ||||
| output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY]) | output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY]) | ||||
| output[EXPRESS_KEY] = self.parse_decorators(output[EXPRESS_KEY]) | |||||
| output[DECORATOR_KEY].extend(output[EXPRESS_KEY]) | |||||
| return output | return output | ||||
| @@ -481,6 +521,13 @@ class FilesAstScaning(object): | |||||
| module_import[value_dict[MODULE_KEY]] = value_dict[IMPORT_KEY] | module_import[value_dict[MODULE_KEY]] = value_dict[IMPORT_KEY] | ||||
| return module_import | return module_import | ||||
| def _ignore_useless_keys(self, inverted_index): | |||||
| if ('OPTIMIZERS', 'default', 'name') in inverted_index: | |||||
| del inverted_index[('OPTIMIZERS', 'default', 'name')] | |||||
| if ('LR_SCHEDULER', 'default', 'name') in inverted_index: | |||||
| del inverted_index[('LR_SCHEDULER', 'default', 'name')] | |||||
| return inverted_index | |||||
| def get_files_scan_results(self, | def get_files_scan_results(self, | ||||
| target_dir=MODELSCOPE_PATH, | target_dir=MODELSCOPE_PATH, | ||||
| target_folders=SCAN_SUB_FOLDERS): | target_folders=SCAN_SUB_FOLDERS): | ||||
| @@ -514,6 +561,8 @@ class FilesAstScaning(object): | |||||
| MODULE_KEY: module_name | MODULE_KEY: module_name | ||||
| } | } | ||||
| inverted_index_with_results = self._inverted_index(result) | inverted_index_with_results = self._inverted_index(result) | ||||
| inverted_index_with_results = self._ignore_useless_keys( | |||||
| inverted_index_with_results) | |||||
| module_import = self._module_import(result) | module_import = self._module_import(result) | ||||
| index = { | index = { | ||||
| INDEX_KEY: inverted_index_with_results, | INDEX_KEY: inverted_index_with_results, | ||||