You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_converter.cc 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/parse/data_converter.h"
  19. #include <unordered_map>
  20. #include <utility>
  21. #include <string>
  22. #include <memory>
  23. #include <vector>
  24. #include "pipeline/jit/parse/resolve.h"
  25. #include "pipeline/jit/parse/python_adapter.h"
  26. #include "frontend/operator/ops.h"
  27. #include "frontend/operator/composite/composite.h"
  28. #include "ir/func_graph_cloner.h"
  29. #include "ir/cell.h"
  30. #include "utils/symbolic.h"
  31. #include "utils/ms_context.h"
  32. namespace mindspore {
  33. namespace parse {
  34. using Tensor = mindspore::tensor::Tensor;
  35. using TensorPtr = mindspore::tensor::TensorPtr;
  36. using MetaTensor = mindspore::tensor::MetaTensor;
  37. using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
  38. FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
  39. std::vector<std::string> results = data_converter::GetObjKey(obj);
  40. std::string obj_key = results[0];
  41. py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME);
  42. auto bprop_graph = std::make_shared<FuncGraph>();
  43. std::vector<AnfNodePtr> outputs;
  44. auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
  45. fake_bprop->set_hook(bprop_func);
  46. (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
  47. outputs.push_back(NewValueNode(fake_bprop));
  48. py::object code_obj = py::getattr(bprop_func, "__code__");
  49. size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3;
  50. for (size_t i = 0; i < inputs_num; ++i) {
  51. auto param = bprop_graph->add_parameter();
  52. outputs.push_back(param);
  53. }
  54. auto p1 = bprop_graph->add_parameter();
  55. auto p2 = bprop_graph->add_parameter();
  56. outputs.push_back(p1);
  57. outputs.push_back(p2);
  58. bprop_graph->set_output(bprop_graph->NewCNode(outputs));
  59. data_converter::SetObjGraphValue(obj_key, bprop_graph);
  60. return bprop_graph;
  61. }
  62. namespace {
  63. bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) {
  64. MS_LOG(DEBUG) << "Converting python tuple";
  65. auto tuple = obj.cast<py::tuple>();
  66. std::vector<ValuePtr> value_list;
  67. for (size_t it = 0; it < tuple.size(); ++it) {
  68. ValuePtr out = nullptr;
  69. bool success = ConvertData(tuple[it], &out, use_signature);
  70. if (!success) {
  71. return false;
  72. }
  73. value_list.push_back(out);
  74. }
  75. *data = std::make_shared<ValueTuple>(value_list);
  76. return true;
  77. }
  78. bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) {
  79. MS_LOG(DEBUG) << "Converting python list";
  80. auto list = obj.cast<py::list>();
  81. std::vector<ValuePtr> value_list;
  82. for (size_t it = 0; it < list.size(); ++it) {
  83. ValuePtr out = nullptr;
  84. bool success = ConvertData(list[it], &out, use_signature);
  85. if (!success) {
  86. return false;
  87. }
  88. value_list.push_back(out);
  89. }
  90. *data = std::make_shared<ValueList>(value_list);
  91. return true;
  92. }
  93. bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) {
  94. MS_LOG(DEBUG) << "Converting cell list";
  95. py::sequence list = obj;
  96. std::vector<ValuePtr> value_list;
  97. for (size_t it = 0; it < list.size(); ++it) {
  98. ValuePtr out = nullptr;
  99. bool success = ConvertData(list[it], &out, use_signature);
  100. if (!success) {
  101. return false;
  102. }
  103. value_list.push_back(out);
  104. }
  105. *data = std::make_shared<ValueTuple>(value_list);
  106. return true;
  107. }
  108. bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) {
  109. MS_LOG(DEBUG) << "Converting python dict";
  110. auto dict_values = obj.cast<py::dict>();
  111. std::vector<std::pair<std::string, ValuePtr>> key_values;
  112. for (auto item : dict_values) {
  113. if (!py::isinstance<py::str>(item.first)) {
  114. MS_LOG(ERROR) << "The key of dict is only support str.";
  115. return false;
  116. }
  117. std::string key = py::str(item.first);
  118. ValuePtr out = nullptr;
  119. bool success = ConvertData(dict_values[item.first], &out, use_signature);
  120. if (!success) {
  121. return false;
  122. }
  123. key_values.emplace_back(key, out);
  124. }
  125. *data = std::make_shared<ValueDictionary>(key_values);
  126. return true;
  127. }
  128. void ConvertNameSpace(const py::object &obj, ValuePtr *const data) {
  129. MS_LOG(DEBUG) << "Converting python module";
  130. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  131. py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
  132. *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace));
  133. }
  134. void ConvertDataClass(py::object obj, ValuePtr *const data) {
  135. MS_LOG(DEBUG) << "Converting dataclass";
  136. // Maybe the obj is dataclass define
  137. auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
  138. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  139. *data = std::make_shared<ClassObject>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  140. }
  141. bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) {
  142. MS_LOG(DEBUG) << "Converting primitive object" << use_signature;
  143. // need check the primitive is class type or instance
  144. auto obj_type = data_converter::GetObjType(obj);
  145. if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
  146. auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
  147. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  148. *data = std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  149. } else {
  150. auto primitive = obj.cast<PrimitivePyPtr>();
  151. if (primitive == nullptr) {
  152. MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null";
  153. return false;
  154. }
  155. if (py::hasattr(obj, "__setattr_flag__")) {
  156. if (py::hasattr(obj, "_clone")) {
  157. auto clone_fn = obj.attr("_clone");
  158. py::object new_obj = clone_fn();
  159. primitive = new_obj.cast<PrimitivePyPtr>();
  160. }
  161. }
  162. if (use_signature) {
  163. *data = std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive);
  164. } else {
  165. *data = primitive;
  166. }
  167. MS_LOG(DEBUG) << "Converting primitive object ok " << (*data)->ToString();
  168. }
  169. return true;
  170. }
  171. bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) {
  172. MS_LOG(DEBUG) << "Converting MetaFuncGraph object";
  173. auto meta = obj.cast<MetaFuncGraphPtr>();
  174. if (meta == nullptr) {
  175. MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null";
  176. return false;
  177. }
  178. if (use_signature) {
  179. *data = std::make_shared<prim::DoSignaturePrimitive>(meta->name(), meta);
  180. } else {
  181. *data = meta;
  182. }
  183. return true;
  184. }
  185. bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
  186. MS_LOG(DEBUG) << "Converting slice object";
  187. auto slice_obj = obj.cast<py::slice>();
  188. auto convert_func = [obj](std::string attr) -> ValuePtr {
  189. auto py_attr = py::getattr(obj, attr.c_str());
  190. if (py::isinstance<py::none>(py_attr)) {
  191. return kNone;
  192. } else if (py::isinstance<py::int_>(py_attr)) {
  193. int value = py::cast<int>(py_attr);
  194. return MakeValue(value);
  195. } else {
  196. MS_LOG(EXCEPTION) << "Slice should contain only int or none";
  197. }
  198. };
  199. ValuePtr start = convert_func("start");
  200. ValuePtr stop = convert_func("stop");
  201. ValuePtr step = convert_func("step");
  202. *data = std::make_shared<ValueSlice>(start, stop, step);
  203. return true;
  204. }
  205. bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) {
  206. auto obj = py::cast(cell);
  207. FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
  208. if (func_graph == nullptr) {
  209. MS_LOG(ERROR) << "Parse resolve function error.";
  210. return false;
  211. }
  212. // if the cell object has specified bprop, it has user-defined bprop function parse and record it
  213. if (py::hasattr(obj, CUSTOM_BPROP_NAME)) {
  214. FuncGraphPtr bprop_graph = nullptr;
  215. bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
  216. if (enable_bprop_debug) {
  217. bprop_graph = ConvertToBpropCut(obj);
  218. } else {
  219. bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
  220. }
  221. if (bprop_graph != nullptr) {
  222. (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
  223. (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
  224. func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
  225. }
  226. }
  227. *data = func_graph;
  228. return true;
  229. }
  230. bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
  231. auto obj_type = data_converter::GetObjType(obj);
  232. MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
  233. if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
  234. MS_LOG(DEBUG) << "Resolve the class type, need create class instance.";
  235. std::string desc = py::str(obj);
  236. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  237. *data = std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  238. return true;
  239. }
  240. if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
  241. MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
  242. FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
  243. if (func_graph == nullptr) {
  244. MS_LOG(ERROR) << "Parse resolve function error.";
  245. return false;
  246. }
  247. *data = func_graph;
  248. return true;
  249. }
  250. if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
  251. // Create the namespace for common class instance
  252. // When the obj is Cell, default parse the 'construct'
  253. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  254. py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
  255. *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
  256. return true;
  257. }
  258. MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
  259. return false;
  260. }
  261. template <typename T>
  262. bool ConvertNumberWithType(const T &obj, ValuePtr *const data, TypePtr dtype) {
  263. auto int_dypte = dyn_cast<Int>(dtype);
  264. if (int_dypte != nullptr) {
  265. switch (int_dypte->nbits()) {
  266. case 8:
  267. *data = std::make_shared<Int8Imm>(obj);
  268. break;
  269. case 16:
  270. *data = std::make_shared<Int16Imm>(obj);
  271. break;
  272. case 32:
  273. *data = std::make_shared<Int32Imm>(obj);
  274. break;
  275. case 64:
  276. *data = std::make_shared<Int64Imm>(obj);
  277. break;
  278. default:
  279. *data = std::make_shared<Int32Imm>(obj);
  280. }
  281. return true;
  282. }
  283. auto uint_dypte = dyn_cast<UInt>(dtype);
  284. if (uint_dypte != nullptr) {
  285. switch (uint_dypte->nbits()) {
  286. case 8:
  287. *data = std::make_shared<UInt8Imm>(obj);
  288. break;
  289. case 16:
  290. *data = std::make_shared<UInt16Imm>(obj);
  291. break;
  292. case 32:
  293. *data = std::make_shared<UInt32Imm>(obj);
  294. break;
  295. case 64:
  296. *data = std::make_shared<UInt64Imm>(obj);
  297. break;
  298. default:
  299. *data = std::make_shared<UInt32Imm>(obj);
  300. }
  301. return true;
  302. }
  303. auto float_dypte = dyn_cast<Float>(dtype);
  304. if (float_dypte != nullptr) {
  305. switch (float_dypte->nbits()) {
  306. case 32:
  307. *data = std::make_shared<FP32Imm>(obj);
  308. break;
  309. case 64:
  310. *data = std::make_shared<FP64Imm>(obj);
  311. break;
  312. default:
  313. *data = std::make_shared<FP32Imm>(obj);
  314. }
  315. return true;
  316. }
  317. return false;
  318. }
  319. bool ConvertIntegerWithType(const int &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
  320. if (dtype == nullptr) {
  321. *data = std::make_shared<Int32Imm>(obj);
  322. return true;
  323. }
  324. return ConvertNumberWithType<int>(obj, data, dtype);
  325. }
  326. bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
  327. if (dtype == nullptr) {
  328. *data = std::make_shared<FP32Imm>(obj);
  329. return true;
  330. }
  331. return ConvertNumberWithType<float>(obj, data, dtype);
  332. }
  333. } // namespace
  334. bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) {
  335. // check parameter valid
  336. if (data == nullptr) {
  337. MS_LOG(ERROR) << "Data is null pointer";
  338. return false;
  339. }
  340. bool ret = true;
  341. ValuePtr converted = nullptr;
  342. if (py::isinstance<py::none>(obj)) {
  343. converted = kNone;
  344. } else if (py::isinstance<py::bool_>(obj)) {
  345. converted = std::make_shared<BoolImm>(py::cast<bool>(obj));
  346. } else if (py::isinstance<py::int_>(obj)) {
  347. ret = ConvertIntegerWithType(py::cast<int>(obj), &converted, dtype);
  348. } else if (py::isinstance<py::float_>(obj)) {
  349. ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype);
  350. } else if (py::isinstance<py::str>(obj)) {
  351. converted = std::make_shared<StringImm>(py::cast<std::string>(obj));
  352. } else if (py::isinstance<py::dict>(obj)) {
  353. ret = ConvertDict(obj, &converted, use_signature);
  354. } else if (py::isinstance<py::slice>(obj)) {
  355. ret = ConvertSlice(obj, &converted);
  356. } else if (py::isinstance<py::ellipsis>(obj)) {
  357. converted = kEllipsis;
  358. } else if (py::isinstance<py::tuple>(obj)) {
  359. ret = ConvertTuple(obj, &converted, use_signature);
  360. } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
  361. ret = ConvertCellList(obj, &converted, use_signature);
  362. } else if (py::isinstance<Cell>(obj)) {
  363. return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data);
  364. } else if (py::isinstance<py::list>(obj)) {
  365. ret = ConvertList(obj, &converted, use_signature);
  366. } else if (py::isinstance<py::module>(obj)) {
  367. ConvertNameSpace(obj, &converted);
  368. } else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) {
  369. ConvertDataClass(obj, &converted);
  370. } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) {
  371. ret = ConvertPrimitive(obj, &converted, use_signature);
  372. } else if (py::isinstance<MetaFuncGraph>(obj)) {
  373. ret = ConvertMetaFuncGraph(obj, &converted, use_signature);
  374. } else if (py::isinstance<Type>(obj)) {
  375. converted = obj.cast<TypePtr>();
  376. } else if (py::isinstance<Tensor>(obj)) {
  377. converted = obj.cast<TensorPtr>();
  378. } else if (py::isinstance<MetaTensor>(obj)) {
  379. converted = obj.cast<MetaTensorPtr>();
  380. } else if (py::isinstance<EnvInstance>(obj)) {
  381. auto env = obj.cast<std::shared_ptr<EnvInstance>>();
  382. converted = env;
  383. } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
  384. converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
  385. } else {
  386. ret = ConvertOtherObj(obj, &converted);
  387. }
  388. *data = converted;
  389. return ret;
  390. }
  391. // convert data to graph
  392. FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) {
  393. std::vector<std::string> results = data_converter::GetObjKey(obj);
  394. std::string obj_id = results[0] + python_mod_get_parse_method;
  395. std::string obj_key = results[1];
  396. FuncGraphPtr func_graph = nullptr;
  397. ValuePtr value = nullptr;
  398. bool is_cache = data_converter::GetObjectValue(obj_id, &value);
  399. if (is_cache) {
  400. if (value && value->isa<FuncGraph>()) {
  401. MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id;
  402. func_graph = value->cast<FuncGraphPtr>();
  403. return func_graph;
  404. }
  405. }
  406. func_graph = ParsePythonCode(obj, python_mod_get_parse_method);
  407. if (func_graph == nullptr) {
  408. MS_LOG(ERROR) << "Parse resolve function error.";
  409. return nullptr;
  410. }
  411. data_converter::MakeProperNameToFuncGraph(func_graph, obj_id);
  412. data_converter::CacheObjectValue(obj_id, func_graph);
  413. if (!obj_key.empty()) {
  414. MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
  415. data_converter::SetObjGraphValue(obj_key, func_graph);
  416. }
  417. return func_graph;
  418. }
  419. namespace data_converter {
  420. static std::unordered_map<std::string, ValuePtr> object_map_;
  421. static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
  422. void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
  423. object_graphs_map_[obj_key].push_back(data);
  424. MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size();
  425. }
  426. const std::unordered_map<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
  427. MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size();
  428. return object_graphs_map_;
  429. }
  430. void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; }
  431. bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) {
  432. if (object_map_.count(obj_key)) {
  433. *data = object_map_[obj_key];
  434. return true;
  435. }
  436. return false;
  437. }
  438. std::vector<std::string> GetObjKey(const py::object &obj) {
  439. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  440. py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj);
  441. if (obj_tuple.size() != 2) {
  442. MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements";
  443. }
  444. return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
  445. }
  446. // get obj detail type
  447. ResolveTypeDef GetObjType(const py::object &obj) {
  448. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  449. auto obj_type =
  450. ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
  451. return obj_type;
  452. }
  453. // get class instance detail type
  454. ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
  455. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  456. auto class_type =
  457. ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast<int32_t>());
  458. return class_type;
  459. }
  460. // check the object is Cell Instance
  461. bool IsCellInstance(const py::object &obj) {
  462. auto class_type = GetClassInstanceType(obj);
  463. bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL);
  464. return isCell;
  465. }
  466. // create the python class instance
  467. py::object CreatePythonObject(const py::object &type, const py::tuple &params) {
  468. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  469. py::object obj;
  470. if (params.empty()) {
  471. obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type);
  472. } else {
  473. obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params);
  474. }
  475. return obj;
  476. }
  477. // Generate an appropriate name and set to graph debuginfo
  478. // character <> can not used in the dot file, so change to another symbol
  479. void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) {
  480. MS_EXCEPTION_IF_NULL(func_graph);
  481. MS_EXCEPTION_IF_NULL(func_graph->debug_info());
  482. // set detail name info of function
  483. std::ostringstream oss;
  484. for (size_t i = 0; i < name.size(); i++) {
  485. if (name[i] == '<') {
  486. oss << "「";
  487. } else if (name[i] == '>') {
  488. oss << "」";
  489. } else {
  490. oss << name[i];
  491. }
  492. }
  493. func_graph->debug_info()->set_full_name(oss.str());
  494. }
  495. ValuePtr PyDataToValue(const py::object &obj) {
  496. py::object to_convert = obj;
  497. ValuePtr value = nullptr;
  498. (void)ConvertData(to_convert, &value);
  499. return value;
  500. }
  501. void ClearObjectCache() {
  502. object_map_.clear();
  503. object_graphs_map_.clear();
  504. }
  505. } // namespace data_converter
  506. static std::unordered_map<std::string, ClassPtr> g_dataClassToClass = {};
  507. // parse dataclass to mindspore Class type
  508. ClassPtr ParseDataClass(const py::object &cls_obj) {
  509. std::string cls_name = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__name__"));
  510. std::string cls_module = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__module__"));
  511. std::string cls = cls_module + "." + cls_name;
  512. auto iterator = g_dataClassToClass.find(cls);
  513. if (iterator != g_dataClassToClass.end()) {
  514. return iterator->second;
  515. }
  516. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  517. ClassAttrVector attributes;
  518. py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj);
  519. for (auto &item : names) {
  520. auto type_value = item.second.cast<TypePtr>();
  521. MS_EXCEPTION_IF_NULL(type_value);
  522. MS_LOG(DEBUG) << "(Name: " << py::cast<std::string>(item.first) << ", type: " << type_value->ToString() << ")";
  523. attributes.push_back(std::make_pair(py::cast<std::string>(item.first), type_value));
  524. }
  525. std::unordered_map<std::string, ValuePtr> methods_map;
  526. py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj);
  527. for (auto &item : methods) {
  528. auto fun_name = item.first.cast<std::string>();
  529. auto obj = py::cast<py::object>(item.second);
  530. std::shared_ptr<PyObjectWrapper> method_obj = std::make_shared<PyObjectWrapper>(obj, fun_name);
  531. methods_map[fun_name] = method_obj;
  532. }
  533. std::shared_ptr<Class> me_class = std::make_shared<Class>(Named(cls_name), attributes, methods_map);
  534. // static Variable for cache
  535. // cppcheck-suppress unreadVariable
  536. g_dataClassToClass[cls] = me_class;
  537. return me_class;
  538. }
  539. void CleanDataClassToClassMap() { g_dataClassToClass.clear(); }
  540. } // namespace parse
  541. } // namespace mindspore