diff --git a/mindspore/_extends/parse/namespace.py b/mindspore/_extends/parse/namespace.py index 8d8b6fd30e..f32abed284 100644 --- a/mindspore/_extends/parse/namespace.py +++ b/mindspore/_extends/parse/namespace.py @@ -99,12 +99,19 @@ class ClassMemberNamespace(Namespace): obj (Object): A python class object. """ def __init__(self, obj): + self.__class_member_namespace__ = True label = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>' super().__init__(label, obj) def __getitem__(self, name): d, = self.dicts + if name == "self": + return d + if name == "namespace": + return self try: - return getattr(d, name) + if hasattr(d, name): + return getattr(d, name) + return d.__dict__[name] except ValueError: raise UnboundLocalError(name) diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index 9d715fdf53..a69d62869c 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -70,6 +70,7 @@ parse_expr_statement_white_list = ( "append", ) + def create_slice_obj(start, end, step): """Create slice object""" return slice(start, end, step) @@ -201,9 +202,10 @@ def get_object_key(obj): if isinstance(obj, types.MethodType): method_instance = obj.__self__ instance_id = "%s_ID%d" % (str(method_instance.__class__.__name__), id(method_instance)) - obj_id = instance_id + obj_id + obj_id = instance_id + obj_id + str(obj.__hash__()) return obj_id, obj_key + def get_default_input(obj): if hasattr(obj, '__parameter__'): return obj.default_input @@ -213,6 +215,7 @@ def get_default_input(obj): return args return obj + def is_class_member(node): """Check the attr is class member variable.""" type_ = node.__class__.__name__ @@ -224,10 +227,12 @@ def is_class_member(node): return True return False + def get_obj_id(obj): """Get the obj id.""" return str(id(obj)) + def get_obj_type(obj): """Get the obj type.""" obj_type = RESOLVE_TYPE_INVALID @@ -320,6 +325,7 @@ def get_dataclass_methods(cls): if isinstance(getattr(cls, name), (types.FunctionType,))} return methods + class Parser: """ Parser python code to ast tree. @@ -453,6 +459,28 @@ class Parser: logger.debug("ops info = %r", ops_info) return ops_info + def analyze_super(self, father_class_node, subclass_instance): + """Analyze super and return a class instance.""" + father_class = None + if father_class_node is None: + father_class = type(subclass_instance) + if isinstance(father_class_node, ast.Name): + father_class_name = getattr(father_class_node, 'id') + father_class = self.global_namespace[father_class_name] + if isinstance(father_class_node, ast.Attribute): + value = getattr(father_class_node, 'value') + attr = getattr(father_class_node, 'attr') + module_name = getattr(value, 'id') + father_class_module = self.global_namespace[module_name] + father_class = getattr(father_class_module, attr) + if father_class is None: + raise ValueError("When call 'super', the father class is None.") + if not isinstance(subclass_instance, father_class): + raise ValueError("When call 'super', the second arg should be an instance of first arg.") + + target_class_instance = super(father_class, subclass_instance) + return target_class_instance + def get_location(self, node): """ Get location of node start and end line no. diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index e2b83331f5..6c246332ef 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -117,6 +117,7 @@ convert_object_map = { T.zip: C.zip_operation, T.print: F.print_, T.enumerate: M.enumerate_, + T.isinstance: M.isinstance_, # custom define operation T.iter: M.ms_iter, diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index d70c6edcf4..dbfd625f8d 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -114,6 +114,12 @@ def enumerate_(x, start=0): return ret +def isinstance_(x, base_type): + """Determine whether x is an instance of base_type.""" + x_type = F.typeof(x) + return check_type_same(x_type, base_type) + + def while_cond(x): """For while condtion, if the condition is a tensor, the loop will not be unrolled""" if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)): @@ -123,6 +129,12 @@ def while_cond(x): return x +@constexpr +def check_type_same(x_type, base_type): + """Check x_type is same as base_type.""" + return mstype.issubclass_(x_type, base_type) + + @constexpr def check_is_tuple_or_list(x, op_name, arg_name): """check whether x is list or tuple.""" @@ -141,6 +153,7 @@ def check_is_const_int(x, op_name, arg_name): return True + @constexpr def check_is_tensor_bool_cond(shp): """check if tensor is a bool condition""" @@ -148,6 +161,7 @@ def check_is_tensor_bool_cond(shp): return True raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp) + @constexpr def const_tensor_to_bool(x): """convert bool tensor to bool condition""" @@ -162,6 +176,7 @@ def const_tensor_to_bool(x): value = bool(x[0]) return value + def tensor_bool(x): """tensor as conditon, if is constant, return immediate bool value""" is_cond = check_is_tensor_bool_cond(F.shape(x)) diff --git a/mindspore/_extends/parse/trope.py b/mindspore/_extends/parse/trope.py index 28f3196975..674715ef59 100644 --- a/mindspore/_extends/parse/trope.py +++ b/mindspore/_extends/parse/trope.py @@ -27,7 +27,7 @@ from operator import ( # noqa # support system function call from builtins import ( # noqa - bool, getattr, setattr, len, iter, next, pow, range, map, zip, print, enumerate + bool, getattr, setattr, len, iter, next, pow, range, map, zip, print, enumerate, isinstance ) # support functools @@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt', 'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains', 'matmul', 'getitem', 'setitem', 'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip', - 'partial', 'print', 'enumerate', + 'partial', 'print', 'enumerate', 'isinstance', 'exp', 'log', 'sin', 'cos', 'tan'] diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index baef64481b..a7de687714 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -370,6 +370,8 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature } else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) { std::shared_ptr env = obj.cast>(); converted = env; + } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) { + converted = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj); } else if (py::hasattr(obj, "__parameter__")) { auto to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); ret = ConvertData(to_convert, &converted); diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc index 14e9f739d5..588c099082 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc @@ -109,7 +109,7 @@ AnfNodePtr FunctionBlock::MakeResolveClassMember(std::string attr) { // Make a resolve node for symbol string AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) { - if (value.compare(0, strlen("self."), "self.") == 0) { + if (value.compare(0, strlen("self"), "self") == 0) { auto start = value.find_first_of('.') + 1; if (start >= value.size()) { MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value; diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index b2e95c5070..d168ae09b5 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -22,6 +22,7 @@ #include #include #include +#include "pipeline/jit/parse/resolve.h" #include "frontend/operator/ops.h" #include "pipeline/jit/parse/data_converter.h" #include "frontend/operator/composite/composite.h" @@ -504,14 +505,45 @@ AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::v [](AnfNodePtr arg) -> AnfNodePtr { return arg; }); return block->func_graph()->NewCNode(make_tuple_nodes); } + +AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) { + py::object father_class; + if (args.empty()) { + father_class = py::none(); + } else if (args.size() == 2) { + father_class = args[0]; + auto arg_type = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[1]))); + if (arg_type != AST_SUB_TYPE_NAME || py::cast(python_adapter::GetPyObjAttr(args[1], "id")) != "self") { + MS_EXCEPTION(ArgumentError) << "When call 'super', the second arg should be 'self'."; + } + } else { + MS_EXCEPTION(ArgumentError) << "When call 'super', the args number should be 0 or 2, but got" << args.size() << "."; + } + py::object target_class_instance = ast()->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast()->obj()); + py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance); + NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); + SymbolPtr symbol = std::make_shared("namespace"); + return block->MakeResolve(name_space, symbol); +} + // process function call, eg : f1(x, y) ... AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Call"; // process function call py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func"); + py::list args = python_adapter::GetPyObjAttr(node, "args"); + + auto arg_type = + AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, function_ast_node))); + if (arg_type == AST_SUB_TYPE_NAME) { + auto name_id = py::cast(python_adapter::GetPyObjAttr(function_ast_node, "id")); + if (name_id == "super") { + return ParseSuper(block, args); + } + } + AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node); // function call arguments should be passed in as groups and unpacked later using unpack call - py::list args = python_adapter::GetPyObjAttr(node, "args"); std::vector packed_arguments; std::vector group_arguments; diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.h b/mindspore/ccsrc/pipeline/jit/parse/parse.h index c8779f9b42..dc0b43c3a6 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.h @@ -138,6 +138,8 @@ class Parser { AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node); // process a function call AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node); + // process function 'super' + AnfNodePtr ParseSuper(const FunctionBlockPtr &block, const py::list &args); // process the if expression AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node); // process class type define diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index 553782ed6f..1750fe7309 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -81,6 +81,7 @@ const char PYTHON_PARSE_GET_LOCATION[] = "get_location"; const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement"; const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope"; const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name"; +const char PYTHON_PARSE_ANALYZE_SUPER[] = "analyze_super"; const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj"; diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.h b/mindspore/ccsrc/pipeline/jit/parse/resolve.h index 2cd88efb1d..1024012d46 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.h +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.h @@ -80,7 +80,7 @@ using SymbolPtr = std::shared_ptr; // PyObjectWrapper class wrappers resolved python object for further processing. class PyObjectWrapper : public Named { public: - explicit PyObjectWrapper(const py::object &obj, const std::string name = "Python object") : Named(name), obj_(obj) {} + explicit PyObjectWrapper(const py::object &obj, const std::string &name = "Python object") : Named(name), obj_(obj) {} ~PyObjectWrapper() override = default; MS_DECLARE_PARENT(PyObjectWrapper, Named); py::object obj() { return obj_; } @@ -93,7 +93,7 @@ class PyObjectWrapper : public Named { // ClassObject class wrappers dataclass class ClassObject : public PyObjectWrapper { public: - explicit ClassObject(const py::object &obj, const std::string name = "Python dataclass") + explicit ClassObject(const py::object &obj, const std::string &name = "Python dataclass") : PyObjectWrapper(obj, name) {} ~ClassObject() override = default; MS_DECLARE_PARENT(ClassObject, PyObjectWrapper); @@ -103,7 +103,7 @@ class ClassObject : public PyObjectWrapper { // ClassType class wrappers class name in python class ClassType : public PyObjectWrapper { public: - explicit ClassType(const py::object &obj, const std::string name = "Python class type") + explicit ClassType(const py::object &obj, const std::string &name = "Python class type") : PyObjectWrapper(obj, name) {} ~ClassType() override = default; MS_DECLARE_PARENT(ClassType, PyObjectWrapper); diff --git a/mindspore/ccsrc/pybind_api/export_flags.cc b/mindspore/ccsrc/pybind_api/export_flags.cc index a21cfd30bf..bbec26ee99 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.cc +++ b/mindspore/ccsrc/pybind_api/export_flags.cc @@ -25,6 +25,7 @@ const char PYTHON_ENVINSTANCE_FLAG[] = "__envinstance_flag__"; const char PYTHON_DTYPE_FLAG[] = "__dtype_flag__"; const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__"; const char PYTHON_DATACLASS_FIELDS[] = "__dataclass_fields__"; +const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__"; // flag names const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16"; diff --git a/mindspore/ccsrc/pybind_api/export_flags.h b/mindspore/ccsrc/pybind_api/export_flags.h index b84efda770..a927794462 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.h +++ b/mindspore/ccsrc/pybind_api/export_flags.h @@ -27,6 +27,7 @@ extern const char PYTHON_ENVINSTANCE_FLAG[]; extern const char PYTHON_DTYPE_FLAG[]; extern const char PYTHON_CELL_AS_LIST[]; extern const char PYTHON_DATACLASS_FIELDS[]; +extern const char PYTHON_CLASS_MEMBER_NAMESPACE[]; extern const char GRAPH_FLAG_MIX_PRECISION_FP16[]; extern const char GRAPH_FLAG_MIX_PRECISION_FP32[]; diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 3ef21b9626..68810355e9 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -62,6 +62,7 @@ def _wrap_func(fn): Returns: Function, a new function with return suitable format data. """ + @wraps(fn) def wrapper(*arg, **kwargs): results = fn(*arg, **kwargs) @@ -74,6 +75,7 @@ def _wrap_func(fn): if isinstance(data, list): return list(_convert_data(x) for x in data) return data + return _convert_data(results) return wrapper @@ -106,6 +108,7 @@ class _MindSporeFunction: obj (Object): If function is a method, obj is the owner of function, else, obj is none. """ + def __init__(self, fn, input_signature=None, obj=None): self.fn = fn self.save_graphs = context.get_context("save_graphs") @@ -245,6 +248,7 @@ def ms_function(fn=None, obj=None, input_signature=None): >>> out = tensor_add_with_dec(x, y) >>> out = tensor_add_with_sig(x, y) """ + def wrap_mindspore(func): @wraps(func) def staging_specialize(*args): @@ -275,6 +279,7 @@ def _generate_pip_args(obj, *args, method="construct"): obj.__parse_method__ = parse_method return args_names, args_list + class _PynativeExecutor: """ An pynative executor used to compile/manage/run graph. @@ -304,6 +309,7 @@ class _PynativeExecutor: def __call__(self, *args): return self._executor(args, "") + class _Executor: """ An executor used to compile/manage/run graph. @@ -532,6 +538,7 @@ class _Executor: return None return self._executor.fetch_info_for_quant_export(exec_id) + _executor = _Executor() _pynative_exec = _PynativeExecutor() diff --git a/tests/ut/python/pipeline/parse/test_super.py b/tests/ut/python/pipeline/parse/test_super.py new file mode 100644 index 0000000000..eb3bf1682d --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_super.py @@ -0,0 +1,116 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test super""" +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context + +context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + + +class FatherNet(nn.Cell): + def __init__(self, x): + super(FatherNet, self).__init__(x) + self.x = x + + def construct(self, x, y): + return self.x * x + + def test_father(self, x): + return self.x + x + + +class MatherNet(nn.Cell): + def __init__(self, y): + super(MatherNet, self).__init__() + self.y = y + + def construct(self, x, y): + return self.y * y + + def test_mather(self, y): + return self.y + y + + +class SingleSubNet(FatherNet): + def __init__(self, x, z): + super(SingleSubNet, self).__init__(x) + self.z = z + + def construct(self, x, y): + ret_father_construct = super().construct(x, y) + ret_father_test = super(SingleSubNet, self).test_father(x) + ret_father_x = super(SingleSubNet, self).x + ret_sub_z = self.z + + return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z + + +class MulSubNet(FatherNet, MatherNet): + def __init__(self, x, y, z): + super(MulSubNet, self).__init__(x) + super(FatherNet, self).__init__(y) + self.z = z + + def construct(self, x, y): + ret_father_construct = super().construct(x, y) + ret_father_test = super(MulSubNet, self).test_father(x) + ret_father_x = super(MulSubNet, self).x + ret_mather_construct = super(FatherNet, self).construct(x, y) + ret_mather_test = super(FatherNet, self).test_mather(y) + ret_mather_y = super(FatherNet, self).y + ret_sub_z = self.z + + return ret_father_construct, ret_father_test, ret_father_x, \ + ret_mather_construct, ret_mather_test, ret_mather_y, ret_sub_z + + +class Net(nn.Cell): + def __init__(self, x): + super(Net, self).__init__() + self.x = x + + def construct(self, x, y): + ret = super(Net, self).construct(x, y) + return ret + + +def test_single_super(): + single_net = SingleSubNet(2, 3) + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + x = Tensor(np.ones([1, 2, 3], np.int32)) + y = Tensor(np.ones([1, 2, 3], np.int32)) + single_net(x, y) + + +def test_mul_super(): + mul_net = MulSubNet(2, 3, 4) + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + x = Tensor(np.ones([1, 2, 3], np.int32)) + y = Tensor(np.ones([1, 2, 3], np.int32)) + mul_net(x, y) + + +def test_super_cell(): + net = Net(2) + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + x = Tensor(np.ones([1, 2, 3], np.int32)) + y = Tensor(np.ones([1, 2, 3], np.int32)) + with pytest.raises(RuntimeError) as er: + net(x, y) + assert "Unsupported syntax 'Raise'" in str(er.value)