/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pipeline/parse/data_converter.h" #include #include #include #include #include #include #include #include "pipeline/parse/resolve.h" #include "pipeline/parse/python_adapter.h" #include "operator/ops.h" #include "operator/composite/composite.h" #include "ir/func_graph_cloner.h" #include "utils/symbolic.h" #include "debug/trace.h" namespace mindspore { namespace parse { using Tensor = mindspore::tensor::Tensor; using TensorPtr = mindspore::tensor::TensorPtr; namespace { bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting python tuple"; py::tuple tuple = obj.cast(); std::vector value_list; for (size_t it = 0; it < tuple.size(); ++it) { ValuePtr out = nullptr; bool success = ConvertData(tuple[it], &out, use_signature); if (!success) { return false; } value_list.push_back(out); } *data = std::make_shared(value_list); return true; } bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting python list"; py::list list = obj.cast(); std::vector value_list; for (size_t it = 0; it < list.size(); ++it) { ValuePtr out = nullptr; bool success = ConvertData(list[it], &out, use_signature); if (!success) { return false; } value_list.push_back(out); } *data = std::make_shared(value_list); return true; } bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting cell list"; py::sequence list = obj; std::vector value_list; for (size_t it = 0; it < list.size(); ++it) { ValuePtr out = nullptr; bool success = ConvertData(list[it], &out, use_signature); if (!success) { return false; } value_list.push_back(out); } *data = std::make_shared(value_list); return true; } bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { MS_LOG(DEBUG) << "Converting python dict"; py::dict dict_values = obj.cast(); std::vector> key_values; for (auto item : dict_values) { if (!py::isinstance(item.first)) { MS_LOG(EXCEPTION) << "The key of dict is only support str."; } std::string key = py::str(item.first); ValuePtr out = nullptr; bool success = ConvertData(dict_values[item.first], &out, use_signature); if (!success) { return false; } key_values.emplace_back(key, out); } *data = std::make_shared(key_values); return true; } void ConvertNameSpace(const py::object &obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting python module"; py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj); *data = std::make_shared(RESOLVE_NAMESPACE_NAME_MODULE, py::cast(module_namespace)); } void ConvertDataClass(py::object obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting dataclass"; // Maybe the obj is dataclass define auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); // desc has format "", strip the '<' and '>' by offset 1; *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); } bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) { MS_LOG(DEBUG) << "Converting primitive object"; // need check the primitive is class type or instance auto obj_type = data_converter::GetObjType(obj); if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); // desc has format "", strip the '<' and '>' by offset 1; *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); } else { auto primitive = obj.cast(); if (primitive == nullptr) { MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null"; return false; } if (py::hasattr(obj, "__setattr_flag__")) { if (py::hasattr(obj, "_clone")) { auto clone_fn = obj.attr("_clone"); py::object new_obj = clone_fn(); primitive = new_obj.cast(); } } if (use_signature) { *data = std::make_shared(primitive->name(), primitive); } else { *data = primitive; } } return true; } bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) { MS_LOG(DEBUG) << "Converting MetaFuncGraph object"; auto meta = obj.cast(); if (meta == nullptr) { MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null"; return false; } if (use_signature) { *data = std::make_shared(meta->name(), meta); } else { *data = meta; } return true; } bool ConvertDataType(const py::object &obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting type object"; auto typeptr = obj.cast(); if (typeptr == nullptr) { MS_LOG(ERROR) << "Resolve TypePtr error, get ptr is null"; return false; } *data = typeptr; return true; } bool ConvertTensor(const py::object &obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting tensor object"; auto m_tensor = obj.cast(); if (m_tensor == nullptr) { MS_LOG(ERROR) << "Resolve Tensor error, get ptr is null"; return false; } *data = m_tensor; return true; } bool ConvertOtherObj(py::object obj, ValuePtr *const data) { auto obj_type = data_converter::GetObjType(obj); MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { MS_LOG(DEBUG) << "Resolve the class type, need create class instance."; std::string desc = py::str(obj); // desc has format "", strip the '<' and '>' by offset 1; *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); return true; } if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) { MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type; FuncGraphPtr func_graph = ConvertToFuncGraph(obj); if (func_graph == nullptr) { MS_LOG(ERROR) << "Parse resolve function error."; return false; } *data = func_graph; return true; } if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) { // Create the namespace for common class instance // When the obj is Cell, default parse the 'construct' if (data_converter::IsCellInstance(obj)) { FuncGraphPtr func_graph = ConvertToFuncGraph(obj); if (func_graph == nullptr) { MS_LOG(ERROR) << "Parse resolve function error."; return false; } // if the cell object has specified bprop, it has user-defined bprop function parse and record it if (py::hasattr(obj, "bprop")) { FuncGraphPtr bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); if (bprop_graph != nullptr) { (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true); } } *data = func_graph; } else { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); *data = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); } return true; } MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj)); return false; } } // namespace bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature) { // check parameter valid if (data == nullptr) { MS_LOG(ERROR) << "Data is null pointer"; return false; } bool ret = true; ValuePtr converted = nullptr; if (py::isinstance(obj)) { converted = kNone; } else if (py::isinstance(obj)) { converted = std::make_shared(py::cast(obj)); } else if (py::isinstance(obj)) { converted = std::make_shared(py::cast(obj)); } else if (py::isinstance(obj)) { converted = std::make_shared(py::cast(obj)); } else if (py::isinstance(obj)) { converted = std::make_shared(py::cast(obj)); } else if (py::isinstance(obj)) { ret = ConvertDict(obj, &converted, use_signature); } else if (py::isinstance(obj)) { ret = ConvertTuple(obj, &converted, use_signature); } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { ret = ConvertCellList(obj, &converted, use_signature); } else if (py::isinstance(obj)) { ret = ConvertList(obj, &converted, use_signature); } else if (py::isinstance(obj)) { ConvertNameSpace(obj, &converted); } else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) { ConvertDataClass(obj, &converted); } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) { ret = ConvertPrimitive(obj, &converted, use_signature); } else if (py::hasattr(obj, PYTHON_METAFUNCGRAPH_FLAG)) { ret = ConvertMetaFuncGraph(obj, &converted, use_signature); } else if (py::hasattr(obj, PYTHON_DTYPE_FLAG)) { ret = ConvertDataType(obj, &converted); } else if (py::hasattr(obj, PYTHON_TENSOR_FLAG)) { ret = ConvertTensor(obj, &converted); } else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) { std::shared_ptr env = obj.cast>(); converted = env; } else { ret = ConvertOtherObj(obj, &converted); } *data = converted; return ret; } // convert data to graph FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) { std::vector results = data_converter::GetObjKey(obj); std::string obj_id = results[0] + python_mod_get_parse_method; std::string obj_key = results[1]; FuncGraphPtr func_graph = nullptr; Any value = Any(); bool is_cache = data_converter::GetObjectValue(obj_id, &value); if (is_cache) { if (value.is()) { MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id; func_graph = value.cast(); return func_graph; } } func_graph = ParsePythonCode(obj, python_mod_get_parse_method); if (func_graph == nullptr) { MS_LOG(ERROR) << "Parse resolve function error."; return nullptr; } data_converter::MakeProperNameToFuncGraph(func_graph, obj_id); data_converter::CacheObjectValue(obj_id, func_graph); if (obj_key != "") { MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString(); data_converter::SetObjGraphValue(obj_key, func_graph); } return func_graph; } namespace data_converter { static std::unordered_map object_map_ = std::unordered_map(); static std::unordered_map> object_graphs_map_ = std::unordered_map>(); void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) { object_graphs_map_[obj_key].push_back(data); MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size(); } const std::unordered_map> &GetObjGraphs() { MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size(); return object_graphs_map_; } void CacheObjectValue(const std::string &obj_key, const Any &data) { object_map_[obj_key] = data; } bool GetObjectValue(const std::string &obj_key, Any *const data) { if (object_map_.count(obj_key)) { *data = object_map_[obj_key]; return true; } return false; } std::vector GetObjKey(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj); if (obj_tuple.size() != 2) { MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements"; } return {py::cast(obj_tuple[0]), py::cast(obj_tuple[1])}; } // get obj detail type ResolveTypeDef GetObjType(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); auto obj_type = ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast()); return obj_type; } // get class instance detail type ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); auto class_type = ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast()); return class_type; } // check the object is Cell Instance bool IsCellInstance(const py::object &obj) { auto class_type = GetClassInstanceType(obj); bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL); return isCell; } // create the python class instance py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object obj; if (params.size() == 0) { obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type); } else { obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params); } return obj; } // Generate an appropriate name and set to graph debuginfo // character <> can not used in the dot file, so change to another symbol void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph->debug_info()); // set detail name info of function std::ostringstream oss; for (size_t i = 0; i < name.size(); i++) { if (name[i] == '<') { oss << "「"; } else if (name[i] == '>') { oss << "」"; } else { oss << name[i]; } } func_graph->debug_info()->set_full_name(oss.str()); } ValuePtr PyDataToValue(const py::object &obj) { py::object to_convert = obj; if (py::hasattr(obj, "__parameter__")) { to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); } ValuePtr value = nullptr; (void)ConvertData(to_convert, &value); return value; } void ClearObjectCache() { object_map_.clear(); object_graphs_map_.clear(); } } // namespace data_converter static std::unordered_map g_dataClassToClass = {}; // parse dataclass to mindspore Class type ClassPtr ParseDataClass(const py::object &cls_obj) { std::string cls_name = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__name__")); std::string cls_module = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__module__")); std::string cls = cls_module + "." + cls_name; auto iterator = g_dataClassToClass.find(cls); if (iterator != g_dataClassToClass.end()) { return iterator->second; } py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); ClassAttrVector attributes; py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj); for (auto &item : names) { TypePtr type_value = item.second.cast(); MS_EXCEPTION_IF_NULL(type_value); MS_LOG(DEBUG) << "(Name: " << py::cast(item.first) << ", type: " << type_value->ToString() << ")"; attributes.push_back(std::make_pair(py::cast(item.first), type_value)); } std::unordered_map methods_map; py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj); for (auto &item : methods) { std::string fun_name = item.first.cast(); py::object obj = py::cast(item.second); std::shared_ptr method_obj = std::make_shared(obj, fun_name); methods_map[fun_name] = method_obj; } std::shared_ptr me_class = std::make_shared(Named(cls_name), attributes, methods_map); // static Variable for cache // cppcheck-suppress unreadVariable g_dataClassToClass[cls] = me_class; return me_class; } void CleanDataClassToClassMap() { g_dataClassToClass.clear(); } } // namespace parse } // namespace mindspore