You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_converter.cc 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/parse/data_converter.h"
  19. #include <unordered_map>
  20. #include <utility>
  21. #include <string>
  22. #include <memory>
  23. #include <vector>
  24. #include "pipeline/jit/parse/resolve.h"
  25. #include "pipeline/jit/parse/python_adapter.h"
  26. #include "frontend/operator/ops.h"
  27. #include "frontend/operator/composite/composite.h"
  28. #include "ir/func_graph_cloner.h"
  29. #include "ir/cell.h"
  30. #include "utils/symbolic.h"
  31. #include "utils/ms_context.h"
  32. namespace mindspore {
  33. namespace parse {
  34. using Tensor = mindspore::tensor::Tensor;
  35. using TensorPtr = mindspore::tensor::TensorPtr;
  36. using MetaTensor = mindspore::tensor::MetaTensor;
  37. using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
  38. FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
  39. std::vector<std::string> results = data_converter::GetObjKey(obj);
  40. std::string obj_key = results[0];
  41. py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME);
  42. auto bprop_graph = std::make_shared<FuncGraph>();
  43. std::vector<AnfNodePtr> outputs;
  44. auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
  45. fake_bprop->set_hook(bprop_func);
  46. (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
  47. outputs.push_back(NewValueNode(fake_bprop));
  48. py::object code_obj = py::getattr(bprop_func, "__code__");
  49. // Three parameters self, out and dout need to be excluded
  50. size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
  51. for (size_t i = 0; i < inputs_num; ++i) {
  52. auto param = bprop_graph->add_parameter();
  53. outputs.push_back(param);
  54. }
  55. auto p1 = bprop_graph->add_parameter();
  56. auto p2 = bprop_graph->add_parameter();
  57. outputs.push_back(p1);
  58. outputs.push_back(p2);
  59. bprop_graph->set_output(bprop_graph->NewCNode(outputs));
  60. data_converter::SetObjGraphValue(obj_key, bprop_graph);
  61. return bprop_graph;
  62. }
  63. namespace {
  64. bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) {
  65. MS_LOG(DEBUG) << "Converting python tuple";
  66. auto tuple = obj.cast<py::tuple>();
  67. std::vector<ValuePtr> value_list;
  68. for (size_t it = 0; it < tuple.size(); ++it) {
  69. ValuePtr out = nullptr;
  70. bool success = ConvertData(tuple[it], &out, use_signature);
  71. if (!success) {
  72. return false;
  73. }
  74. value_list.push_back(out);
  75. }
  76. *data = std::make_shared<ValueTuple>(value_list);
  77. return true;
  78. }
  79. bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) {
  80. MS_LOG(DEBUG) << "Converting python list";
  81. auto list = obj.cast<py::list>();
  82. std::vector<ValuePtr> value_list;
  83. for (size_t it = 0; it < list.size(); ++it) {
  84. ValuePtr out = nullptr;
  85. bool success = ConvertData(list[it], &out, use_signature);
  86. if (!success) {
  87. return false;
  88. }
  89. value_list.push_back(out);
  90. }
  91. *data = std::make_shared<ValueList>(value_list);
  92. return true;
  93. }
  94. bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) {
  95. MS_LOG(DEBUG) << "Converting cell list";
  96. py::sequence list = obj;
  97. std::vector<ValuePtr> value_list;
  98. for (size_t it = 0; it < list.size(); ++it) {
  99. ValuePtr out = nullptr;
  100. bool success = ConvertData(list[it], &out, use_signature);
  101. if (!success) {
  102. return false;
  103. }
  104. value_list.push_back(out);
  105. }
  106. *data = std::make_shared<ValueTuple>(value_list);
  107. return true;
  108. }
  109. bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) {
  110. MS_LOG(DEBUG) << "Converting python dict";
  111. auto dict_values = obj.cast<py::dict>();
  112. std::vector<std::pair<std::string, ValuePtr>> key_values;
  113. for (auto item : dict_values) {
  114. if (!py::isinstance<py::str>(item.first)) {
  115. MS_LOG(ERROR) << "The key of dict is only support str.";
  116. return false;
  117. }
  118. std::string key = py::str(item.first);
  119. ValuePtr out = nullptr;
  120. bool success = ConvertData(dict_values[item.first], &out, use_signature);
  121. if (!success) {
  122. return false;
  123. }
  124. key_values.emplace_back(key, out);
  125. }
  126. *data = std::make_shared<ValueDictionary>(key_values);
  127. return true;
  128. }
  129. void ConvertNameSpace(const py::object &obj, ValuePtr *const data) {
  130. MS_LOG(DEBUG) << "Converting python module";
  131. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  132. py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
  133. *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace));
  134. }
  135. void ConvertDataClass(py::object obj, ValuePtr *const data) {
  136. MS_LOG(DEBUG) << "Converting dataclass";
  137. // Maybe the obj is dataclass define
  138. auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
  139. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  140. *data = std::make_shared<ClassObject>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  141. }
  142. bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) {
  143. MS_LOG(DEBUG) << "Converting primitive object" << use_signature;
  144. // need check the primitive is class type or instance
  145. auto obj_type = data_converter::GetObjType(obj);
  146. if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
  147. auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
  148. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  149. *data = std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  150. } else {
  151. auto primitive = obj.cast<PrimitivePyPtr>();
  152. if (primitive == nullptr) {
  153. MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null";
  154. return false;
  155. }
  156. if (py::hasattr(obj, "__setattr_flag__")) {
  157. if (py::hasattr(obj, "_clone")) {
  158. auto clone_fn = obj.attr("_clone");
  159. py::object new_obj = clone_fn();
  160. primitive = new_obj.cast<PrimitivePyPtr>();
  161. }
  162. }
  163. if (use_signature) {
  164. *data = std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive);
  165. } else {
  166. *data = primitive;
  167. }
  168. MS_LOG(DEBUG) << "Converting primitive object ok " << (*data)->ToString();
  169. }
  170. return true;
  171. }
  172. bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) {
  173. MS_LOG(DEBUG) << "Converting MetaFuncGraph object";
  174. auto meta = obj.cast<MetaFuncGraphPtr>();
  175. if (meta == nullptr) {
  176. MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null";
  177. return false;
  178. }
  179. if (use_signature) {
  180. *data = std::make_shared<prim::DoSignaturePrimitive>(meta->name(), meta);
  181. } else {
  182. *data = meta;
  183. }
  184. return true;
  185. }
  186. bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
  187. MS_LOG(DEBUG) << "Converting slice object";
  188. auto slice_obj = obj.cast<py::slice>();
  189. auto convert_func = [obj](std::string attr) -> ValuePtr {
  190. auto py_attr = py::getattr(obj, attr.c_str());
  191. if (py::isinstance<py::none>(py_attr)) {
  192. return kNone;
  193. } else if (py::isinstance<py::int_>(py_attr)) {
  194. int64_t value = py::cast<int64_t>(py_attr);
  195. return MakeValue(value);
  196. } else {
  197. MS_LOG(EXCEPTION) << "Slice should contain only int64_t or none";
  198. }
  199. };
  200. ValuePtr start = convert_func("start");
  201. ValuePtr stop = convert_func("stop");
  202. ValuePtr step = convert_func("step");
  203. *data = std::make_shared<ValueSlice>(start, stop, step);
  204. return true;
  205. }
  206. bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) {
  207. auto obj = py::cast(cell);
  208. FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
  209. if (func_graph == nullptr) {
  210. MS_LOG(ERROR) << "Parse resolve function error.";
  211. return false;
  212. }
  213. // if the cell object has specified bprop, it has user-defined bprop function parse and record it
  214. if (py::hasattr(obj, CUSTOM_BPROP_NAME)) {
  215. FuncGraphPtr bprop_graph = nullptr;
  216. bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
  217. if (enable_bprop_debug) {
  218. bprop_graph = ConvertToBpropCut(obj);
  219. } else {
  220. bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
  221. }
  222. if (bprop_graph != nullptr) {
  223. (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
  224. (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
  225. func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
  226. }
  227. }
  228. if (py::hasattr(obj, STAGE_NAME)) {
  229. auto stage = py::cast<int>(py::getattr(obj, STAGE_NAME));
  230. func_graph->set_stage(stage);
  231. }
  232. *data = func_graph;
  233. return true;
  234. }
  235. bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
  236. auto obj_type = data_converter::GetObjType(obj);
  237. MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
  238. if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
  239. MS_LOG(DEBUG) << "Resolve the class type, need create class instance.";
  240. std::string desc = py::str(obj);
  241. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  242. *data = std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  243. return true;
  244. }
  245. if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
  246. MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
  247. FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
  248. if (func_graph == nullptr) {
  249. MS_LOG(ERROR) << "Parse resolve function error.";
  250. return false;
  251. }
  252. *data = func_graph;
  253. return true;
  254. }
  255. if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
  256. // Create the namespace for common class instance
  257. // When the obj is Cell, default parse the 'construct'
  258. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  259. py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
  260. *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
  261. return true;
  262. }
  263. MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
  264. return false;
  265. }
  266. template <typename T>
  267. bool ConvertNumberWithType(const T &obj, ValuePtr *const data, TypePtr dtype) {
  268. auto int_dypte = dyn_cast<Int>(dtype);
  269. if (int_dypte != nullptr) {
  270. switch (int_dypte->nbits()) {
  271. case 8:
  272. *data = std::make_shared<Int8Imm>(obj);
  273. break;
  274. case 16:
  275. *data = std::make_shared<Int16Imm>(obj);
  276. break;
  277. case 32:
  278. *data = std::make_shared<Int32Imm>(obj);
  279. break;
  280. case 64:
  281. *data = std::make_shared<Int64Imm>(obj);
  282. break;
  283. default:
  284. *data = std::make_shared<Int64Imm>(obj);
  285. }
  286. return true;
  287. }
  288. auto uint_dypte = dyn_cast<UInt>(dtype);
  289. if (uint_dypte != nullptr) {
  290. switch (uint_dypte->nbits()) {
  291. case 8:
  292. *data = std::make_shared<UInt8Imm>(obj);
  293. break;
  294. case 16:
  295. *data = std::make_shared<UInt16Imm>(obj);
  296. break;
  297. case 32:
  298. *data = std::make_shared<UInt32Imm>(obj);
  299. break;
  300. case 64:
  301. *data = std::make_shared<UInt64Imm>(obj);
  302. break;
  303. default:
  304. *data = std::make_shared<UInt32Imm>(obj);
  305. }
  306. return true;
  307. }
  308. auto float_dypte = dyn_cast<Float>(dtype);
  309. if (float_dypte != nullptr) {
  310. switch (float_dypte->nbits()) {
  311. case 32:
  312. *data = std::make_shared<FP32Imm>(obj);
  313. break;
  314. case 64:
  315. *data = std::make_shared<FP64Imm>(obj);
  316. break;
  317. default:
  318. *data = std::make_shared<FP32Imm>(obj);
  319. }
  320. return true;
  321. }
  322. return false;
  323. }
  324. bool ConvertIntegerWithType(const int64_t &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
  325. if (dtype == nullptr) {
  326. *data = std::make_shared<Int64Imm>(obj);
  327. return true;
  328. }
  329. return ConvertNumberWithType<int64_t>(obj, data, dtype);
  330. }
  331. bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
  332. if (dtype == nullptr) {
  333. *data = std::make_shared<FP32Imm>(obj);
  334. return true;
  335. }
  336. return ConvertNumberWithType<float>(obj, data, dtype);
  337. }
  338. } // namespace
  339. bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) {
  340. // check parameter valid
  341. if (data == nullptr) {
  342. MS_LOG(ERROR) << "Data is null pointer";
  343. return false;
  344. }
  345. bool ret = true;
  346. ValuePtr converted = nullptr;
  347. if (py::isinstance<py::none>(obj)) {
  348. converted = kNone;
  349. } else if (py::isinstance<py::bool_>(obj)) {
  350. converted = std::make_shared<BoolImm>(py::cast<bool>(obj));
  351. } else if (py::isinstance<py::int_>(obj)) {
  352. ret = ConvertIntegerWithType(py::cast<int64_t>(obj), &converted, dtype);
  353. } else if (py::isinstance<py::float_>(obj)) {
  354. ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype);
  355. } else if (py::isinstance<py::str>(obj)) {
  356. converted = std::make_shared<StringImm>(py::cast<std::string>(obj));
  357. } else if (py::isinstance<py::dict>(obj)) {
  358. ret = ConvertDict(obj, &converted, use_signature);
  359. } else if (py::isinstance<py::slice>(obj)) {
  360. ret = ConvertSlice(obj, &converted);
  361. } else if (py::isinstance<py::ellipsis>(obj)) {
  362. converted = kEllipsis;
  363. } else if (py::isinstance<py::tuple>(obj)) {
  364. ret = ConvertTuple(obj, &converted, use_signature);
  365. } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
  366. ret = ConvertCellList(obj, &converted, use_signature);
  367. } else if (py::isinstance<Cell>(obj)) {
  368. return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data);
  369. } else if (py::isinstance<py::list>(obj)) {
  370. ret = ConvertList(obj, &converted, use_signature);
  371. } else if (py::isinstance<py::module>(obj)) {
  372. ConvertNameSpace(obj, &converted);
  373. } else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) {
  374. ConvertDataClass(obj, &converted);
  375. } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) {
  376. ret = ConvertPrimitive(obj, &converted, use_signature);
  377. } else if (py::isinstance<MetaFuncGraph>(obj)) {
  378. ret = ConvertMetaFuncGraph(obj, &converted, use_signature);
  379. } else if (py::isinstance<Type>(obj)) {
  380. converted = obj.cast<TypePtr>();
  381. } else if (py::isinstance<Tensor>(obj)) {
  382. converted = obj.cast<TensorPtr>();
  383. } else if (py::isinstance<MetaTensor>(obj)) {
  384. converted = obj.cast<MetaTensorPtr>();
  385. } else if (py::isinstance<UMonad>(obj)) {
  386. converted = obj.cast<UMonadPtr>();
  387. } else if (py::isinstance<IOMonad>(obj)) {
  388. converted = obj.cast<IOMonadPtr>();
  389. } else if (py::isinstance<EnvInstance>(obj)) {
  390. auto env = obj.cast<std::shared_ptr<EnvInstance>>();
  391. converted = env;
  392. } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
  393. converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
  394. } else {
  395. ret = ConvertOtherObj(obj, &converted);
  396. }
  397. *data = converted;
  398. return ret;
  399. }
  400. // convert data to graph
  401. FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) {
  402. std::vector<std::string> results = data_converter::GetObjKey(obj);
  403. std::string obj_id = results[0] + python_mod_get_parse_method;
  404. std::string obj_key = results[1];
  405. FuncGraphPtr func_graph = nullptr;
  406. ValuePtr value = nullptr;
  407. bool is_cache = data_converter::GetObjectValue(obj_id, &value);
  408. if (is_cache) {
  409. if (value && value->isa<FuncGraph>()) {
  410. MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id;
  411. func_graph = value->cast<FuncGraphPtr>();
  412. return func_graph;
  413. }
  414. }
  415. func_graph = ParsePythonCode(obj, python_mod_get_parse_method);
  416. if (func_graph == nullptr) {
  417. MS_LOG(ERROR) << "Parse resolve function error.";
  418. return nullptr;
  419. }
  420. data_converter::MakeProperNameToFuncGraph(func_graph, obj_id);
  421. data_converter::CacheObjectValue(obj_id, func_graph);
  422. if (!obj_key.empty()) {
  423. MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
  424. data_converter::SetObjGraphValue(obj_key, func_graph);
  425. }
  426. return func_graph;
  427. }
  428. namespace data_converter {
  429. static std::unordered_map<std::string, ValuePtr> object_map_;
  430. static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
  431. void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
  432. object_graphs_map_[obj_key].push_back(data);
  433. MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size();
  434. }
  435. const std::unordered_map<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
  436. MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size();
  437. return object_graphs_map_;
  438. }
  439. void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; }
  440. bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) {
  441. if (object_map_.count(obj_key)) {
  442. *data = object_map_[obj_key];
  443. return true;
  444. }
  445. return false;
  446. }
  447. std::vector<std::string> GetObjKey(const py::object &obj) {
  448. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  449. py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj);
  450. if (obj_tuple.size() != 2) {
  451. MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements";
  452. }
  453. return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
  454. }
  455. // get obj detail type
  456. ResolveTypeDef GetObjType(const py::object &obj) {
  457. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  458. auto obj_type =
  459. ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
  460. return obj_type;
  461. }
  462. // get class instance detail type
  463. ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
  464. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  465. auto class_type =
  466. ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast<int32_t>());
  467. return class_type;
  468. }
  469. // check the object is Cell Instance
  470. bool IsCellInstance(const py::object &obj) {
  471. auto class_type = GetClassInstanceType(obj);
  472. bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL);
  473. return isCell;
  474. }
  475. // create the python class instance
  476. py::object CreatePythonObject(const py::object &type, const py::tuple &params) {
  477. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  478. py::object obj;
  479. if (params.empty()) {
  480. obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type);
  481. } else {
  482. obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params);
  483. }
  484. return obj;
  485. }
  486. // Generate an appropriate name and set to graph debuginfo
  487. // character <> can not used in the dot file, so change to another symbol
  488. void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) {
  489. MS_EXCEPTION_IF_NULL(func_graph);
  490. MS_EXCEPTION_IF_NULL(func_graph->debug_info());
  491. // set detail name info of function
  492. std::ostringstream oss;
  493. for (size_t i = 0; i < name.size(); i++) {
  494. if (name[i] == '<') {
  495. oss << "「";
  496. } else if (name[i] == '>') {
  497. oss << "」";
  498. } else {
  499. oss << name[i];
  500. }
  501. }
  502. func_graph->debug_info()->set_full_name(oss.str());
  503. }
  504. ValuePtr PyDataToValue(const py::object &obj) {
  505. py::object to_convert = obj;
  506. ValuePtr value = nullptr;
  507. (void)ConvertData(to_convert, &value);
  508. return value;
  509. }
  510. void ClearObjectCache() {
  511. object_map_.clear();
  512. object_graphs_map_.clear();
  513. }
  514. } // namespace data_converter
  515. static std::unordered_map<std::string, ClassPtr> g_dataClassToClass = {};
  516. // parse dataclass to mindspore Class type
  517. ClassPtr ParseDataClass(const py::object &cls_obj) {
  518. std::string cls_name = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__name__"));
  519. std::string cls_module = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__module__"));
  520. std::string cls = cls_module + "." + cls_name;
  521. auto iterator = g_dataClassToClass.find(cls);
  522. if (iterator != g_dataClassToClass.end()) {
  523. return iterator->second;
  524. }
  525. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  526. ClassAttrVector attributes;
  527. py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj);
  528. for (auto &item : names) {
  529. auto type_value = item.second.cast<TypePtr>();
  530. MS_EXCEPTION_IF_NULL(type_value);
  531. MS_LOG(DEBUG) << "(Name: " << py::cast<std::string>(item.first) << ", type: " << type_value->ToString() << ")";
  532. attributes.push_back(std::make_pair(py::cast<std::string>(item.first), type_value));
  533. }
  534. std::unordered_map<std::string, ValuePtr> methods_map;
  535. py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj);
  536. for (auto &item : methods) {
  537. auto fun_name = item.first.cast<std::string>();
  538. auto obj = py::cast<py::object>(item.second);
  539. std::shared_ptr<PyObjectWrapper> method_obj = std::make_shared<PyObjectWrapper>(obj, fun_name);
  540. methods_map[fun_name] = method_obj;
  541. }
  542. std::shared_ptr<Class> me_class = std::make_shared<Class>(Named(cls_name), attributes, methods_map);
  543. // static Variable for cache
  544. // cppcheck-suppress unreadVariable
  545. g_dataClassToClass[cls] = me_class;
  546. return me_class;
  547. }
  548. void CleanDataClassToClassMap() { g_dataClassToClass.clear(); }
  549. } // namespace parse
  550. } // namespace mindspore