Merge pull request !3237 from zhangbuxue/support_call_the_parent_class_construct_functiontags/v0.6.0-beta
| @@ -99,12 +99,19 @@ class ClassMemberNamespace(Namespace): | |||
| obj (Object): A python class object. | |||
| """ | |||
| def __init__(self, obj): | |||
| self.__class_member_namespace__ = True | |||
| label = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>' | |||
| super().__init__(label, obj) | |||
| def __getitem__(self, name): | |||
| d, = self.dicts | |||
| if name == "self": | |||
| return d | |||
| if name == "namespace": | |||
| return self | |||
| try: | |||
| return getattr(d, name) | |||
| if hasattr(d, name): | |||
| return getattr(d, name) | |||
| return d.__dict__[name] | |||
| except ValueError: | |||
| raise UnboundLocalError(name) | |||
| @@ -70,6 +70,7 @@ parse_expr_statement_white_list = ( | |||
| "append", | |||
| ) | |||
| def create_slice_obj(start, end, step): | |||
| """Create slice object""" | |||
| return slice(start, end, step) | |||
| @@ -201,9 +202,10 @@ def get_object_key(obj): | |||
| if isinstance(obj, types.MethodType): | |||
| method_instance = obj.__self__ | |||
| instance_id = "%s_ID%d" % (str(method_instance.__class__.__name__), id(method_instance)) | |||
| obj_id = instance_id + obj_id | |||
| obj_id = instance_id + obj_id + str(obj.__hash__()) | |||
| return obj_id, obj_key | |||
| def get_default_input(obj): | |||
| if hasattr(obj, '__parameter__'): | |||
| return obj.default_input | |||
| @@ -213,6 +215,7 @@ def get_default_input(obj): | |||
| return args | |||
| return obj | |||
| def is_class_member(node): | |||
| """Check the attr is class member variable.""" | |||
| type_ = node.__class__.__name__ | |||
| @@ -224,10 +227,12 @@ def is_class_member(node): | |||
| return True | |||
| return False | |||
| def get_obj_id(obj): | |||
| """Get the obj id.""" | |||
| return str(id(obj)) | |||
| def get_obj_type(obj): | |||
| """Get the obj type.""" | |||
| obj_type = RESOLVE_TYPE_INVALID | |||
| @@ -320,6 +325,7 @@ def get_dataclass_methods(cls): | |||
| if isinstance(getattr(cls, name), (types.FunctionType,))} | |||
| return methods | |||
| class Parser: | |||
| """ | |||
| Parser python code to ast tree. | |||
| @@ -453,6 +459,28 @@ class Parser: | |||
| logger.debug("ops info = %r", ops_info) | |||
| return ops_info | |||
| def analyze_super(self, father_class_node, subclass_instance): | |||
| """Analyze super and return a class instance.""" | |||
| father_class = None | |||
| if father_class_node is None: | |||
| father_class = type(subclass_instance) | |||
| if isinstance(father_class_node, ast.Name): | |||
| father_class_name = getattr(father_class_node, 'id') | |||
| father_class = self.global_namespace[father_class_name] | |||
| if isinstance(father_class_node, ast.Attribute): | |||
| value = getattr(father_class_node, 'value') | |||
| attr = getattr(father_class_node, 'attr') | |||
| module_name = getattr(value, 'id') | |||
| father_class_module = self.global_namespace[module_name] | |||
| father_class = getattr(father_class_module, attr) | |||
| if father_class is None: | |||
| raise ValueError("When call 'super', the father class is None.") | |||
| if not isinstance(subclass_instance, father_class): | |||
| raise ValueError("When call 'super', the second arg should be an instance of first arg.") | |||
| target_class_instance = super(father_class, subclass_instance) | |||
| return target_class_instance | |||
| def get_location(self, node): | |||
| """ | |||
| Get location of node start and end line no. | |||
| @@ -117,6 +117,7 @@ convert_object_map = { | |||
| T.zip: C.zip_operation, | |||
| T.print: F.print_, | |||
| T.enumerate: M.enumerate_, | |||
| T.isinstance: M.isinstance_, | |||
| # custom define operation | |||
| T.iter: M.ms_iter, | |||
| @@ -114,6 +114,12 @@ def enumerate_(x, start=0): | |||
| return ret | |||
| def isinstance_(x, base_type): | |||
| """Determine whether x is an instance of base_type.""" | |||
| x_type = F.typeof(x) | |||
| return check_type_same(x_type, base_type) | |||
| def while_cond(x): | |||
| """For while condtion, if the condition is a tensor, the loop will not be unrolled""" | |||
| if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)): | |||
| @@ -123,6 +129,12 @@ def while_cond(x): | |||
| return x | |||
| @constexpr | |||
| def check_type_same(x_type, base_type): | |||
| """Check x_type is same as base_type.""" | |||
| return mstype.issubclass_(x_type, base_type) | |||
| @constexpr | |||
| def check_is_tuple_or_list(x, op_name, arg_name): | |||
| """check whether x is list or tuple.""" | |||
| @@ -141,6 +153,7 @@ def check_is_const_int(x, op_name, arg_name): | |||
| return True | |||
| @constexpr | |||
| def check_is_tensor_bool_cond(shp): | |||
| """check if tensor is a bool condition""" | |||
| @@ -148,6 +161,7 @@ def check_is_tensor_bool_cond(shp): | |||
| return True | |||
| raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp) | |||
| @constexpr | |||
| def const_tensor_to_bool(x): | |||
| """convert bool tensor to bool condition""" | |||
| @@ -162,6 +176,7 @@ def const_tensor_to_bool(x): | |||
| value = bool(x[0]) | |||
| return value | |||
| def tensor_bool(x): | |||
| """tensor as conditon, if is constant, return immediate bool value""" | |||
| is_cond = check_is_tensor_bool_cond(F.shape(x)) | |||
| @@ -27,7 +27,7 @@ from operator import ( # noqa | |||
| # support system function call | |||
| from builtins import ( # noqa | |||
| bool, getattr, setattr, len, iter, next, pow, range, map, zip, print, enumerate | |||
| bool, getattr, setattr, len, iter, next, pow, range, map, zip, print, enumerate, isinstance | |||
| ) | |||
| # support functools | |||
| @@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt', | |||
| 'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains', | |||
| 'matmul', 'getitem', 'setitem', | |||
| 'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip', | |||
| 'partial', 'print', 'enumerate', | |||
| 'partial', 'print', 'enumerate', 'isinstance', | |||
| 'exp', 'log', 'sin', 'cos', 'tan'] | |||
| @@ -370,6 +370,8 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature | |||
| } else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) { | |||
| std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>(); | |||
| converted = env; | |||
| } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) { | |||
| converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj); | |||
| } else if (py::hasattr(obj, "__parameter__")) { | |||
| auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input")); | |||
| ret = ConvertData(to_convert, &converted); | |||
| @@ -109,7 +109,7 @@ AnfNodePtr FunctionBlock::MakeResolveClassMember(std::string attr) { | |||
| // Make a resolve node for symbol string | |||
| AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) { | |||
| if (value.compare(0, strlen("self."), "self.") == 0) { | |||
| if (value.compare(0, strlen("self"), "self") == 0) { | |||
| auto start = value.find_first_of('.') + 1; | |||
| if (start >= value.size()) { | |||
| MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value; | |||
| @@ -22,6 +22,7 @@ | |||
| #include <sstream> | |||
| #include <unordered_map> | |||
| #include <algorithm> | |||
| #include "pipeline/jit/parse/resolve.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "pipeline/jit/parse/data_converter.h" | |||
| #include "frontend/operator/composite/composite.h" | |||
| @@ -504,14 +505,45 @@ AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::v | |||
| [](AnfNodePtr arg) -> AnfNodePtr { return arg; }); | |||
| return block->func_graph()->NewCNode(make_tuple_nodes); | |||
| } | |||
| AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) { | |||
| py::object father_class; | |||
| if (args.empty()) { | |||
| father_class = py::none(); | |||
| } else if (args.size() == 2) { | |||
| father_class = args[0]; | |||
| auto arg_type = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[1]))); | |||
| if (arg_type != AST_SUB_TYPE_NAME || py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) != "self") { | |||
| MS_EXCEPTION(ArgumentError) << "When call 'super', the second arg should be 'self'."; | |||
| } | |||
| } else { | |||
| MS_EXCEPTION(ArgumentError) << "When call 'super', the args number should be 0 or 2, but got" << args.size() << "."; | |||
| } | |||
| py::object target_class_instance = ast()->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast()->obj()); | |||
| py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance); | |||
| NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); | |||
| SymbolPtr symbol = std::make_shared<Symbol>("namespace"); | |||
| return block->MakeResolve(name_space, symbol); | |||
| } | |||
| // process function call, eg : f1(x, y) ... | |||
| AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) { | |||
| MS_LOG(DEBUG) << "Process ast Call"; | |||
| // process function call | |||
| py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func"); | |||
| py::list args = python_adapter::GetPyObjAttr(node, "args"); | |||
| auto arg_type = | |||
| AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, function_ast_node))); | |||
| if (arg_type == AST_SUB_TYPE_NAME) { | |||
| auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(function_ast_node, "id")); | |||
| if (name_id == "super") { | |||
| return ParseSuper(block, args); | |||
| } | |||
| } | |||
| AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node); | |||
| // function call arguments should be passed in as groups and unpacked later using unpack call | |||
| py::list args = python_adapter::GetPyObjAttr(node, "args"); | |||
| std::vector<AnfNodePtr> packed_arguments; | |||
| std::vector<AnfNodePtr> group_arguments; | |||
| @@ -138,6 +138,8 @@ class Parser { | |||
| AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node); | |||
| // process a function call | |||
| AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node); | |||
| // process function 'super' | |||
| AnfNodePtr ParseSuper(const FunctionBlockPtr &block, const py::list &args); | |||
| // process the if expression | |||
| AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node); | |||
| // process class type define | |||
| @@ -81,6 +81,7 @@ const char PYTHON_PARSE_GET_LOCATION[] = "get_location"; | |||
| const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement"; | |||
| const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope"; | |||
| const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name"; | |||
| const char PYTHON_PARSE_ANALYZE_SUPER[] = "analyze_super"; | |||
| const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; | |||
| const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj"; | |||
| @@ -80,7 +80,7 @@ using SymbolPtr = std::shared_ptr<Symbol>; | |||
| // PyObjectWrapper class wrappers resolved python object for further processing. | |||
| class PyObjectWrapper : public Named { | |||
| public: | |||
| explicit PyObjectWrapper(const py::object &obj, const std::string name = "Python object") : Named(name), obj_(obj) {} | |||
| explicit PyObjectWrapper(const py::object &obj, const std::string &name = "Python object") : Named(name), obj_(obj) {} | |||
| ~PyObjectWrapper() override = default; | |||
| MS_DECLARE_PARENT(PyObjectWrapper, Named); | |||
| py::object obj() { return obj_; } | |||
| @@ -93,7 +93,7 @@ class PyObjectWrapper : public Named { | |||
| // ClassObject class wrappers dataclass | |||
| class ClassObject : public PyObjectWrapper { | |||
| public: | |||
| explicit ClassObject(const py::object &obj, const std::string name = "Python dataclass") | |||
| explicit ClassObject(const py::object &obj, const std::string &name = "Python dataclass") | |||
| : PyObjectWrapper(obj, name) {} | |||
| ~ClassObject() override = default; | |||
| MS_DECLARE_PARENT(ClassObject, PyObjectWrapper); | |||
| @@ -103,7 +103,7 @@ class ClassObject : public PyObjectWrapper { | |||
| // ClassType class wrappers class name in python | |||
| class ClassType : public PyObjectWrapper { | |||
| public: | |||
| explicit ClassType(const py::object &obj, const std::string name = "Python class type") | |||
| explicit ClassType(const py::object &obj, const std::string &name = "Python class type") | |||
| : PyObjectWrapper(obj, name) {} | |||
| ~ClassType() override = default; | |||
| MS_DECLARE_PARENT(ClassType, PyObjectWrapper); | |||
| @@ -25,6 +25,7 @@ const char PYTHON_ENVINSTANCE_FLAG[] = "__envinstance_flag__"; | |||
| const char PYTHON_DTYPE_FLAG[] = "__dtype_flag__"; | |||
| const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__"; | |||
| const char PYTHON_DATACLASS_FIELDS[] = "__dataclass_fields__"; | |||
| const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__"; | |||
| // flag names | |||
| const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16"; | |||
| @@ -27,6 +27,7 @@ extern const char PYTHON_ENVINSTANCE_FLAG[]; | |||
| extern const char PYTHON_DTYPE_FLAG[]; | |||
| extern const char PYTHON_CELL_AS_LIST[]; | |||
| extern const char PYTHON_DATACLASS_FIELDS[]; | |||
| extern const char PYTHON_CLASS_MEMBER_NAMESPACE[]; | |||
| extern const char GRAPH_FLAG_MIX_PRECISION_FP16[]; | |||
| extern const char GRAPH_FLAG_MIX_PRECISION_FP32[]; | |||
| @@ -62,6 +62,7 @@ def _wrap_func(fn): | |||
| Returns: | |||
| Function, a new function with return suitable format data. | |||
| """ | |||
| @wraps(fn) | |||
| def wrapper(*arg, **kwargs): | |||
| results = fn(*arg, **kwargs) | |||
| @@ -74,6 +75,7 @@ def _wrap_func(fn): | |||
| if isinstance(data, list): | |||
| return list(_convert_data(x) for x in data) | |||
| return data | |||
| return _convert_data(results) | |||
| return wrapper | |||
| @@ -106,6 +108,7 @@ class _MindSporeFunction: | |||
| obj (Object): If function is a method, obj is the owner of function, | |||
| else, obj is none. | |||
| """ | |||
| def __init__(self, fn, input_signature=None, obj=None): | |||
| self.fn = fn | |||
| self.save_graphs = context.get_context("save_graphs") | |||
| @@ -245,6 +248,7 @@ def ms_function(fn=None, obj=None, input_signature=None): | |||
| >>> out = tensor_add_with_dec(x, y) | |||
| >>> out = tensor_add_with_sig(x, y) | |||
| """ | |||
| def wrap_mindspore(func): | |||
| @wraps(func) | |||
| def staging_specialize(*args): | |||
| @@ -275,6 +279,7 @@ def _generate_pip_args(obj, *args, method="construct"): | |||
| obj.__parse_method__ = parse_method | |||
| return args_names, args_list | |||
| class _PynativeExecutor: | |||
| """ | |||
| An pynative executor used to compile/manage/run graph. | |||
| @@ -304,6 +309,7 @@ class _PynativeExecutor: | |||
| def __call__(self, *args): | |||
| return self._executor(args, "") | |||
| class _Executor: | |||
| """ | |||
| An executor used to compile/manage/run graph. | |||
| @@ -532,6 +538,7 @@ class _Executor: | |||
| return None | |||
| return self._executor.fetch_info_for_quant_export(exec_id) | |||
| _executor = _Executor() | |||
| _pynative_exec = _PynativeExecutor() | |||
| @@ -0,0 +1,116 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test super""" | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| class FatherNet(nn.Cell): | |||
| def __init__(self, x): | |||
| super(FatherNet, self).__init__(x) | |||
| self.x = x | |||
| def construct(self, x, y): | |||
| return self.x * x | |||
| def test_father(self, x): | |||
| return self.x + x | |||
| class MatherNet(nn.Cell): | |||
| def __init__(self, y): | |||
| super(MatherNet, self).__init__() | |||
| self.y = y | |||
| def construct(self, x, y): | |||
| return self.y * y | |||
| def test_mather(self, y): | |||
| return self.y + y | |||
| class SingleSubNet(FatherNet): | |||
| def __init__(self, x, z): | |||
| super(SingleSubNet, self).__init__(x) | |||
| self.z = z | |||
| def construct(self, x, y): | |||
| ret_father_construct = super().construct(x, y) | |||
| ret_father_test = super(SingleSubNet, self).test_father(x) | |||
| ret_father_x = super(SingleSubNet, self).x | |||
| ret_sub_z = self.z | |||
| return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z | |||
| class MulSubNet(FatherNet, MatherNet): | |||
| def __init__(self, x, y, z): | |||
| super(MulSubNet, self).__init__(x) | |||
| super(FatherNet, self).__init__(y) | |||
| self.z = z | |||
| def construct(self, x, y): | |||
| ret_father_construct = super().construct(x, y) | |||
| ret_father_test = super(MulSubNet, self).test_father(x) | |||
| ret_father_x = super(MulSubNet, self).x | |||
| ret_mather_construct = super(FatherNet, self).construct(x, y) | |||
| ret_mather_test = super(FatherNet, self).test_mather(y) | |||
| ret_mather_y = super(FatherNet, self).y | |||
| ret_sub_z = self.z | |||
| return ret_father_construct, ret_father_test, ret_father_x, \ | |||
| ret_mather_construct, ret_mather_test, ret_mather_y, ret_sub_z | |||
| class Net(nn.Cell): | |||
| def __init__(self, x): | |||
| super(Net, self).__init__() | |||
| self.x = x | |||
| def construct(self, x, y): | |||
| ret = super(Net, self).construct(x, y) | |||
| return ret | |||
| def test_single_super(): | |||
| single_net = SingleSubNet(2, 3) | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| y = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| single_net(x, y) | |||
| def test_mul_super(): | |||
| mul_net = MulSubNet(2, 3, 4) | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| y = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| mul_net(x, y) | |||
| def test_super_cell(): | |||
| net = Net(2) | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| y = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| with pytest.raises(RuntimeError) as er: | |||
| net(x, y) | |||
| assert "Unsupported syntax 'Raise'" in str(er.value) | |||