|
|
|
@@ -17,11 +17,12 @@ |
|
|
|
"""The module of parser python object, called by c++.""" |
|
|
|
|
|
|
|
import os |
|
|
|
import sys |
|
|
|
import ast |
|
|
|
import hashlib |
|
|
|
import inspect |
|
|
|
import types |
|
|
|
import platform |
|
|
|
import importlib |
|
|
|
from dataclasses import is_dataclass |
|
|
|
from textwrap import dedent |
|
|
|
|
|
|
|
@@ -35,6 +36,7 @@ from mindspore.common.api import _MindsporeFunctionExecutor, _convert_data |
|
|
|
from mindspore.common.dtype import pytype_to_dtype |
|
|
|
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace |
|
|
|
from .resources import parse_object_map, ops_symbol_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT |
|
|
|
from .jit_fallback_modules import jit_fallback_third_party_modules_whitelist |
|
|
|
|
|
|
|
# define return value |
|
|
|
RET_SUCCESS = 0 |
|
|
|
@@ -152,9 +154,33 @@ def get_bprop_method_of_class(obj, parse_method=None): |
|
|
|
method = getattr(obj, method_name) |
|
|
|
return method |
|
|
|
|
|
|
|
|
|
|
|
def get_env_support_modules(): |
|
|
|
"""Get support modules from environment variable.""" |
|
|
|
support_modules = os.getenv('MS_DEV_SUPPORT_MODULES') |
|
|
|
if support_modules is None: |
|
|
|
return [] |
|
|
|
env_support_modules = [] |
|
|
|
modules = support_modules.split(',') |
|
|
|
for module in modules: |
|
|
|
try: |
|
|
|
module_spec = importlib.util.find_spec(module) |
|
|
|
except (ModuleNotFoundError, ValueError): |
|
|
|
module = module[0:module.rfind('.')] |
|
|
|
module_spec = importlib.util.find_spec(module) |
|
|
|
if module_spec is None: |
|
|
|
raise ModuleNotFoundError(f"Cannot find module: {module}. " \ |
|
|
|
f"Please check if {module} is installed, or if MS_DEV_SUPPORT_MODULES is set correctly.") |
|
|
|
# Add the outermost module. |
|
|
|
env_support_modules.append(module.split('.')[0]) |
|
|
|
logger.debug(f"Get support modules from env: {env_support_modules}") |
|
|
|
return env_support_modules |
|
|
|
|
|
|
|
|
|
|
|
# The fallback feature is enabled in default. |
|
|
|
# Not support change the flag during the process is alive. |
|
|
|
support_fallback_ = os.getenv('MS_DEV_ENABLE_FALLBACK') |
|
|
|
support_modules_ = get_env_support_modules() |
|
|
|
|
|
|
|
|
|
|
|
def resolve_symbol(namespace, symbol): |
|
|
|
@@ -571,6 +597,41 @@ def get_args(node): |
|
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
def _in_sys_path(file_path): |
|
|
|
for path in list(sys.path): |
|
|
|
if file_path.startswith(path): |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def is_third_party_module(value): |
|
|
|
"""To check if value is a third-party module.""" |
|
|
|
# Check if value is a module or package. |
|
|
|
if not inspect.ismodule(value) or not hasattr(value, '__file__'): |
|
|
|
return False |
|
|
|
# Check if module file is under the sys path. |
|
|
|
module_file = value.__file__ |
|
|
|
if not _in_sys_path(module_file): |
|
|
|
return False |
|
|
|
|
|
|
|
# Get module leftmost name. |
|
|
|
if not hasattr(value, '__name__'): |
|
|
|
return False |
|
|
|
module_name = value.__name__ |
|
|
|
module_leftmost_name = module_name.split('.')[0] |
|
|
|
# Ignore mindspore package. |
|
|
|
if module_leftmost_name == "mindspore": |
|
|
|
return False |
|
|
|
# Check if module is in whitelist. |
|
|
|
if module_leftmost_name in support_modules_: |
|
|
|
logger.debug(f"Found support modules from env: {module_name}") |
|
|
|
return True |
|
|
|
if module_leftmost_name in jit_fallback_third_party_modules_whitelist: |
|
|
|
logger.debug(f"Found third-party module: {module_name}") |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def eval_script(exp_str, params): |
|
|
|
"""Evaluate a python expression.""" |
|
|
|
if not isinstance(params, tuple): |
|
|
|
@@ -614,17 +675,6 @@ class Parser: |
|
|
|
self.line_offset = 0 |
|
|
|
self.filename: str = inspect.getfile(inspect.unwrap(self.fn)) |
|
|
|
|
|
|
|
# Used to resolve mindspore builtin ops namespace. |
|
|
|
self.ms_common_ns = CellNamespace('mindspore.common') |
|
|
|
self.ms_nn_ns = CellNamespace('mindspore.nn') |
|
|
|
self.ms_ops_ns = CellNamespace('mindspore.ops') |
|
|
|
self.ms_ops_c_ns = CellNamespace('mindspore.ops.composite') |
|
|
|
self.ms_ops_c_multitype_ns = CellNamespace('mindspore.ops.composite.multitype_ops') |
|
|
|
self.ms_ops_p_ns = CellNamespace('mindspore.ops.operations') |
|
|
|
if platform.system().lower() != 'windows': |
|
|
|
self.ms_scipy_ns = CellNamespace('mindspore.scipy') |
|
|
|
else: |
|
|
|
self.ms_scipy_ns = {} |
|
|
|
# Used to resolve the function's globals namespace. |
|
|
|
self.global_namespace = CellNamespace(fn.__module__) |
|
|
|
self.function_module = fn.__module__ |
|
|
|
@@ -720,86 +770,6 @@ class Parser: |
|
|
|
error_info = f"The name '{var}' is not defined in function '{self.function_name}'." |
|
|
|
return None, error_info |
|
|
|
|
|
|
|
def is_rightmost_name_in_namespace_module(self, name): |
|
|
|
"""Check supported Module namespace.""" |
|
|
|
rightmost_name = name.split('.')[-1] |
|
|
|
if rightmost_name in self.ms_ops_ns: |
|
|
|
logger.debug(f"Found '{name}'({rightmost_name}) in ops namespace: {str(self.ms_ops_ns)}.") |
|
|
|
return True |
|
|
|
if rightmost_name in self.ms_ops_c_ns: |
|
|
|
logger.debug(f"Found '{name}'({rightmost_name}) in C namespace: {str(self.ms_ops_c_ns)}.") |
|
|
|
return True |
|
|
|
if rightmost_name in self.ms_ops_c_multitype_ns: |
|
|
|
logger.debug( |
|
|
|
f"Found '{name}'({rightmost_name}) in C.multitype namespace: {str(self.ms_ops_c_multitype_ns)}.") |
|
|
|
return True |
|
|
|
if rightmost_name in self.ms_ops_p_ns: |
|
|
|
logger.debug(f"Found '{name}'({rightmost_name}) in P namespace: {str(self.ms_ops_p_ns)}.") |
|
|
|
return True |
|
|
|
if rightmost_name in self.ms_common_ns: |
|
|
|
logger.debug(f"Found '{name}'({rightmost_name}) in common namespace: {str(self.ms_common_ns)}.") |
|
|
|
return True |
|
|
|
# Support nn.layer. To check if exclude other module. |
|
|
|
if rightmost_name in self.ms_nn_ns: |
|
|
|
logger.debug(f"Found '{name}'({rightmost_name}) in nn namespace: {str(self.ms_nn_ns)}.") |
|
|
|
return True |
|
|
|
if rightmost_name in self.ms_scipy_ns: |
|
|
|
logger.debug(f"Found '{name}'({rightmost_name}) in scipy namespace: {str(self.ms_scipy_ns)}.") |
|
|
|
return True |
|
|
|
if rightmost_name in trope_ns: |
|
|
|
logger.debug(f"Found '{name}'({rightmost_name}) in trope namespace: {str(trope_ns)}.") |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
def is_supported_namespace_module(self, value): |
|
|
|
"""To check if the module is allowed to support.""" |
|
|
|
# Check `mindspore` namespace. |
|
|
|
if not hasattr(value, '__name__'): |
|
|
|
logger.debug(f"'{str(value)}' has no '__name__' attribute, we suppose it's supported.") |
|
|
|
return True |
|
|
|
name = value.__name__ |
|
|
|
if name == 'mindspore': |
|
|
|
logger.debug(f"Found 'mindspore' root namespace.") |
|
|
|
return True |
|
|
|
if name == 'mindspore.ops': |
|
|
|
logger.debug(f"Found 'mindspore.ops' namespace.") |
|
|
|
return True |
|
|
|
if name == 'mindspore.nn': |
|
|
|
logger.debug(f"Found 'mindspore.nn' namespace.") |
|
|
|
return True |
|
|
|
if name == 'mindspore.numpy': |
|
|
|
logger.debug(f"Found 'mindspore.numpy' namespace.") |
|
|
|
return True |
|
|
|
if platform.system().lower() != 'windows' and name == 'mindspore.scipy': |
|
|
|
logger.debug(f"Found 'mindspore.scipy' namespace.") |
|
|
|
return True |
|
|
|
if name == 'mindspore.context': |
|
|
|
logger.debug(f"Found 'mindspore.context' namespace.") |
|
|
|
return True |
|
|
|
|
|
|
|
if name == 'functools': |
|
|
|
logger.debug(f"Found 'functools' namespace.") |
|
|
|
return True |
|
|
|
|
|
|
|
# Check `builtins` namespace. |
|
|
|
if hasattr(value, '__module__'): # Not types.ModuleType |
|
|
|
mod = value.__module__ |
|
|
|
if mod == 'builtins': |
|
|
|
logger.debug(f"Found '{name}' in 'builtins' namespace.") |
|
|
|
return True |
|
|
|
|
|
|
|
# We suppose it's supported if not a Module. |
|
|
|
if not isinstance(value, types.ModuleType): |
|
|
|
logger.debug(f"Found '{name}', not a module.") |
|
|
|
return True |
|
|
|
|
|
|
|
# Check supported Module namespace. |
|
|
|
if self.is_rightmost_name_in_namespace_module(name): |
|
|
|
return True |
|
|
|
|
|
|
|
logger.info(f"Not found '{name}' in mindspore supported namespace.") |
|
|
|
return False |
|
|
|
|
|
|
|
def get_builtin_namespace_symbol(self, var: str): |
|
|
|
"""Get mindspore builtin namespace and symbol.""" |
|
|
|
if var in self.closure_namespace: |
|
|
|
@@ -817,7 +787,7 @@ class Parser: |
|
|
|
support_info = self.global_namespace, var, value, SYNTAX_UNSUPPORTED_EXTERNAL_TYPE |
|
|
|
elif self.is_unsupported_special_type(value): |
|
|
|
support_info = self.global_namespace, var, value, SYNTAX_UNSUPPORTED_SPECIAL_TYPE |
|
|
|
elif self.is_unsupported_namespace(value) or not self.is_supported_namespace_module(value): |
|
|
|
elif self.is_unsupported_namespace(value) or is_third_party_module(value): |
|
|
|
support_info = self.global_namespace, var, value, SYNTAX_UNSUPPORTED_NAMESPACE |
|
|
|
else: |
|
|
|
support_info = self.global_namespace, var, value, SYNTAX_SUPPORTED |
|
|
|
|