You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_converter.cc 28 kB

5 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/parse/data_converter.h"
  19. #include <utility>
  20. #include <string>
  21. #include <memory>
  22. #include <vector>
  23. #include "utils/hash_map.h"
  24. #include "pipeline/jit/parse/resolve.h"
  25. #include "pipeline/jit/parse/python_adapter.h"
  26. #include "frontend/operator/ops.h"
  27. #include "frontend/operator/composite/composite.h"
  28. #include "ir/func_graph_cloner.h"
  29. #include "ir/cell.h"
  30. #include "utils/symbolic.h"
  31. #include "utils/ms_context.h"
  32. #include "utils/utils.h"
  33. namespace mindspore {
  34. namespace parse {
  35. using Tensor = mindspore::tensor::Tensor;
  36. using TensorPtr = mindspore::tensor::TensorPtr;
  37. using MetaTensor = mindspore::tensor::MetaTensor;
  38. using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
  39. using InstanceCheckFunc = std::function<bool(const py::object &)>;
  40. using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>;
  41. static constexpr int kBit8 = 8;
  42. static constexpr int kBit16 = 16;
  43. static constexpr int kBit32 = 32;
  44. static constexpr int kBit64 = 64;
  45. class DataConverter {
  46. public:
  47. explicit DataConverter(InstanceConvertFunc convert_func) : convert_func_(std::move(convert_func)) {}
  48. virtual ~DataConverter() = default;
  49. virtual bool Matched(const py::object &obj) = 0;
  50. virtual ValuePtr ConvertPyObject(const py::object &obj, bool use_sig, const TypePtr &dtype) {
  51. if (convert_func_ == nullptr) {
  52. MS_LOG(EXCEPTION) << "convert func is null";
  53. }
  54. return convert_func_(obj, use_sig, dtype);
  55. }
  56. private:
  57. InstanceConvertFunc convert_func_ = nullptr;
  58. };
  59. using DataConverterPtr = std::shared_ptr<DataConverter>;
  60. using ArgsObjConvertFunc = std::function<ValuePtr(const py::object &)>;
  61. using ArgsObjSigConvertFunc = std::function<ValuePtr(const py::object &, bool)>;
  62. using ArgsOjbTypeConvertFunc = std::function<ValuePtr(const py::object &, const TypePtr &)>;
  63. // Convert the data according instance type
  64. template <typename T>
  65. class ByTypeDataConverter : public DataConverter {
  66. public:
  67. explicit ByTypeDataConverter(const InstanceConvertFunc &convert_func)
  68. : DataConverter(convert_func), check_func_(py::isinstance<T>) {}
  69. explicit ByTypeDataConverter(const ValuePtr &converted_type)
  70. : DataConverter(
  71. [converted_type](const py::object &, bool, const TypePtr &) -> ValuePtr { return converted_type; }),
  72. check_func_(py::isinstance<T>) {}
  73. explicit ByTypeDataConverter(const ArgsObjConvertFunc &convert_func)
  74. : DataConverter(
  75. [convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }),
  76. check_func_(py::isinstance<T>) {}
  77. explicit ByTypeDataConverter(const ArgsObjSigConvertFunc &convert_func)
  78. : DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr {
  79. return convert_func(obj, use_sig);
  80. }),
  81. check_func_(py::isinstance<T>) {}
  82. explicit ByTypeDataConverter(const ArgsOjbTypeConvertFunc &convert_func)
  83. : DataConverter([convert_func](const py::object &obj, bool, const TypePtr &dtype) -> ValuePtr {
  84. return convert_func(obj, dtype);
  85. }),
  86. check_func_(py::isinstance<T>) {}
  87. ~ByTypeDataConverter() override = default;
  88. bool Matched(const py::object &obj) override { return check_func_ != nullptr ? check_func_(obj) : false; }
  89. private:
  90. InstanceCheckFunc check_func_ = nullptr;
  91. };
  92. // Convert the data according object attribute.
  93. class ByAttrDataConverter : public DataConverter {
  94. public:
  95. ByAttrDataConverter(const char *attr_name, const ArgsObjConvertFunc &convert_func)
  96. : DataConverter(
  97. [convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }),
  98. attr_name_(attr_name) {}
  99. ByAttrDataConverter(const char *attr_name, const ArgsObjSigConvertFunc &convert_func)
  100. : DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr {
  101. return convert_func(obj, use_sig);
  102. }),
  103. attr_name_(attr_name) {}
  104. ~ByAttrDataConverter() override = default;
  105. bool Matched(const py::object &obj) override { return py::hasattr(obj, attr_name_); }
  106. private:
  107. const char *attr_name_ = nullptr;
  108. };
  109. FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
  110. std::vector<std::string> results = data_converter::GetObjKey(obj);
  111. std::string obj_key = results[0];
  112. py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME);
  113. auto bprop_graph = std::make_shared<FuncGraph>();
  114. std::vector<AnfNodePtr> outputs;
  115. auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut");
  116. fake_bprop->set_hook(bprop_func);
  117. (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
  118. outputs.push_back(NewValueNode(fake_bprop));
  119. py::object code_obj = py::getattr(bprop_func, "__code__");
  120. // Three parameters self, out and dout need to be excluded
  121. constexpr auto kBpropExcludeParamNum = 3;
  122. size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - kBpropExcludeParamNum;
  123. for (size_t i = 0; i < inputs_num; ++i) {
  124. auto param = bprop_graph->add_parameter();
  125. outputs.push_back(param);
  126. }
  127. auto p1 = bprop_graph->add_parameter();
  128. auto p2 = bprop_graph->add_parameter();
  129. outputs.push_back(p1);
  130. outputs.push_back(p2);
  131. bprop_graph->set_output(bprop_graph->NewCNode(std::move(outputs)));
  132. data_converter::SetObjGraphValue(obj_key, bprop_graph);
  133. return bprop_graph;
  134. }
  135. namespace {
  136. ValuePtr ConvertTuple(const py::object &obj, bool use_signature) {
  137. MS_LOG(DEBUG) << "Converting python tuple";
  138. auto tuple = obj.cast<py::tuple>();
  139. std::vector<ValuePtr> value_list;
  140. for (size_t it = 0; it < tuple.size(); ++it) {
  141. ValuePtr out = nullptr;
  142. bool success = ConvertData(tuple[it], &out, use_signature);
  143. if (!success) {
  144. return nullptr;
  145. }
  146. value_list.push_back(out);
  147. }
  148. return std::make_shared<ValueTuple>(value_list);
  149. }
  150. ValuePtr ConvertList(const py::object &obj, bool use_signature) {
  151. MS_LOG(DEBUG) << "Converting python list";
  152. auto list = obj.cast<py::list>();
  153. std::vector<ValuePtr> value_list;
  154. for (size_t it = 0; it < list.size(); ++it) {
  155. ValuePtr out = nullptr;
  156. bool success = ConvertData(list[it], &out, use_signature);
  157. if (!success) {
  158. return nullptr;
  159. }
  160. value_list.push_back(out);
  161. }
  162. return std::make_shared<ValueList>(value_list);
  163. }
  164. ValuePtr ConvertCellList(const py::object &obj, bool use_signature) {
  165. MS_LOG(DEBUG) << "Converting cell list";
  166. py::sequence list = obj;
  167. std::vector<ValuePtr> value_list;
  168. for (size_t it = 0; it < list.size(); ++it) {
  169. ValuePtr out = nullptr;
  170. bool success = ConvertData(list[it], &out, use_signature);
  171. if (!success) {
  172. return nullptr;
  173. }
  174. value_list.push_back(out);
  175. }
  176. return std::make_shared<ValueTuple>(value_list);
  177. }
  178. ValuePtr ConvertDict(const py::object &obj, bool use_signature) {
  179. MS_LOG(DEBUG) << "Converting python dict";
  180. auto dict_values = obj.cast<py::dict>();
  181. std::vector<std::pair<std::string, ValuePtr>> key_values;
  182. for (auto item : dict_values) {
  183. if (!py::isinstance<py::str>(item.first)) {
  184. MS_LOG(ERROR) << "The key of dict is only support str.";
  185. return nullptr;
  186. }
  187. std::string key = py::str(item.first);
  188. ValuePtr out = nullptr;
  189. bool success = ConvertData(dict_values[item.first], &out, use_signature);
  190. if (!success) {
  191. return nullptr;
  192. }
  193. key_values.emplace_back(key, out);
  194. }
  195. return std::make_shared<ValueDictionary>(key_values);
  196. }
  197. ValuePtr ConvertModuleNameSpace(const py::object &obj) {
  198. MS_LOG(DEBUG) << "Converting python module";
  199. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  200. py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
  201. auto converted =
  202. std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace), obj);
  203. MS_LOG(DEBUG) << "name_space: " << converted->ToString();
  204. return converted;
  205. }
  206. ValuePtr ConvertDataClass(const py::object &obj) {
  207. MS_LOG(DEBUG) << "Converting dataclass";
  208. // Maybe the obj is dataclass define
  209. auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
  210. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1
  211. auto converted = std::make_shared<ClassObject>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  212. return converted;
  213. }
  214. ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) {
  215. MS_LOG(DEBUG) << "Converting primitive object" << use_signature;
  216. // need check the primitive is class type or instance
  217. auto obj_type = data_converter::GetObjType(obj);
  218. if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
  219. auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
  220. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1.
  221. return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  222. }
  223. py::object adapter_obj = obj;
  224. if (py::hasattr(obj, "__setattr_flag__")) {
  225. if (py::hasattr(obj, "_clone")) {
  226. auto clone_fn = obj.attr("_clone");
  227. adapter_obj = clone_fn();
  228. }
  229. }
  230. auto prim_adapter = adapter_obj.cast<PrimitivePyAdapterPtr>();
  231. MS_EXCEPTION_IF_NULL(prim_adapter);
  232. auto primitive = prim_adapter->attached_primitive();
  233. if (primitive == nullptr) {
  234. primitive = std::make_shared<PrimitivePy>(adapter_obj, prim_adapter);
  235. prim_adapter->set_attached_primitive(primitive);
  236. }
  237. if (use_signature) {
  238. return std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive);
  239. }
  240. return primitive;
  241. }
  242. ValuePtr ConvertMetaFuncGraph(const py::object &obj, bool use_signature = false) {
  243. MS_LOG(DEBUG) << "Converting MetaFuncGraph object";
  244. auto meta = obj.cast<MetaFuncGraphPtr>();
  245. if (meta == nullptr) {
  246. MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null";
  247. return nullptr;
  248. }
  249. if (use_signature) {
  250. return std::make_shared<prim::DoSignaturePrimitive>(meta->name(), meta);
  251. }
  252. return meta;
  253. }
  254. ValuePtr ConvertFuncGraph(const py::object &obj) {
  255. MS_LOG(DEBUG) << "Converting FuncGraph object";
  256. auto func_graph = obj.cast<FuncGraphPtr>();
  257. if (func_graph == nullptr) {
  258. MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null";
  259. return nullptr;
  260. }
  261. auto new_fg = BasicClone(func_graph);
  262. new_fg->set_attr("is_load", MakeValue(true));
  263. return new_fg;
  264. }
  265. ValuePtr ConvertSlice(const py::object &obj) {
  266. MS_LOG(DEBUG) << "Converting slice object";
  267. auto convert_func = [obj](const std::string &attr) -> ValuePtr {
  268. auto py_attr = py::getattr(obj, attr.c_str());
  269. if (py::isinstance<py::none>(py_attr)) {
  270. return kNone;
  271. }
  272. if (py::isinstance<py::int_>(py_attr)) {
  273. auto value = py::cast<int64_t>(py_attr);
  274. return MakeValue(value);
  275. }
  276. if (py::isinstance<Tensor>(py_attr)) {
  277. return py::cast<TensorPtr>(py_attr);
  278. }
  279. MS_LOG(EXCEPTION) << "Attribute '" << attr << "' of " << py::str(obj)
  280. << " should be int or Tensor with Int type but got " << py::str(py_attr);
  281. };
  282. ValuePtr start = convert_func(kSliceStart);
  283. ValuePtr stop = convert_func(kSliceStop);
  284. ValuePtr step = convert_func(kSliceStep);
  285. return std::make_shared<ValueSlice>(start, stop, step);
  286. }
  287. ValuePtr ConvertCellObjToFuncGraph(const py::object &obj) {
  288. FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
  289. if (func_graph == nullptr) {
  290. MS_LOG(ERROR) << "Parse resolve function error.";
  291. return nullptr;
  292. }
  293. // if the cell object has specified bprop, it has user-defined bprop function parse and record it
  294. if (py::hasattr(obj, CUSTOM_BPROP_NAME)) {
  295. bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
  296. FuncGraphPtr bprop_graph =
  297. enable_bprop_debug ? ConvertToBpropCut(obj) : ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
  298. if (bprop_graph != nullptr) {
  299. (void)func_graph->transforms().emplace(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph));
  300. (void)bprop_graph->transforms().emplace("primal", FuncGraphTransform(func_graph));
  301. func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
  302. }
  303. }
  304. if (py::hasattr(obj, STAGE_NAME)) {
  305. auto stage = py::cast<int>(py::getattr(obj, STAGE_NAME));
  306. func_graph->set_stage(stage);
  307. }
  308. return func_graph;
  309. }
  310. ValuePtr ConvertOtherObj(const py::object &obj) {
  311. auto obj_type = data_converter::GetObjType(obj);
  312. MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
  313. if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
  314. MS_LOG(DEBUG) << "Resolve the class type, need create class instance.";
  315. std::string desc = py::str(obj);
  316. // desc has format "<class xxxx>", strip the '<' and '>' by offset 1.
  317. return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
  318. }
  319. if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
  320. MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
  321. FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
  322. if (func_graph == nullptr) {
  323. MS_LOG(ERROR) << "Parse resolve function error.";
  324. return nullptr;
  325. }
  326. return func_graph;
  327. }
  328. if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
  329. // Create the namespace for common class instance
  330. // When the obj is Cell, default parse the 'construct'
  331. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  332. py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
  333. auto res = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
  334. MS_LOG(DEBUG) << "name_space: " << res->ToString();
  335. return res;
  336. }
  337. // Start RESOLVE_TYPE_INVALID...
  338. // The fallback feature is enabled in default.
  339. // Not support change the flag during the process is alive.
  340. static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK");
  341. static const auto use_fallback = (support_fallback != "0");
  342. if (use_fallback) {
  343. auto res = std::make_shared<InterpretedObject>(obj, py::str(obj));
  344. MS_LOG(DEBUG) << "Get interpreted object: " << res->ToString();
  345. return res;
  346. }
  347. MS_LOG(ERROR) << "Resolve type is invalid, obj: " << py::str(obj);
  348. return nullptr;
  349. }
  350. template <typename T>
  351. ValuePtr ConvertNumberWithType(const T &obj, const TypePtr &dtype) {
  352. ValuePtr data = nullptr;
  353. auto int_dypte = dyn_cast<Int>(dtype);
  354. if (int_dypte != nullptr) {
  355. switch (int_dypte->nbits()) {
  356. case kBit8:
  357. data = std::make_shared<Int8Imm>(obj);
  358. break;
  359. case kBit16:
  360. data = std::make_shared<Int16Imm>(obj);
  361. break;
  362. case kBit32:
  363. data = std::make_shared<Int32Imm>(obj);
  364. break;
  365. case kBit64:
  366. data = std::make_shared<Int64Imm>(obj);
  367. break;
  368. default:
  369. data = std::make_shared<Int64Imm>(obj);
  370. }
  371. return data;
  372. }
  373. auto uint_dypte = dyn_cast<UInt>(dtype);
  374. if (uint_dypte != nullptr) {
  375. switch (uint_dypte->nbits()) {
  376. case kBit8:
  377. data = std::make_shared<UInt8Imm>(obj);
  378. break;
  379. case kBit16:
  380. data = std::make_shared<UInt16Imm>(obj);
  381. break;
  382. case kBit32:
  383. data = std::make_shared<UInt32Imm>(obj);
  384. break;
  385. case kBit64:
  386. data = std::make_shared<UInt64Imm>(obj);
  387. break;
  388. default:
  389. data = std::make_shared<UInt32Imm>(obj);
  390. }
  391. return data;
  392. }
  393. auto float_dypte = dyn_cast<Float>(dtype);
  394. if (float_dypte != nullptr) {
  395. switch (float_dypte->nbits()) {
  396. case kBit32:
  397. data = std::make_shared<FP32Imm>(obj);
  398. break;
  399. case kBit64:
  400. data = std::make_shared<FP64Imm>(obj);
  401. break;
  402. default:
  403. data = std::make_shared<FP32Imm>(obj);
  404. }
  405. return data;
  406. }
  407. return nullptr;
  408. }
  409. ValuePtr ConvertIntegerWithType(const py::object &obj, const TypePtr &dtype = nullptr) {
  410. auto obj_int64 = py::cast<int64_t>(obj);
  411. if (dtype == nullptr) {
  412. return std::make_shared<Int64Imm>(obj_int64);
  413. }
  414. return ConvertNumberWithType<int64_t>(obj_int64, dtype);
  415. }
  416. ValuePtr ConvertFloatWithType(const py::object &obj, const TypePtr &dtype = nullptr) {
  417. auto obj_float64 = py::cast<float>(obj);
  418. if (dtype == nullptr) {
  419. return std::make_shared<FP32Imm>(obj_float64);
  420. }
  421. return ConvertNumberWithType<float>(obj_float64, dtype);
  422. }
  423. template <typename T, typename U>
  424. ValuePtr PyCast(const py::object &obj) {
  425. return std::make_shared<T>(py::cast<U>(obj));
  426. }
  427. template <typename T>
  428. ValuePtr ObjCast(const py::object &obj) {
  429. return obj.cast<T>();
  430. }
  431. std::vector<DataConverterPtr> GetDataConverters() {
  432. static std::vector<DataConverterPtr> data_converters = {
  433. // Convert data by python object type.
  434. std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
  435. std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
  436. std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
  437. std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
  438. std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),
  439. std::make_shared<ByTypeDataConverter<py::int_>>(ConvertIntegerWithType),
  440. std::make_shared<ByTypeDataConverter<py::float_>>(ConvertFloatWithType),
  441. std::make_shared<ByTypeDataConverter<py::str>>(PyCast<StringImm, string>),
  442. std::make_shared<ByTypeDataConverter<py::none>>(kNone),
  443. std::make_shared<ByTypeDataConverter<py::ellipsis>>(kEllipsis),
  444. std::make_shared<ByTypeDataConverter<py::module>>(ConvertModuleNameSpace),
  445. std::make_shared<ByAttrDataConverter>(PYTHON_DATACLASS_FIELDS, ConvertDataClass),
  446. std::make_shared<ByTypeDataConverter<Type>>(ObjCast<TypePtr>),
  447. std::make_shared<ByTypeDataConverter<UMonad>>(ObjCast<UMonadPtr>),
  448. std::make_shared<ByTypeDataConverter<IOMonad>>(ObjCast<IOMonadPtr>),
  449. std::make_shared<ByTypeDataConverter<EnvInstance>>(ObjCast<std::shared_ptr<EnvInstance>>),
  450. std::make_shared<ByAttrDataConverter>(PYTHON_CLASS_MEMBER_NAMESPACE,
  451. [](const py::object &obj) -> ValuePtr {
  452. auto res =
  453. std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
  454. MS_LOG(DEBUG) << "name_space: " << res->ToString();
  455. return res;
  456. }),
  457. std::make_shared<ByTypeDataConverter<py::dict>>(ConvertDict),
  458. std::make_shared<ByTypeDataConverter<py::slice>>(ConvertSlice),
  459. std::make_shared<ByAttrDataConverter>(PYTHON_CELL_AS_LIST, ConvertCellList),
  460. std::make_shared<ByTypeDataConverter<Cell>>(ConvertCellObjToFuncGraph),
  461. std::make_shared<ByAttrDataConverter>(PYTHON_PRIMITIVE_FLAG, ConvertPrimitive),
  462. std::make_shared<ByTypeDataConverter<MetaFuncGraph>>(ConvertMetaFuncGraph),
  463. std::make_shared<ByTypeDataConverter<FuncGraph>>(ConvertFuncGraph),
  464. };
  465. return data_converters;
  466. }
  467. } // namespace
  468. bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) {
  469. // Check parameter valid
  470. if (data == nullptr) {
  471. MS_LOG(ERROR) << "Data is null pointer";
  472. return false;
  473. }
  474. ValuePtr converted = nullptr;
  475. bool matched = false;
  476. auto converters = GetDataConverters();
  477. for (auto &converter : converters) {
  478. if (converter->Matched(obj)) {
  479. converted = converter->ConvertPyObject(obj, use_signature, dtype);
  480. matched = true;
  481. break;
  482. }
  483. }
  484. if (!matched) {
  485. converted = ConvertOtherObj(obj);
  486. }
  487. *data = converted;
  488. return converted != nullptr;
  489. }
  490. // Convert data to graph
  491. FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) {
  492. std::vector<std::string> results = data_converter::GetObjKey(obj);
  493. std::string obj_id = results[0] + python_mod_get_parse_method;
  494. std::string obj_key = results[1];
  495. FuncGraphPtr func_graph = nullptr;
  496. ValuePtr value = nullptr;
  497. bool is_cache = data_converter::GetObjectValue(obj_id, &value);
  498. if (is_cache && value != nullptr && value->isa<FuncGraph>()) {
  499. MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id;
  500. func_graph = value->cast<FuncGraphPtr>();
  501. if (!func_graph->dropped()) {
  502. return func_graph;
  503. }
  504. }
  505. func_graph = ParsePythonCode(obj, python_mod_get_parse_method);
  506. if (func_graph == nullptr) {
  507. MS_LOG(ERROR) << "Parse resolve function error.";
  508. return nullptr;
  509. }
  510. data_converter::MakeProperNameToFuncGraph(func_graph, obj_id);
  511. data_converter::CacheObjectValue(obj_id, func_graph);
  512. if (!obj_key.empty()) {
  513. MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
  514. data_converter::SetObjGraphValue(obj_key, func_graph);
  515. }
  516. return func_graph;
  517. }
  518. namespace data_converter {
  519. static mindspore::HashMap<std::string, ValuePtr> object_map_;
  520. static mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
  521. void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
  522. object_graphs_map_[obj_key].push_back(data);
  523. MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size();
  524. }
  525. const mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
  526. MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size();
  527. return object_graphs_map_;
  528. }
  529. void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; }
  530. bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) {
  531. if (object_map_.count(obj_key)) {
  532. *data = object_map_[obj_key];
  533. return true;
  534. }
  535. return false;
  536. }
  537. std::vector<std::string> GetObjKey(const py::object &obj) {
  538. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  539. py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj);
  540. if (obj_tuple.size() != 2) {
  541. MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements";
  542. }
  543. return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
  544. }
  545. // Get obj detail type
  546. ResolveTypeDef GetObjType(const py::object &obj) {
  547. try {
  548. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  549. auto obj_type =
  550. ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
  551. return obj_type;
  552. } catch (const py::error_already_set &ex) {
  553. MS_LOG(ERROR) << "Meet a exception from Python when get the type of `" << py::str(obj) << "`.\n" << ex.what();
  554. std::rethrow_exception(std::current_exception());
  555. } catch (const py::type_error &ex) {
  556. MS_LOG(ERROR) << "Meet a exception when get the type of `" << py::str(obj) << "`.\n" << ex.what();
  557. std::rethrow_exception(std::current_exception());
  558. }
  559. }
  560. // Get class instance detail type.
  561. ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
  562. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  563. auto class_type =
  564. ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast<int32_t>());
  565. return class_type;
  566. }
  567. // Check the object is Cell Instance.
  568. bool IsCellInstance(const py::object &obj) {
  569. auto class_type = GetClassInstanceType(obj);
  570. bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL);
  571. return isCell;
  572. }
  573. // Create the python class instance.
  574. py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
  575. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  576. // `args_kwargs` maybe a tuple(*args), tuple(**kwargs), or tuple(*args, **kwargs).
  577. return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_INSTANCE, type)
  578. : python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_INSTANCE, type, args_kwargs);
  579. }
  580. // Call the python script string.
  581. py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs) {
  582. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  583. // `args_kwargs` is a tuple(dict(global), dict(local)).
  584. return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_EVAL_PY_SCRIPT, script)
  585. : python_adapter::CallPyModFn(mod, PYTHON_MOD_EVAL_PY_SCRIPT, script, args_kwargs);
  586. }
  587. // Generate an appropriate name and set to graph debuginfo,
  588. // character <> can not used in the dot file, so change to another symbol.
  589. void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) {
  590. MS_EXCEPTION_IF_NULL(func_graph);
  591. MS_EXCEPTION_IF_NULL(func_graph->debug_info());
  592. // Set detail name info of function
  593. std::ostringstream oss;
  594. for (size_t i = 0; i < name.size(); i++) {
  595. if (name[i] == '<') {
  596. oss << "「";
  597. } else if (name[i] == '>') {
  598. oss << "」";
  599. } else {
  600. oss << name[i];
  601. }
  602. }
  603. func_graph->debug_info()->set_full_name(oss.str());
  604. }
  605. ValuePtr PyDataToValue(const py::object &obj) {
  606. py::object to_convert = obj;
  607. ValuePtr value = nullptr;
  608. (void)ConvertData(to_convert, &value);
  609. return value;
  610. }
  611. void ClearObjectCache() {
  612. object_map_.clear();
  613. object_graphs_map_.clear();
  614. }
  615. } // namespace data_converter
  616. static mindspore::HashMap<std::string, ClassPtr> g_dataClassToClass = {};
  617. // Parse dataclass to mindspore Class type
  618. ClassPtr ParseDataClass(const py::object &cls_obj) {
  619. std::string cls_name = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__name__"));
  620. std::string cls_module = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__module__"));
  621. std::string cls = cls_module + "." + cls_name;
  622. auto iterator = g_dataClassToClass.find(cls);
  623. if (iterator != g_dataClassToClass.end()) {
  624. return iterator->second;
  625. }
  626. py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  627. ClassAttrVector attributes;
  628. py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj);
  629. for (auto &item : names) {
  630. auto type_value = item.second.cast<TypePtr>();
  631. MS_EXCEPTION_IF_NULL(type_value);
  632. MS_LOG(DEBUG) << "(Name: " << py::cast<std::string>(item.first) << ", type: " << type_value->ToString() << ")";
  633. attributes.push_back(std::make_pair(py::cast<std::string>(item.first), type_value));
  634. }
  635. mindspore::HashMap<std::string, ValuePtr> methods_map;
  636. py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj);
  637. for (auto &item : methods) {
  638. auto fun_name = item.first.cast<std::string>();
  639. auto obj = py::cast<py::object>(item.second);
  640. std::shared_ptr<PyObjectWrapper> method_obj = std::make_shared<PyObjectWrapper>(obj, fun_name);
  641. methods_map[fun_name] = method_obj;
  642. }
  643. std::shared_ptr<Class> me_class = std::make_shared<Class>(Named(cls_name), attributes, methods_map);
  644. // static Variable for cache
  645. // cppcheck-suppress unreadVariable
  646. g_dataClassToClass[cls] = me_class;
  647. return me_class;
  648. }
  649. void CleanDataClassToClassMap() { g_dataClassToClass.clear(); }
  650. } // namespace parse
  651. } // namespace mindspore