You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils_py.cc 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "include/common/utils/convert_utils_py.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <algorithm>
  21. #include <list>
  22. #include <utility>
  23. #include <cfloat>
  24. #include "abstract/abstract_value.h"
  25. #include "abstract/utils.h"
  26. #include "pipeline/jit/parse/parse_base.h"
  27. #include "pipeline/jit/parse/resolve.h"
  28. #include "ir/value.h"
  29. #include "ir/anf.h"
  30. #include "ir/tensor.h"
  31. #include "ir/param_info.h"
  32. #include "pybind_api/ir/base_ref_py.h"
  33. #include "ir/dtype/tensor_type.h"
  34. #include "utils/ms_context.h"
  35. #include "include/common/utils/convert_utils.h"
  36. namespace mindspore {
  37. py::object BuiltinsToPyData(const Any &value);
  38. py::object BuiltinsToPyData(const BaseRef &value);
  39. py::object VectorToPyData(const Any &value);
  40. py::object VectorRefToPyData(const VectorRef &value_list);
  41. py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &output);
  42. // Wrap VectorRef to CSRTensor
  43. py::object MakeCSRTensor(const VectorRef &value_list);
  44. py::object MakeCOOTensor(const VectorRef &value_list);
  45. ShapeVector ConvertToShapeVector(const ValuePtr &shape_ptr, const VectorRef &value_list, size_t shape_idx);
  46. py::object CSRTensorToPyData(const tensor::CSRTensorPtr &csr_tensor) {
  47. auto ref = py::tuple(1);
  48. ref[0] = csr_tensor;
  49. return ref[0];
  50. }
  51. py::object TensorToPyData(const tensor::TensorPtr &tensor) {
  52. MS_EXCEPTION_IF_NULL(tensor);
  53. if (tensor->NeedWait()) {
  54. py::gil_scoped_release release;
  55. tensor->Wait();
  56. }
  57. py::tuple v(1);
  58. v[0] = tensor;
  59. return v[0];
  60. }
  61. py::object ScalarPtrToPyData(const ScalarPtr &value) {
  62. py::int_ int_v;
  63. py::float_ float_v;
  64. py::bool_ bool_v;
  65. TypeId scalar_type = value->type()->type_id();
  66. switch (scalar_type) {
  67. case kNumberTypeUInt8:
  68. MS_LOG(DEBUG) << "uint8";
  69. int_v = value->cast<UInt8ImmPtr>()->value();
  70. return std::move(int_v);
  71. case kNumberTypeUInt16:
  72. MS_LOG(DEBUG) << "uint16";
  73. int_v = value->cast<UInt16ImmPtr>()->value();
  74. return std::move(int_v);
  75. case kNumberTypeUInt32:
  76. MS_LOG(DEBUG) << "uint32";
  77. int_v = value->cast<UInt32ImmPtr>()->value();
  78. return std::move(int_v);
  79. case kNumberTypeUInt64:
  80. MS_LOG(DEBUG) << "uint64";
  81. int_v = value->cast<UInt64ImmPtr>()->value();
  82. return std::move(int_v);
  83. case kNumberTypeInt8:
  84. MS_LOG(DEBUG) << "int8";
  85. int_v = value->cast<Int8ImmPtr>()->value();
  86. return std::move(int_v);
  87. case kNumberTypeInt16:
  88. MS_LOG(DEBUG) << "int16";
  89. int_v = value->cast<Int16ImmPtr>()->value();
  90. return std::move(int_v);
  91. case kNumberTypeInt32:
  92. MS_LOG(DEBUG) << "int32";
  93. int_v = value->cast<Int32ImmPtr>()->value();
  94. return std::move(int_v);
  95. case kNumberTypeInt64:
  96. MS_LOG(DEBUG) << "int64";
  97. int_v = value->cast<Int64ImmPtr>()->value();
  98. return std::move(int_v);
  99. case kNumberTypeFloat32:
  100. MS_LOG(DEBUG) << "float";
  101. float_v = value->cast<FP32ImmPtr>()->value();
  102. return std::move(float_v);
  103. case kNumberTypeFloat64:
  104. MS_LOG(DEBUG) << "double";
  105. float_v = value->cast<FP64ImmPtr>()->value();
  106. return std::move(float_v);
  107. case kNumberTypeBool:
  108. MS_LOG(DEBUG) << "bool";
  109. bool_v = value->cast<BoolImmPtr>()->value();
  110. return std::move(bool_v);
  111. default:
  112. MS_EXCEPTION(TypeError) << "Unsupported scalar converted to py data: " << value->ToString();
  113. }
  114. }
  115. using ConverterFunction = std::function<py::object(const ValuePtr &value)>;
  116. using ValueNameToConverterVector = std::vector<std::pair<uint32_t, ConverterFunction>>;
  117. // (Value Type Name) -> (Converter Function)
  118. // The converter function is used to convert Value object to Python data object.
  119. static ValueNameToConverterVector value_name_to_converter = {
  120. // Scalar
  121. {Scalar::kTypeId, [](const ValuePtr &value) -> py::object { return ScalarPtrToPyData(value->cast<ScalarPtr>()); }},
  122. // Tensor
  123. {tensor::Tensor::kTypeId,
  124. [](const ValuePtr &value) -> py::object {
  125. auto tensor_ptr = value->cast<tensor::TensorPtr>();
  126. return TensorToPyData(tensor_ptr);
  127. }},
  128. // MetaTenser
  129. {tensor::MetaTensor::kTypeId,
  130. [](const ValuePtr &value) -> py::object {
  131. py::tuple tuple_container(1);
  132. tuple_container[0] = value->cast<tensor::MetaTensorPtr>();
  133. return tuple_container[0];
  134. }},
  135. // CSRTensor
  136. {tensor::CSRTensor::kTypeId,
  137. [](const ValuePtr &value) -> py::object {
  138. auto csr_tensor_ptr = value->cast<tensor::CSRTensorPtr>();
  139. return CSRTensorToPyData(csr_tensor_ptr);
  140. }},
  141. // RefKey
  142. {RefKey::kTypeId,
  143. [](const ValuePtr &value) -> py::object {
  144. py::tuple tuple_container(1);
  145. tuple_container[0] = value->cast<RefKeyPtr>();
  146. return tuple_container[0];
  147. }},
  148. // Type
  149. {Type::kTypeId,
  150. [](const ValuePtr &value) -> py::object {
  151. py::tuple tuple_container(1);
  152. tuple_container[0] = value->cast<TypePtr>();
  153. return tuple_container[0];
  154. }},
  155. // StringImm
  156. {StringImm::kTypeId,
  157. [](const ValuePtr &value) -> py::object {
  158. py::str res = value->cast<StringImmPtr>()->value();
  159. return res;
  160. }},
  161. // ValueSequence
  162. {ValueSequence::kTypeId,
  163. [](const ValuePtr &value) -> py::object {
  164. auto value_sequeue = value->cast<ValueSequencePtr>()->value();
  165. py::tuple res_sequeue(value_sequeue.size());
  166. for (size_t i = 0; i < value_sequeue.size(); i++) {
  167. res_sequeue[i] = ValueToPyData(value_sequeue[i]);
  168. }
  169. if (value->isa<ValueTuple>()) {
  170. return res_sequeue;
  171. }
  172. return res_sequeue.cast<py::list>();
  173. }},
  174. // ValueDictionary
  175. {ValueDictionary::kTypeId,
  176. [](const ValuePtr &value) -> py::object {
  177. auto value_list = value->cast<ValueDictionaryPtr>()->value();
  178. py::dict res_dict;
  179. for (const auto &v : value_list) {
  180. res_dict[py::str(v.first)] = ValueToPyData(v.second);
  181. }
  182. return res_dict;
  183. }},
  184. // ValueSlice
  185. {ValueSlice::kTypeId,
  186. [](const ValuePtr &value) -> py::object {
  187. auto slice = value->cast<ValueSlicePtr>();
  188. auto start = ValueToPyData(slice->start());
  189. auto end = ValueToPyData(slice->stop());
  190. auto step = ValueToPyData(slice->step());
  191. return python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end, step);
  192. }},
  193. // KeywordArg
  194. {KeywordArg::kTypeId,
  195. [](const ValuePtr &value) -> py::object {
  196. auto abs_keyword_arg = value->ToAbstract()->cast<abstract::AbstractKeywordArgPtr>();
  197. auto key = abs_keyword_arg->get_key();
  198. auto val = abs_keyword_arg->get_arg()->BuildValue();
  199. auto py_value = ValueToPyData(val);
  200. auto kwargs = py::kwargs();
  201. kwargs[key.c_str()] = py_value;
  202. return kwargs;
  203. }},
  204. // parse::NameSpace
  205. {parse::NameSpace::kTypeId,
  206. [](const ValuePtr &value) -> py::object {
  207. auto ns = value->cast<parse::NameSpacePtr>();
  208. return ns->module_obj();
  209. }},
  210. // parse::ClassType
  211. {parse::ClassType::kTypeId,
  212. [](const ValuePtr &value) -> py::object {
  213. auto class_type = value->cast<parse::ClassTypePtr>();
  214. return class_type->obj();
  215. }},
  216. // parse::InterpretedObject
  217. {parse::InterpretedObject::kTypeId,
  218. [](const ValuePtr &value) -> py::object {
  219. auto interpreted_object = value->cast<parse::InterpretedObjectPtr>();
  220. return interpreted_object->obj();
  221. }},
  222. // None
  223. {None::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  224. // AnyValue
  225. {AnyValue::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  226. // FuncGraph
  227. {FuncGraph::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  228. // Monad
  229. {Monad::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  230. // Ellipsis
  231. {Ellipsis::kTypeId, [](const ValuePtr &value) -> py::object { return py::ellipsis(); }}};
  232. py::object ValueToPyData(const ValuePtr &value) {
  233. if (value == nullptr) {
  234. MS_LOG(EXCEPTION) << "The `value` should not be null";
  235. }
  236. for (auto &iter : value_name_to_converter) {
  237. if (value->IsFromTypeId(iter.first)) {
  238. return iter.second(value);
  239. }
  240. }
  241. MS_LOG(EXCEPTION) << "Unsupported to convert " << value->ToString() << "[" << value->type_name() << "] to a PyData";
  242. }
  243. py::object AnyToPyData(const Any &value) {
  244. py::object ret;
  245. MS_LOG(DEBUG) << "AnyToPyData " << value.GetString();
  246. if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) {
  247. ret = BuiltinsToPyData(value);
  248. } else if (value.is<ValuePtr>()) {
  249. MS_LOG(DEBUG) << "ValuePtr";
  250. ValuePtr v = value.cast<ValuePtr>();
  251. ret = ValueToPyData(v);
  252. } else if (value.is<tensor::TensorPtr>()) {
  253. MS_LOG(DEBUG) << "tensor";
  254. auto tensor_ptr = value.cast<tensor::TensorPtr>();
  255. ret = TensorToPyData(tensor_ptr);
  256. } else if (value.is<py::object>()) {
  257. MS_LOG(DEBUG) << "py obj";
  258. ret = value.cast<py::object>();
  259. } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) {
  260. ret = VectorToPyData(value);
  261. } else if (value.is<std::list<Any>>()) {
  262. MS_LOG(DEBUG) << "list_any";
  263. auto value_list = value.cast<std::list<Any>>();
  264. py::list rets = py::list();
  265. for (auto &v : value_list) {
  266. rets.append(AnyToPyData(v));
  267. }
  268. ret = rets;
  269. } else if (value.is<std::vector<Any>>()) {
  270. auto value_list = value.cast<std::vector<Any>>();
  271. py::tuple rets(value_list.size());
  272. for (size_t i = 0; i < value_list.size(); i++) {
  273. rets[i] = AnyToPyData(value_list[i]);
  274. }
  275. ret = rets;
  276. } else if (value.is<TypePtr>()) {
  277. py::tuple v(1);
  278. v[0] = value.cast<TypePtr>();
  279. ret = v[0];
  280. } else {
  281. MS_LOG(EXCEPTION) << "value is not support type";
  282. }
  283. return ret;
  284. }
  285. py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &output) {
  286. py::object ret;
  287. // If output value is a tuple, check if abstract is a COOTensor in funcgraph output
  288. if (utils::isa<VectorRef>(value)) {
  289. MS_LOG(DEBUG) << "BaseRefToPyData, value is tuple: " << value.ToString();
  290. auto vec_ref = utils::cast<VectorRef>(value);
  291. ret = VectorRefToPyData(vec_ref, output);
  292. } else {
  293. ret = BaseRefToPyData(value);
  294. }
  295. return ret;
  296. }
  297. py::object BaseRefToPyData(const BaseRef &value) {
  298. py::object ret;
  299. MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString();
  300. if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
  301. ret = BuiltinsToPyData(value);
  302. } else if (utils::isa<ValuePtr>(value)) {
  303. MS_LOG(DEBUG) << "ValuePtr";
  304. ValuePtr v = utils::cast<ValuePtr>(value);
  305. ret = ValueToPyData(v);
  306. } else if (utils::isa<tensor::TensorPtr>(value)) {
  307. MS_LOG(DEBUG) << "tensor";
  308. auto tensor_ptr = utils::cast<tensor::TensorPtr>(value);
  309. ret = TensorToPyData(tensor_ptr);
  310. } else if (utils::isa<PyObjectRef>(value)) {
  311. MS_LOG(DEBUG) << "py obj";
  312. PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
  313. ret = py_ref.object_;
  314. } else if (utils::isa<VectorRef>(value)) {
  315. auto vec_ref = utils::cast<VectorRef>(value);
  316. ret = VectorRefToPyData(vec_ref);
  317. } else if (utils::isa<TypePtr>(value)) {
  318. py::tuple v(1);
  319. v[0] = utils::cast<TypePtr>(value);
  320. ret = v[0];
  321. } else {
  322. MS_LOG(EXCEPTION) << "value is not support type";
  323. }
  324. return ret;
  325. }
  326. py::object BuiltinsToPyData(const Any &value) {
  327. if (value.is<int>()) {
  328. MS_LOG(DEBUG) << "int";
  329. py::int_ ret = value.cast<int>();
  330. return std::move(ret);
  331. } else if (value.is<float>()) {
  332. MS_LOG(DEBUG) << "float";
  333. py::float_ ret = value.cast<float>();
  334. return std::move(ret);
  335. } else if (value.is<double>()) {
  336. MS_LOG(DEBUG) << "double";
  337. py::float_ ret = value.cast<double>();
  338. return std::move(ret);
  339. } else {
  340. MS_LOG(DEBUG) << "bool";
  341. py::bool_ ret = value.cast<bool>();
  342. return std::move(ret);
  343. }
  344. }
  345. py::object BuiltinsToPyData(const BaseRef &value) {
  346. if (utils::isa<int>(value)) {
  347. MS_LOG(DEBUG) << "int";
  348. py::int_ ret = utils::cast<int>(value);
  349. return std::move(ret);
  350. } else if (utils::isa<float>(value)) {
  351. MS_LOG(DEBUG) << "float";
  352. py::float_ ret = utils::cast<float>(value);
  353. return std::move(ret);
  354. } else if (utils::isa<double>(value)) {
  355. MS_LOG(DEBUG) << "double";
  356. py::float_ ret = utils::cast<double>(value);
  357. return std::move(ret);
  358. } else {
  359. MS_LOG(DEBUG) << "bool";
  360. py::bool_ ret = utils::cast<bool>(value);
  361. return std::move(ret);
  362. }
  363. }
  364. py::object VectorToPyData(const Any &value) {
  365. py::object ret;
  366. if (value.is<std::vector<tensor::TensorPtr>>()) {
  367. MS_LOG(DEBUG) << "vector_tensor";
  368. std::vector<tensor::TensorPtr> outputs;
  369. outputs = value.cast<std::vector<tensor::TensorPtr>>();
  370. py::tuple tensor_tuple(outputs.size());
  371. for (std::size_t i = 0; i < outputs.size(); ++i) {
  372. tensor_tuple[i] = *outputs[i];
  373. }
  374. ret = tensor_tuple;
  375. } else {
  376. MS_LOG(DEBUG) << "vector_any";
  377. auto value_list = value.cast<std::vector<Any>>();
  378. py::tuple any_tuple = py::tuple(value_list.size());
  379. size_t i = 0;
  380. for (auto &v : value_list) {
  381. any_tuple[i] = AnyToPyData(v);
  382. i++;
  383. }
  384. ret = any_tuple;
  385. }
  386. return ret;
  387. }
  388. py::object VectorRefToPyData(const VectorRef &value_list) {
  389. py::object ret;
  390. MS_LOG(DEBUG) << "vector_ref";
  391. size_t value_size = value_list.size();
  392. auto ref_tuple = py::tuple(value_size);
  393. for (size_t i = 0; i < value_size; i++) {
  394. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  395. }
  396. ret = ref_tuple;
  397. return ret;
  398. }
  399. py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &output) {
  400. MS_LOG(DEBUG) << "vector_ref";
  401. // Current VectorRef reflects a COOTensor type
  402. if (output->isa<abstract::AbstractCSRTensor>()) {
  403. return MakeCSRTensor(value_list);
  404. }
  405. if (output->isa<abstract::AbstractCOOTensor>()) {
  406. return MakeCOOTensor(value_list);
  407. }
  408. py::object ret;
  409. size_t value_size = value_list.size();
  410. auto ref_tuple = py::tuple(value_size);
  411. abstract::AbstractTuplePtr tuple_output = output->cast<abstract::AbstractTuplePtr>();
  412. bool is_abstract_tuple = tuple_output != nullptr;
  413. for (size_t i = 0; i < value_size; i++) {
  414. if (!is_abstract_tuple || i >= tuple_output->size()) {
  415. // Fall back to original process
  416. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  417. } else {
  418. ref_tuple[i] = BaseRefToPyData(value_list[i], (*tuple_output)[i]);
  419. }
  420. }
  421. ret = ref_tuple;
  422. return ret;
  423. }
  424. bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
  425. const std::shared_ptr<py::object> &ret_val) {
  426. if (output->isa<ValueNode>()) {
  427. MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
  428. ValuePtr value = GetValueNode(output);
  429. *ret_val = ValueToPyData(value);
  430. return true;
  431. }
  432. // Adapter will transform values in __init__() and construct() to parameters, this could cause
  433. // inputs (a.k.a args in current function) size less than parameters'.
  434. if (output->isa<Parameter>()) {
  435. MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
  436. // Find the right parameter as ret_val.
  437. auto func_graph = output->func_graph();
  438. MS_EXCEPTION_IF_NULL(func_graph);
  439. auto params = func_graph->parameters();
  440. if ((args.size() + func_graph->hyper_param_count()) != params.size()) {
  441. MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count()
  442. << " not equal to graph input size " << params.size() << ", let graph to be executed.";
  443. }
  444. auto it = std::find(params.begin(), params.end(), output);
  445. if (it == params.end()) {
  446. MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
  447. }
  448. size_t index = it - params.cbegin();
  449. if (index >= args.size() + func_graph->hyper_param_count()) {
  450. MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
  451. << " add Parameter count " << func_graph->hyper_param_count() << ".";
  452. }
  453. if (index < args.size()) {
  454. *ret_val = args[index];
  455. } else {
  456. auto param = dyn_cast<Parameter>(params[index]);
  457. MS_EXCEPTION_IF_NULL(param);
  458. if (!param->has_default()) {
  459. MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
  460. }
  461. auto tensor = param->default_param();
  462. *ret_val = py::cast(tensor);
  463. }
  464. return true;
  465. }
  466. return false;
  467. }
  468. ShapeVector ConvertToShapeVector(const ValuePtr &shape_ptr, const VectorRef &value_list, size_t shape_idx) {
  469. MS_EXCEPTION_IF_NULL(shape_ptr);
  470. ShapeVector shape;
  471. ValueTuplePtr shape_tuple = shape_ptr->cast<ValueTuplePtr>();
  472. if (shape_tuple) {
  473. for (const auto &v : shape_tuple->value()) {
  474. MS_EXCEPTION_IF_NULL(v);
  475. ScalarPtr scalar = v->cast<ScalarPtr>();
  476. MS_EXCEPTION_IF_NULL(scalar);
  477. shape.push_back(GetValue<int64_t>(scalar));
  478. }
  479. } else {
  480. auto shape_ref = utils::cast<VectorRef>(value_list[shape_idx]);
  481. MS_EXCEPTION_IF_NULL(shape_ref);
  482. for (const auto &v : shape_ref) {
  483. MS_EXCEPTION_IF_NULL(v);
  484. auto tensorptr = utils::cast<tensor::TensorPtr>(v);
  485. MS_EXCEPTION_IF_NULL(tensorptr);
  486. if (tensorptr->DataDim() != 0) {
  487. MS_LOG(EXCEPTION) << "Element in COOTensor's shape must be scalar!";
  488. }
  489. tensorptr->data_sync(false);
  490. shape.push_back(*(static_cast<int64_t *>(tensorptr->data_c())));
  491. }
  492. }
  493. return shape;
  494. }
  495. py::object MakeCSRTensor(const VectorRef &value_list) {
  496. constexpr size_t kCSRTensorInputSize{4};
  497. if (value_list.size() != kCSRTensorInputSize) {
  498. MS_LOG(EXCEPTION) << "CSRTensor must have 4 inputs.";
  499. }
  500. using TensorPtr = tensor::TensorPtr;
  501. using CSRTensor = tensor::CSRTensor;
  502. constexpr size_t kIndptrIdx{0};
  503. constexpr size_t kIndicesIdx{1};
  504. constexpr size_t kValuesIdx{2};
  505. constexpr size_t kShapeIdx{3};
  506. TensorPtr indptr = utils::cast<TensorPtr>(value_list[kIndptrIdx]);
  507. TensorPtr indices = utils::cast<TensorPtr>(value_list[kIndicesIdx]);
  508. TensorPtr values = utils::cast<TensorPtr>(value_list[kValuesIdx]);
  509. ValuePtr shape_ptr = utils::cast<ValuePtr>(value_list[kShapeIdx]);
  510. ShapeVector shape = ConvertToShapeVector(shape_ptr, value_list, kShapeIdx);
  511. auto csr_tensor_ptr = std::make_shared<CSRTensor>(indptr, indices, values, shape);
  512. return CSRTensorToPyData(csr_tensor_ptr);
  513. }
  514. py::object MakeCOOTensor(const VectorRef &value_list) {
  515. constexpr size_t kCOOTensorInputSize{3};
  516. constexpr size_t kIndicesIdx{0};
  517. constexpr size_t kValuesIdx{1};
  518. constexpr size_t kShapeIdx{2};
  519. if (value_list.size() != kCOOTensorInputSize) {
  520. MS_LOG(EXCEPTION) << "COOTensor must have " << kCOOTensorInputSize << "inputs.";
  521. }
  522. tensor::TensorPtr indices = utils::cast<tensor::TensorPtr>(value_list[kIndicesIdx]);
  523. tensor::TensorPtr values = utils::cast<tensor::TensorPtr>(value_list[kValuesIdx]);
  524. ValuePtr shape_ptr = utils::cast<ValuePtr>(value_list[kShapeIdx]);
  525. ShapeVector shape = ConvertToShapeVector(shape_ptr, value_list, kShapeIdx);
  526. auto ref = py::tuple(1);
  527. auto coo_tensor_ptr = std::make_shared<tensor::COOTensor>(indices, values, shape);
  528. ref[0] = coo_tensor_ptr;
  529. return ref[0];
  530. }
  531. } // namespace mindspore