You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils_py.cc 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/convert_utils_py.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <algorithm>
  21. #include <list>
  22. #include <utility>
  23. #include <cfloat>
  24. #include "abstract/abstract_value.h"
  25. #include "abstract/utils.h"
  26. #include "pipeline/jit/parse/parse.h"
  27. #include "pipeline/jit/parse/parse_base.h"
  28. #include "pipeline/jit/parse/resolve.h"
  29. #include "ir/value.h"
  30. #include "ir/anf.h"
  31. #include "ir/tensor.h"
  32. #include "ir/param_info.h"
  33. #include "pybind_api/ir/base_ref_py.h"
  34. #include "ir/dtype/tensor_type.h"
  35. #include "utils/ms_context.h"
  36. #include "utils/convert_utils.h"
  37. #include "backend/session/anf_runtime_algorithm.h"
  38. namespace mindspore {
  39. py::object BuiltinsToPyData(const Any &value);
  40. py::object BuiltinsToPyData(const BaseRef &value);
  41. py::object VectorToPyData(const Any &value);
  42. py::object VectorRefToPyData(const VectorRef &value_list);
  43. py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &output);
  44. // Wrap VectorRef to CSRTensor
  45. py::object MakeCSRTensor(const VectorRef &value_list);
  46. py::object MakeCOOTensor(const VectorRef &value_list);
  47. ShapeVector ConvertToShapeVector(const ValuePtr &shape_ptr, const VectorRef &value_list, size_t shape_idx);
  48. py::object CSRTensorToPyData(const tensor::CSRTensorPtr &csr_tensor) {
  49. auto ref = py::tuple(1);
  50. ref[0] = csr_tensor;
  51. return ref[0];
  52. }
  53. py::object TensorToPyData(const tensor::TensorPtr &tensor) {
  54. MS_EXCEPTION_IF_NULL(tensor);
  55. if (tensor->NeedWait()) {
  56. py::gil_scoped_release release;
  57. tensor->Wait();
  58. }
  59. py::tuple v(1);
  60. v[0] = tensor;
  61. return v[0];
  62. }
  63. py::object ScalarPtrToPyData(const ScalarPtr &value) {
  64. py::int_ int_v;
  65. py::float_ float_v;
  66. py::bool_ bool_v;
  67. TypeId scalar_type = value->type()->type_id();
  68. switch (scalar_type) {
  69. case kNumberTypeUInt8:
  70. MS_LOG(DEBUG) << "uint8";
  71. int_v = value->cast<UInt8ImmPtr>()->value();
  72. return std::move(int_v);
  73. case kNumberTypeUInt16:
  74. MS_LOG(DEBUG) << "uint16";
  75. int_v = value->cast<UInt16ImmPtr>()->value();
  76. return std::move(int_v);
  77. case kNumberTypeUInt32:
  78. MS_LOG(DEBUG) << "uint32";
  79. int_v = value->cast<UInt32ImmPtr>()->value();
  80. return std::move(int_v);
  81. case kNumberTypeUInt64:
  82. MS_LOG(DEBUG) << "uint64";
  83. int_v = value->cast<UInt64ImmPtr>()->value();
  84. return std::move(int_v);
  85. case kNumberTypeInt8:
  86. MS_LOG(DEBUG) << "int8";
  87. int_v = value->cast<Int8ImmPtr>()->value();
  88. return std::move(int_v);
  89. case kNumberTypeInt16:
  90. MS_LOG(DEBUG) << "int16";
  91. int_v = value->cast<Int16ImmPtr>()->value();
  92. return std::move(int_v);
  93. case kNumberTypeInt32:
  94. MS_LOG(DEBUG) << "int32";
  95. int_v = value->cast<Int32ImmPtr>()->value();
  96. return std::move(int_v);
  97. case kNumberTypeInt64:
  98. MS_LOG(DEBUG) << "int64";
  99. int_v = value->cast<Int64ImmPtr>()->value();
  100. return std::move(int_v);
  101. case kNumberTypeFloat32:
  102. MS_LOG(DEBUG) << "float";
  103. float_v = value->cast<FP32ImmPtr>()->value();
  104. return std::move(float_v);
  105. case kNumberTypeFloat64:
  106. MS_LOG(DEBUG) << "double";
  107. float_v = value->cast<FP64ImmPtr>()->value();
  108. return std::move(float_v);
  109. case kNumberTypeBool:
  110. MS_LOG(DEBUG) << "bool";
  111. bool_v = value->cast<BoolImmPtr>()->value();
  112. return std::move(bool_v);
  113. default:
  114. MS_EXCEPTION(TypeError) << "Unsupported scalar converted to py data: " << value->ToString();
  115. }
  116. }
  117. using ConverterFunction = std::function<py::object(const ValuePtr &value)>;
  118. using ValueNameToConverterVector = std::vector<std::pair<uint32_t, ConverterFunction>>;
  119. // (Value Type Name) -> (Converter Function)
  120. // The converter function is used to convert Value object to Python data object.
  121. static ValueNameToConverterVector value_name_to_converter = {
  122. // Scalar
  123. {Scalar::kTypeId, [](const ValuePtr &value) -> py::object { return ScalarPtrToPyData(value->cast<ScalarPtr>()); }},
  124. // Tensor
  125. {tensor::Tensor::kTypeId,
  126. [](const ValuePtr &value) -> py::object {
  127. auto tensor_ptr = value->cast<tensor::TensorPtr>();
  128. return TensorToPyData(tensor_ptr);
  129. }},
  130. // MetaTenser
  131. {tensor::MetaTensor::kTypeId,
  132. [](const ValuePtr &value) -> py::object {
  133. py::tuple tuple_container(1);
  134. tuple_container[0] = value->cast<tensor::MetaTensorPtr>();
  135. return tuple_container[0];
  136. }},
  137. // CSRTensor
  138. {tensor::CSRTensor::kTypeId,
  139. [](const ValuePtr &value) -> py::object {
  140. auto csr_tensor_ptr = value->cast<tensor::CSRTensorPtr>();
  141. return CSRTensorToPyData(csr_tensor_ptr);
  142. }},
  143. // RefKey
  144. {RefKey::kTypeId,
  145. [](const ValuePtr &value) -> py::object {
  146. py::tuple tuple_container(1);
  147. tuple_container[0] = value->cast<RefKeyPtr>();
  148. return tuple_container[0];
  149. }},
  150. // Type
  151. {Type::kTypeId,
  152. [](const ValuePtr &value) -> py::object {
  153. py::tuple tuple_container(1);
  154. tuple_container[0] = value->cast<TypePtr>();
  155. return tuple_container[0];
  156. }},
  157. // StringImm
  158. {StringImm::kTypeId,
  159. [](const ValuePtr &value) -> py::object {
  160. py::str res = value->cast<StringImmPtr>()->value();
  161. return res;
  162. }},
  163. // ValueSequence
  164. {ValueSequence::kTypeId,
  165. [](const ValuePtr &value) -> py::object {
  166. auto value_sequeue = value->cast<ValueSequencePtr>()->value();
  167. py::tuple res_sequeue(value_sequeue.size());
  168. for (size_t i = 0; i < value_sequeue.size(); i++) {
  169. res_sequeue[i] = ValueToPyData(value_sequeue[i]);
  170. }
  171. if (value->isa<ValueTuple>()) {
  172. return res_sequeue;
  173. }
  174. return res_sequeue.cast<py::list>();
  175. }},
  176. // ValueDictionary
  177. {ValueDictionary::kTypeId,
  178. [](const ValuePtr &value) -> py::object {
  179. auto value_list = value->cast<ValueDictionaryPtr>()->value();
  180. py::dict res_dict;
  181. for (const auto &value : value_list) {
  182. res_dict[py::str(value.first)] = ValueToPyData(value.second);
  183. }
  184. return res_dict;
  185. }},
  186. // ValueSlice
  187. {ValueSlice::kTypeId,
  188. [](const ValuePtr &value) -> py::object {
  189. auto slice = value->cast<ValueSlicePtr>();
  190. auto start = ValueToPyData(slice->start());
  191. auto end = ValueToPyData(slice->stop());
  192. auto step = ValueToPyData(slice->step());
  193. return parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
  194. step);
  195. }},
  196. // KeywordArg
  197. {KeywordArg::kTypeId,
  198. [](const ValuePtr &value) -> py::object {
  199. auto abs_keyword_arg = value->ToAbstract()->cast<abstract::AbstractKeywordArgPtr>();
  200. auto key = abs_keyword_arg->get_key();
  201. auto val = abs_keyword_arg->get_arg()->BuildValue();
  202. auto py_value = ValueToPyData(val);
  203. auto kwargs = py::kwargs();
  204. kwargs[key.c_str()] = py_value;
  205. return kwargs;
  206. }},
  207. // parse::NameSpace
  208. {parse::NameSpace::kTypeId,
  209. [](const ValuePtr &value) -> py::object {
  210. auto ns = value->cast<parse::NameSpacePtr>();
  211. return ns->module_obj();
  212. }},
  213. // parse::ClassType
  214. {parse::ClassType::kTypeId,
  215. [](const ValuePtr &value) -> py::object {
  216. auto class_type = value->cast<parse::ClassTypePtr>();
  217. return class_type->obj();
  218. }},
  219. // parse::InterpretedObject
  220. {parse::InterpretedObject::kTypeId,
  221. [](const ValuePtr &value) -> py::object {
  222. auto interpreted_object = value->cast<parse::InterpretedObjectPtr>();
  223. return interpreted_object->obj();
  224. }},
  225. // None
  226. {None::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  227. // AnyValue
  228. {AnyValue::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  229. // FuncGraph
  230. {FuncGraph::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  231. // Monad
  232. {Monad::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  233. // Ellipsis
  234. {Ellipsis::kTypeId, [](const ValuePtr &value) -> py::object { return py::ellipsis(); }}};
  235. py::object ValueToPyData(const ValuePtr &value) {
  236. if (value == nullptr) {
  237. MS_LOG(EXCEPTION) << "The `value` should not be null";
  238. }
  239. for (auto &iter : value_name_to_converter) {
  240. if (value->IsFromTypeId(iter.first)) {
  241. return iter.second(value);
  242. }
  243. }
  244. MS_LOG(EXCEPTION) << "Unsupported to convert " << value->ToString() << "[" << value->type_name() << "] to a PyData";
  245. }
  246. py::object AnyToPyData(const Any &value) {
  247. py::object ret;
  248. MS_LOG(DEBUG) << "AnyToPyData " << value.GetString();
  249. if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) {
  250. ret = BuiltinsToPyData(value);
  251. } else if (value.is<ValuePtr>()) {
  252. MS_LOG(DEBUG) << "ValuePtr";
  253. ValuePtr v = value.cast<ValuePtr>();
  254. ret = ValueToPyData(v);
  255. } else if (value.is<tensor::TensorPtr>()) {
  256. MS_LOG(DEBUG) << "tensor";
  257. auto tensor_ptr = value.cast<tensor::TensorPtr>();
  258. ret = TensorToPyData(tensor_ptr);
  259. } else if (value.is<py::object>()) {
  260. MS_LOG(DEBUG) << "py obj";
  261. ret = value.cast<py::object>();
  262. } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) {
  263. ret = VectorToPyData(value);
  264. } else if (value.is<std::list<Any>>()) {
  265. MS_LOG(DEBUG) << "list_any";
  266. auto value_list = value.cast<std::list<Any>>();
  267. py::list rets = py::list();
  268. for (auto &v : value_list) {
  269. rets.append(AnyToPyData(v));
  270. }
  271. ret = rets;
  272. } else if (value.is<std::vector<Any>>()) {
  273. auto value_list = value.cast<std::vector<Any>>();
  274. py::tuple rets(value_list.size());
  275. for (size_t i = 0; i < value_list.size(); i++) {
  276. rets[i] = AnyToPyData(value_list[i]);
  277. }
  278. ret = rets;
  279. } else if (value.is<TypePtr>()) {
  280. py::tuple v(1);
  281. v[0] = value.cast<TypePtr>();
  282. ret = v[0];
  283. } else {
  284. MS_LOG(EXCEPTION) << "value is not support type";
  285. }
  286. return ret;
  287. }
  288. py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &output) {
  289. py::object ret;
  290. // If output value is a tuple, check if abstract is a COOTensor in funcgraph output
  291. if (utils::isa<VectorRef>(value)) {
  292. MS_LOG(DEBUG) << "BaseRefToPyData, value is tuple: " << value.ToString();
  293. auto vec_ref = utils::cast<VectorRef>(value);
  294. ret = VectorRefToPyData(vec_ref, output);
  295. } else {
  296. ret = BaseRefToPyData(value);
  297. }
  298. return ret;
  299. }
  300. py::object BaseRefToPyData(const BaseRef &value) {
  301. py::object ret;
  302. MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString();
  303. if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
  304. ret = BuiltinsToPyData(value);
  305. } else if (utils::isa<ValuePtr>(value)) {
  306. MS_LOG(DEBUG) << "ValuePtr";
  307. ValuePtr v = utils::cast<ValuePtr>(value);
  308. ret = ValueToPyData(v);
  309. } else if (utils::isa<tensor::TensorPtr>(value)) {
  310. MS_LOG(DEBUG) << "tensor";
  311. auto tensor_ptr = utils::cast<tensor::TensorPtr>(value);
  312. ret = TensorToPyData(tensor_ptr);
  313. } else if (utils::isa<PyObjectRef>(value)) {
  314. MS_LOG(DEBUG) << "py obj";
  315. PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
  316. ret = py_ref.object_;
  317. } else if (utils::isa<VectorRef>(value)) {
  318. auto vec_ref = utils::cast<VectorRef>(value);
  319. ret = VectorRefToPyData(vec_ref);
  320. } else if (utils::isa<TypePtr>(value)) {
  321. py::tuple v(1);
  322. v[0] = utils::cast<TypePtr>(value);
  323. ret = v[0];
  324. } else {
  325. MS_LOG(EXCEPTION) << "value is not support type";
  326. }
  327. return ret;
  328. }
  329. py::object BuiltinsToPyData(const Any &value) {
  330. if (value.is<int>()) {
  331. MS_LOG(DEBUG) << "int";
  332. py::int_ ret = value.cast<int>();
  333. return std::move(ret);
  334. } else if (value.is<float>()) {
  335. MS_LOG(DEBUG) << "float";
  336. py::float_ ret = value.cast<float>();
  337. return std::move(ret);
  338. } else if (value.is<double>()) {
  339. MS_LOG(DEBUG) << "double";
  340. py::float_ ret = value.cast<double>();
  341. return std::move(ret);
  342. } else {
  343. MS_LOG(DEBUG) << "bool";
  344. py::bool_ ret = value.cast<bool>();
  345. return std::move(ret);
  346. }
  347. }
  348. py::object BuiltinsToPyData(const BaseRef &value) {
  349. if (utils::isa<int>(value)) {
  350. MS_LOG(DEBUG) << "int";
  351. py::int_ ret = utils::cast<int>(value);
  352. return std::move(ret);
  353. } else if (utils::isa<float>(value)) {
  354. MS_LOG(DEBUG) << "float";
  355. py::float_ ret = utils::cast<float>(value);
  356. return std::move(ret);
  357. } else if (utils::isa<double>(value)) {
  358. MS_LOG(DEBUG) << "double";
  359. py::float_ ret = utils::cast<double>(value);
  360. return std::move(ret);
  361. } else {
  362. MS_LOG(DEBUG) << "bool";
  363. py::bool_ ret = utils::cast<bool>(value);
  364. return std::move(ret);
  365. }
  366. }
  367. py::object VectorToPyData(const Any &value) {
  368. py::object ret;
  369. if (value.is<std::vector<tensor::TensorPtr>>()) {
  370. MS_LOG(DEBUG) << "vector_tensor";
  371. std::vector<tensor::TensorPtr> outputs;
  372. outputs = value.cast<std::vector<tensor::TensorPtr>>();
  373. py::tuple tensor_tuple(outputs.size());
  374. for (std::size_t i = 0; i < outputs.size(); ++i) {
  375. tensor_tuple[i] = *outputs[i];
  376. }
  377. ret = tensor_tuple;
  378. } else {
  379. MS_LOG(DEBUG) << "vector_any";
  380. auto value_list = value.cast<std::vector<Any>>();
  381. py::tuple any_tuple = py::tuple(value_list.size());
  382. size_t i = 0;
  383. for (auto &v : value_list) {
  384. any_tuple[i] = AnyToPyData(v);
  385. i++;
  386. }
  387. ret = any_tuple;
  388. }
  389. return ret;
  390. }
  391. py::object VectorRefToPyData(const VectorRef &value_list) {
  392. py::object ret;
  393. MS_LOG(DEBUG) << "vector_ref";
  394. size_t value_size = value_list.size();
  395. auto ref_tuple = py::tuple(value_size);
  396. for (size_t i = 0; i < value_size; i++) {
  397. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  398. }
  399. ret = ref_tuple;
  400. return ret;
  401. }
  402. py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &output) {
  403. MS_LOG(DEBUG) << "vector_ref";
  404. // Current VectorRef reflects a COOTensor type
  405. if (output->isa<abstract::AbstractCSRTensor>()) {
  406. return MakeCSRTensor(value_list);
  407. }
  408. if (output->isa<abstract::AbstractCOOTensor>()) {
  409. return MakeCOOTensor(value_list);
  410. }
  411. py::object ret;
  412. size_t value_size = value_list.size();
  413. auto ref_tuple = py::tuple(value_size);
  414. abstract::AbstractTuplePtr tuple_output = output->cast<abstract::AbstractTuplePtr>();
  415. bool is_abstract_tuple = tuple_output != nullptr;
  416. for (size_t i = 0; i < value_size; i++) {
  417. if (!is_abstract_tuple || i >= tuple_output->size()) {
  418. // Fall back to original process
  419. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  420. } else {
  421. ref_tuple[i] = BaseRefToPyData(value_list[i], (*tuple_output)[i]);
  422. }
  423. }
  424. ret = ref_tuple;
  425. return ret;
  426. }
  427. void SetValueRange(const AbstractBasePtr &tensor, const py::object &output) {
  428. if (output.is_none()) {
  429. return;
  430. }
  431. py::object obj_min =
  432. output.contains(py::str(ATTR_MIN_VALUE)) ? (py::object)output[ATTR_MIN_VALUE] : (py::object)py::none();
  433. py::object obj_max =
  434. output.contains(py::str(ATTR_MAX_VALUE)) ? (py::object)output[ATTR_MAX_VALUE] : (py::object)py::none();
  435. if (!obj_min.is_none() && !obj_max.is_none()) {
  436. bool converted = true;
  437. ValuePtr min_value = nullptr;
  438. ValuePtr max_value = nullptr;
  439. converted = parse::ConvertData(obj_min, &min_value);
  440. if (!converted) {
  441. MS_LOG(EXCEPTION) << "Convert shape min value data failed";
  442. }
  443. converted = parse::ConvertData(obj_max, &max_value);
  444. if (!converted) {
  445. MS_LOG(EXCEPTION) << "Convert shape max value data failed";
  446. }
  447. auto abs_tensor = dyn_cast<abstract::AbstractTensor>(tensor);
  448. abs_tensor->set_value_range(min_value, max_value);
  449. }
  450. }
  451. AbstractBasePtr MakePyInferRes2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
  452. const py::object &output) {
  453. auto ret_vec = shape_obj.cast<ShapeVector>();
  454. auto ret_dtype = type_obj.cast<TypePtr>();
  455. ShapeVector min_shape_vec;
  456. ShapeVector max_shape_vec;
  457. if (!output.is_none()) {
  458. py::object min_shape =
  459. output.contains(py::str(ATTR_MIN_SHAPE)) ? (py::object)output[ATTR_MIN_SHAPE] : (py::object)py::none();
  460. py::object max_shape =
  461. output.contains(py::str(ATTR_MAX_SHAPE)) ? (py::object)output[ATTR_MAX_SHAPE] : (py::object)py::none();
  462. if (!min_shape.is_none()) {
  463. min_shape_vec = min_shape.cast<ShapeVector>();
  464. }
  465. if (!max_shape.is_none()) {
  466. max_shape_vec = max_shape.cast<ShapeVector>();
  467. }
  468. }
  469. auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
  470. AbstractBasePtr tensor = MakeAbstractTensor(ret_shape, ret_dtype);
  471. SetValueRange(tensor, output);
  472. return tensor;
  473. }
  474. static bool IsMonadType(const py::object &type_obj) {
  475. if (py::isinstance<Type>(type_obj)) {
  476. auto type = type_obj.cast<Type *>();
  477. return type->isa<MonadType>();
  478. }
  479. return false;
  480. }
  481. static AbstractBasePtr ToMonadAbstract(const py::object &type_obj) {
  482. if (py::isinstance<Type>(type_obj)) {
  483. auto type = type_obj.cast<Type *>();
  484. if (!type->isa<MonadType>()) {
  485. MS_LOG(EXCEPTION) << "Not a monad type object: " << py::str(type_obj);
  486. }
  487. return abstract::MakeMonadAbstract(type->cast<MonadTypePtr>());
  488. }
  489. MS_LOG(EXCEPTION) << "Not a type object: " << py::str(type_obj);
  490. }
  491. static py::object GetPyAbsItemOfTupleOut(const py::object &output, const size_t index) {
  492. auto out_dict = output.cast<py::dict>();
  493. auto type_obj = out_dict[ATTR_DTYPE];
  494. auto shape_obj = out_dict[ATTR_SHAPE];
  495. auto out_item = py::dict();
  496. auto shape_tuple = shape_obj.cast<py::tuple>();
  497. auto typeid_tuple = type_obj.cast<py::tuple>();
  498. out_item[ATTR_DTYPE] = typeid_tuple[index];
  499. out_item[ATTR_SHAPE] = shape_tuple[index];
  500. if (output.contains(py::str(ATTR_MIN_SHAPE))) {
  501. out_item[ATTR_MIN_SHAPE] = output[ATTR_MIN_SHAPE].cast<py::tuple>()[index];
  502. }
  503. if (output.contains(py::str(ATTR_MAX_SHAPE))) {
  504. out_item[ATTR_MAX_SHAPE] = output[ATTR_MAX_SHAPE].cast<py::tuple>()[index];
  505. }
  506. out_item[ATTR_VALUE] = py::none();
  507. return out_item;
  508. }
  509. AbstractBasePtr MakePyInferRes2Abstract(const py::object &output) {
  510. auto out_dict = output.cast<py::dict>();
  511. auto type_obj = out_dict[ATTR_DTYPE];
  512. auto shape_obj = out_dict[ATTR_SHAPE];
  513. if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) {
  514. auto ret_vec = shape_obj.cast<ShapeVector>();
  515. auto ret_dtype = type_obj.cast<TypePtr>();
  516. MS_EXCEPTION_IF_NULL(ret_dtype);
  517. // if the size of shape list is empty, return an scalar abstract
  518. if (ret_vec.empty() && (!ret_dtype->isa<TensorType>())) {
  519. abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
  520. return abs_scalar;
  521. }
  522. return MakePyInferRes2AbstractTensor(shape_obj, type_obj, output);
  523. } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
  524. auto typeid_tuple = type_obj.cast<py::tuple>();
  525. AbstractBasePtrList ptr_list;
  526. for (size_t it = 0; it < typeid_tuple.size(); ++it) {
  527. auto output_it = GetPyAbsItemOfTupleOut(output, it);
  528. auto tensor_it = MakePyInferRes2Abstract(output_it);
  529. ptr_list.push_back(tensor_it);
  530. }
  531. auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
  532. return tuple;
  533. } else if (py::isinstance<py::list>(shape_obj) && py::isinstance<py::list>(type_obj)) {
  534. auto typeid_list = type_obj.cast<py::list>();
  535. AbstractBasePtrList ptr_list;
  536. for (size_t it = 0; it < typeid_list.size(); ++it) {
  537. auto output_it = GetPyAbsItemOfTupleOut(output, it);
  538. auto tensor_it = MakePyInferRes2Abstract(output_it);
  539. ptr_list.push_back(tensor_it);
  540. }
  541. auto list = std::make_shared<abstract::AbstractList>(ptr_list);
  542. return list;
  543. } else if (shape_obj.is_none() && type_obj.is_none()) {
  544. // AbstractNone indicates there is no output for this CNode node.
  545. auto abstract_none = std::make_shared<abstract::AbstractNone>();
  546. return abstract_none;
  547. } else if (IsMonadType(type_obj)) {
  548. // Return monad abstract if it is monad type.
  549. return ToMonadAbstract(type_obj);
  550. } else {
  551. // When sparse enabled, the undetermined might be raised and eliminated in opt passes
  552. auto context = MsContext::GetInstance();
  553. MS_EXCEPTION_IF_NULL(context);
  554. bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
  555. if (enable_sparse) {
  556. return std::make_shared<abstract::AbstractUndetermined>();
  557. }
  558. MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
  559. }
  560. }
  561. bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
  562. const std::shared_ptr<py::object> &ret_val) {
  563. if (output->isa<ValueNode>()) {
  564. MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
  565. ValuePtr value = GetValueNode(output);
  566. *ret_val = ValueToPyData(value);
  567. return true;
  568. }
  569. // Adapter will transform values in __init__() and construct() to parameters, this could cause
  570. // inputs (a.k.a args in current function) size less than parameters'.
  571. if (output->isa<Parameter>()) {
  572. MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
  573. // Find the right parameter as ret_val.
  574. auto func_graph = output->func_graph();
  575. MS_EXCEPTION_IF_NULL(func_graph);
  576. auto params = func_graph->parameters();
  577. if ((args.size() + func_graph->hyper_param_count()) != params.size()) {
  578. MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count()
  579. << " not equal to graph input size " << params.size() << ", let graph to be executed.";
  580. }
  581. auto it = std::find(params.begin(), params.end(), output);
  582. if (it == params.end()) {
  583. MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
  584. }
  585. size_t index = it - params.cbegin();
  586. if (index >= args.size() + func_graph->hyper_param_count()) {
  587. MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
  588. << " add Parameter count " << func_graph->hyper_param_count() << ".";
  589. }
  590. if (index < args.size()) {
  591. *ret_val = args[index];
  592. } else {
  593. auto param = dyn_cast<Parameter>(params[index]);
  594. MS_EXCEPTION_IF_NULL(param);
  595. if (!param->has_default()) {
  596. MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
  597. }
  598. auto tensor = param->default_param();
  599. *ret_val = py::cast(tensor);
  600. }
  601. return true;
  602. }
  603. return false;
  604. }
  605. ShapeVector ConvertToShapeVector(const ValuePtr &shape_ptr, const VectorRef &value_list, size_t shape_idx) {
  606. MS_EXCEPTION_IF_NULL(shape_ptr);
  607. ShapeVector shape;
  608. ValueTuplePtr shape_tuple = shape_ptr->cast<ValueTuplePtr>();
  609. if (shape_tuple) {
  610. for (const auto &v : shape_tuple->value()) {
  611. MS_EXCEPTION_IF_NULL(v);
  612. ScalarPtr scalar = v->cast<ScalarPtr>();
  613. MS_EXCEPTION_IF_NULL(scalar);
  614. shape.push_back(GetValue<int64_t>(scalar));
  615. }
  616. } else {
  617. auto shape_ref = utils::cast<VectorRef>(value_list[shape_idx]);
  618. MS_EXCEPTION_IF_NULL(shape_ref);
  619. for (const auto &v : shape_ref) {
  620. MS_EXCEPTION_IF_NULL(v);
  621. auto tensorptr = utils::cast<tensor::TensorPtr>(v);
  622. MS_EXCEPTION_IF_NULL(tensorptr);
  623. if (tensorptr->DataDim() != 0) {
  624. MS_LOG(EXCEPTION) << "Element in COOTensor's shape must be scalar!";
  625. }
  626. tensorptr->data_sync(false);
  627. shape.push_back(*(static_cast<int64_t *>(tensorptr->data_c())));
  628. }
  629. }
  630. return shape;
  631. }
  632. py::object MakeCSRTensor(const VectorRef &value_list) {
  633. constexpr size_t kCSRTensorInputSize{4};
  634. if (value_list.size() != kCSRTensorInputSize) {
  635. MS_LOG(EXCEPTION) << "CSRTensor must have 4 inputs.";
  636. }
  637. using TensorPtr = tensor::TensorPtr;
  638. using CSRTensor = tensor::CSRTensor;
  639. constexpr size_t kIndptrIdx{0};
  640. constexpr size_t kIndicesIdx{1};
  641. constexpr size_t kValuesIdx{2};
  642. constexpr size_t kShapeIdx{3};
  643. TensorPtr indptr = utils::cast<TensorPtr>(value_list[kIndptrIdx]);
  644. TensorPtr indices = utils::cast<TensorPtr>(value_list[kIndicesIdx]);
  645. TensorPtr values = utils::cast<TensorPtr>(value_list[kValuesIdx]);
  646. ValuePtr shape_ptr = utils::cast<ValuePtr>(value_list[kShapeIdx]);
  647. ShapeVector shape = ConvertToShapeVector(shape_ptr, value_list, kShapeIdx);
  648. auto csr_tensor_ptr = std::make_shared<CSRTensor>(indptr, indices, values, shape);
  649. return CSRTensorToPyData(csr_tensor_ptr);
  650. }
  651. py::object MakeCOOTensor(const VectorRef &value_list) {
  652. constexpr size_t kCOOTensorInputSize{3};
  653. constexpr size_t kIndicesIdx{0};
  654. constexpr size_t kValuesIdx{1};
  655. constexpr size_t kShapeIdx{2};
  656. if (value_list.size() != kCOOTensorInputSize) {
  657. MS_LOG(EXCEPTION) << "COOTensor must have " << kCOOTensorInputSize << "inputs.";
  658. }
  659. tensor::TensorPtr indices = utils::cast<tensor::TensorPtr>(value_list[kIndicesIdx]);
  660. tensor::TensorPtr values = utils::cast<tensor::TensorPtr>(value_list[kValuesIdx]);
  661. ValuePtr shape_ptr = utils::cast<ValuePtr>(value_list[kShapeIdx]);
  662. ShapeVector shape = ConvertToShapeVector(shape_ptr, value_list, kShapeIdx);
  663. auto ref = py::tuple(1);
  664. auto coo_tensor_ptr = std::make_shared<tensor::COOTensor>(indices, values, shape);
  665. ref[0] = coo_tensor_ptr;
  666. return ref[0];
  667. }
  668. } // namespace mindspore