You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_converter.cc 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/parse/data_converter.h"
  19. #include <unordered_map>
  20. #include <map>
  21. #include <utility>
  22. #include <string>
  23. #include <memory>
  24. #include <vector>
  25. #include <list>
  26. #include "pipeline/parse/resolve.h"
  27. #include "pipeline/parse/python_adapter.h"
  28. #include "operator/ops.h"
  29. #include "operator/composite/composite.h"
  30. #include "ir/func_graph_cloner.h"
  31. #include "utils/symbolic.h"
  32. #include "debug/trace.h"
  33. namespace mindspore {
  34. namespace parse {
  35. using Tensor = mindspore::tensor::Tensor;
  36. using TensorPtr = mindspore::tensor::TensorPtr;
  37. namespace {
  38. bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) {
  39. MS_LOG(DEBUG) << "Converting python tuple";
  40. py::tuple tuple = obj.cast<py::tuple>();
  41. std::vector<ValuePtr> value_list;
  42. for (size_t it = 0; it < tuple.size(); ++it) {
  43. ValuePtr out = nullptr;
  44. bool success = ConvertData(tuple[it], &out, use_signature);
  45. if (!success) {
  46. return false;
  47. }
  48. value_list.push_back(out);
  49. }
  50. *data = std::make_shared<ValueTuple>(value_list);
  51. return true;
  52. }
  53. bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) {
  54. MS_LOG(DEBUG) << "Converting python list";
  55. py::list list = obj.cast<py::list>();
  56. std::vector<ValuePtr> value_list;
  57. for (size_t it = 0; it < list.size(); ++it) {
  58. ValuePtr out = nullptr;
  59. bool success = ConvertData(list[it], &out, use_signature);
  60. if (!success) {
  61. return false;
  62. }
  63. value_list.push_back(out);
  64. }
  65. *data = std::make_shared<ValueList>(value_list);
  66. return true;
  67. }
  68. bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) {
  69. MS_LOG(DEBUG) << "Converting cell list";
  70. py::sequence list = obj;
  71. std::vector<ValuePtr> value_list;
  72. for (size_t it = 0; it < list.size(); ++it) {
  73. ValuePtr out = nullptr;
  74. bool success = ConvertData(list[it], &out, use_signature);
  75. if (!success) {
  76. return false;
  77. }
  78. value_list.push_back(out);
  79. }
  80. *data = std::make_shared<ValueTuple>(value_list);
  81. return true;
  82. }
  83. bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) {
  84. MS_LOG(DEBUG) << "Converting python dict";
  85. py::dict dict_values = obj.cast<py::dict>();
  86. std::vector<std::pair<std::string, ValuePtr>> key_values;
  87. for (auto item : dict_values) {
  88. if (!py::isinstance<py::str>(item.first)) {
  89. MS_LOG(EXCEPTION) << "The key of dict is only support str.";
  90. }
  91. std::string key = py::str(item.first);
  92. ValuePtr out = nullptr;
  93. bool success = ConvertData(dict_values[item.first], &out, use_signature);
  94. if (!success) {
  95. return false;
  96. }
  97. key_values.emplace_back(key, out);
  98. }
  99. *data = std::make_shared<ValueDictionary>(key_values);
  100. return true;
  101. }
  102. void ConvertNameSpace(const py::object &obj, ValuePtr *const data) {
  103. MS_LOG(DEBUG) << "Converting python module";
  104. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  105. py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
  106. *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace));
  107. }
  108. void ConvertDataClass(py::object obj, ValuePtr *const data) {
  109. MS_LOG(DEBUG) << "Converting dataclass";
  110. // Maybe the obj is dataclass define
  111. auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
  112. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  113. *data = std::make_shared<ClassObject>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  114. }
  115. bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) {
  116. MS_LOG(DEBUG) << "Converting primitive object";
  117. // need check the primitive is class type or instance
  118. auto obj_type = data_converter::GetObjType(obj);
  119. if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
  120. auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
  121. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  122. *data = std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  123. } else {
  124. auto primitive = obj.cast<PrimitivePyPtr>();
  125. if (primitive == nullptr) {
  126. MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null";
  127. return false;
  128. }
  129. if (py::hasattr(obj, "__setattr_flag__")) {
  130. if (py::hasattr(obj, "_clone")) {
  131. auto clone_fn = obj.attr("_clone");
  132. py::object new_obj = clone_fn();
  133. primitive = new_obj.cast<PrimitivePyPtr>();
  134. }
  135. }
  136. if (use_signature) {
  137. *data = std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive);
  138. } else {
  139. *data = primitive;
  140. }
  141. }
  142. return true;
  143. }
  144. bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) {
  145. MS_LOG(DEBUG) << "Converting MetaFuncGraph object";
  146. auto meta = obj.cast<MetaFuncGraphPtr>();
  147. if (meta == nullptr) {
  148. MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null";
  149. return false;
  150. }
  151. if (use_signature) {
  152. *data = std::make_shared<prim::DoSignaturePrimitive>(meta->name(), meta);
  153. } else {
  154. *data = meta;
  155. }
  156. return true;
  157. }
  158. bool ConvertDataType(const py::object &obj, ValuePtr *const data) {
  159. MS_LOG(DEBUG) << "Converting type object";
  160. auto typeptr = obj.cast<TypePtr>();
  161. if (typeptr == nullptr) {
  162. MS_LOG(ERROR) << "Resolve TypePtr error, get ptr is null";
  163. return false;
  164. }
  165. *data = typeptr;
  166. return true;
  167. }
  168. bool ConvertTensor(const py::object &obj, ValuePtr *const data) {
  169. MS_LOG(DEBUG) << "Converting tensor object";
  170. auto m_tensor = obj.cast<TensorPtr>();
  171. if (m_tensor == nullptr) {
  172. MS_LOG(ERROR) << "Resolve Tensor error, get ptr is null";
  173. return false;
  174. }
  175. *data = m_tensor;
  176. return true;
  177. }
  178. bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
  179. auto obj_type = data_converter::GetObjType(obj);
  180. MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
  181. if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
  182. MS_LOG(DEBUG) << "Resolve the class type, need create class instance.";
  183. std::string desc = py::str(obj);
  184. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  185. *data = std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  186. return true;
  187. }
  188. if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
  189. MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
  190. FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
  191. if (func_graph == nullptr) {
  192. MS_LOG(ERROR) << "Parse resolve function error.";
  193. return false;
  194. }
  195. *data = func_graph;
  196. return true;
  197. }
  198. if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
  199. // Create the namespace for common class instance
  200. // When the obj is Cell, default parse the 'construct'
  201. if (data_converter::IsCellInstance(obj)) {
  202. FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
  203. if (func_graph == nullptr) {
  204. MS_LOG(ERROR) << "Parse resolve function error.";
  205. return false;
  206. }
  207. // if the cell object has specified bprop, it has user-defined bprop function parse and record it
  208. if (py::hasattr(obj, "bprop")) {
  209. FuncGraphPtr bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
  210. if (bprop_graph != nullptr) {
  211. (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph)));
  212. (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
  213. func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
  214. }
  215. }
  216. *data = func_graph;
  217. } else {
  218. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  219. py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
  220. *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
  221. }
  222. return true;
  223. }
  224. MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
  225. return false;
  226. }
  227. } // namespace
  228. bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature) {
  229. // check parameter valid
  230. if (data == nullptr) {
  231. MS_LOG(ERROR) << "Data is null pointer";
  232. return false;
  233. }
  234. bool ret = true;
  235. ValuePtr converted = nullptr;
  236. if (py::isinstance<py::none>(obj)) {
  237. converted = kNone;
  238. } else if (py::isinstance<py::bool_>(obj)) {
  239. converted = std::make_shared<BoolImm>(py::cast<bool>(obj));
  240. } else if (py::isinstance<py::int_>(obj)) {
  241. converted = std::make_shared<Int32Imm>(py::cast<int>(obj));
  242. } else if (py::isinstance<py::float_>(obj)) {
  243. converted = std::make_shared<FP32Imm>(py::cast<float>(obj));
  244. } else if (py::isinstance<py::str>(obj)) {
  245. converted = std::make_shared<StringImm>(py::cast<std::string>(obj));
  246. } else if (py::isinstance<py::dict>(obj)) {
  247. ret = ConvertDict(obj, &converted, use_signature);
  248. } else if (py::isinstance<py::tuple>(obj)) {
  249. ret = ConvertTuple(obj, &converted, use_signature);
  250. } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
  251. ret = ConvertCellList(obj, &converted, use_signature);
  252. } else if (py::isinstance<py::list>(obj)) {
  253. ret = ConvertList(obj, &converted, use_signature);
  254. } else if (py::isinstance<py::module>(obj)) {
  255. ConvertNameSpace(obj, &converted);
  256. } else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) {
  257. ConvertDataClass(obj, &converted);
  258. } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) {
  259. ret = ConvertPrimitive(obj, &converted, use_signature);
  260. } else if (py::hasattr(obj, PYTHON_METAFUNCGRAPH_FLAG)) {
  261. ret = ConvertMetaFuncGraph(obj, &converted, use_signature);
  262. } else if (py::hasattr(obj, PYTHON_DTYPE_FLAG)) {
  263. ret = ConvertDataType(obj, &converted);
  264. } else if (py::hasattr(obj, PYTHON_TENSOR_FLAG)) {
  265. ret = ConvertTensor(obj, &converted);
  266. } else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) {
  267. std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>();
  268. converted = env;
  269. } else {
  270. ret = ConvertOtherObj(obj, &converted);
  271. }
  272. *data = converted;
  273. return ret;
  274. }
  275. // convert data to graph
  276. FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) {
  277. std::vector<std::string> results = data_converter::GetObjKey(obj);
  278. std::string obj_id = results[0] + python_mod_get_parse_method;
  279. std::string obj_key = results[1];
  280. FuncGraphPtr func_graph = nullptr;
  281. Any value = Any();
  282. bool is_cache = data_converter::GetObjectValue(obj_id, &value);
  283. if (is_cache) {
  284. if (value.is<FuncGraphPtr>()) {
  285. MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id;
  286. func_graph = value.cast<FuncGraphPtr>();
  287. return func_graph;
  288. }
  289. }
  290. func_graph = ParsePythonCode(obj, python_mod_get_parse_method);
  291. if (func_graph == nullptr) {
  292. MS_LOG(ERROR) << "Parse resolve function error.";
  293. return nullptr;
  294. }
  295. data_converter::MakeProperNameToFuncGraph(func_graph, obj_id);
  296. data_converter::CacheObjectValue(obj_id, func_graph);
  297. if (obj_key != "") {
  298. MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
  299. data_converter::SetObjGraphValue(obj_key, func_graph);
  300. }
  301. return func_graph;
  302. }
  303. namespace data_converter {
  304. static std::unordered_map<std::string, Any> object_map_ = std::unordered_map<std::string, Any>();
  305. static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_ =
  306. std::unordered_map<std::string, std::vector<FuncGraphPtr>>();
  307. void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
  308. object_graphs_map_[obj_key].push_back(data);
  309. MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size();
  310. }
  311. const std::unordered_map<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
  312. MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size();
  313. return object_graphs_map_;
  314. }
  315. void CacheObjectValue(const std::string &obj_key, const Any &data) { object_map_[obj_key] = data; }
  316. bool GetObjectValue(const std::string &obj_key, Any *const data) {
  317. if (object_map_.count(obj_key)) {
  318. *data = object_map_[obj_key];
  319. return true;
  320. }
  321. return false;
  322. }
  323. std::vector<std::string> GetObjKey(const py::object &obj) {
  324. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  325. py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj);
  326. if (obj_tuple.size() != 2) {
  327. MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements";
  328. }
  329. return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
  330. }
  331. // get obj detail type
  332. ResolveTypeDef GetObjType(const py::object &obj) {
  333. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  334. auto obj_type =
  335. ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
  336. return obj_type;
  337. }
  338. // get class instance detail type
  339. ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
  340. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  341. auto class_type =
  342. ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast<int32_t>());
  343. return class_type;
  344. }
  345. // check the object is Cell Instance
  346. bool IsCellInstance(const py::object &obj) {
  347. auto class_type = GetClassInstanceType(obj);
  348. bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL);
  349. return isCell;
  350. }
  351. // create the python class instance
  352. py::object CreatePythonObject(const py::object &type, const py::tuple &params) {
  353. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  354. py::object obj;
  355. if (params.size() == 0) {
  356. obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type);
  357. } else {
  358. obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params);
  359. }
  360. return obj;
  361. }
  362. // Generate an appropriate name and set to graph debuginfo
  363. // character <> can not used in the dot file, so change to another symbol
  364. void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) {
  365. MS_EXCEPTION_IF_NULL(func_graph);
  366. MS_EXCEPTION_IF_NULL(func_graph->debug_info());
  367. // set detail name info of function
  368. std::ostringstream oss;
  369. for (size_t i = 0; i < name.size(); i++) {
  370. if (name[i] == '<') {
  371. oss << "「";
  372. } else if (name[i] == '>') {
  373. oss << "」";
  374. } else {
  375. oss << name[i];
  376. }
  377. }
  378. func_graph->debug_info()->set_full_name(oss.str());
  379. }
  380. ValuePtr PyDataToValue(const py::object &obj) {
  381. py::object to_convert = obj;
  382. if (py::hasattr(obj, "__parameter__")) {
  383. to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
  384. }
  385. ValuePtr value = nullptr;
  386. (void)ConvertData(to_convert, &value);
  387. return value;
  388. }
  389. void ClearObjectCache() {
  390. object_map_.clear();
  391. object_graphs_map_.clear();
  392. }
  393. } // namespace data_converter
  394. static std::unordered_map<std::string, ClassPtr> g_dataClassToClass = {};
  395. // parse dataclass to mindspore Class type
  396. ClassPtr ParseDataClass(const py::object &cls_obj) {
  397. std::string cls_name = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__name__"));
  398. std::string cls_module = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__module__"));
  399. std::string cls = cls_module + "." + cls_name;
  400. auto iterator = g_dataClassToClass.find(cls);
  401. if (iterator != g_dataClassToClass.end()) {
  402. return iterator->second;
  403. }
  404. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  405. ClassAttrVector attributes;
  406. py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj);
  407. for (auto &item : names) {
  408. TypePtr type_value = item.second.cast<TypePtr>();
  409. MS_EXCEPTION_IF_NULL(type_value);
  410. MS_LOG(DEBUG) << "(Name: " << py::cast<std::string>(item.first) << ", type: " << type_value->ToString() << ")";
  411. attributes.push_back(std::make_pair(py::cast<std::string>(item.first), type_value));
  412. }
  413. std::unordered_map<std::string, ValuePtr> methods_map;
  414. py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj);
  415. for (auto &item : methods) {
  416. std::string fun_name = item.first.cast<std::string>();
  417. py::object obj = py::cast<py::object>(item.second);
  418. std::shared_ptr<PyObjectWrapper> method_obj = std::make_shared<PyObjectWrapper>(obj, fun_name);
  419. methods_map[fun_name] = method_obj;
  420. }
  421. std::shared_ptr<Class> me_class = std::make_shared<Class>(Named(cls_name), attributes, methods_map);
  422. // static Variable for cache
  423. // cppcheck-suppress unreadVariable
  424. g_dataClassToClass[cls] = me_class;
  425. return me_class;
  426. }
  427. void CleanDataClassToClassMap() { g_dataClassToClass.clear(); }
  428. } // namespace parse
  429. } // namespace mindspore