You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils_py.cc 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/convert_utils_py.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <algorithm>
  21. #include <list>
  22. #include <utility>
  23. #include <cfloat>
  24. #include "abstract/abstract_value.h"
  25. #include "abstract/utils.h"
  26. #include "pipeline/jit/parse/parse.h"
  27. #include "pipeline/jit/parse/parse_base.h"
  28. #include "pipeline/jit/parse/resolve.h"
  29. #include "ir/value.h"
  30. #include "ir/tensor.h"
  31. #include "ir/param_info.h"
  32. #include "pybind_api/ir/base_ref_py.h"
  33. #include "utils/ms_context.h"
  34. namespace mindspore {
  35. py::object BuiltinsToPyData(const Any &value);
  36. py::object BuiltinsToPyData(const BaseRef &value);
  37. py::object VectorToPyData(const Any &value);
  38. py::object VectorRefToPyData(const VectorRef &value);
  39. py::object TensorToPyData(const tensor::TensorPtr &tensor) {
  40. MS_EXCEPTION_IF_NULL(tensor);
  41. if (tensor->NeedWait()) {
  42. py::gil_scoped_release release;
  43. tensor->Wait();
  44. }
  45. py::tuple v(1);
  46. v[0] = tensor;
  47. return v[0];
  48. }
  49. py::object ScalarPtrToPyData(const ScalarPtr &value) {
  50. py::int_ int_v;
  51. py::float_ float_v;
  52. py::bool_ bool_v;
  53. TypeId scalar_type = value->type()->type_id();
  54. switch (scalar_type) {
  55. case kNumberTypeUInt8:
  56. MS_LOG(DEBUG) << "uint8";
  57. int_v = value->cast<UInt8ImmPtr>()->value();
  58. return std::move(int_v);
  59. case kNumberTypeUInt16:
  60. MS_LOG(DEBUG) << "uint16";
  61. int_v = value->cast<UInt16ImmPtr>()->value();
  62. return std::move(int_v);
  63. case kNumberTypeUInt32:
  64. MS_LOG(DEBUG) << "uint32";
  65. int_v = value->cast<UInt32ImmPtr>()->value();
  66. return std::move(int_v);
  67. case kNumberTypeUInt64:
  68. MS_LOG(DEBUG) << "uint64";
  69. int_v = value->cast<UInt64ImmPtr>()->value();
  70. return std::move(int_v);
  71. case kNumberTypeInt8:
  72. MS_LOG(DEBUG) << "int8";
  73. int_v = value->cast<Int8ImmPtr>()->value();
  74. return std::move(int_v);
  75. case kNumberTypeInt16:
  76. MS_LOG(DEBUG) << "int16";
  77. int_v = value->cast<Int16ImmPtr>()->value();
  78. return std::move(int_v);
  79. case kNumberTypeInt32:
  80. MS_LOG(DEBUG) << "int32";
  81. int_v = value->cast<Int32ImmPtr>()->value();
  82. return std::move(int_v);
  83. case kNumberTypeInt64:
  84. MS_LOG(DEBUG) << "int64";
  85. int_v = value->cast<Int64ImmPtr>()->value();
  86. return std::move(int_v);
  87. case kNumberTypeFloat32:
  88. MS_LOG(DEBUG) << "float";
  89. float_v = value->cast<FP32ImmPtr>()->value();
  90. return std::move(float_v);
  91. case kNumberTypeFloat64:
  92. MS_LOG(DEBUG) << "double";
  93. float_v = value->cast<FP64ImmPtr>()->value();
  94. return std::move(float_v);
  95. case kNumberTypeBool:
  96. MS_LOG(DEBUG) << "bool";
  97. bool_v = value->cast<BoolImmPtr>()->value();
  98. return std::move(bool_v);
  99. default:
  100. MS_EXCEPTION(TypeError) << "Unsupported scalar converted to py data: " << value->ToString();
  101. }
  102. }
  103. using ConverterFunction = std::function<py::object(const ValuePtr &value)>;
  104. using ValueNameToConverterVector = std::vector<std::pair<uint32_t, ConverterFunction>>;
  105. // (Value Type Name) -> (Converter Function)
  106. // The converter function is used to convert Value object to Python data object.
  107. static ValueNameToConverterVector value_name_to_converter = {
  108. // Scalar
  109. {Scalar::kTypeId, [](const ValuePtr &value) -> py::object { return ScalarPtrToPyData(value->cast<ScalarPtr>()); }},
  110. // Tensor
  111. {tensor::Tensor::kTypeId,
  112. [](const ValuePtr &value) -> py::object {
  113. auto tensor_ptr = value->cast<tensor::TensorPtr>();
  114. return TensorToPyData(tensor_ptr);
  115. }},
  116. // MetaTenser
  117. {tensor::MetaTensor::kTypeId,
  118. [](const ValuePtr &value) -> py::object {
  119. py::tuple tuple_container(1);
  120. tuple_container[0] = value->cast<tensor::MetaTensorPtr>();
  121. return tuple_container[0];
  122. }},
  123. // RefKey
  124. {RefKey::kTypeId,
  125. [](const ValuePtr &value) -> py::object {
  126. py::tuple tuple_container(1);
  127. tuple_container[0] = value->cast<RefKeyPtr>();
  128. return tuple_container[0];
  129. }},
  130. // Type
  131. {Type::kTypeId,
  132. [](const ValuePtr &value) -> py::object {
  133. py::tuple tuple_container(1);
  134. tuple_container[0] = value->cast<TypePtr>();
  135. return tuple_container[0];
  136. }},
  137. // StringImm
  138. {StringImm::kTypeId,
  139. [](const ValuePtr &value) -> py::object {
  140. py::str res = value->cast<StringImmPtr>()->value();
  141. return res;
  142. }},
  143. // ValueSequeue
  144. {ValueSequeue::kTypeId,
  145. [](const ValuePtr &value) -> py::object {
  146. auto value_sequeue = value->cast<ValueSequeuePtr>()->value();
  147. py::tuple res_sequeue(value_sequeue.size());
  148. for (size_t i = 0; i < value_sequeue.size(); i++) {
  149. res_sequeue[i] = ValueToPyData(value_sequeue[i]);
  150. }
  151. if (value->isa<ValueTuple>()) {
  152. return res_sequeue;
  153. }
  154. return res_sequeue.cast<py::list>();
  155. }},
  156. // ValueDictionary
  157. {ValueDictionary::kTypeId,
  158. [](const ValuePtr &value) -> py::object {
  159. auto value_list = value->cast<ValueDictionaryPtr>()->value();
  160. py::dict res_dict;
  161. for (const auto &value : value_list) {
  162. res_dict[py::str(value.first)] = ValueToPyData(value.second);
  163. }
  164. return res_dict;
  165. }},
  166. // ValueSlice
  167. {ValueSlice::kTypeId,
  168. [](const ValuePtr &value) -> py::object {
  169. auto slice = value->cast<ValueSlicePtr>();
  170. auto start = ValueToPyData(slice->start());
  171. auto end = ValueToPyData(slice->stop());
  172. auto step = ValueToPyData(slice->step());
  173. return parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
  174. step);
  175. }},
  176. // KeywordArg
  177. {KeywordArg::kTypeId,
  178. [](const ValuePtr &value) -> py::object {
  179. auto abs_keyword_arg = value->ToAbstract()->cast<abstract::AbstractKeywordArgPtr>();
  180. auto key = abs_keyword_arg->get_key();
  181. auto val = abs_keyword_arg->get_arg()->BuildValue();
  182. auto py_value = ValueToPyData(val);
  183. auto kwargs = py::kwargs();
  184. kwargs[key.c_str()] = py_value;
  185. return kwargs;
  186. }},
  187. // parse::NameSpace
  188. {parse::NameSpace::kTypeId,
  189. [](const ValuePtr &value) -> py::object {
  190. auto ns = value->cast<parse::NameSpacePtr>();
  191. return ns->module_obj();
  192. }},
  193. // parse::ClassType
  194. {parse::ClassType::kTypeId,
  195. [](const ValuePtr &value) -> py::object {
  196. auto class_type = value->cast<parse::ClassTypePtr>();
  197. return class_type->obj();
  198. }},
  199. // parse::InterpretedObject
  200. {parse::InterpretedObject::kTypeId,
  201. [](const ValuePtr &value) -> py::object {
  202. auto interpreted_object = value->cast<parse::InterpretedObjectPtr>();
  203. return interpreted_object->obj();
  204. }},
  205. // None
  206. {None::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  207. // AnyValue
  208. {AnyValue::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  209. // FuncGraph
  210. {FuncGraph::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  211. // Monad
  212. {Monad::kTypeId, [](const ValuePtr &value) -> py::object { return py::none(); }},
  213. // Ellipsis
  214. {Ellipsis::kTypeId, [](const ValuePtr &value) -> py::object { return py::ellipsis(); }}};
  215. py::object ValueToPyData(const ValuePtr &value) {
  216. if (value == nullptr) {
  217. MS_LOG(EXCEPTION) << "The `value` should not be null";
  218. }
  219. for (auto &iter : value_name_to_converter) {
  220. if (value->IsFromTypeId(iter.first)) {
  221. return iter.second(value);
  222. }
  223. }
  224. MS_LOG(EXCEPTION) << "Unsupported to convert " << value->ToString() << "[" << value->type_name() << "] to a PyData";
  225. }
  226. py::object AnyToPyData(const Any &value) {
  227. py::object ret;
  228. MS_LOG(DEBUG) << "AnyToPyData " << value.GetString();
  229. if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) {
  230. ret = BuiltinsToPyData(value);
  231. } else if (value.is<ValuePtr>()) {
  232. MS_LOG(DEBUG) << "ValuePtr";
  233. ValuePtr v = value.cast<ValuePtr>();
  234. ret = ValueToPyData(v);
  235. } else if (value.is<tensor::TensorPtr>()) {
  236. MS_LOG(DEBUG) << "tensor";
  237. auto tensor_ptr = value.cast<tensor::TensorPtr>();
  238. ret = TensorToPyData(tensor_ptr);
  239. } else if (value.is<py::object>()) {
  240. MS_LOG(DEBUG) << "py obj";
  241. ret = value.cast<py::object>();
  242. } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) {
  243. ret = VectorToPyData(value);
  244. } else if (value.is<std::list<Any>>()) {
  245. MS_LOG(DEBUG) << "list_any";
  246. auto value_list = value.cast<std::list<Any>>();
  247. py::list rets = py::list();
  248. for (auto &v : value_list) {
  249. rets.append(AnyToPyData(v));
  250. }
  251. ret = rets;
  252. } else if (value.is<std::vector<Any>>()) {
  253. auto value_list = value.cast<std::vector<Any>>();
  254. py::tuple rets(value_list.size());
  255. for (size_t i = 0; i < value_list.size(); i++) {
  256. rets[i] = AnyToPyData(value_list[i]);
  257. }
  258. ret = rets;
  259. } else if (value.is<TypePtr>()) {
  260. py::tuple v(1);
  261. v[0] = value.cast<TypePtr>();
  262. ret = v[0];
  263. } else {
  264. MS_LOG(EXCEPTION) << "value is not support type";
  265. }
  266. return ret;
  267. }
  268. py::object BaseRefToPyData(const BaseRef &value) {
  269. py::object ret;
  270. MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString();
  271. if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
  272. ret = BuiltinsToPyData(value);
  273. } else if (utils::isa<ValuePtr>(value)) {
  274. MS_LOG(DEBUG) << "ValuePtr";
  275. ValuePtr v = utils::cast<ValuePtr>(value);
  276. ret = ValueToPyData(v);
  277. } else if (utils::isa<tensor::TensorPtr>(value)) {
  278. MS_LOG(DEBUG) << "tensor";
  279. auto tensor_ptr = utils::cast<tensor::TensorPtr>(value);
  280. ret = TensorToPyData(tensor_ptr);
  281. } else if (utils::isa<PyObjectRef>(value)) {
  282. MS_LOG(DEBUG) << "py obj";
  283. PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
  284. ret = py_ref.object_;
  285. } else if (utils::isa<VectorRef>(value)) {
  286. auto vec_ref = utils::cast<VectorRef>(value);
  287. ret = VectorRefToPyData(vec_ref);
  288. } else if (utils::isa<TypePtr>(value)) {
  289. py::tuple v(1);
  290. v[0] = utils::cast<TypePtr>(value);
  291. ret = v[0];
  292. } else {
  293. MS_LOG(EXCEPTION) << "value is not support type";
  294. }
  295. return ret;
  296. }
  297. py::object BuiltinsToPyData(const Any &value) {
  298. if (value.is<int>()) {
  299. MS_LOG(DEBUG) << "int";
  300. py::int_ ret = value.cast<int>();
  301. return std::move(ret);
  302. } else if (value.is<float>()) {
  303. MS_LOG(DEBUG) << "float";
  304. py::float_ ret = value.cast<float>();
  305. return std::move(ret);
  306. } else if (value.is<double>()) {
  307. MS_LOG(DEBUG) << "double";
  308. py::float_ ret = value.cast<double>();
  309. return std::move(ret);
  310. } else {
  311. MS_LOG(DEBUG) << "bool";
  312. py::bool_ ret = value.cast<bool>();
  313. return std::move(ret);
  314. }
  315. }
  316. py::object BuiltinsToPyData(const BaseRef &value) {
  317. if (utils::isa<int>(value)) {
  318. MS_LOG(DEBUG) << "int";
  319. py::int_ ret = utils::cast<int>(value);
  320. return std::move(ret);
  321. } else if (utils::isa<float>(value)) {
  322. MS_LOG(DEBUG) << "float";
  323. py::float_ ret = utils::cast<float>(value);
  324. return std::move(ret);
  325. } else if (utils::isa<double>(value)) {
  326. MS_LOG(DEBUG) << "double";
  327. py::float_ ret = utils::cast<double>(value);
  328. return std::move(ret);
  329. } else {
  330. MS_LOG(DEBUG) << "bool";
  331. py::bool_ ret = utils::cast<bool>(value);
  332. return std::move(ret);
  333. }
  334. }
  335. py::object VectorToPyData(const Any &value) {
  336. py::object ret;
  337. if (value.is<std::vector<tensor::TensorPtr>>()) {
  338. MS_LOG(DEBUG) << "vector_tensor";
  339. std::vector<tensor::TensorPtr> outputs;
  340. outputs = value.cast<std::vector<tensor::TensorPtr>>();
  341. py::tuple tensor_tuple(outputs.size());
  342. for (std::size_t i = 0; i < outputs.size(); ++i) {
  343. tensor_tuple[i] = *outputs[i];
  344. }
  345. ret = tensor_tuple;
  346. } else {
  347. MS_LOG(DEBUG) << "vector_any";
  348. auto value_list = value.cast<std::vector<Any>>();
  349. py::tuple any_tuple = py::tuple(value_list.size());
  350. size_t i = 0;
  351. for (auto &v : value_list) {
  352. any_tuple[i] = AnyToPyData(v);
  353. i++;
  354. }
  355. ret = any_tuple;
  356. }
  357. return ret;
  358. }
  359. py::object VectorRefToPyData(const VectorRef &value_list) {
  360. py::object ret;
  361. MS_LOG(DEBUG) << "vector_ref";
  362. size_t value_size = value_list.size();
  363. auto ref_tuple = py::tuple(value_size);
  364. for (size_t i = 0; i < value_size; i++) {
  365. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  366. }
  367. ret = ref_tuple;
  368. return ret;
  369. }
  370. void SetValueRange(const AbstractBasePtr &tensor, const py::object &output) {
  371. if (output.is_none()) {
  372. return;
  373. }
  374. py::object obj_min =
  375. output.contains(py::str(ATTR_MIN_VALUE)) ? (py::object)output[ATTR_MIN_VALUE] : (py::object)py::none();
  376. py::object obj_max =
  377. output.contains(py::str(ATTR_MAX_VALUE)) ? (py::object)output[ATTR_MAX_VALUE] : (py::object)py::none();
  378. if (!obj_min.is_none() && !obj_max.is_none()) {
  379. bool converted = true;
  380. ValuePtr min_value = nullptr;
  381. ValuePtr max_value = nullptr;
  382. converted = parse::ConvertData(obj_min, &min_value);
  383. if (!converted) {
  384. MS_LOG(EXCEPTION) << "Convert shape min value data failed";
  385. }
  386. converted = parse::ConvertData(obj_max, &max_value);
  387. if (!converted) {
  388. MS_LOG(EXCEPTION) << "Convert shape max value data failed";
  389. }
  390. auto abs_tensor = dyn_cast<abstract::AbstractTensor>(tensor);
  391. abs_tensor->set_value_range(min_value, max_value);
  392. }
  393. }
  394. AbstractBasePtr MakePyInferRes2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
  395. const py::object &output) {
  396. auto ret_vec = shape_obj.cast<ShapeVector>();
  397. auto ret_dtype = type_obj.cast<TypePtr>();
  398. ShapeVector min_shape_vec;
  399. ShapeVector max_shape_vec;
  400. if (!output.is_none()) {
  401. py::object min_shape =
  402. output.contains(py::str(ATTR_MIN_SHAPE)) ? (py::object)output[ATTR_MIN_SHAPE] : (py::object)py::none();
  403. py::object max_shape =
  404. output.contains(py::str(ATTR_MAX_SHAPE)) ? (py::object)output[ATTR_MAX_SHAPE] : (py::object)py::none();
  405. if (!min_shape.is_none()) {
  406. min_shape_vec = min_shape.cast<ShapeVector>();
  407. }
  408. if (!max_shape.is_none()) {
  409. max_shape_vec = max_shape.cast<ShapeVector>();
  410. }
  411. }
  412. auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
  413. AbstractBasePtr tensor = MakeAbstractTensor(ret_shape, ret_dtype);
  414. SetValueRange(tensor, output);
  415. return tensor;
  416. }
  417. static bool IsMonadType(const py::object &type_obj) {
  418. if (py::isinstance<Type>(type_obj)) {
  419. auto type = type_obj.cast<Type *>();
  420. return type->isa<MonadType>();
  421. }
  422. return false;
  423. }
  424. static AbstractBasePtr ToMonadAbstract(const py::object &type_obj) {
  425. if (py::isinstance<Type>(type_obj)) {
  426. auto type = type_obj.cast<Type *>();
  427. if (!type->isa<MonadType>()) {
  428. MS_LOG(EXCEPTION) << "Not a monad type object: " << py::str(type_obj);
  429. }
  430. return abstract::MakeMonadAbstract(type->cast<MonadTypePtr>());
  431. }
  432. MS_LOG(EXCEPTION) << "Not a type object: " << py::str(type_obj);
  433. }
  434. AbstractBasePtr MakePyInferRes2Abstract(const py::object &shape_obj, const py::object &type_obj,
  435. const py::object &output) {
  436. if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) {
  437. auto ret_vec = shape_obj.cast<ShapeVector>();
  438. auto ret_dtype = type_obj.cast<TypePtr>();
  439. MS_EXCEPTION_IF_NULL(ret_dtype);
  440. // if the size of shape list is empty, return an scalar abstract
  441. if (ret_vec.empty() && (!ret_dtype->isa<TensorType>())) {
  442. abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
  443. return abs_scalar;
  444. }
  445. return MakePyInferRes2AbstractTensor(shape_obj, type_obj, output);
  446. } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
  447. auto shape_tuple = shape_obj.cast<py::tuple>();
  448. auto typeid_tuple = type_obj.cast<py::tuple>();
  449. AbstractBasePtrList ptr_list;
  450. for (size_t it = 0; it < shape_tuple.size(); ++it) {
  451. auto tensor_it = MakePyInferRes2Abstract(shape_tuple[it], typeid_tuple[it]);
  452. ptr_list.push_back(tensor_it);
  453. }
  454. auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
  455. return tuple;
  456. } else if (py::isinstance<py::list>(shape_obj) && py::isinstance<py::list>(type_obj)) {
  457. auto shape_list = shape_obj.cast<py::list>();
  458. auto typeid_list = type_obj.cast<py::list>();
  459. AbstractBasePtrList ptr_list;
  460. for (size_t it = 0; it < shape_list.size(); ++it) {
  461. auto tensor_it = MakePyInferRes2Abstract(shape_list[it], typeid_list[it]);
  462. ptr_list.push_back(tensor_it);
  463. }
  464. auto list = std::make_shared<abstract::AbstractList>(ptr_list);
  465. return list;
  466. } else if (shape_obj.is_none() && type_obj.is_none()) {
  467. // AbstractNone indicates there is no output for this CNode node.
  468. auto abstract_none = std::make_shared<abstract::AbstractNone>();
  469. return abstract_none;
  470. } else if (IsMonadType(type_obj)) {
  471. // Return monad abstract if it is monad type.
  472. return ToMonadAbstract(type_obj);
  473. } else {
  474. // When sparse enabled, the undetermined might be raised and eliminated in opt passes
  475. auto context = MsContext::GetInstance();
  476. MS_EXCEPTION_IF_NULL(context);
  477. bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
  478. if (enable_sparse) {
  479. return std::make_shared<abstract::AbstractUndetermined>();
  480. }
  481. MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
  482. }
  483. }
  484. bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
  485. const std::shared_ptr<py::object> &ret_val) {
  486. if (output->isa<ValueNode>()) {
  487. MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
  488. ValuePtr value = GetValueNode(output);
  489. *ret_val = ValueToPyData(value);
  490. return true;
  491. }
  492. // Adapter will transform values in __init__() and construct() to parameters, this could cause
  493. // inputs (a.k.a args in current function) size less than parameters'.
  494. if (output->isa<Parameter>()) {
  495. MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
  496. // Find the right parameter as ret_val.
  497. auto func_graph = output->func_graph();
  498. MS_EXCEPTION_IF_NULL(func_graph);
  499. auto params = func_graph->parameters();
  500. if ((args.size() + func_graph->hyper_param_count()) != params.size()) {
  501. MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count()
  502. << " not equal to graph input size " << params.size() << ", let graph to be executed.";
  503. }
  504. auto it = std::find(params.begin(), params.end(), output);
  505. if (it == params.end()) {
  506. MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
  507. }
  508. size_t index = it - params.cbegin();
  509. if (index >= args.size() + func_graph->hyper_param_count()) {
  510. MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
  511. << " add Parameter count " << func_graph->hyper_param_count() << ".";
  512. }
  513. if (index < args.size()) {
  514. *ret_val = args[index];
  515. } else {
  516. auto param = dyn_cast<Parameter>(params[index]);
  517. MS_EXCEPTION_IF_NULL(param);
  518. if (!param->has_default()) {
  519. MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
  520. }
  521. auto tensor = param->default_param();
  522. *ret_val = py::cast(tensor);
  523. }
  524. return true;
  525. }
  526. return false;
  527. }
  528. } // namespace mindspore