diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index fcad535cc9..71392af278 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -38,6 +38,79 @@ using TensorPtr = mindspore::tensor::TensorPtr; using MetaTensor = mindspore::tensor::MetaTensor; using MetaTensorPtr = mindspore::tensor::MetaTensorPtr; +using InstanceCheckFunc = std::function; +using InstanceConvertFunc = std::function; +class DataConverter { + public: + explicit DataConverter(InstanceConvertFunc convert_func) : convert_func_(std::move(convert_func)) {} + virtual ~DataConverter() = default; + virtual bool Matched(const py::object &obj) = 0; + virtual ValuePtr ConvertPyObject(const py::object &obj, bool use_sig, const TypePtr &dtype) { + if (convert_func_ == nullptr) { + MS_LOG(EXCEPTION) << "convert func is null"; + } + return convert_func_(obj, use_sig, dtype); + } + + private: + InstanceConvertFunc convert_func_ = nullptr; +}; +using DataConverterPtr = std::shared_ptr; + +using ArgsObjConvertFunc = std::function; +using ArgsObjSigConvertFunc = std::function; +using ArgsOjbTypeConvertFunc = std::function; + +// Convert the data according instance type +template +class ByTypeDataConverter : public DataConverter { + public: + explicit ByTypeDataConverter(const InstanceConvertFunc &convert_func) + : DataConverter(convert_func), check_func_(py::isinstance) {} + explicit ByTypeDataConverter(const ValuePtr &converted_type) + : DataConverter( + [converted_type](const py::object &, bool, const TypePtr &) -> ValuePtr { return converted_type; }), + check_func_(py::isinstance) {} + explicit ByTypeDataConverter(const ArgsObjConvertFunc &convert_func) + : DataConverter( + [convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }), + check_func_(py::isinstance) {} + explicit ByTypeDataConverter(const ArgsObjSigConvertFunc &convert_func) + : DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr { + return convert_func(obj, use_sig); + }), + check_func_(py::isinstance) {} + explicit ByTypeDataConverter(const ArgsOjbTypeConvertFunc &convert_func) + : DataConverter([convert_func](const py::object &obj, bool, const TypePtr &dtype) -> ValuePtr { + return convert_func(obj, dtype); + }), + check_func_(py::isinstance) {} + ~ByTypeDataConverter() override = default; + + bool Matched(const py::object &obj) override { return check_func_ != nullptr ? check_func_(obj) : false; } + + private: + InstanceCheckFunc check_func_ = nullptr; +}; + +// Convert the data according object attribute. +class ByAttrDataConverter : public DataConverter { + public: + ByAttrDataConverter(const char *attr_name, const ArgsObjConvertFunc &convert_func) + : DataConverter( + [convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }), + attr_name_(attr_name) {} + ByAttrDataConverter(const char *attr_name, const ArgsObjSigConvertFunc &convert_func) + : DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr { + return convert_func(obj, use_sig); + }), + attr_name_(attr_name) {} + bool Matched(const py::object &obj) override { return py::hasattr(obj, attr_name_); } + + private: + const char *attr_name_ = nullptr; +}; + FuncGraphPtr ConvertToBpropCut(const py::object &obj) { std::vector results = data_converter::GetObjKey(obj); std::string obj_key = results[0]; @@ -69,7 +142,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) { } namespace { -bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { +ValuePtr ConvertTuple(const py::object &obj, bool use_signature) { MS_LOG(DEBUG) << "Converting python tuple"; auto tuple = obj.cast(); std::vector value_list; @@ -77,16 +150,14 @@ bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signatur ValuePtr out = nullptr; bool success = ConvertData(tuple[it], &out, use_signature); if (!success) { - return false; + return nullptr; } value_list.push_back(out); } - *data = std::make_shared(value_list); - - return true; + return std::make_shared(value_list); } -bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) { +ValuePtr ConvertList(const py::object &obj, bool use_signature) { MS_LOG(DEBUG) << "Converting python list"; auto list = obj.cast(); @@ -95,15 +166,14 @@ bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature ValuePtr out = nullptr; bool success = ConvertData(list[it], &out, use_signature); if (!success) { - return false; + return nullptr; } value_list.push_back(out); } - *data = std::make_shared(value_list); - return true; + return std::make_shared(value_list); } -bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) { +ValuePtr ConvertCellList(const py::object &obj, bool use_signature) { MS_LOG(DEBUG) << "Converting cell list"; py::sequence list = obj; std::vector value_list; @@ -111,15 +181,14 @@ bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signa ValuePtr out = nullptr; bool success = ConvertData(list[it], &out, use_signature); if (!success) { - return false; + return nullptr; } value_list.push_back(out); } - *data = std::make_shared(value_list); - return true; + return std::make_shared(value_list); } -bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { +ValuePtr ConvertDict(const py::object &obj, bool use_signature) { MS_LOG(DEBUG) << "Converting python dict"; auto dict_values = obj.cast(); @@ -127,36 +196,37 @@ bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { for (auto item : dict_values) { if (!py::isinstance(item.first)) { MS_LOG(ERROR) << "The key of dict is only support str."; - return false; + return nullptr; } std::string key = py::str(item.first); ValuePtr out = nullptr; bool success = ConvertData(dict_values[item.first], &out, use_signature); if (!success) { - return false; + return nullptr; } key_values.emplace_back(key, out); } - *data = std::make_shared(key_values); - return true; + return std::make_shared(key_values); } -void ConvertNameSpace(const py::object &obj, ValuePtr *const data) { +ValuePtr ConvertNameSpace(const py::object &obj) { MS_LOG(DEBUG) << "Converting python module"; py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj); - *data = std::make_shared(RESOLVE_NAMESPACE_NAME_MODULE, py::cast(module_namespace)); + auto converted = std::make_shared(RESOLVE_NAMESPACE_NAME_MODULE, py::cast(module_namespace)); + return converted; } -void ConvertDataClass(py::object obj, ValuePtr *const data) { +ValuePtr ConvertDataClass(const py::object &obj) { MS_LOG(DEBUG) << "Converting dataclass"; // Maybe the obj is dataclass define auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); // desc has format "", strip the '<' and '>' by offset 1; - *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); + auto converted = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); + return converted; } -bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) { +ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) { MS_LOG(DEBUG) << "Converting primitive object" << use_signature; // need check the primitive is class type or instance @@ -164,96 +234,81 @@ bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); // desc has format "", strip the '<' and '>' by offset 1; - *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); - } else { - auto primitive = obj.cast(); - if (primitive == nullptr) { - MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null"; - return false; - } - if (py::hasattr(obj, "__setattr_flag__")) { - if (py::hasattr(obj, "_clone")) { - auto clone_fn = obj.attr("_clone"); - py::object new_obj = clone_fn(); - primitive = new_obj.cast(); - } - } - if (use_signature) { - *data = std::make_shared(primitive->name(), primitive); - } else { - *data = primitive; - } - MS_LOG(DEBUG) << "Converting primitive object ok " << (*data)->ToString(); + return std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); + } + auto primitive = obj.cast(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null"; + return nullptr; } - return true; + if (py::hasattr(obj, "__setattr_flag__") && py::hasattr(obj, "_clone")) { + auto clone_fn = obj.attr("_clone"); + py::object new_obj = clone_fn(); + primitive = new_obj.cast(); + } + if (use_signature) { + return std::make_shared(primitive->name(), primitive); + } + return primitive; } -bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) { +ValuePtr ConvertMetaFuncGraph(const py::object &obj, bool use_signature = false) { MS_LOG(DEBUG) << "Converting MetaFuncGraph object"; auto meta = obj.cast(); if (meta == nullptr) { MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null"; - return false; + return nullptr; } if (use_signature) { - *data = std::make_shared(meta->name(), meta); - } else { - *data = meta; + return std::make_shared(meta->name(), meta); } - return true; + return meta; } -bool ConvertFuncGraph(const py::object &obj, ValuePtr *const data) { +ValuePtr ConvertFuncGraph(const py::object &obj) { MS_LOG(DEBUG) << "Converting FuncGraph object"; auto func_graph = obj.cast(); if (func_graph == nullptr) { MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null"; - return false; + return nullptr; } auto new_fg = BasicClone(func_graph); new_fg->set_attr("is_load", MakeValue(true)); - *data = new_fg; - return true; + return new_fg; } -bool ConvertSlice(const py::object &obj, ValuePtr *const data) { +ValuePtr ConvertSlice(const py::object &obj) { MS_LOG(DEBUG) << "Converting slice object"; auto slice_obj = obj.cast(); - auto convert_func = [obj](std::string attr) -> ValuePtr { + auto convert_func = [obj](const std::string &attr) -> ValuePtr { auto py_attr = py::getattr(obj, attr.c_str()); if (py::isinstance(py_attr)) { return kNone; - } else if (py::isinstance(py_attr)) { - int64_t value = py::cast(py_attr); + } + if (py::isinstance(py_attr)) { + auto value = py::cast(py_attr); return MakeValue(value); - } else { - MS_LOG(EXCEPTION) << "Slice should contain only int64_t or none"; } + MS_LOG(EXCEPTION) << "Slice should contain only int64_t or none"; }; ValuePtr start = convert_func("start"); ValuePtr stop = convert_func("stop"); ValuePtr step = convert_func("step"); - *data = std::make_shared(start, stop, step); - return true; + return std::make_shared(start, stop, step); } -bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) { - auto obj = py::cast(cell); +ValuePtr ConvertCellObjToFuncGraph(const py::object &obj) { FuncGraphPtr func_graph = ConvertToFuncGraph(obj); if (func_graph == nullptr) { MS_LOG(ERROR) << "Parse resolve function error."; - return false; + return nullptr; } // if the cell object has specified bprop, it has user-defined bprop function parse and record it if (py::hasattr(obj, CUSTOM_BPROP_NAME)) { - FuncGraphPtr bprop_graph = nullptr; bool enable_bprop_debug = py::cast(py::getattr(obj, "bprop_debug")); - if (enable_bprop_debug) { - bprop_graph = ConvertToBpropCut(obj); - } else { - bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); - } + FuncGraphPtr bprop_graph = + enable_bprop_debug ? ConvertToBpropCut(obj) : ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); if (bprop_graph != nullptr) { (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); @@ -264,200 +319,183 @@ bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) { auto stage = py::cast(py::getattr(obj, STAGE_NAME)); func_graph->set_stage(stage); } - *data = func_graph; - return true; + return func_graph; } -bool ConvertOtherObj(py::object obj, ValuePtr *const data) { +ValuePtr ConvertOtherObj(const py::object &obj) { auto obj_type = data_converter::GetObjType(obj); MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { MS_LOG(DEBUG) << "Resolve the class type, need create class instance."; std::string desc = py::str(obj); // desc has format "", strip the '<' and '>' by offset 1; - *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); - return true; + return std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); } if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) { MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type; FuncGraphPtr func_graph = ConvertToFuncGraph(obj); if (func_graph == nullptr) { MS_LOG(ERROR) << "Parse resolve function error."; - return false; + return nullptr; } - *data = func_graph; - return true; + return func_graph; } if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) { // Create the namespace for common class instance // When the obj is Cell, default parse the 'construct' py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); - *data = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); - return true; + return std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); } MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj)); - return false; + return nullptr; } template -bool ConvertNumberWithType(const T &obj, ValuePtr *const data, TypePtr dtype) { +ValuePtr ConvertNumberWithType(const T &obj, TypePtr dtype) { + ValuePtr data = nullptr; auto int_dypte = dyn_cast(dtype); if (int_dypte != nullptr) { switch (int_dypte->nbits()) { case 8: - *data = std::make_shared(obj); + data = std::make_shared(obj); break; case 16: - *data = std::make_shared(obj); + data = std::make_shared(obj); break; case 32: - *data = std::make_shared(obj); + data = std::make_shared(obj); break; case 64: - *data = std::make_shared(obj); + data = std::make_shared(obj); break; default: - *data = std::make_shared(obj); + data = std::make_shared(obj); } - return true; + return data; } auto uint_dypte = dyn_cast(dtype); if (uint_dypte != nullptr) { switch (uint_dypte->nbits()) { case 8: - *data = std::make_shared(obj); + data = std::make_shared(obj); break; case 16: - *data = std::make_shared(obj); + data = std::make_shared(obj); break; case 32: - *data = std::make_shared(obj); + data = std::make_shared(obj); break; case 64: - *data = std::make_shared(obj); + data = std::make_shared(obj); break; default: - *data = std::make_shared(obj); + data = std::make_shared(obj); } - return true; + return data; } auto float_dypte = dyn_cast(dtype); if (float_dypte != nullptr) { switch (float_dypte->nbits()) { case 32: - *data = std::make_shared(obj); + data = std::make_shared(obj); break; case 64: - *data = std::make_shared(obj); + data = std::make_shared(obj); break; default: - *data = std::make_shared(obj); + data = std::make_shared(obj); } - return true; + return data; } - - return false; + return nullptr; } -bool ConvertIntegerWithType(const int64_t &obj, ValuePtr *const data, TypePtr dtype = nullptr) { +ValuePtr ConvertIntegerWithType(const py::object &obj, const TypePtr &dtype = nullptr) { + auto obj_int64 = py::cast(obj); if (dtype == nullptr) { - *data = std::make_shared(obj); - return true; + return std::make_shared(obj_int64); } - - return ConvertNumberWithType(obj, data, dtype); + return ConvertNumberWithType(obj_int64, dtype); } -bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype = nullptr) { +ValuePtr ConvertFloatWithType(const py::object &obj, const TypePtr &dtype = nullptr) { + auto obj_float64 = py::cast(obj); if (dtype == nullptr) { - *data = std::make_shared(obj); - return true; + return std::make_shared(obj_float64); } + return ConvertNumberWithType(obj_float64, dtype); +} - return ConvertNumberWithType(obj, data, dtype); +template +ValuePtr PyCast(const py::object &obj) { + return std::make_shared(py::cast(obj)); } -} // namespace -bool ConvertSingleData(const py::object &obj, ValuePtr *const data) { - MS_EXCEPTION_IF_NULL(data); - ValuePtr converted = nullptr; - if (py::isinstance(obj)) { - converted = kNone; - } else if (py::isinstance(obj)) { - converted = std::make_shared(py::cast(obj)); - } else if (py::isinstance(obj)) { - converted = std::make_shared(py::cast(obj)); - } else if (py::isinstance(obj)) { - converted = kEllipsis; - } else if (py::isinstance(obj)) { - ConvertNameSpace(obj, &converted); - } else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) { - ConvertDataClass(obj, &converted); - } else if (py::isinstance(obj)) { - converted = obj.cast(); - } else if (py::isinstance(obj)) { - converted = obj.cast(); - } else if (py::isinstance(obj)) { - converted = obj.cast(); - } else if (py::isinstance(obj)) { - converted = obj.cast(); - } else if (py::isinstance(obj)) { - converted = obj.cast(); - } else if (py::isinstance(obj)) { - auto env = obj.cast>(); - converted = env; - } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) { - converted = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj); - } else { - return false; - } - *data = converted; - return true; +template +ValuePtr ObjCast(const py::object &obj) { + return obj.cast(); +} + +std::vector GetDataConverters() { + static std::vector data_converters = { + // Convert data by python object type. + std::make_shared>(kNone), + std::make_shared>(PyCast), + std::make_shared>(PyCast), + std::make_shared>(kEllipsis), + std::make_shared>(ConvertNameSpace), + std::make_shared(PYTHON_DATACLASS_FIELDS, ConvertDataClass), + std::make_shared>(ObjCast), + std::make_shared>(ObjCast), + std::make_shared>(ObjCast), + std::make_shared>(ObjCast), + std::make_shared>(ObjCast), + std::make_shared>(ObjCast>), + std::make_shared(PYTHON_CLASS_MEMBER_NAMESPACE, + [](const py::object &obj) -> ValuePtr { + return std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, + obj); + }), + std::make_shared>(ConvertIntegerWithType), + std::make_shared>(ConvertFloatWithType), + std::make_shared>(ConvertDict), + std::make_shared>(ConvertSlice), + std::make_shared>(ConvertTuple), + std::make_shared(PYTHON_CELL_AS_LIST, ConvertCellList), + std::make_shared>(ConvertCellObjToFuncGraph), + std::make_shared>(ConvertList), + std::make_shared(PYTHON_PRIMITIVE_FLAG, ConvertPrimitive), + std::make_shared>(ConvertMetaFuncGraph), + std::make_shared>(ConvertFuncGraph), + }; + return data_converters; } +} // namespace -bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) { +bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) { // check parameter valid if (data == nullptr) { MS_LOG(ERROR) << "Data is null pointer"; return false; } - ValuePtr converted = nullptr; - bool ret = ConvertSingleData(obj, &converted); - if (ret) { - *data = converted; - return true; + bool matched = false; + auto &&converters = GetDataConverters(); + for (auto &converter : converters) { + if (converter->Matched(obj)) { + converted = converter->ConvertPyObject(obj, use_signature, dtype); + matched = true; + break; + } } - if (py::isinstance(obj)) { - ret = ConvertIntegerWithType(py::cast(obj), &converted, dtype); - } else if (py::isinstance(obj)) { - ret = ConvertFloatWithType(py::cast(obj), &converted, dtype); - } else if (py::isinstance(obj)) { - ret = ConvertDict(obj, &converted, use_signature); - } else if (py::isinstance(obj)) { - ret = ConvertSlice(obj, &converted); - } else if (py::isinstance(obj)) { - ret = ConvertTuple(obj, &converted, use_signature); - } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { - ret = ConvertCellList(obj, &converted, use_signature); - } else if (py::isinstance(obj)) { - return ConvertCellObjToFuncGraph(obj.cast(), data); - } else if (py::isinstance(obj)) { - ret = ConvertList(obj, &converted, use_signature); - } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) { - ret = ConvertPrimitive(obj, &converted, use_signature); - } else if (py::isinstance(obj)) { - ret = ConvertMetaFuncGraph(obj, &converted, use_signature); - } else if (py::isinstance(obj)) { - ret = ConvertFuncGraph(obj, &converted); - } else { - ret = ConvertOtherObj(obj, &converted); + if (!matched) { + converted = ConvertOtherObj(obj); } *data = converted; - return ret; + return converted != nullptr; } // convert data to graph @@ -488,7 +526,6 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString(); data_converter::SetObjGraphValue(obj_key, func_graph); } - return func_graph; } namespace data_converter { @@ -549,13 +586,8 @@ bool IsCellInstance(const py::object &obj) { // create the python class instance py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - py::object obj; - if (params.empty()) { - obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type); - } else { - obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params); - } - return obj; + return params.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type) + : python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params); } // Generate an appropriate name and set to graph debuginfo diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 1a69c60d4b..f770a89648 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -142,48 +142,6 @@ void Parser::CleanParserResource() { ScopeManager::GetInstance().ClearScope(); } -AnfNodePtr AppendParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { - MS_EXCEPTION_IF_NULL(func_graph); - auto value = py::cast(obj); - // Parameter object should not be none - if (value == nullptr || !value->is_parameter()) { - MS_LOG(EXCEPTION) << "Parameter error: because obj is not Parameter object."; - } - - // Get the parameter name from parameter object - auto param_name = value->param_info()->name(); - - auto top_graph = func_graph; - // If the parameter node has been created , return it - AnfNodePtr para_node = nullptr; - for (const auto ¶m : top_graph->parameters()) { - auto param_node = dyn_cast(param); - if (param_node != nullptr && param_node->name() == param_name) { - para_node = param; - break; - } - } - if (para_node == nullptr) { - auto node = top_graph->AddWeightParameter(param_name); - - node->set_default_param(value); - // set_abstract for parameter - auto abs = value->ToAbstract(); - // Boarden value - abs = abs->Broaden(); - node->set_abstract(abs); - para_node = node; - } - return para_node; -} - -void UpdataParam(const FuncGraphPtr &top_graph, const py::object &cell) { - auto params = py::list(cell.attr("get_parameters")()).cast>(); - for (const auto ¶m : params) { - (void)AppendParameterObj(top_graph, param); - } -} - void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr &ast) { // Check whether the functions referred by this function and itself are missing 'return' statement auto mng = Manage(fn, false); diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index 73d6086707..e27444b076 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -163,7 +163,7 @@ enum ClassInstanceTypeDef { }; // Convert python object to ValuePtr -bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, TypePtr dtype = nullptr); +bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, const TypePtr &dtype = nullptr); // Convert python obj to graph FuncGraphPtr ConvertToFuncGraph(const py::object &obj,