You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_converter.cc 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/parse/data_converter.h"
  19. #include <unordered_map>
  20. #include <utility>
  21. #include <string>
  22. #include <memory>
  23. #include <vector>
  24. #include "pipeline/jit/parse/resolve.h"
  25. #include "pipeline/jit/parse/python_adapter.h"
  26. #include "frontend/operator/ops.h"
  27. #include "frontend/operator/composite/composite.h"
  28. #include "ir/func_graph_cloner.h"
  29. #include "ir/cell.h"
  30. #include "utils/symbolic.h"
  31. #include "utils/ms_context.h"
  32. namespace mindspore {
  33. namespace parse {
  34. using Tensor = mindspore::tensor::Tensor;
  35. using TensorPtr = mindspore::tensor::TensorPtr;
  36. using MetaTensor = mindspore::tensor::MetaTensor;
  37. using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
  38. using InstanceCheckFunc = std::function<bool(const py::object &)>;
  39. using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>;
  40. class DataConverter {
  41. public:
  42. explicit DataConverter(InstanceConvertFunc convert_func) : convert_func_(std::move(convert_func)) {}
  43. virtual ~DataConverter() = default;
  44. virtual bool Matched(const py::object &obj) = 0;
  45. virtual ValuePtr ConvertPyObject(const py::object &obj, bool use_sig, const TypePtr &dtype) {
  46. if (convert_func_ == nullptr) {
  47. MS_LOG(EXCEPTION) << "convert func is null";
  48. }
  49. return convert_func_(obj, use_sig, dtype);
  50. }
  51. private:
  52. InstanceConvertFunc convert_func_ = nullptr;
  53. };
  54. using DataConverterPtr = std::shared_ptr<DataConverter>;
  55. using ArgsObjConvertFunc = std::function<ValuePtr(const py::object &)>;
  56. using ArgsObjSigConvertFunc = std::function<ValuePtr(const py::object &, bool)>;
  57. using ArgsOjbTypeConvertFunc = std::function<ValuePtr(const py::object &, const TypePtr &)>;
  58. // Convert the data according instance type
  59. template <typename T>
  60. class ByTypeDataConverter : public DataConverter {
  61. public:
  62. explicit ByTypeDataConverter(const InstanceConvertFunc &convert_func)
  63. : DataConverter(convert_func), check_func_(py::isinstance<T>) {}
  64. explicit ByTypeDataConverter(const ValuePtr &converted_type)
  65. : DataConverter(
  66. [converted_type](const py::object &, bool, const TypePtr &) -> ValuePtr { return converted_type; }),
  67. check_func_(py::isinstance<T>) {}
  68. explicit ByTypeDataConverter(const ArgsObjConvertFunc &convert_func)
  69. : DataConverter(
  70. [convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }),
  71. check_func_(py::isinstance<T>) {}
  72. explicit ByTypeDataConverter(const ArgsObjSigConvertFunc &convert_func)
  73. : DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr {
  74. return convert_func(obj, use_sig);
  75. }),
  76. check_func_(py::isinstance<T>) {}
  77. explicit ByTypeDataConverter(const ArgsOjbTypeConvertFunc &convert_func)
  78. : DataConverter([convert_func](const py::object &obj, bool, const TypePtr &dtype) -> ValuePtr {
  79. return convert_func(obj, dtype);
  80. }),
  81. check_func_(py::isinstance<T>) {}
  82. ~ByTypeDataConverter() override = default;
  83. bool Matched(const py::object &obj) override { return check_func_ != nullptr ? check_func_(obj) : false; }
  84. private:
  85. InstanceCheckFunc check_func_ = nullptr;
  86. };
  87. // Convert the data according object attribute.
  88. class ByAttrDataConverter : public DataConverter {
  89. public:
  90. ByAttrDataConverter(const char *attr_name, const ArgsObjConvertFunc &convert_func)
  91. : DataConverter(
  92. [convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }),
  93. attr_name_(attr_name) {}
  94. ByAttrDataConverter(const char *attr_name, const ArgsObjSigConvertFunc &convert_func)
  95. : DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr {
  96. return convert_func(obj, use_sig);
  97. }),
  98. attr_name_(attr_name) {}
  99. bool Matched(const py::object &obj) override { return py::hasattr(obj, attr_name_); }
  100. private:
  101. const char *attr_name_ = nullptr;
  102. };
  103. FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
  104. std::vector<std::string> results = data_converter::GetObjKey(obj);
  105. std::string obj_key = results[0];
  106. py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME);
  107. auto bprop_graph = std::make_shared<FuncGraph>();
  108. std::vector<AnfNodePtr> outputs;
  109. auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut");
  110. fake_bprop->set_hook(bprop_func);
  111. (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
  112. outputs.push_back(NewValueNode(fake_bprop));
  113. py::object code_obj = py::getattr(bprop_func, "__code__");
  114. // Three parameters self, out and dout need to be excluded
  115. size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
  116. for (size_t i = 0; i < inputs_num; ++i) {
  117. auto param = bprop_graph->add_parameter();
  118. outputs.push_back(param);
  119. }
  120. auto p1 = bprop_graph->add_parameter();
  121. auto p2 = bprop_graph->add_parameter();
  122. outputs.push_back(p1);
  123. outputs.push_back(p2);
  124. bprop_graph->set_output(bprop_graph->NewCNode(outputs));
  125. data_converter::SetObjGraphValue(obj_key, bprop_graph);
  126. return bprop_graph;
  127. }
  128. namespace {
  129. ValuePtr ConvertTuple(const py::object &obj, bool use_signature) {
  130. MS_LOG(DEBUG) << "Converting python tuple";
  131. auto tuple = obj.cast<py::tuple>();
  132. std::vector<ValuePtr> value_list;
  133. for (size_t it = 0; it < tuple.size(); ++it) {
  134. ValuePtr out = nullptr;
  135. bool success = ConvertData(tuple[it], &out, use_signature);
  136. if (!success) {
  137. return nullptr;
  138. }
  139. value_list.push_back(out);
  140. }
  141. return std::make_shared<ValueTuple>(value_list);
  142. }
  143. ValuePtr ConvertList(const py::object &obj, bool use_signature) {
  144. MS_LOG(DEBUG) << "Converting python list";
  145. auto list = obj.cast<py::list>();
  146. std::vector<ValuePtr> value_list;
  147. for (size_t it = 0; it < list.size(); ++it) {
  148. ValuePtr out = nullptr;
  149. bool success = ConvertData(list[it], &out, use_signature);
  150. if (!success) {
  151. return nullptr;
  152. }
  153. value_list.push_back(out);
  154. }
  155. return std::make_shared<ValueList>(value_list);
  156. }
  157. ValuePtr ConvertCellList(const py::object &obj, bool use_signature) {
  158. MS_LOG(DEBUG) << "Converting cell list";
  159. py::sequence list = obj;
  160. std::vector<ValuePtr> value_list;
  161. for (size_t it = 0; it < list.size(); ++it) {
  162. ValuePtr out = nullptr;
  163. bool success = ConvertData(list[it], &out, use_signature);
  164. if (!success) {
  165. return nullptr;
  166. }
  167. value_list.push_back(out);
  168. }
  169. return std::make_shared<ValueTuple>(value_list);
  170. }
  171. ValuePtr ConvertDict(const py::object &obj, bool use_signature) {
  172. MS_LOG(DEBUG) << "Converting python dict";
  173. auto dict_values = obj.cast<py::dict>();
  174. std::vector<std::pair<std::string, ValuePtr>> key_values;
  175. for (auto item : dict_values) {
  176. if (!py::isinstance<py::str>(item.first)) {
  177. MS_LOG(ERROR) << "The key of dict is only support str.";
  178. return nullptr;
  179. }
  180. std::string key = py::str(item.first);
  181. ValuePtr out = nullptr;
  182. bool success = ConvertData(dict_values[item.first], &out, use_signature);
  183. if (!success) {
  184. return nullptr;
  185. }
  186. key_values.emplace_back(key, out);
  187. }
  188. return std::make_shared<ValueDictionary>(key_values);
  189. }
  190. ValuePtr ConvertNameSpace(const py::object &obj) {
  191. MS_LOG(DEBUG) << "Converting python module";
  192. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  193. py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
  194. auto converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace));
  195. return converted;
  196. }
  197. ValuePtr ConvertDataClass(const py::object &obj) {
  198. MS_LOG(DEBUG) << "Converting dataclass";
  199. // Maybe the obj is dataclass define
  200. auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
  201. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  202. auto converted = std::make_shared<ClassObject>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  203. return converted;
  204. }
  205. ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) {
  206. MS_LOG(DEBUG) << "Converting primitive object" << use_signature;
  207. // need check the primitive is class type or instance
  208. auto obj_type = data_converter::GetObjType(obj);
  209. if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
  210. auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
  211. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  212. return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  213. }
  214. py::object adapter_obj = obj;
  215. if (py::hasattr(obj, "__setattr_flag__")) {
  216. if (py::hasattr(obj, "_clone")) {
  217. auto clone_fn = obj.attr("_clone");
  218. adapter_obj = clone_fn();
  219. }
  220. }
  221. auto prim_adapter = adapter_obj.cast<PrimitivePyAdapterPtr>();
  222. MS_EXCEPTION_IF_NULL(prim_adapter);
  223. auto primitive = prim_adapter->attached_primitive();
  224. if (primitive == nullptr) {
  225. primitive = std::make_shared<PrimitivePy>(adapter_obj, prim_adapter);
  226. prim_adapter->set_attached_primitive(primitive);
  227. }
  228. if (use_signature) {
  229. return std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive);
  230. }
  231. return primitive;
  232. }
  233. ValuePtr ConvertMetaFuncGraph(const py::object &obj, bool use_signature = false) {
  234. MS_LOG(DEBUG) << "Converting MetaFuncGraph object";
  235. auto meta = obj.cast<MetaFuncGraphPtr>();
  236. if (meta == nullptr) {
  237. MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null";
  238. return nullptr;
  239. }
  240. if (use_signature) {
  241. return std::make_shared<prim::DoSignaturePrimitive>(meta->name(), meta);
  242. }
  243. return meta;
  244. }
  245. ValuePtr ConvertFuncGraph(const py::object &obj) {
  246. MS_LOG(DEBUG) << "Converting FuncGraph object";
  247. auto func_graph = obj.cast<FuncGraphPtr>();
  248. if (func_graph == nullptr) {
  249. MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null";
  250. return nullptr;
  251. }
  252. auto new_fg = BasicClone(func_graph);
  253. new_fg->set_attr("is_load", MakeValue(true));
  254. return new_fg;
  255. }
  256. ValuePtr ConvertSlice(const py::object &obj) {
  257. MS_LOG(DEBUG) << "Converting slice object";
  258. auto slice_obj = obj.cast<py::slice>();
  259. auto convert_func = [obj](const std::string &attr) -> ValuePtr {
  260. auto py_attr = py::getattr(obj, attr.c_str());
  261. if (py::isinstance<py::none>(py_attr)) {
  262. return kNone;
  263. }
  264. if (py::isinstance<py::int_>(py_attr)) {
  265. auto value = py::cast<int64_t>(py_attr);
  266. return MakeValue(value);
  267. }
  268. MS_LOG(EXCEPTION) << "Slice should contain only int64_t or none";
  269. };
  270. ValuePtr start = convert_func("start");
  271. ValuePtr stop = convert_func("stop");
  272. ValuePtr step = convert_func("step");
  273. return std::make_shared<ValueSlice>(start, stop, step);
  274. }
  275. ValuePtr ConvertCellObjToFuncGraph(const py::object &obj) {
  276. FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
  277. if (func_graph == nullptr) {
  278. MS_LOG(ERROR) << "Parse resolve function error.";
  279. return nullptr;
  280. }
  281. // if the cell object has specified bprop, it has user-defined bprop function parse and record it
  282. if (py::hasattr(obj, CUSTOM_BPROP_NAME)) {
  283. bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
  284. FuncGraphPtr bprop_graph =
  285. enable_bprop_debug ? ConvertToBpropCut(obj) : ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
  286. if (bprop_graph != nullptr) {
  287. (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
  288. (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
  289. func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
  290. }
  291. }
  292. if (py::hasattr(obj, STAGE_NAME)) {
  293. auto stage = py::cast<int>(py::getattr(obj, STAGE_NAME));
  294. func_graph->set_stage(stage);
  295. }
  296. return func_graph;
  297. }
  298. ValuePtr ConvertOtherObj(const py::object &obj) {
  299. auto obj_type = data_converter::GetObjType(obj);
  300. MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
  301. if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
  302. MS_LOG(DEBUG) << "Resolve the class type, need create class instance.";
  303. std::string desc = py::str(obj);
  304. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
  305. return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  306. }
  307. if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
  308. MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
  309. FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
  310. if (func_graph == nullptr) {
  311. MS_LOG(ERROR) << "Parse resolve function error.";
  312. return nullptr;
  313. }
  314. return func_graph;
  315. }
  316. if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
  317. // Create the namespace for common class instance
  318. // When the obj is Cell, default parse the 'construct'
  319. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  320. py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
  321. return std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
  322. }
  323. MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
  324. return nullptr;
  325. }
  326. template <typename T>
  327. ValuePtr ConvertNumberWithType(const T &obj, TypePtr dtype) {
  328. ValuePtr data = nullptr;
  329. auto int_dypte = dyn_cast<Int>(dtype);
  330. if (int_dypte != nullptr) {
  331. switch (int_dypte->nbits()) {
  332. case 8:
  333. data = std::make_shared<Int8Imm>(obj);
  334. break;
  335. case 16:
  336. data = std::make_shared<Int16Imm>(obj);
  337. break;
  338. case 32:
  339. data = std::make_shared<Int32Imm>(obj);
  340. break;
  341. case 64:
  342. data = std::make_shared<Int64Imm>(obj);
  343. break;
  344. default:
  345. data = std::make_shared<Int64Imm>(obj);
  346. }
  347. return data;
  348. }
  349. auto uint_dypte = dyn_cast<UInt>(dtype);
  350. if (uint_dypte != nullptr) {
  351. switch (uint_dypte->nbits()) {
  352. case 8:
  353. data = std::make_shared<UInt8Imm>(obj);
  354. break;
  355. case 16:
  356. data = std::make_shared<UInt16Imm>(obj);
  357. break;
  358. case 32:
  359. data = std::make_shared<UInt32Imm>(obj);
  360. break;
  361. case 64:
  362. data = std::make_shared<UInt64Imm>(obj);
  363. break;
  364. default:
  365. data = std::make_shared<UInt32Imm>(obj);
  366. }
  367. return data;
  368. }
  369. auto float_dypte = dyn_cast<Float>(dtype);
  370. if (float_dypte != nullptr) {
  371. switch (float_dypte->nbits()) {
  372. case 32:
  373. data = std::make_shared<FP32Imm>(obj);
  374. break;
  375. case 64:
  376. data = std::make_shared<FP64Imm>(obj);
  377. break;
  378. default:
  379. data = std::make_shared<FP32Imm>(obj);
  380. }
  381. return data;
  382. }
  383. return nullptr;
  384. }
  385. ValuePtr ConvertIntegerWithType(const py::object &obj, const TypePtr &dtype = nullptr) {
  386. auto obj_int64 = py::cast<int64_t>(obj);
  387. if (dtype == nullptr) {
  388. return std::make_shared<Int64Imm>(obj_int64);
  389. }
  390. return ConvertNumberWithType<int64_t>(obj_int64, dtype);
  391. }
  392. ValuePtr ConvertFloatWithType(const py::object &obj, const TypePtr &dtype = nullptr) {
  393. auto obj_float64 = py::cast<float>(obj);
  394. if (dtype == nullptr) {
  395. return std::make_shared<FP32Imm>(obj_float64);
  396. }
  397. return ConvertNumberWithType<float>(obj_float64, dtype);
  398. }
  399. template <typename T, typename U>
  400. ValuePtr PyCast(const py::object &obj) {
  401. return std::make_shared<T>(py::cast<U>(obj));
  402. }
  403. template <typename T>
  404. ValuePtr ObjCast(const py::object &obj) {
  405. return obj.cast<T>();
  406. }
  407. std::vector<DataConverterPtr> GetDataConverters() {
  408. static std::vector<DataConverterPtr> data_converters = {
  409. // Convert data by python object type.
  410. std::make_shared<ByTypeDataConverter<py::none>>(kNone),
  411. std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),
  412. std::make_shared<ByTypeDataConverter<py::str>>(PyCast<StringImm, string>),
  413. std::make_shared<ByTypeDataConverter<py::ellipsis>>(kEllipsis),
  414. std::make_shared<ByTypeDataConverter<py::module>>(ConvertNameSpace),
  415. std::make_shared<ByAttrDataConverter>(PYTHON_DATACLASS_FIELDS, ConvertDataClass),
  416. std::make_shared<ByTypeDataConverter<Type>>(ObjCast<TypePtr>),
  417. std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
  418. std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
  419. std::make_shared<ByTypeDataConverter<UMonad>>(ObjCast<UMonadPtr>),
  420. std::make_shared<ByTypeDataConverter<IOMonad>>(ObjCast<IOMonadPtr>),
  421. std::make_shared<ByTypeDataConverter<EnvInstance>>(ObjCast<std::shared_ptr<EnvInstance>>),
  422. std::make_shared<ByAttrDataConverter>(PYTHON_CLASS_MEMBER_NAMESPACE,
  423. [](const py::object &obj) -> ValuePtr {
  424. return std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER,
  425. obj);
  426. }),
  427. std::make_shared<ByTypeDataConverter<py::int_>>(ConvertIntegerWithType),
  428. std::make_shared<ByTypeDataConverter<py::float_>>(ConvertFloatWithType),
  429. std::make_shared<ByTypeDataConverter<py::dict>>(ConvertDict),
  430. std::make_shared<ByTypeDataConverter<py::slice>>(ConvertSlice),
  431. std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
  432. std::make_shared<ByAttrDataConverter>(PYTHON_CELL_AS_LIST, ConvertCellList),
  433. std::make_shared<ByTypeDataConverter<Cell>>(ConvertCellObjToFuncGraph),
  434. std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
  435. std::make_shared<ByAttrDataConverter>(PYTHON_PRIMITIVE_FLAG, ConvertPrimitive),
  436. std::make_shared<ByTypeDataConverter<MetaFuncGraph>>(ConvertMetaFuncGraph),
  437. std::make_shared<ByTypeDataConverter<FuncGraph>>(ConvertFuncGraph),
  438. };
  439. return data_converters;
  440. }
  441. } // namespace
  442. bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) {
  443. // Check parameter valid
  444. if (data == nullptr) {
  445. MS_LOG(ERROR) << "Data is null pointer";
  446. return false;
  447. }
  448. ValuePtr converted = nullptr;
  449. bool matched = false;
  450. auto &&converters = GetDataConverters();
  451. for (auto &converter : converters) {
  452. if (converter->Matched(obj)) {
  453. converted = converter->ConvertPyObject(obj, use_signature, dtype);
  454. matched = true;
  455. break;
  456. }
  457. }
  458. if (!matched) {
  459. converted = ConvertOtherObj(obj);
  460. }
  461. *data = converted;
  462. return converted != nullptr;
  463. }
  464. // Convert data to graph
  465. FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) {
  466. std::vector<std::string> results = data_converter::GetObjKey(obj);
  467. std::string obj_id = results[0] + python_mod_get_parse_method;
  468. std::string obj_key = results[1];
  469. FuncGraphPtr func_graph = nullptr;
  470. ValuePtr value = nullptr;
  471. bool is_cache = data_converter::GetObjectValue(obj_id, &value);
  472. if (is_cache && value != nullptr && value->isa<FuncGraph>()) {
  473. MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id;
  474. func_graph = value->cast<FuncGraphPtr>();
  475. if (!func_graph->dropped()) {
  476. return func_graph;
  477. }
  478. }
  479. func_graph = ParsePythonCode(obj, python_mod_get_parse_method);
  480. if (func_graph == nullptr) {
  481. MS_LOG(ERROR) << "Parse resolve function error.";
  482. return nullptr;
  483. }
  484. data_converter::MakeProperNameToFuncGraph(func_graph, obj_id);
  485. data_converter::CacheObjectValue(obj_id, func_graph);
  486. if (!obj_key.empty()) {
  487. MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
  488. data_converter::SetObjGraphValue(obj_key, func_graph);
  489. }
  490. return func_graph;
  491. }
  492. namespace data_converter {
  493. static std::unordered_map<std::string, ValuePtr> object_map_;
  494. static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
  495. void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
  496. object_graphs_map_[obj_key].push_back(data);
  497. MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size();
  498. }
  499. const std::unordered_map<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
  500. MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size();
  501. return object_graphs_map_;
  502. }
  503. void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; }
  504. bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) {
  505. if (object_map_.count(obj_key)) {
  506. *data = object_map_[obj_key];
  507. return true;
  508. }
  509. return false;
  510. }
  511. std::vector<std::string> GetObjKey(const py::object &obj) {
  512. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  513. py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj);
  514. if (obj_tuple.size() != 2) {
  515. MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements";
  516. }
  517. return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
  518. }
  519. // Get obj detail type
  520. ResolveTypeDef GetObjType(const py::object &obj) {
  521. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  522. auto obj_type =
  523. ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
  524. return obj_type;
  525. }
  526. // Get class instance detail type.
  527. ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
  528. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  529. auto class_type =
  530. ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast<int32_t>());
  531. return class_type;
  532. }
  533. // Check the object is Cell Instance.
  534. bool IsCellInstance(const py::object &obj) {
  535. auto class_type = GetClassInstanceType(obj);
  536. bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL);
  537. return isCell;
  538. }
  539. // Create the python class instance.
  540. py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
  541. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  542. // `args_kwargs` maybe a tuple(*args), tuple(**kwargs), or tuple(*args, **kwargs).
  543. return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type)
  544. : python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, args_kwargs);
  545. }
  546. // Generate an appropriate name and set to graph debuginfo,
  547. // character <> can not used in the dot file, so change to another symbol.
  548. void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) {
  549. MS_EXCEPTION_IF_NULL(func_graph);
  550. MS_EXCEPTION_IF_NULL(func_graph->debug_info());
  551. // Set detail name info of function
  552. std::ostringstream oss;
  553. for (size_t i = 0; i < name.size(); i++) {
  554. if (name[i] == '<') {
  555. oss << "「";
  556. } else if (name[i] == '>') {
  557. oss << "」";
  558. } else {
  559. oss << name[i];
  560. }
  561. }
  562. func_graph->debug_info()->set_full_name(oss.str());
  563. }
  564. ValuePtr PyDataToValue(const py::object &obj) {
  565. py::object to_convert = obj;
  566. ValuePtr value = nullptr;
  567. (void)ConvertData(to_convert, &value);
  568. return value;
  569. }
  570. void ClearObjectCache() {
  571. object_map_.clear();
  572. object_graphs_map_.clear();
  573. }
  574. } // namespace data_converter
  575. static std::unordered_map<std::string, ClassPtr> g_dataClassToClass = {};
  576. // Parse dataclass to mindspore Class type
  577. ClassPtr ParseDataClass(const py::object &cls_obj) {
  578. std::string cls_name = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__name__"));
  579. std::string cls_module = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__module__"));
  580. std::string cls = cls_module + "." + cls_name;
  581. auto iterator = g_dataClassToClass.find(cls);
  582. if (iterator != g_dataClassToClass.end()) {
  583. return iterator->second;
  584. }
  585. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  586. ClassAttrVector attributes;
  587. py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj);
  588. for (auto &item : names) {
  589. auto type_value = item.second.cast<TypePtr>();
  590. MS_EXCEPTION_IF_NULL(type_value);
  591. MS_LOG(DEBUG) << "(Name: " << py::cast<std::string>(item.first) << ", type: " << type_value->ToString() << ")";
  592. attributes.push_back(std::make_pair(py::cast<std::string>(item.first), type_value));
  593. }
  594. std::unordered_map<std::string, ValuePtr> methods_map;
  595. py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj);
  596. for (auto &item : methods) {
  597. auto fun_name = item.first.cast<std::string>();
  598. auto obj = py::cast<py::object>(item.second);
  599. std::shared_ptr<PyObjectWrapper> method_obj = std::make_shared<PyObjectWrapper>(obj, fun_name);
  600. methods_map[fun_name] = method_obj;
  601. }
  602. std::shared_ptr<Class> me_class = std::make_shared<Class>(Named(cls_name), attributes, methods_map);
  603. // static Variable for cache
  604. // cppcheck-suppress unreadVariable
  605. g_dataClassToClass[cls] = me_class;
  606. return me_class;
  607. }
  608. void CleanDataClassToClassMap() { g_dataClassToClass.clear(); }
  609. } // namespace parse
  610. } // namespace mindspore