You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_converter.cc 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/parse/data_converter.h"
  19. #include <unordered_map>
  20. #include <utility>
  21. #include <string>
  22. #include <memory>
  23. #include <vector>
  24. #include "pipeline/jit/parse/resolve.h"
  25. #include "pipeline/jit/parse/python_adapter.h"
  26. #include "frontend/operator/ops.h"
  27. #include "frontend/operator/composite/composite.h"
  28. #include "ir/func_graph_cloner.h"
  29. #include "ir/cell.h"
  30. #include "utils/symbolic.h"
  31. #include "utils/ms_context.h"
  32. namespace mindspore {
  33. namespace parse {
  34. using Tensor = mindspore::tensor::Tensor;
  35. using TensorPtr = mindspore::tensor::TensorPtr;
  36. using MetaTensor = mindspore::tensor::MetaTensor;
  37. using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
  38. FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
  39. std::vector<std::string> results = data_converter::GetObjKey(obj);
  40. std::string obj_key = results[0];
  41. py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME);
  42. auto bprop_graph = std::make_shared<FuncGraph>();
  43. std::vector<AnfNodePtr> outputs;
  44. auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
  45. fake_bprop->set_hook(bprop_func);
  46. (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
  47. outputs.push_back(NewValueNode(fake_bprop));
  48. py::object code_obj = py::getattr(bprop_func, "__code__");
  49. // Three parameters self, out and dout need to be excluded
  50. size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
  51. for (size_t i = 0; i < inputs_num; ++i) {
  52. auto param = bprop_graph->add_parameter();
  53. outputs.push_back(param);
  54. }
  55. auto p1 = bprop_graph->add_parameter();
  56. auto p2 = bprop_graph->add_parameter();
  57. outputs.push_back(p1);
  58. outputs.push_back(p2);
  59. bprop_graph->set_output(bprop_graph->NewCNode(outputs));
  60. data_converter::SetObjGraphValue(obj_key, bprop_graph);
  61. return bprop_graph;
  62. }
  63. namespace {
  64. bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) {
  65. MS_LOG(DEBUG) << "Converting python tuple";
  66. auto tuple = obj.cast<py::tuple>();
  67. std::vector<ValuePtr> value_list;
  68. for (size_t it = 0; it < tuple.size(); ++it) {
  69. ValuePtr out = nullptr;
  70. bool success = ConvertData(tuple[it], &out, use_signature);
  71. if (!success) {
  72. return false;
  73. }
  74. value_list.push_back(out);
  75. }
  76. *data = std::make_shared<ValueTuple>(value_list);
  77. return true;
  78. }
  79. bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) {
  80. MS_LOG(DEBUG) << "Converting python list";
  81. auto list = obj.cast<py::list>();
  82. std::vector<ValuePtr> value_list;
  83. for (size_t it = 0; it < list.size(); ++it) {
  84. ValuePtr out = nullptr;
  85. bool success = ConvertData(list[it], &out, use_signature);
  86. if (!success) {
  87. return false;
  88. }
  89. value_list.push_back(out);
  90. }
  91. *data = std::make_shared<ValueList>(value_list);
  92. return true;
  93. }
  94. bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) {
  95. MS_LOG(DEBUG) << "Converting cell list";
  96. py::sequence list = obj;
  97. std::vector<ValuePtr> value_list;
  98. for (size_t it = 0; it < list.size(); ++it) {
  99. ValuePtr out = nullptr;
  100. bool success = ConvertData(list[it], &out, use_signature);
  101. if (!success) {
  102. return false;
  103. }
  104. value_list.push_back(out);
  105. }
  106. *data = std::make_shared<ValueTuple>(value_list);
  107. return true;
  108. }
  109. bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) {
  110. MS_LOG(DEBUG) << "Converting python dict";
  111. auto dict_values = obj.cast<py::dict>();
  112. std::vector<std::pair<std::string, ValuePtr>> key_values;
  113. for (auto item : dict_values) {
  114. if (!py::isinstance<py::str>(item.first)) {
  115. MS_LOG(ERROR) << "The key of dict is only support str.";
  116. return false;
  117. }
  118. std::string key = py::str(item.first);
  119. ValuePtr out = nullptr;
  120. bool success = ConvertData(dict_values[item.first], &out, use_signature);
  121. if (!success) {
  122. return false;
  123. }
  124. key_values.emplace_back(key, out);
  125. }
  126. *data = std::make_shared<ValueDictionary>(key_values);
  127. return true;
  128. }
  129. void ConvertNameSpace(const py::object &obj, ValuePtr *const data) {
  130. MS_LOG(DEBUG) << "Converting python module";
  131. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  132. py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
  133. *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace));
  134. }
  135. void ConvertDataClass(py::object obj, ValuePtr *const data) {
  136. MS_LOG(DEBUG) << "Converting dataclass";
  137. // Maybe the obj is dataclass define
  138. auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
  139. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  140. *data = std::make_shared<ClassObject>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  141. }
  142. bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) {
  143. MS_LOG(DEBUG) << "Converting primitive object" << use_signature;
  144. // need check the primitive is class type or instance
  145. auto obj_type = data_converter::GetObjType(obj);
  146. if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
  147. auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
  148. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  149. *data = std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  150. } else {
  151. auto primitive = obj.cast<PrimitivePyPtr>();
  152. if (primitive == nullptr) {
  153. MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null";
  154. return false;
  155. }
  156. if (py::hasattr(obj, "__setattr_flag__")) {
  157. if (py::hasattr(obj, "_clone")) {
  158. auto clone_fn = obj.attr("_clone");
  159. py::object new_obj = clone_fn();
  160. primitive = new_obj.cast<PrimitivePyPtr>();
  161. }
  162. }
  163. if (use_signature) {
  164. *data = std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive);
  165. } else {
  166. *data = primitive;
  167. }
  168. MS_LOG(DEBUG) << "Converting primitive object ok " << (*data)->ToString();
  169. }
  170. return true;
  171. }
  172. bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) {
  173. MS_LOG(DEBUG) << "Converting MetaFuncGraph object";
  174. auto meta = obj.cast<MetaFuncGraphPtr>();
  175. if (meta == nullptr) {
  176. MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null";
  177. return false;
  178. }
  179. if (use_signature) {
  180. *data = std::make_shared<prim::DoSignaturePrimitive>(meta->name(), meta);
  181. } else {
  182. *data = meta;
  183. }
  184. return true;
  185. }
  186. bool ConvertFuncGraph(const py::object &obj, ValuePtr *const data) {
  187. MS_LOG(DEBUG) << "Converting FuncGraph object";
  188. auto func_graph = obj.cast<FuncGraphPtr>();
  189. if (func_graph == nullptr) {
  190. MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null";
  191. return false;
  192. }
  193. auto new_fg = BasicClone(func_graph);
  194. new_fg->set_attr("is_load", MakeValue(true));
  195. *data = new_fg;
  196. return true;
  197. }
  198. bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
  199. MS_LOG(DEBUG) << "Converting slice object";
  200. auto slice_obj = obj.cast<py::slice>();
  201. auto convert_func = [obj](std::string attr) -> ValuePtr {
  202. auto py_attr = py::getattr(obj, attr.c_str());
  203. if (py::isinstance<py::none>(py_attr)) {
  204. return kNone;
  205. } else if (py::isinstance<py::int_>(py_attr)) {
  206. int64_t value = py::cast<int64_t>(py_attr);
  207. return MakeValue(value);
  208. } else {
  209. MS_LOG(EXCEPTION) << "Slice should contain only int64_t or none";
  210. }
  211. };
  212. ValuePtr start = convert_func("start");
  213. ValuePtr stop = convert_func("stop");
  214. ValuePtr step = convert_func("step");
  215. *data = std::make_shared<ValueSlice>(start, stop, step);
  216. return true;
  217. }
  218. bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) {
  219. auto obj = py::cast(cell);
  220. FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
  221. if (func_graph == nullptr) {
  222. MS_LOG(ERROR) << "Parse resolve function error.";
  223. return false;
  224. }
  225. // if the cell object has specified bprop, it has user-defined bprop function parse and record it
  226. if (py::hasattr(obj, CUSTOM_BPROP_NAME)) {
  227. FuncGraphPtr bprop_graph = nullptr;
  228. bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
  229. if (enable_bprop_debug) {
  230. bprop_graph = ConvertToBpropCut(obj);
  231. } else {
  232. bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
  233. }
  234. if (bprop_graph != nullptr) {
  235. (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
  236. (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
  237. func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
  238. }
  239. }
  240. if (py::hasattr(obj, STAGE_NAME)) {
  241. auto stage = py::cast<int>(py::getattr(obj, STAGE_NAME));
  242. func_graph->set_stage(stage);
  243. }
  244. *data = func_graph;
  245. return true;
  246. }
  247. bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
  248. auto obj_type = data_converter::GetObjType(obj);
  249. MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
  250. if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
  251. MS_LOG(DEBUG) << "Resolve the class type, need create class instance.";
  252. std::string desc = py::str(obj);
  253. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  254. *data = std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  255. return true;
  256. }
  257. if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
  258. MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
  259. FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
  260. if (func_graph == nullptr) {
  261. MS_LOG(ERROR) << "Parse resolve function error.";
  262. return false;
  263. }
  264. *data = func_graph;
  265. return true;
  266. }
  267. if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
  268. // Create the namespace for common class instance
  269. // When the obj is Cell, default parse the 'construct'
  270. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  271. py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
  272. *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
  273. return true;
  274. }
  275. MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
  276. return false;
  277. }
  278. template <typename T>
  279. bool ConvertNumberWithType(const T &obj, ValuePtr *const data, TypePtr dtype) {
  280. auto int_dypte = dyn_cast<Int>(dtype);
  281. if (int_dypte != nullptr) {
  282. switch (int_dypte->nbits()) {
  283. case 8:
  284. *data = std::make_shared<Int8Imm>(obj);
  285. break;
  286. case 16:
  287. *data = std::make_shared<Int16Imm>(obj);
  288. break;
  289. case 32:
  290. *data = std::make_shared<Int32Imm>(obj);
  291. break;
  292. case 64:
  293. *data = std::make_shared<Int64Imm>(obj);
  294. break;
  295. default:
  296. *data = std::make_shared<Int64Imm>(obj);
  297. }
  298. return true;
  299. }
  300. auto uint_dypte = dyn_cast<UInt>(dtype);
  301. if (uint_dypte != nullptr) {
  302. switch (uint_dypte->nbits()) {
  303. case 8:
  304. *data = std::make_shared<UInt8Imm>(obj);
  305. break;
  306. case 16:
  307. *data = std::make_shared<UInt16Imm>(obj);
  308. break;
  309. case 32:
  310. *data = std::make_shared<UInt32Imm>(obj);
  311. break;
  312. case 64:
  313. *data = std::make_shared<UInt64Imm>(obj);
  314. break;
  315. default:
  316. *data = std::make_shared<UInt32Imm>(obj);
  317. }
  318. return true;
  319. }
  320. auto float_dypte = dyn_cast<Float>(dtype);
  321. if (float_dypte != nullptr) {
  322. switch (float_dypte->nbits()) {
  323. case 32:
  324. *data = std::make_shared<FP32Imm>(obj);
  325. break;
  326. case 64:
  327. *data = std::make_shared<FP64Imm>(obj);
  328. break;
  329. default:
  330. *data = std::make_shared<FP32Imm>(obj);
  331. }
  332. return true;
  333. }
  334. return false;
  335. }
  336. bool ConvertIntegerWithType(const int64_t &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
  337. if (dtype == nullptr) {
  338. *data = std::make_shared<Int64Imm>(obj);
  339. return true;
  340. }
  341. return ConvertNumberWithType<int64_t>(obj, data, dtype);
  342. }
  343. bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
  344. if (dtype == nullptr) {
  345. *data = std::make_shared<FP32Imm>(obj);
  346. return true;
  347. }
  348. return ConvertNumberWithType<float>(obj, data, dtype);
  349. }
  350. } // namespace
  351. bool ConvertSingleData(const py::object &obj, ValuePtr *const data) {
  352. MS_EXCEPTION_IF_NULL(data);
  353. ValuePtr converted = nullptr;
  354. if (py::isinstance<py::none>(obj)) {
  355. converted = kNone;
  356. } else if (py::isinstance<py::bool_>(obj)) {
  357. converted = std::make_shared<BoolImm>(py::cast<bool>(obj));
  358. } else if (py::isinstance<py::str>(obj)) {
  359. converted = std::make_shared<StringImm>(py::cast<std::string>(obj));
  360. } else if (py::isinstance<py::ellipsis>(obj)) {
  361. converted = kEllipsis;
  362. } else if (py::isinstance<py::module>(obj)) {
  363. ConvertNameSpace(obj, &converted);
  364. } else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) {
  365. ConvertDataClass(obj, &converted);
  366. } else if (py::isinstance<Type>(obj)) {
  367. converted = obj.cast<TypePtr>();
  368. } else if (py::isinstance<Tensor>(obj)) {
  369. converted = obj.cast<TensorPtr>();
  370. } else if (py::isinstance<MetaTensor>(obj)) {
  371. converted = obj.cast<MetaTensorPtr>();
  372. } else if (py::isinstance<UMonad>(obj)) {
  373. converted = obj.cast<UMonadPtr>();
  374. } else if (py::isinstance<IOMonad>(obj)) {
  375. converted = obj.cast<IOMonadPtr>();
  376. } else if (py::isinstance<EnvInstance>(obj)) {
  377. auto env = obj.cast<std::shared_ptr<EnvInstance>>();
  378. converted = env;
  379. } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
  380. converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
  381. } else {
  382. return false;
  383. }
  384. *data = converted;
  385. return true;
  386. }
  387. bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) {
  388. // check parameter valid
  389. if (data == nullptr) {
  390. MS_LOG(ERROR) << "Data is null pointer";
  391. return false;
  392. }
  393. ValuePtr converted = nullptr;
  394. bool ret = ConvertSingleData(obj, &converted);
  395. if (ret) {
  396. *data = converted;
  397. return true;
  398. }
  399. if (py::isinstance<py::int_>(obj)) {
  400. ret = ConvertIntegerWithType(py::cast<int64_t>(obj), &converted, dtype);
  401. } else if (py::isinstance<py::float_>(obj)) {
  402. ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype);
  403. } else if (py::isinstance<py::dict>(obj)) {
  404. ret = ConvertDict(obj, &converted, use_signature);
  405. } else if (py::isinstance<py::slice>(obj)) {
  406. ret = ConvertSlice(obj, &converted);
  407. } else if (py::isinstance<py::tuple>(obj)) {
  408. ret = ConvertTuple(obj, &converted, use_signature);
  409. } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
  410. ret = ConvertCellList(obj, &converted, use_signature);
  411. } else if (py::isinstance<Cell>(obj)) {
  412. return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data);
  413. } else if (py::isinstance<py::list>(obj)) {
  414. ret = ConvertList(obj, &converted, use_signature);
  415. } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) {
  416. ret = ConvertPrimitive(obj, &converted, use_signature);
  417. } else if (py::isinstance<MetaFuncGraph>(obj)) {
  418. ret = ConvertMetaFuncGraph(obj, &converted, use_signature);
  419. } else if (py::isinstance<FuncGraph>(obj)) {
  420. ret = ConvertFuncGraph(obj, &converted);
  421. } else {
  422. ret = ConvertOtherObj(obj, &converted);
  423. }
  424. *data = converted;
  425. return ret;
  426. }
  427. // convert data to graph
  428. FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) {
  429. std::vector<std::string> results = data_converter::GetObjKey(obj);
  430. std::string obj_id = results[0] + python_mod_get_parse_method;
  431. std::string obj_key = results[1];
  432. FuncGraphPtr func_graph = nullptr;
  433. ValuePtr value = nullptr;
  434. bool is_cache = data_converter::GetObjectValue(obj_id, &value);
  435. if (is_cache) {
  436. if (value && value->isa<FuncGraph>()) {
  437. MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id;
  438. func_graph = value->cast<FuncGraphPtr>();
  439. return func_graph;
  440. }
  441. }
  442. func_graph = ParsePythonCode(obj, python_mod_get_parse_method);
  443. if (func_graph == nullptr) {
  444. MS_LOG(ERROR) << "Parse resolve function error.";
  445. return nullptr;
  446. }
  447. data_converter::MakeProperNameToFuncGraph(func_graph, obj_id);
  448. data_converter::CacheObjectValue(obj_id, func_graph);
  449. if (!obj_key.empty()) {
  450. MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
  451. data_converter::SetObjGraphValue(obj_key, func_graph);
  452. }
  453. return func_graph;
  454. }
  455. namespace data_converter {
  456. static std::unordered_map<std::string, ValuePtr> object_map_;
  457. static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
  458. void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
  459. object_graphs_map_[obj_key].push_back(data);
  460. MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size();
  461. }
  462. const std::unordered_map<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
  463. MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size();
  464. return object_graphs_map_;
  465. }
  466. void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; }
  467. bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) {
  468. if (object_map_.count(obj_key)) {
  469. *data = object_map_[obj_key];
  470. return true;
  471. }
  472. return false;
  473. }
  474. std::vector<std::string> GetObjKey(const py::object &obj) {
  475. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  476. py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj);
  477. if (obj_tuple.size() != 2) {
  478. MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements";
  479. }
  480. return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
  481. }
  482. // get obj detail type
  483. ResolveTypeDef GetObjType(const py::object &obj) {
  484. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  485. auto obj_type =
  486. ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
  487. return obj_type;
  488. }
  489. // get class instance detail type
  490. ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
  491. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  492. auto class_type =
  493. ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast<int32_t>());
  494. return class_type;
  495. }
  496. // check the object is Cell Instance
  497. bool IsCellInstance(const py::object &obj) {
  498. auto class_type = GetClassInstanceType(obj);
  499. bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL);
  500. return isCell;
  501. }
  502. // create the python class instance
  503. py::object CreatePythonObject(const py::object &type, const py::tuple &params) {
  504. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  505. py::object obj;
  506. if (params.empty()) {
  507. obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type);
  508. } else {
  509. obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params);
  510. }
  511. return obj;
  512. }
  513. // Generate an appropriate name and set to graph debuginfo
  514. // character <> can not used in the dot file, so change to another symbol
  515. void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) {
  516. MS_EXCEPTION_IF_NULL(func_graph);
  517. MS_EXCEPTION_IF_NULL(func_graph->debug_info());
  518. // set detail name info of function
  519. std::ostringstream oss;
  520. for (size_t i = 0; i < name.size(); i++) {
  521. if (name[i] == '<') {
  522. oss << "「";
  523. } else if (name[i] == '>') {
  524. oss << "」";
  525. } else {
  526. oss << name[i];
  527. }
  528. }
  529. func_graph->debug_info()->set_full_name(oss.str());
  530. }
  531. ValuePtr PyDataToValue(const py::object &obj) {
  532. py::object to_convert = obj;
  533. ValuePtr value = nullptr;
  534. (void)ConvertData(to_convert, &value);
  535. return value;
  536. }
  537. void ClearObjectCache() {
  538. object_map_.clear();
  539. object_graphs_map_.clear();
  540. }
  541. } // namespace data_converter
  542. static std::unordered_map<std::string, ClassPtr> g_dataClassToClass = {};
  543. // parse dataclass to mindspore Class type
  544. ClassPtr ParseDataClass(const py::object &cls_obj) {
  545. std::string cls_name = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__name__"));
  546. std::string cls_module = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__module__"));
  547. std::string cls = cls_module + "." + cls_name;
  548. auto iterator = g_dataClassToClass.find(cls);
  549. if (iterator != g_dataClassToClass.end()) {
  550. return iterator->second;
  551. }
  552. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  553. ClassAttrVector attributes;
  554. py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj);
  555. for (auto &item : names) {
  556. auto type_value = item.second.cast<TypePtr>();
  557. MS_EXCEPTION_IF_NULL(type_value);
  558. MS_LOG(DEBUG) << "(Name: " << py::cast<std::string>(item.first) << ", type: " << type_value->ToString() << ")";
  559. attributes.push_back(std::make_pair(py::cast<std::string>(item.first), type_value));
  560. }
  561. std::unordered_map<std::string, ValuePtr> methods_map;
  562. py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj);
  563. for (auto &item : methods) {
  564. auto fun_name = item.first.cast<std::string>();
  565. auto obj = py::cast<py::object>(item.second);
  566. std::shared_ptr<PyObjectWrapper> method_obj = std::make_shared<PyObjectWrapper>(obj, fun_name);
  567. methods_map[fun_name] = method_obj;
  568. }
  569. std::shared_ptr<Class> me_class = std::make_shared<Class>(Named(cls_name), attributes, methods_map);
  570. // static Variable for cache
  571. // cppcheck-suppress unreadVariable
  572. g_dataClassToClass[cls] = me_class;
  573. return me_class;
  574. }
  575. void CleanDataClassToClassMap() { g_dataClassToClass.clear(); }
  576. } // namespace parse
  577. } // namespace mindspore