You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils_py.cc 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/convert_utils_py.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <algorithm>
  21. #include <list>
  22. #include <utility>
  23. #include <cfloat>
  24. #include "abstract/abstract_value.h"
  25. #include "abstract/utils.h"
  26. #include "pipeline/jit/parse/parse.h"
  27. #include "pipeline/jit/parse/parse_base.h"
  28. #include "pipeline/jit/parse/resolve.h"
  29. #include "ir/value.h"
  30. #include "ir/anf.h"
  31. #include "ir/tensor.h"
  32. #include "ir/param_info.h"
  33. #include "pybind_api/ir/base_ref_py.h"
  34. #include "ir/dtype/tensor_type.h"
  35. #include "utils/ms_context.h"
  36. #include "utils/convert_utils.h"
  37. #include "backend/session/anf_runtime_algorithm.h"
  38. namespace mindspore {
  39. py::object BuiltinsToPyData(const Any &value);
  40. py::object BuiltinsToPyData(const BaseRef &value);
  41. py::object VectorToPyData(const Any &value);
  42. py::object VectorRefToPyData(const VectorRef &value_list);
  43. py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &output);
  44. // Wrap VectorRef to CSRTensor
  45. py::object MakeCSRTensor(const VectorRef &value_list);
  46. py::object TensorToPyData(const tensor::TensorPtr &tensor) {
  47. MS_EXCEPTION_IF_NULL(tensor);
  48. if (tensor->NeedWait()) {
  49. py::gil_scoped_release release;
  50. tensor->Wait();
  51. }
  52. py::tuple v(1);
  53. v[0] = tensor;
  54. return v[0];
  55. }
  56. py::object ScalarPtrToPyData(const ScalarPtr &value) {
  57. py::int_ int_v;
  58. py::float_ float_v;
  59. py::bool_ bool_v;
  60. TypeId scalar_type = value->type()->type_id();
  61. switch (scalar_type) {
  62. case kNumberTypeUInt8:
  63. MS_LOG(DEBUG) << "uint8";
  64. int_v = value->cast<UInt8ImmPtr>()->value();
  65. return std::move(int_v);
  66. case kNumberTypeUInt16:
  67. MS_LOG(DEBUG) << "uint16";
  68. int_v = value->cast<UInt16ImmPtr>()->value();
  69. return std::move(int_v);
  70. case kNumberTypeUInt32:
  71. MS_LOG(DEBUG) << "uint32";
  72. int_v = value->cast<UInt32ImmPtr>()->value();
  73. return std::move(int_v);
  74. case kNumberTypeUInt64:
  75. MS_LOG(DEBUG) << "uint64";
  76. int_v = value->cast<UInt64ImmPtr>()->value();
  77. return std::move(int_v);
  78. case kNumberTypeInt8:
  79. MS_LOG(DEBUG) << "int8";
  80. int_v = value->cast<Int8ImmPtr>()->value();
  81. return std::move(int_v);
  82. case kNumberTypeInt16:
  83. MS_LOG(DEBUG) << "int16";
  84. int_v = value->cast<Int16ImmPtr>()->value();
  85. return std::move(int_v);
  86. case kNumberTypeInt32:
  87. MS_LOG(DEBUG) << "int32";
  88. int_v = value->cast<Int32ImmPtr>()->value();
  89. return std::move(int_v);
  90. case kNumberTypeInt64:
  91. MS_LOG(DEBUG) << "int64";
  92. int_v = value->cast<Int64ImmPtr>()->value();
  93. return std::move(int_v);
  94. case kNumberTypeFloat32:
  95. MS_LOG(DEBUG) << "float";
  96. float_v = value->cast<FP32ImmPtr>()->value();
  97. return std::move(float_v);
  98. case kNumberTypeFloat64:
  99. MS_LOG(DEBUG) << "double";
  100. float_v = value->cast<FP64ImmPtr>()->value();
  101. return std::move(float_v);
  102. case kNumberTypeBool:
  103. MS_LOG(DEBUG) << "bool";
  104. bool_v = value->cast<BoolImmPtr>()->value();
  105. return std::move(bool_v);
  106. default:
  107. MS_EXCEPTION(TypeError) << "Unsupported scalar converted to py data: " << value->ToString();
  108. }
  109. }
  110. using ConverterFunction = std::function<py::object(const ValuePtr &value)>;
  111. using ValueNameToConverterVector = std::vector<std::pair<uint32_t, ConverterFunction>>;
  112. // (Value Type Name) -> (Converter Function)
  113. // The converter function is used to convert Value object to Python data object.
  114. static ValueNameToConverterVector value_name_to_converter = {
  115. // Scalar
  116. {Scalar::kTypeId, [](const ValuePtr &value) -> py::object { return ScalarPtrToPyData(value->cast<ScalarPtr>()); }},
  117. // Tensor
  118. {tensor::Tensor::kTypeId,
  119. [](const ValuePtr &value) -> py::object {
  120. auto tensor_ptr = value->cast<tensor::TensorPtr>();
  121. return TensorToPyData(tensor_ptr);
  122. }},
  123. // MetaTenser
  124. {tensor::MetaTensor::kTypeId,
  125. [](const ValuePtr &value) -> py::object {
  126. py::tuple tuple_container(1);
  127. tuple_container[0] = value->cast<tensor::MetaTensorPtr>();
  128. return tuple_container[0];
  129. }},
  130. // RefKey
  131. {RefKey::kTypeId,
  132. [](const ValuePtr &value) -> py::object {
  133. py::tuple tuple_container(1);
  134. tuple_container[0] = value->cast<RefKeyPtr>();
  135. return tuple_container[0];
  136. }},
  137. // Type
  138. {Type::kTypeId,
  139. [](const ValuePtr &value) -> py::object {
  140. py::tuple tuple_container(1);
  141. tuple_container[0] = value->cast<TypePtr>();
  142. return tuple_container[0];
  143. }},
  144. // StringImm
  145. {StringImm::kTypeId,
  146. [](const ValuePtr &value) -> py::object {
  147. py::str res = value->cast<StringImmPtr>()->value();
  148. return res;
  149. }},
  150. // ValueSequeue
  151. {ValueSequeue::kTypeId,
  152. [](const ValuePtr &value) -> py::object {
  153. auto value_sequeue = value->cast<ValueSequeuePtr>()->value();
  154. py::tuple res_sequeue(value_sequeue.size());
  155. for (size_t i = 0; i < value_sequeue.size(); i++) {
  156. res_sequeue[i] = ValueToPyData(value_sequeue[i]);
  157. }
  158. if (value->isa<ValueTuple>()) {
  159. return res_sequeue;
  160. }
  161. return res_sequeue.cast<py::list>();
  162. }},
  163. // ValueDictionary
  164. {ValueDictionary::kTypeId,
  165. [](const ValuePtr &value) -> py::object {
  166. auto value_list = value->cast<ValueDictionaryPtr>()->value();
  167. py::dict res_dict;
  168. for (const auto &value : value_list) {
  169. res_dict[py::str(value.first)] = ValueToPyData(value.second);
  170. }
  171. return res_dict;
  172. }},
  173. // ValueSlice
  174. {ValueSlice::kTypeId,
  175. [](const ValuePtr &value) -> py::object {
  176. auto slice = value->cast<ValueSlicePtr>();
  177. auto start = ValueToPyData(slice->start());
  178. auto end = ValueToPyData(slice->stop());
  179. auto step = ValueToPyData(slice->step());
  180. return parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
  181. step);
  182. }},
  183. // KeywordArg
  184. {KeywordArg::kTypeId,
  185. [](const ValuePtr &value) -> py::object {
  186. auto abs_keyword_arg = value->ToAbstract()->cast<abstract::AbstractKeywordArgPtr>();
  187. auto key = abs_keyword_arg->get_key();
  188. auto val = abs_keyword_arg->get_arg()->BuildValue();
  189. auto py_value = ValueToPyData(val);
  190. auto kwargs = py::kwargs();
  191. kwargs[key.c_str()] = py_value;
  192. return kwargs;
  193. }},
  194. // parse::NameSpace
  195. {parse::NameSpace::kTypeId,
  196. [](const ValuePtr &value) -> py::object {
  197. auto ns = value->cast<parse::NameSpacePtr>();
  198. return ns->module_obj();
  199. }},
  200. // parse::ClassType
  201. {parse::ClassType::kTypeId,
  202. [](const ValuePtr &value) -> py::object {
  203. auto class_type = value->cast<parse::ClassTypePtr>();
  204. return class_type->obj();
  205. }},
  206. // parse::InterpretedObject
  207. {parse::InterpretedObject::kTypeId,
  208. [](const ValuePtr &value) -> py::object {
  209. auto interpreted_object = value->cast<parse::InterpretedObjectPtr>();
  210. return interpreted_object->obj();
  211. }},
  212. // None
  213. {None::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  214. // AnyValue
  215. {AnyValue::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  216. // FuncGraph
  217. {FuncGraph::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  218. // Monad
  219. {Monad::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  220. // Ellipsis
  221. {Ellipsis::kTypeId, [](const ValuePtr &value) -> py::object { return py::ellipsis(); }}};
  222. py::object ValueToPyData(const ValuePtr &value) {
  223. if (value == nullptr) {
  224. MS_LOG(EXCEPTION) << "The `value` should not be null";
  225. }
  226. for (auto &iter : value_name_to_converter) {
  227. if (value->IsFromTypeId(iter.first)) {
  228. return iter.second(value);
  229. }
  230. }
  231. MS_LOG(EXCEPTION) << "Unsupported to convert " << value->ToString() << "[" << value->type_name() << "] to a PyData";
  232. }
  233. py::object AnyToPyData(const Any &value) {
  234. py::object ret;
  235. MS_LOG(DEBUG) << "AnyToPyData " << value.GetString();
  236. if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) {
  237. ret = BuiltinsToPyData(value);
  238. } else if (value.is<ValuePtr>()) {
  239. MS_LOG(DEBUG) << "ValuePtr";
  240. ValuePtr v = value.cast<ValuePtr>();
  241. ret = ValueToPyData(v);
  242. } else if (value.is<tensor::TensorPtr>()) {
  243. MS_LOG(DEBUG) << "tensor";
  244. auto tensor_ptr = value.cast<tensor::TensorPtr>();
  245. ret = TensorToPyData(tensor_ptr);
  246. } else if (value.is<py::object>()) {
  247. MS_LOG(DEBUG) << "py obj";
  248. ret = value.cast<py::object>();
  249. } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) {
  250. ret = VectorToPyData(value);
  251. } else if (value.is<std::list<Any>>()) {
  252. MS_LOG(DEBUG) << "list_any";
  253. auto value_list = value.cast<std::list<Any>>();
  254. py::list rets = py::list();
  255. for (auto &v : value_list) {
  256. rets.append(AnyToPyData(v));
  257. }
  258. ret = rets;
  259. } else if (value.is<std::vector<Any>>()) {
  260. auto value_list = value.cast<std::vector<Any>>();
  261. py::tuple rets(value_list.size());
  262. for (size_t i = 0; i < value_list.size(); i++) {
  263. rets[i] = AnyToPyData(value_list[i]);
  264. }
  265. ret = rets;
  266. } else if (value.is<TypePtr>()) {
  267. py::tuple v(1);
  268. v[0] = value.cast<TypePtr>();
  269. ret = v[0];
  270. } else {
  271. MS_LOG(EXCEPTION) << "value is not support type";
  272. }
  273. return ret;
  274. }
  275. py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &output) {
  276. py::object ret;
  277. // If output value is a tuple, check if abstract is a SparseTensor in funcgraph output
  278. if (utils::isa<VectorRef>(value)) {
  279. MS_LOG(DEBUG) << "BaseRefToPyData, value is tuple: " << value.ToString();
  280. auto vec_ref = utils::cast<VectorRef>(value);
  281. ret = VectorRefToPyData(vec_ref, output);
  282. } else {
  283. ret = BaseRefToPyData(value);
  284. }
  285. return ret;
  286. }
  287. py::object BaseRefToPyData(const BaseRef &value) {
  288. py::object ret;
  289. MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString();
  290. if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
  291. ret = BuiltinsToPyData(value);
  292. } else if (utils::isa<ValuePtr>(value)) {
  293. MS_LOG(DEBUG) << "ValuePtr";
  294. ValuePtr v = utils::cast<ValuePtr>(value);
  295. ret = ValueToPyData(v);
  296. } else if (utils::isa<tensor::TensorPtr>(value)) {
  297. MS_LOG(DEBUG) << "tensor";
  298. auto tensor_ptr = utils::cast<tensor::TensorPtr>(value);
  299. ret = TensorToPyData(tensor_ptr);
  300. } else if (utils::isa<PyObjectRef>(value)) {
  301. MS_LOG(DEBUG) << "py obj";
  302. PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
  303. ret = py_ref.object_;
  304. } else if (utils::isa<VectorRef>(value)) {
  305. auto vec_ref = utils::cast<VectorRef>(value);
  306. ret = VectorRefToPyData(vec_ref);
  307. } else if (utils::isa<TypePtr>(value)) {
  308. py::tuple v(1);
  309. v[0] = utils::cast<TypePtr>(value);
  310. ret = v[0];
  311. } else {
  312. MS_LOG(EXCEPTION) << "value is not support type";
  313. }
  314. return ret;
  315. }
  316. py::object BuiltinsToPyData(const Any &value) {
  317. if (value.is<int>()) {
  318. MS_LOG(DEBUG) << "int";
  319. py::int_ ret = value.cast<int>();
  320. return std::move(ret);
  321. } else if (value.is<float>()) {
  322. MS_LOG(DEBUG) << "float";
  323. py::float_ ret = value.cast<float>();
  324. return std::move(ret);
  325. } else if (value.is<double>()) {
  326. MS_LOG(DEBUG) << "double";
  327. py::float_ ret = value.cast<double>();
  328. return std::move(ret);
  329. } else {
  330. MS_LOG(DEBUG) << "bool";
  331. py::bool_ ret = value.cast<bool>();
  332. return std::move(ret);
  333. }
  334. }
  335. py::object BuiltinsToPyData(const BaseRef &value) {
  336. if (utils::isa<int>(value)) {
  337. MS_LOG(DEBUG) << "int";
  338. py::int_ ret = utils::cast<int>(value);
  339. return std::move(ret);
  340. } else if (utils::isa<float>(value)) {
  341. MS_LOG(DEBUG) << "float";
  342. py::float_ ret = utils::cast<float>(value);
  343. return std::move(ret);
  344. } else if (utils::isa<double>(value)) {
  345. MS_LOG(DEBUG) << "double";
  346. py::float_ ret = utils::cast<double>(value);
  347. return std::move(ret);
  348. } else {
  349. MS_LOG(DEBUG) << "bool";
  350. py::bool_ ret = utils::cast<bool>(value);
  351. return std::move(ret);
  352. }
  353. }
  354. py::object VectorToPyData(const Any &value) {
  355. py::object ret;
  356. if (value.is<std::vector<tensor::TensorPtr>>()) {
  357. MS_LOG(DEBUG) << "vector_tensor";
  358. std::vector<tensor::TensorPtr> outputs;
  359. outputs = value.cast<std::vector<tensor::TensorPtr>>();
  360. py::tuple tensor_tuple(outputs.size());
  361. for (std::size_t i = 0; i < outputs.size(); ++i) {
  362. tensor_tuple[i] = *outputs[i];
  363. }
  364. ret = tensor_tuple;
  365. } else {
  366. MS_LOG(DEBUG) << "vector_any";
  367. auto value_list = value.cast<std::vector<Any>>();
  368. py::tuple any_tuple = py::tuple(value_list.size());
  369. size_t i = 0;
  370. for (auto &v : value_list) {
  371. any_tuple[i] = AnyToPyData(v);
  372. i++;
  373. }
  374. ret = any_tuple;
  375. }
  376. return ret;
  377. }
  378. py::object VectorRefToPyData(const VectorRef &value_list) {
  379. py::object ret;
  380. MS_LOG(DEBUG) << "vector_ref";
  381. size_t value_size = value_list.size();
  382. auto ref_tuple = py::tuple(value_size);
  383. for (size_t i = 0; i < value_size; i++) {
  384. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  385. }
  386. ret = ref_tuple;
  387. return ret;
  388. }
  389. py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &output) {
  390. MS_LOG(DEBUG) << "vector_ref";
  391. // Current VectorRef reflects a SparseTensor type
  392. if (output->isa<abstract::AbstractCSRTensor>()) {
  393. return MakeCSRTensor(value_list);
  394. }
  395. py::object ret;
  396. size_t value_size = value_list.size();
  397. auto ref_tuple = py::tuple(value_size);
  398. abstract::AbstractTuplePtr tuple_output = output->cast<abstract::AbstractTuplePtr>();
  399. bool is_abstract_tuple = tuple_output != nullptr;
  400. for (size_t i = 0; i < value_size; i++) {
  401. if (!is_abstract_tuple || i >= tuple_output->size()) {
  402. // Fall back to original process
  403. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  404. } else {
  405. ref_tuple[i] = BaseRefToPyData(value_list[i], (*tuple_output)[i]);
  406. }
  407. }
  408. ret = ref_tuple;
  409. return ret;
  410. }
  411. void SetValueRange(const AbstractBasePtr &tensor, const py::object &output) {
  412. if (output.is_none()) {
  413. return;
  414. }
  415. py::object obj_min =
  416. output.contains(py::str(ATTR_MIN_VALUE)) ? (py::object)output[ATTR_MIN_VALUE] : (py::object)py::none();
  417. py::object obj_max =
  418. output.contains(py::str(ATTR_MAX_VALUE)) ? (py::object)output[ATTR_MAX_VALUE] : (py::object)py::none();
  419. if (!obj_min.is_none() && !obj_max.is_none()) {
  420. bool converted = true;
  421. ValuePtr min_value = nullptr;
  422. ValuePtr max_value = nullptr;
  423. converted = parse::ConvertData(obj_min, &min_value);
  424. if (!converted) {
  425. MS_LOG(EXCEPTION) << "Convert shape min value data failed";
  426. }
  427. converted = parse::ConvertData(obj_max, &max_value);
  428. if (!converted) {
  429. MS_LOG(EXCEPTION) << "Convert shape max value data failed";
  430. }
  431. auto abs_tensor = dyn_cast<abstract::AbstractTensor>(tensor);
  432. abs_tensor->set_value_range(min_value, max_value);
  433. }
  434. }
  435. AbstractBasePtr MakePyInferRes2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
  436. const py::object &output) {
  437. auto ret_vec = shape_obj.cast<ShapeVector>();
  438. auto ret_dtype = type_obj.cast<TypePtr>();
  439. ShapeVector min_shape_vec;
  440. ShapeVector max_shape_vec;
  441. if (!output.is_none()) {
  442. py::object min_shape =
  443. output.contains(py::str(ATTR_MIN_SHAPE)) ? (py::object)output[ATTR_MIN_SHAPE] : (py::object)py::none();
  444. py::object max_shape =
  445. output.contains(py::str(ATTR_MAX_SHAPE)) ? (py::object)output[ATTR_MAX_SHAPE] : (py::object)py::none();
  446. if (!min_shape.is_none()) {
  447. min_shape_vec = min_shape.cast<ShapeVector>();
  448. }
  449. if (!max_shape.is_none()) {
  450. max_shape_vec = max_shape.cast<ShapeVector>();
  451. }
  452. }
  453. auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
  454. AbstractBasePtr tensor = MakeAbstractTensor(ret_shape, ret_dtype);
  455. SetValueRange(tensor, output);
  456. return tensor;
  457. }
  458. static bool IsMonadType(const py::object &type_obj) {
  459. if (py::isinstance<Type>(type_obj)) {
  460. auto type = type_obj.cast<Type *>();
  461. return type->isa<MonadType>();
  462. }
  463. return false;
  464. }
  465. static AbstractBasePtr ToMonadAbstract(const py::object &type_obj) {
  466. if (py::isinstance<Type>(type_obj)) {
  467. auto type = type_obj.cast<Type *>();
  468. if (!type->isa<MonadType>()) {
  469. MS_LOG(EXCEPTION) << "Not a monad type object: " << py::str(type_obj);
  470. }
  471. return abstract::MakeMonadAbstract(type->cast<MonadTypePtr>());
  472. }
  473. MS_LOG(EXCEPTION) << "Not a type object: " << py::str(type_obj);
  474. }
  475. AbstractBasePtr MakePyInferRes2Abstract(const py::object &shape_obj, const py::object &type_obj,
  476. const py::object &output) {
  477. if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) {
  478. auto ret_vec = shape_obj.cast<ShapeVector>();
  479. auto ret_dtype = type_obj.cast<TypePtr>();
  480. MS_EXCEPTION_IF_NULL(ret_dtype);
  481. // if the size of shape list is empty, return an scalar abstract
  482. if (ret_vec.empty() && (!ret_dtype->isa<TensorType>())) {
  483. abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
  484. return abs_scalar;
  485. }
  486. return MakePyInferRes2AbstractTensor(shape_obj, type_obj, output);
  487. } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
  488. auto shape_tuple = shape_obj.cast<py::tuple>();
  489. auto typeid_tuple = type_obj.cast<py::tuple>();
  490. AbstractBasePtrList ptr_list;
  491. for (size_t it = 0; it < shape_tuple.size(); ++it) {
  492. auto tensor_it = MakePyInferRes2Abstract(shape_tuple[it], typeid_tuple[it]);
  493. ptr_list.push_back(tensor_it);
  494. }
  495. auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
  496. return tuple;
  497. } else if (py::isinstance<py::list>(shape_obj) && py::isinstance<py::list>(type_obj)) {
  498. auto shape_list = shape_obj.cast<py::list>();
  499. auto typeid_list = type_obj.cast<py::list>();
  500. AbstractBasePtrList ptr_list;
  501. for (size_t it = 0; it < shape_list.size(); ++it) {
  502. auto tensor_it = MakePyInferRes2Abstract(shape_list[it], typeid_list[it]);
  503. ptr_list.push_back(tensor_it);
  504. }
  505. auto list = std::make_shared<abstract::AbstractList>(ptr_list);
  506. return list;
  507. } else if (shape_obj.is_none() && type_obj.is_none()) {
  508. // AbstractNone indicates there is no output for this CNode node.
  509. auto abstract_none = std::make_shared<abstract::AbstractNone>();
  510. return abstract_none;
  511. } else if (IsMonadType(type_obj)) {
  512. // Return monad abstract if it is monad type.
  513. return ToMonadAbstract(type_obj);
  514. } else {
  515. // When sparse enabled, the undetermined might be raised and eliminated in opt passes
  516. auto context = MsContext::GetInstance();
  517. MS_EXCEPTION_IF_NULL(context);
  518. bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
  519. if (enable_sparse) {
  520. return std::make_shared<abstract::AbstractUndetermined>();
  521. }
  522. MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
  523. }
  524. }
  525. bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
  526. const std::shared_ptr<py::object> &ret_val) {
  527. if (output->isa<ValueNode>()) {
  528. MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
  529. ValuePtr value = GetValueNode(output);
  530. *ret_val = ValueToPyData(value);
  531. return true;
  532. }
  533. // Adapter will transform values in __init__() and construct() to parameters, this could cause
  534. // inputs (a.k.a args in current function) size less than parameters'.
  535. if (output->isa<Parameter>()) {
  536. MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
  537. // Find the right parameter as ret_val.
  538. auto func_graph = output->func_graph();
  539. MS_EXCEPTION_IF_NULL(func_graph);
  540. auto params = func_graph->parameters();
  541. if ((args.size() + func_graph->hyper_param_count()) != params.size()) {
  542. MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count()
  543. << " not equal to graph input size " << params.size() << ", let graph to be executed.";
  544. }
  545. auto it = std::find(params.begin(), params.end(), output);
  546. if (it == params.end()) {
  547. MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
  548. }
  549. size_t index = it - params.cbegin();
  550. if (index >= args.size() + func_graph->hyper_param_count()) {
  551. MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
  552. << " add Parameter count " << func_graph->hyper_param_count() << ".";
  553. }
  554. if (index < args.size()) {
  555. *ret_val = args[index];
  556. } else {
  557. auto param = dyn_cast<Parameter>(params[index]);
  558. MS_EXCEPTION_IF_NULL(param);
  559. if (!param->has_default()) {
  560. MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
  561. }
  562. auto tensor = param->default_param();
  563. *ret_val = py::cast(tensor);
  564. }
  565. return true;
  566. }
  567. return false;
  568. }
  569. py::object MakeCSRTensor(const VectorRef &value_list) {
  570. constexpr size_t kCSRTensorInputSize{4};
  571. if (value_list.size() != kCSRTensorInputSize) {
  572. MS_LOG(EXCEPTION) << "CSRTensor must have 4 inputs.";
  573. }
  574. using TensorPtr = tensor::TensorPtr;
  575. using CSRTensor = tensor::CSRTensor;
  576. TensorPtr indptr = utils::cast<TensorPtr>(value_list[0]);
  577. TensorPtr indices = utils::cast<TensorPtr>(value_list[1]);
  578. TensorPtr values = utils::cast<TensorPtr>(value_list[2]);
  579. ValuePtr shape_ptr = utils::cast<ValuePtr>(value_list[3]);
  580. ValueTuplePtr shape_tuple = shape_ptr->cast<ValueTuplePtr>();
  581. ShapeVector shape{};
  582. // CSRTensor shape is a tuple on GPU and CPU
  583. if (shape_tuple) {
  584. for (const auto &v : shape_tuple->value()) {
  585. MS_EXCEPTION_IF_NULL(v);
  586. ScalarPtr scalar = v->cast<ScalarPtr>();
  587. MS_EXCEPTION_IF_NULL(scalar);
  588. shape.push_back(GetValue<int64_t>(scalar));
  589. }
  590. // CSRTensor shape is a VectorRef(TensorPtr, TensorPtr) on Ascend
  591. } else {
  592. auto shape_ref = utils::cast<VectorRef>(value_list[3]);
  593. MS_EXCEPTION_IF_NULL(shape_ref);
  594. for (const auto &v : shape_ref) {
  595. MS_EXCEPTION_IF_NULL(v);
  596. auto tensorptr = utils::cast<TensorPtr>(v);
  597. MS_EXCEPTION_IF_NULL(tensorptr);
  598. if (tensorptr->DataDim() != 0) {
  599. MS_LOG(EXCEPTION) << "Element in CSRTensor's shape must be scalar!";
  600. }
  601. shape.push_back(*(static_cast<int64_t *>(tensorptr->data_c())));
  602. }
  603. }
  604. auto ref = py::tuple(1);
  605. auto csr_tensor_ptr = std::make_shared<CSRTensor>(indptr, indices, values, shape);
  606. ref[0] = csr_tensor_ptr;
  607. return ref[0];
  608. }
  609. } // namespace mindspore