From: @zh_qh Reviewed-by: @ginfung,@hwhewei Signed-off-by: @hwheweipull/15871/MERGE
| @@ -278,16 +278,42 @@ def _is_dataclass_instance(obj): | |||
| return is_dataclass(obj) and not isinstance(obj, type) | |||
| def create_obj_instance(cls_type, args_tuple=None): | |||
| def _convert_tuple_to_args_kwargs(params): | |||
| args = tuple() | |||
| kwargs = dict() | |||
| for param in params: | |||
| if isinstance(param, dict): | |||
| kwargs.update(param) | |||
| else: | |||
| args += (param,) | |||
| return (args, kwargs) | |||
| def create_obj_instance(cls_type, params=None): | |||
| """Create python instance.""" | |||
| if not isinstance(cls_type, type): | |||
| logger.warning(f"create_obj_instance(), cls_type is not a type, cls_type: {cls_type}") | |||
| return None | |||
| # Check the type, now only support nn.Cell and Primitive. | |||
| obj = None | |||
| if isinstance(cls_type, type): | |||
| # check the type, now only support nn.Cell and Primitive | |||
| if issubclass(cls_type, (nn.Cell, ops.Primitive)): | |||
| if args_tuple is not None: | |||
| obj = cls_type(*args_tuple) | |||
| else: | |||
| obj = cls_type() | |||
| if issubclass(cls_type, (nn.Cell, ops.Primitive)): | |||
| # Check arguments, only support *args or **kwargs. | |||
| if params is None: | |||
| obj = cls_type() | |||
| elif isinstance(params, tuple): | |||
| args, kwargs = _convert_tuple_to_args_kwargs(params) | |||
| logger.debug(f"create_obj_instance(), args: {args}, kwargs: {kwargs}") | |||
| if args and kwargs: | |||
| obj = cls_type(*args, **kwargs) | |||
| elif args: | |||
| obj = cls_type(*args) | |||
| elif kwargs: | |||
| obj = cls_type(**kwargs) | |||
| # If invalid parameters. | |||
| if obj is None: | |||
| raise ValueError(f"When call 'create_instance', the parameter should be *args or **kwargs, " | |||
| f"but got {params.__class__.__name__}, params: {params}") | |||
| return obj | |||
| @@ -481,7 +481,7 @@ std::vector<DataConverterPtr> GetDataConverters() { | |||
| } // namespace | |||
| bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) { | |||
| // check parameter valid | |||
| // Check parameter valid | |||
| if (data == nullptr) { | |||
| MS_LOG(ERROR) << "Data is null pointer"; | |||
| return false; | |||
| @@ -503,7 +503,7 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature | |||
| return converted != nullptr; | |||
| } | |||
| // convert data to graph | |||
| // Convert data to graph | |||
| FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) { | |||
| std::vector<std::string> results = data_converter::GetObjKey(obj); | |||
| std::string obj_id = results[0] + python_mod_get_parse_method; | |||
| @@ -565,7 +565,7 @@ std::vector<std::string> GetObjKey(const py::object &obj) { | |||
| return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])}; | |||
| } | |||
| // get obj detail type | |||
| // Get obj detail type | |||
| ResolveTypeDef GetObjType(const py::object &obj) { | |||
| py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); | |||
| auto obj_type = | |||
| @@ -573,7 +573,7 @@ ResolveTypeDef GetObjType(const py::object &obj) { | |||
| return obj_type; | |||
| } | |||
| // get class instance detail type | |||
| // Get class instance detail type. | |||
| ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) { | |||
| py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); | |||
| auto class_type = | |||
| @@ -581,26 +581,27 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) { | |||
| return class_type; | |||
| } | |||
| // check the object is Cell Instance | |||
| // Check the object is Cell Instance. | |||
| bool IsCellInstance(const py::object &obj) { | |||
| auto class_type = GetClassInstanceType(obj); | |||
| bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL); | |||
| return isCell; | |||
| } | |||
| // create the python class instance | |||
| py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) { | |||
| // Create the python class instance. | |||
| py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) { | |||
| py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); | |||
| return params.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type) | |||
| : python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params); | |||
| // `args_kwargs` maybe a tuple(*args), tuple(**kwargs), or tuple(*args, **kwargs). | |||
| return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type) | |||
| : python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, args_kwargs); | |||
| } | |||
| // Generate an appropriate name and set to graph debuginfo | |||
| // character <> can not used in the dot file, so change to another symbol | |||
| // Generate an appropriate name and set to graph debuginfo, | |||
| // character <> can not used in the dot file, so change to another symbol. | |||
| void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(func_graph->debug_info()); | |||
| // set detail name info of function | |||
| // Set detail name info of function | |||
| std::ostringstream oss; | |||
| for (size_t i = 0; i < name.size(); i++) { | |||
| if (name[i] == '<') { | |||
| @@ -629,7 +630,7 @@ void ClearObjectCache() { | |||
| static std::unordered_map<std::string, ClassPtr> g_dataClassToClass = {}; | |||
| // parse dataclass to mindspore Class type | |||
| // Parse dataclass to mindspore Class type | |||
| ClassPtr ParseDataClass(const py::object &cls_obj) { | |||
| std::string cls_name = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__name__")); | |||
| std::string cls_module = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__module__")); | |||
| @@ -44,7 +44,7 @@ ResolveTypeDef GetObjType(const py::object &obj); | |||
| ClassInstanceTypeDef GetClassInstanceType(const py::object &obj); | |||
| bool IsCellInstance(const py::object &obj); | |||
| py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms); | |||
| py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs); | |||
| void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name); | |||
| ValuePtr PyDataToValue(const py::object &obj); | |||
| void ClearObjectCache(); | |||
| @@ -1107,7 +1107,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { | |||
| MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty"; | |||
| } | |||
| // get the type parameter | |||
| // Get the type parameter. | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| TypePtr type = args_spec_list[0]->GetTypeTrack(); | |||
| if (type->type_id() != kMetaTypeTypeType) { | |||
| @@ -1131,17 +1131,17 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { | |||
| auto class_type = type_obj->obj(); | |||
| MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << "."; | |||
| // get the create instance obj's parameters | |||
| pybind11::tuple params = GetParameters(args_spec_list); | |||
| // Get the create instance obj's parameters, `params` may contain tuple(args, kwargs). | |||
| py::tuple params = GetParameters(args_spec_list); | |||
| // create class instance | |||
| // Create class instance. | |||
| auto obj = parse::data_converter::CreatePythonObject(class_type, params); | |||
| if (py::isinstance<py::none>(obj)) { | |||
| MS_LOG(EXCEPTION) << "Create python object" << py::str(class_type) | |||
| << " failed, only support create Cell or Primitive object."; | |||
| } | |||
| // process the object | |||
| // Process the object. | |||
| ValuePtr converted_ret = nullptr; | |||
| bool converted = parse::ConvertData(obj, &converted_ret, true); | |||
| if (!converted) { | |||
| @@ -167,6 +167,14 @@ py::object ValuePtrToPyData(const ValuePtr &value) { | |||
| } else if (value->isa<AnyValue>() || value->isa<None>() || value->isa<Monad>() || value->isa<FuncGraph>()) { | |||
| // FuncGraph is not used in the backend, return None | |||
| ret = py::none(); | |||
| } else if (value->isa<KeywordArg>()) { | |||
| auto abs_keyword_arg = value->ToAbstract()->cast<abstract::AbstractKeywordArgPtr>(); | |||
| auto key = abs_keyword_arg->get_key(); | |||
| auto val = abs_keyword_arg->get_arg()->BuildValue(); | |||
| auto py_value = ValuePtrToPyData(val); | |||
| auto kwargs = py::kwargs(); | |||
| kwargs[key.c_str()] = py_value; | |||
| ret = kwargs; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData."; | |||
| } | |||
| @@ -27,7 +27,7 @@ import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common import Tensor, Parameter | |||
| from mindspore.ops import operations as P | |||
| from ...ut_filter import non_graph_engine | |||
| @@ -48,7 +48,7 @@ class Net(nn.Cell): | |||
| return x | |||
| # Test: creat CELL OR Primitive instance on construct | |||
| # Test: Create Cell OR Primitive instance on construct | |||
| @non_graph_engine | |||
| def test_create_cell_object_on_construct(): | |||
| """ test_create_cell_object_on_construct """ | |||
| @@ -65,7 +65,7 @@ def test_create_cell_object_on_construct(): | |||
| log.debug("finished test_create_object_on_construct") | |||
| # Test: creat CELL OR Primitive instance on construct | |||
| # Test: Create Cell OR Primitive instance on construct | |||
| class Net1(nn.Cell): | |||
| """ Net1 definition """ | |||
| @@ -92,7 +92,7 @@ def test_create_primitive_object_on_construct(): | |||
| log.debug("finished test_create_object_on_construct") | |||
| # Test: creat CELL OR Primitive instance on construct use many parameter | |||
| # Test: Create Cell OR Primitive instance on construct use many parameter | |||
| class NetM(nn.Cell): | |||
| """ NetM definition """ | |||
| @@ -120,7 +120,7 @@ class NetC(nn.Cell): | |||
| return x | |||
| # Test: creat CELL OR Primitive instance on construct | |||
| # Test: Create Cell OR Primitive instance on construct | |||
| @non_graph_engine | |||
| def test_create_cell_object_on_construct_use_many_parameter(): | |||
| """ test_create_cell_object_on_construct_use_many_parameter """ | |||
| @@ -135,3 +135,60 @@ def test_create_cell_object_on_construct_use_many_parameter(): | |||
| print(np1) | |||
| print(out_me1) | |||
| log.debug("finished test_create_object_on_construct") | |||
| class NetD(nn.Cell): | |||
| """ NetD definition """ | |||
| def __init__(self): | |||
| super(NetD, self).__init__() | |||
| def construct(self, x, y): | |||
| concat = P.Concat(axis=1) | |||
| return concat((x, y)) | |||
| # Test: Create Cell OR Primitive instance on construct | |||
| @non_graph_engine | |||
| def test_create_primitive_object_on_construct_use_kwargs(): | |||
| """ test_create_primitive_object_on_construct_use_kwargs """ | |||
| log.debug("begin test_create_primitive_object_on_construct_use_kwargs") | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32)) | |||
| y = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32)) | |||
| net = NetD() | |||
| net(x, y) | |||
| log.debug("finished test_create_primitive_object_on_construct_use_kwargs") | |||
| class NetE(nn.Cell): | |||
| """ NetE definition """ | |||
| def __init__(self): | |||
| super(NetE, self).__init__() | |||
| self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w') | |||
| def construct(self, x): | |||
| out_channel = 16 | |||
| kernel_size = 3 | |||
| conv2d = P.Conv2D(out_channel, | |||
| kernel_size, | |||
| 1, | |||
| pad_mode='valid', | |||
| pad=0, | |||
| stride=1, | |||
| dilation=1, | |||
| group=1) | |||
| return conv2d(x, self.w) | |||
| # Test: Create Cell OR Primitive instance on construct | |||
| @non_graph_engine | |||
| def test_create_primitive_object_on_construct_use_args_and_kwargs(): | |||
| """ test_create_primitive_object_on_construct_use_args_and_kwargs """ | |||
| log.debug("begin test_create_primitive_object_on_construct_use_args_and_kwargs") | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| inputs = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32)) | |||
| net = NetE() | |||
| net(inputs) | |||
| log.debug("finished test_create_primitive_object_on_construct_use_args_and_kwargs") | |||