Merge pull request !30855 from huangbingjian/ms_class_devr1.7
| @@ -23,58 +23,63 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr} | |||
| // {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr} | |||
| // {prim::kPrimGetAttr, namespace, attr} | |||
| // {prim::kPrimGetAttr, bool, attr} | |||
| // {prim::kPrimGetAttr, object, attr} | |||
| // {prim::kPrimResolve, namespace, symbol} | |||
| AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| PatternNode<AnfNodePtr> getattr_operand, ns_node, sym_node, attr_node, bool_node; | |||
| auto GetAttrResolveLambda = [&node, &getattr_operand, &attr_node, &optimizer]() -> AnfNodePtr { | |||
| auto getattr_operand_node = getattr_operand.GetNode(node); | |||
| auto attr = attr_node.GetNode(node); | |||
| PatternNode<AnfNodePtr> object, attr, ns_node, sym_node; | |||
| auto GetAttrLambda = [&node, &object, &attr, &optimizer]() -> AnfNodePtr { | |||
| auto object_node = object.GetNode(node); | |||
| auto attr_node = attr.GetNode(node); | |||
| // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr} | |||
| if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimResolve)) { | |||
| auto [name_space, symbol] = parse::GetNamespaceAndSymbol(getattr_operand_node); | |||
| if (IsPrimitiveCNode(object_node, prim::kPrimResolve)) { | |||
| auto [name_space, symbol] = parse::GetNamespaceAndSymbol(object_node); | |||
| auto module_name = name_space->module(); | |||
| constexpr std::string_view parse_super_name = "namespace"; | |||
| if (module_name.find(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER) != std::string::npos && | |||
| symbol->symbol() != parse_super_name) { | |||
| auto obj = parse::GetSymbolObject(name_space, symbol, node); | |||
| return parse::ResolveCellWithAttr(optimizer->manager(), obj, getattr_operand_node, attr); | |||
| auto symbol_obj = parse::GetSymbolObject(name_space, symbol, node); | |||
| return parse::ResolveCellWithAttr(optimizer->manager(), symbol_obj, object_node, attr_node); | |||
| } | |||
| } | |||
| // {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr} | |||
| auto operand_cnode = getattr_operand_node->cast<CNodePtr>(); | |||
| constexpr size_t getitem_inputs_size = 3; | |||
| if (operand_cnode != nullptr && operand_cnode->size() == getitem_inputs_size) { | |||
| constexpr auto prim_index = 0; | |||
| if (parse::IsGetItemCNode(object_node)) { | |||
| auto getitem_cnode = object_node->cast<CNodePtr>(); | |||
| constexpr auto resolve_index = 1; | |||
| constexpr auto index_index = 2; | |||
| auto prim_node = operand_cnode->input(prim_index); | |||
| auto resolve_node = operand_cnode->input(resolve_index); | |||
| auto index_node = operand_cnode->input(index_index); | |||
| if (!parse::IsResolveNodeWithGetItem(prim_node) || !IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) { | |||
| return nullptr; | |||
| auto resolve_node = getitem_cnode->input(resolve_index); | |||
| auto index_node = getitem_cnode->input(index_index); | |||
| if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) { | |||
| auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node); | |||
| auto obj = parse::GetObjectFromSequence(name_space, symbol, resolve_node, index_node); | |||
| if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) { | |||
| return parse::ResolveSequenceWithAttr(optimizer->manager(), obj, resolve_node, attr_node, getitem_cnode); | |||
| } | |||
| return parse::ResolveCellWithAttr(optimizer->manager(), obj, resolve_node, attr_node); | |||
| } | |||
| auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node); | |||
| auto obj = parse::GetObjectFromSequence(name_space, symbol, resolve_node, index_node); | |||
| if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) { | |||
| return parse::ResolveSequenceWithAttr(optimizer->manager(), obj, resolve_node, attr, operand_cnode); | |||
| } | |||
| return parse::ResolveCellWithAttr(optimizer->manager(), obj, resolve_node, attr); | |||
| } | |||
| return nullptr; | |||
| }; | |||
| auto GetAttrLambda = [&node, &ns_node, &attr_node, &optimizer]() -> AnfNodePtr { | |||
| auto name_space = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node)); | |||
| auto str = GetValue<std::string>(GetValueNode(attr_node.GetNode(node))); | |||
| parse::SymbolPtr symbol = std::make_shared<parse::Symbol>(str); | |||
| auto manager = optimizer->manager(); | |||
| return parse::ResolveSymbol(manager, name_space, symbol, node); | |||
| // {prim::kPrimGetAttr, namespace, attr} | |||
| if (IsValueNode<parse::NameSpace>(object_node)) { | |||
| auto name_space = GetValueNode<parse::NameSpacePtr>(object_node); | |||
| auto attr_str = GetValue<std::string>(GetValueNode(attr_node)); | |||
| parse::SymbolPtr symbol = std::make_shared<parse::Symbol>(attr_str); | |||
| return parse::ResolveSymbol(optimizer->manager(), name_space, symbol, node); | |||
| } | |||
| // {prim::kPrimGetAttr, MsClassObject, attr} | |||
| if (IsValueNode<parse::MsClassObject>(object_node)) { | |||
| auto ms_class = GetValueNode<parse::MsClassObjectPtr>(object_node); | |||
| auto attr_str = GetValue<std::string>(GetValueNode(attr_node)); | |||
| return parse::ResolveMsClassWithAttr(optimizer->manager(), ms_class, attr_str, node); | |||
| } | |||
| // {prim::kPrimGetAttr, bool, attr} | |||
| if (IsValueNode<BoolImm>(object_node)) { | |||
| return object_node; | |||
| } | |||
| return nullptr; | |||
| }; | |||
| auto ResolveLambda = [&node, &ns_node, &sym_node, &optimizer]() -> AnfNodePtr { | |||
| @@ -84,18 +89,9 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr | |||
| return parse::ResolveSymbol(manager, name_space, symbol, node); | |||
| }; | |||
| // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr} | |||
| // {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr} | |||
| MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, getattr_operand, attr_node), GetAttrResolveLambda, | |||
| attr_node.CheckFunc(IsValueNode<StringImm>, node)); | |||
| // {prim::kPrimGetAttr, namespace, attr} | |||
| MATCH_REPLACE_LAMBDA_IF( | |||
| node, PPrimitive(prim::kPrimGetAttr, ns_node, attr_node), GetAttrLambda, | |||
| ns_node.CheckFunc(IsValueNode<parse::NameSpace>, node) && attr_node.CheckFunc(IsValueNode<StringImm>, node)); | |||
| // {prim::kPrimGetAttr, bool, attr} | |||
| MATCH_REPLACE_IF( | |||
| node, PPrimitive(prim::kPrimGetAttr, bool_node, attr_node), bool_node, | |||
| bool_node.CheckFunc(IsValueNode<BoolImm>, node) && attr_node.CheckFunc(IsValueNode<StringImm>, node)); | |||
| // {prim::kPrimGetAttr, object, attr} | |||
| MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, object, attr), GetAttrLambda, | |||
| attr.CheckFunc(IsValueNode<StringImm>, node)); | |||
| // {prim::kPrimResolve, namespace, symbol} | |||
| MATCH_REPLACE_LAMBDA_IF( | |||
| node, PPrimitive(prim::kPrimResolve, ns_node, sym_node), ResolveLambda, | |||
| @@ -40,6 +40,7 @@ namespace irpass { | |||
| // {prim::kPrimGetAttr, {prim::kPrimTupleGetItem, {prim::kPrimResolve, namespace, symbol}, index}, attr} | |||
| // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr} | |||
| // {prim::kPrimGetAttr, namespace, attr} | |||
| // {prim::kPrimGetAttr, MsClassObject, attr} | |||
| // {prim::kPrimGetAttr, bool, attr} | |||
| // {prim::kPrimResolve, namespace, symbol} | |||
| class Resolver : public OptimizerCaller { | |||
| @@ -253,6 +253,15 @@ ValuePtr ConvertDataClass(const py::object &obj) { | |||
| return converted; | |||
| } | |||
| ValuePtr ConvertMsClass(const py::object &obj) { | |||
| MS_LOG(DEBUG) << "Converting ms class"; | |||
| // Convert class instance decorated with ms_class. | |||
| py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); | |||
| py::object name = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MS_CLASS_NAME, obj); | |||
| auto cls_name = py::cast<std::string>(name); | |||
| return std::make_shared<MsClassObject>(obj, cls_name); | |||
| } | |||
| ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) { | |||
| MS_LOG(DEBUG) << "Converting primitive object" << use_signature; | |||
| @@ -502,6 +511,7 @@ static const std::vector<DataConverterPtr> &GetDataConverters() { | |||
| std::make_shared<ByTypeDataConverter<py::ellipsis>>(kEllipsis), | |||
| std::make_shared<ByTypeDataConverter<py::module>>(ConvertModuleNameSpace), | |||
| std::make_shared<ByAttrDataConverter>(PYTHON_DATACLASS_FIELDS, ConvertDataClass), | |||
| std::make_shared<ByAttrDataConverter>(PYTHON_MS_CLASS, ConvertMsClass), | |||
| std::make_shared<ByTypeDataConverter<Type>>(ObjCast<TypePtr>), | |||
| std::make_shared<ByTypeDataConverter<UMonad>>(ObjCast<UMonadPtr>), | |||
| std::make_shared<ByTypeDataConverter<IOMonad>>(ObjCast<IOMonadPtr>), | |||
| @@ -67,6 +67,8 @@ const char PYTHON_MOD_CREATE_INSTANCE[] = "create_instance"; | |||
| const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create_instance_type"; | |||
| const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes"; | |||
| const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods"; | |||
| const char PYTHON_MOD_GET_MS_CLASS_NAME[] = "get_ms_class_name"; | |||
| const char PYTHON_MOD_GET_MS_CLASS_ATTR[] = "get_ms_class_attr"; | |||
| const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace"; | |||
| const char PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL[] = "get_class_attr_namespace_symbol"; | |||
| const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol"; | |||
| @@ -307,7 +307,7 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons | |||
| AnfNodePtr resolved_node = nullptr; | |||
| bool success = ResolveObjectToNode(node->func_graph(), obj, &resolved_node); | |||
| if (!success) { | |||
| MS_LOG(EXCEPTION) << "Parse Resolve covert failed NodeInfo."; | |||
| MS_LOG(EXCEPTION) << "Parse Resolve covert failed."; | |||
| } | |||
| if (IsValueNode<FuncGraph>(resolved_node)) { | |||
| auto new_fg = GetValueNode<FuncGraphPtr>(resolved_node); | |||
| @@ -465,6 +465,40 @@ bool IsResolveNodeWithGetItem(const AnfNodePtr &node) { | |||
| return false; | |||
| } | |||
| bool IsGetItemCNode(const AnfNodePtr &node) { | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| constexpr size_t getitem_inputs_size = 3; | |||
| if (cnode->size() != getitem_inputs_size) { | |||
| return false; | |||
| } | |||
| constexpr auto prim_index = 0; | |||
| return IsResolveNodeWithGetItem(cnode->input(prim_index)); | |||
| } | |||
| AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsClassObjectPtr &ms_class, | |||
| const std::string &attr, const AnfNodePtr &node) { | |||
| // Get attribute or method from ms_class obj. | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_LOG(DEBUG) << "Resolve ms_class obj (" << ms_class->name() << ") with attr " << attr << "."; | |||
| TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info())); | |||
| py::object cls_obj = ms_class->obj(); | |||
| if (!py::hasattr(cls_obj, attr.c_str())) { | |||
| MS_LOG(EXCEPTION) << ms_class->name() << " has not attribute: " << attr << "."; | |||
| } | |||
| const std::string fn = PYTHON_MOD_GET_MS_CLASS_ATTR; | |||
| const std::string module = "mindspore._extends.parse.parser"; | |||
| py::object attr_obj = python_adapter::GetPyFn(module, fn)(cls_obj, attr); | |||
| AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, node); | |||
| TraceManager::ClearParseOrResolveDebugInfo(); | |||
| return res_node; | |||
| } | |||
| namespace { | |||
| opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { | |||
| // For resolve and getattr primitive. | |||
| @@ -131,6 +131,18 @@ class InterpretedObject final : public PyObjectWrapper { | |||
| }; | |||
| using InterpretedObjectPtr = std::shared_ptr<InterpretedObject>; | |||
| class MsClassObject final : public PyObjectWrapper { | |||
| public: | |||
| explicit MsClassObject(const py::object &obj, const std::string &name = "ms class") | |||
| : PyObjectWrapper(obj, "MsClassObject: \'" + name + "\'") {} | |||
| ~MsClassObject() override = default; | |||
| MS_DECLARE_PARENT(MsClassObject, PyObjectWrapper); | |||
| abstract::AbstractBasePtr ToAbstract() override { | |||
| return std::make_shared<abstract::AbstractScalar>(shared_from_base<MsClassObject>(), std::make_shared<External>()); | |||
| } | |||
| }; | |||
| using MsClassObjectPtr = std::shared_ptr<MsClassObject>; | |||
| // ClassObject class wrappers dataclass | |||
| class ClassObject final : public PyObjectWrapper { | |||
| public: | |||
| @@ -168,8 +180,11 @@ AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::obj | |||
| AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj, | |||
| const AnfNodePtr &resolve_node, const AnfNodePtr &attr, | |||
| const CNodePtr &operand_cnode); | |||
| // Check if node is resolve node with getitem. | |||
| bool IsResolveNodeWithGetItem(const AnfNodePtr &node); | |||
| AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsClassObjectPtr &ms_class, | |||
| const std::string &attr, const AnfNodePtr &node); | |||
| // Check if node is cnode with getitem. | |||
| bool IsGetItemCNode(const AnfNodePtr &node); | |||
| // Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). | |||
| bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); | |||
| @@ -19,5 +19,6 @@ namespace mindspore { | |||
| const char PYTHON_PRIMITIVE_FLAG[] = "__primitive_flag__"; | |||
| const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__"; | |||
| const char PYTHON_DATACLASS_FIELDS[] = "__dataclass_fields__"; | |||
| const char PYTHON_MS_CLASS[] = "__ms_class__"; | |||
| const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__"; | |||
| } // namespace mindspore | |||
| @@ -22,6 +22,7 @@ namespace mindspore { | |||
| extern const char PYTHON_PRIMITIVE_FLAG[]; | |||
| extern const char PYTHON_CELL_AS_LIST[]; | |||
| extern const char PYTHON_DATACLASS_FIELDS[]; | |||
| extern const char PYTHON_MS_CLASS[]; | |||
| extern const char PYTHON_CLASS_MEMBER_NAMESPACE[]; | |||
| } // namespace mindspore | |||
| @@ -220,6 +220,12 @@ static ValueNameToConverterVector value_name_to_converter = { | |||
| auto class_type = value->cast<parse::ClassTypePtr>(); | |||
| return class_type->obj(); | |||
| }}, | |||
| // parse::MsClassObject | |||
| {parse::MsClassObject::kTypeId, | |||
| [](const ValuePtr &value) -> py::object { | |||
| auto ms_class_object = value->cast<parse::MsClassObjectPtr>(); | |||
| return ms_class_object->obj(); | |||
| }}, | |||
| // parse::InterpretedObject | |||
| {parse::InterpretedObject::kTypeId, | |||
| [](const ValuePtr &value) -> py::object { | |||
| @@ -23,7 +23,8 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type, | |||
| get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_symbol, | |||
| get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, eval_script, | |||
| expand_expr_statement, is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, | |||
| get_object_description, get_class_attr_namespace_symbol) | |||
| get_object_description, get_class_attr_namespace_symbol, get_ms_class_name, | |||
| get_ms_class_attr) | |||
| __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', | |||
| 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type', | |||
| @@ -32,4 +33,5 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', | |||
| 'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', | |||
| 'get_dataclass_attributes', 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name', | |||
| 'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement', | |||
| 'generate_scope', 'get_operation_symbol', 'get_class_attr_namespace_symbol'] | |||
| 'generate_scope', 'get_operation_symbol', 'get_class_attr_namespace_symbol', 'get_ms_class_name', | |||
| 'get_ms_class_attr'] | |||
| @@ -410,6 +410,30 @@ def get_dataclass_methods(cls): | |||
| return methods | |||
| def get_ms_class_name(cls): | |||
| """Get the name of the class instance decorated by ms_class.""" | |||
| # Check if cls is nn.Cell. | |||
| if isinstance(cls, nn.Cell): | |||
| raise TypeError(f"ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.") | |||
| if isinstance(cls, type): | |||
| name = cls.__name__ | |||
| else: | |||
| name = cls.__class__.__name__ | |||
| # Get the name of cls. | |||
| cls_name = cls.__module__ + '.' + name | |||
| return cls_name | |||
| def get_ms_class_attr(cls, name: str): | |||
| """Get attribute or method of ms_class obj.""" | |||
| # Don't take into account python magic methods and private variables. | |||
| if name.startswith('_'): | |||
| raise AttributeError(f"{name} is a private variable or magic method, which is not supported.") | |||
| if not hasattr(cls, name): | |||
| raise AttributeError(f"{cls} has no attribute: {name}.") | |||
| return getattr(cls, name) | |||
| def convert_to_ms_tensor(data): | |||
| """Convert C++ tensor to mindspore tensor.""" | |||
| return Tensor(data) | |||
| @@ -562,8 +586,8 @@ def eval_script(exp_str, params): | |||
| local_params = _convert_data(local_params) | |||
| obj = eval(exp_str, global_params, local_params) | |||
| except Exception as e: | |||
| error_info = f"When eval '{exp_str}' by using Fallback feature, an error occurred: " + str(e) + \ | |||
| ". You can try to turn off the Fallback feature by 'export MS_DEV_ENABLE_FALLBACK=0'." | |||
| error_info = f"When eval '{exp_str}' by using JIT Fallback feature, an error occurred: " + str(e) + \ | |||
| ". You can try to turn off JIT Fallback feature by 'export MS_DEV_ENABLE_FALLBACK=0'." | |||
| logger.error(error_info) | |||
| raise e | |||
| @@ -14,7 +14,7 @@ | |||
| # ============================================================================ | |||
| """Top-level reference to dtype of common module.""" | |||
| from . import dtype | |||
| from .api import ms_function, ms_memory_recycle, _convert_data | |||
| from .api import ms_function, ms_memory_recycle, ms_class, _convert_data | |||
| from .dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \ | |||
| uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \ | |||
| float32, single, float64, double, bool_, float_, list_, tuple_, int_, \ | |||
| @@ -54,7 +54,7 @@ __all__ = [ | |||
| __all__.extend([ | |||
| "Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor | |||
| 'ms_function', # api | |||
| 'ms_function', 'ms_class', # api | |||
| 'Parameter', 'ParameterTuple', # parameter | |||
| "dtype", "_convert_data", | |||
| "set_seed", "get_seed", # random seed | |||
| @@ -20,6 +20,7 @@ import sys | |||
| import os | |||
| import time | |||
| import ast | |||
| import inspect | |||
| import importlib | |||
| from collections import OrderedDict | |||
| from functools import wraps | |||
| @@ -439,12 +440,64 @@ def ms_function(fn=None, obj=None, input_signature=None): | |||
| return wrap_mindspore(fn) | |||
| return wrap_mindspore | |||
| def ms_class(cls): | |||
| """ | |||
| Class decorator for user-defined classes. | |||
| This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods. | |||
| Args: | |||
| cls (Class): User-defined class. | |||
| Returns: | |||
| Class with __ms_class__ attribute. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore.nn as nn | |||
| >>> from mindspore import ms_class | |||
| ... | |||
| >>> @ms_class | |||
| >>> class UserDefinedNet: | |||
| ... def __init__(self): | |||
| ... self.value = 10 | |||
| ... | |||
| ... def func(self, x): | |||
| ... return 2 * x | |||
| ... | |||
| >>> class Net(nn.Cell): | |||
| ... def __init__(self): | |||
| ... super(Net, self).__init__() | |||
| ... self.net = UserDefinedNet() | |||
| ... | |||
| ... def construct(self, x): | |||
| ... out = self.net.value + self.net.func(x) | |||
| ... return out | |||
| ... | |||
| >>> net = Net() | |||
| >>> out = net(5) | |||
| >>> print(out) | |||
| 20 | |||
| """ | |||
| # Check if cls is of type class. | |||
| if not inspect.isclass(cls): | |||
| raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.') | |||
| logger.info(f'Found ms_class: {cls}.') | |||
| setattr(cls, '__ms_class__', True) | |||
| return cls | |||
| def is_pynative_parallel(): | |||
| run_mode = context.get_context('mode') | |||
| parallel_mode = context.get_auto_parallel_context('parallel_mode') | |||
| return run_mode == context.PYNATIVE_MODE and parallel_mode in ( | |||
| context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL) | |||
| def _get_auto_split_param_names(parameter_layout_dict): | |||
| auto_split_param_names = [] | |||
| for key, value in parameter_layout_dict.items(): | |||
| @@ -899,4 +952,4 @@ def ms_memory_recycle(): | |||
| _cell_graph_executor = _CellGraphExecutor() | |||
| _pynative_executor = _PynativeExecutor() | |||
| __all__ = ['ms_function', 'ms_memory_recycle'] | |||
| __all__ = ['ms_function', 'ms_memory_recycle', 'ms_class'] | |||
| @@ -243,117 +243,6 @@ def test_scipy_module(): | |||
| print(out) | |||
| def test_self_attr(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test self.attr in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| class Network(nn.Cell): | |||
| def __init__(self): | |||
| super(Network, self).__init__() | |||
| self.dim = 1 | |||
| def construct(self, x): | |||
| batch = x.shape[0] | |||
| one = Tensor(np.ones([batch, self.dim]), mstype.float16) | |||
| return one * x | |||
| net = Network() | |||
| x = Tensor([1, 2], mstype.float32) | |||
| out = net(x) | |||
| print(out) | |||
| def test_self_attr_2(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test self.attr in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| class Network(nn.Cell): | |||
| def __init__(self, fn): | |||
| super(Network, self).__init__() | |||
| self.fn = fn | |||
| def construct(self): | |||
| x = np.array([1, 2, 3]) | |||
| y = np.array([3, 4, 5]) | |||
| out = Tensor(self.fn(x, y)) | |||
| return out | |||
| def fn(x, y): | |||
| return x + y | |||
| net = Network(fn) | |||
| out = net() | |||
| print(out) | |||
| def test_self_attr_3(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test self.attr in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| class Network(nn.Cell): | |||
| def __init__(self): | |||
| super(Network, self).__init__() | |||
| self.value = [2, 2, 3] | |||
| def construct(self): | |||
| x = np.array(self.value.count(2)) | |||
| return Tensor(x) | |||
| net = Network() | |||
| out = net() | |||
| print(out) | |||
| def test_self_method(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test self.method in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| class Network(nn.Cell): | |||
| def construct(self): | |||
| x = np.array([1, 2, 3]) | |||
| y = np.array([3, 4, 5]) | |||
| out = Tensor(self.fn(x, y)) | |||
| return out | |||
| def fn(self, x, y): | |||
| return x + y | |||
| net = Network() | |||
| out = net() | |||
| print(out) | |||
| @pytest.mark.skip(reason='Not support in graph jit fallback feature yet') | |||
| def test_self_method_2(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test self.method in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| class Network(nn.Cell): | |||
| def construct(self): | |||
| x = np.array([1, 2, 3]) | |||
| y = np.array([3, 4, 5]) | |||
| z = self.fn(x, y) | |||
| out = Tensor(z) | |||
| return out | |||
| def fn(self, x, y): | |||
| return x + y | |||
| net = Network() | |||
| out = net() | |||
| print(out) | |||
| def test_probability_cauchy(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| @@ -0,0 +1,398 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test graph fallback """ | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import Tensor, context, ms_class | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def test_fallback_self_attr(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test self.attr in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| class Network(nn.Cell): | |||
| def __init__(self): | |||
| super(Network, self).__init__() | |||
| self.dim = 1 | |||
| def construct(self, x): | |||
| batch = x.shape[0] | |||
| one = Tensor(np.ones([batch, self.dim]), mstype.float32) | |||
| return one * x | |||
| net = Network() | |||
| x = Tensor([1, 2], mstype.float32) | |||
| out = net(x) | |||
| expect = np.array([[1., 2.], [1., 2.]]) | |||
| assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2) | |||
| def test_fallback_self_attr_fn(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test self.attr in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| class Network(nn.Cell): | |||
| def __init__(self, fn): | |||
| super(Network, self).__init__() | |||
| self.fn = fn | |||
| def construct(self): | |||
| x = np.array([1, 2, 3]) | |||
| y = np.array([3, 4, 5]) | |||
| out = Tensor(self.fn(x, y)) | |||
| return out | |||
| def fn(x, y): | |||
| return x + y | |||
| net = Network(fn) | |||
| out = net() | |||
| expect = np.array([4, 6, 8]) | |||
| assert np.all(out.asnumpy() == expect) | |||
| def test_fallback_self_attr_attr(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test self.attr in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| class Network(nn.Cell): | |||
| def __init__(self): | |||
| super(Network, self).__init__() | |||
| self.value = [2, 2, 3] | |||
| def construct(self): | |||
| x = np.array(self.value.count(2)) | |||
| return Tensor(x) | |||
| net = Network() | |||
| out = net() | |||
| assert out == 2 | |||
| def test_fallback_self_method(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test self.method in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| class Network(nn.Cell): | |||
| def construct(self): | |||
| x = np.array([1, 2, 3]) | |||
| y = np.array([3, 4, 5]) | |||
| out = Tensor(self.fn(x, y)) | |||
| return out | |||
| def fn(self, x, y): | |||
| return x + y | |||
| net = Network() | |||
| out = net() | |||
| expect = np.array([4, 6, 8]) | |||
| assert np.all(out.asnumpy() == expect) | |||
| @pytest.mark.skip(reason='Not support in graph jit fallback feature yet') | |||
| def test_fallback_self_method_tensor(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test self.method in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| class Network(nn.Cell): | |||
| def construct(self): | |||
| x = np.array([1, 2, 3]) | |||
| y = np.array([3, 4, 5]) | |||
| z = self.fn(x, y) | |||
| out = Tensor(z) | |||
| return out | |||
| def fn(self, x, y): | |||
| return x + y | |||
| net = Network() | |||
| out = net() | |||
| print(out) | |||
| def test_fallback_class_attr(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test user-defined class attributes in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| @ms_class | |||
| class InnerNet: | |||
| def __init__(self): | |||
| self.number = 1 | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.inner_net = InnerNet() | |||
| def construct(self): | |||
| out = self.inner_net.number | |||
| return out | |||
| net = Net() | |||
| out = net() | |||
| assert out == 1 | |||
| def test_fallback_class_method(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test user-defined class methods in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| @ms_class | |||
| class InnerNet: | |||
| def __init__(self): | |||
| self.val = 2 | |||
| def act(self, x, y): | |||
| return self.val * (x + y) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.inner_net = InnerNet() | |||
| def construct(self): | |||
| out = self.inner_net.act(1, 2) | |||
| return out | |||
| net = Net() | |||
| out = net() | |||
| assert out == 6 | |||
| def test_fallback_class_input_attr(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test user-defined class attributes in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| @ms_class | |||
| class InnerNet: | |||
| def __init__(self): | |||
| self.number = Tensor(np.array([1, 2, 3])) | |||
| class Net(nn.Cell): | |||
| def __init__(self, net): | |||
| super(Net, self).__init__() | |||
| self.inner_net = net() | |||
| def construct(self): | |||
| out = self.inner_net.number | |||
| return out | |||
| net = Net(InnerNet) | |||
| out = net() | |||
| expect_res = np.array([1, 2, 3]) | |||
| assert np.all(out.asnumpy() == expect_res) | |||
| def test_fallback_class_input_method(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test user-defined class methods in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| @ms_class | |||
| class InnerNet: | |||
| def __init__(self): | |||
| self.val = 2 | |||
| def act(self, x, y): | |||
| return self.val * (x + y) | |||
| class Net(nn.Cell): | |||
| def __init__(self, net): | |||
| super(Net, self).__init__() | |||
| self.inner_net = net() | |||
| def construct(self): | |||
| out = self.inner_net.act(1, 2) | |||
| return out | |||
| net = Net(InnerNet) | |||
| out = net() | |||
| assert out == 6 | |||
| def test_fallback_class_class_nested(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test nested ms_class in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| @ms_class | |||
| class Inner: | |||
| def __init__(self): | |||
| self.number = 1 | |||
| @ms_class | |||
| class InnerNet: | |||
| def __init__(self): | |||
| self.inner = Inner() | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.inner_net = InnerNet() | |||
| def construct(self): | |||
| out = self.inner_net.inner.number | |||
| return out | |||
| net = Net() | |||
| out = net() | |||
| assert out == 1 | |||
| def test_fallback_class_cell_nested(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test nested ms_class and cell in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| class Net(nn.Cell): | |||
| def __init__(self, val): | |||
| super().__init__() | |||
| self.val = val | |||
| def construct(self, x): | |||
| return x + self.val | |||
| @ms_class | |||
| class TrainNet(): | |||
| class Loss(nn.Cell): | |||
| def __init__(self, net): | |||
| super().__init__() | |||
| self.net = net | |||
| def construct(self, x): | |||
| out = self.net(x) | |||
| return out * 2 | |||
| def __init__(self, net): | |||
| self.net = net | |||
| loss_net = self.Loss(self.net) | |||
| self.number = loss_net(10) | |||
| global_net = Net(1) | |||
| class LearnNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.value = TrainNet(global_net).number | |||
| def construct(self, x): | |||
| return x + self.value | |||
| leanrn_net = LearnNet() | |||
| out = leanrn_net(3) | |||
| print(out) | |||
| assert out == 25 | |||
| @pytest.mark.skip(reason='Not support in graph yet') | |||
| def test_fallback_class_isinstance(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test ms_class in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| @ms_class | |||
| class InnerNet: | |||
| def __init__(self): | |||
| self.number = 1 | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.inner_net = InnerNet() | |||
| def construct(self, x): | |||
| if isinstance(self.inner_net, InnerNet): | |||
| return x + 10 | |||
| return x | |||
| net = Net() | |||
| out = net(5) | |||
| assert out == 15 | |||
| def test_fallback_raise_error_not_class_type(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test ms_class in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| with pytest.raises(TypeError): | |||
| @ms_class | |||
| def func(x, y): | |||
| return x + y | |||
| func(1, 2) | |||
| def test_fallback_raise_error_not_class_instance(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test ms_class in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| @ms_class | |||
| class InnerNet: | |||
| def __init__(self): | |||
| self.number = 1 | |||
| class Net(nn.Cell): | |||
| def construct(self): | |||
| out = InnerNet().number | |||
| return out | |||
| with pytest.raises(ValueError): | |||
| net = Net() | |||
| net() | |||
| def test_fallback_raise_error_decorate_cell(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test ms_class in graph. | |||
| Expectation: No exception. | |||
| """ | |||
| @ms_class | |||
| class Net(nn.Cell): | |||
| def construct(self, x): | |||
| return x | |||
| with pytest.raises(TypeError): | |||
| x = Tensor(1) | |||
| net = Net() | |||
| net(x) | |||