You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils_py.cc 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/convert_utils_py.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <algorithm>
  21. #include <list>
  22. #include <utility>
  23. #include <cfloat>
  24. #include "abstract/abstract_value.h"
  25. #include "abstract/utils.h"
  26. #include "pipeline/jit/parse/parse.h"
  27. #include "pipeline/jit/parse/parse_base.h"
  28. #include "pipeline/jit/parse/resolve.h"
  29. #include "ir/value.h"
  30. #include "ir/tensor.h"
  31. #include "ir/param_info.h"
  32. #include "pybind_api/ir/base_ref_py.h"
  33. #include "utils/ms_context.h"
  34. namespace mindspore {
  35. py::object BuiltinsToPyData(const Any &value);
  36. py::object BuiltinsToPyData(const BaseRef &value);
  37. py::object VectorToPyData(const Any &value);
  38. py::object VectorRefToPyData(const VectorRef &value);
  39. py::object TensorToPyData(const tensor::TensorPtr &tensor) {
  40. MS_EXCEPTION_IF_NULL(tensor);
  41. if (tensor->NeedWait()) {
  42. py::gil_scoped_release release;
  43. tensor->Wait();
  44. }
  45. py::tuple v(1);
  46. v[0] = tensor;
  47. return v[0];
  48. }
  49. py::object ScalarPtrToPyData(const ScalarPtr &value) {
  50. py::int_ int_v;
  51. py::float_ float_v;
  52. py::bool_ bool_v;
  53. TypeId scalar_type = value->type()->type_id();
  54. switch (scalar_type) {
  55. case kNumberTypeUInt8:
  56. MS_LOG(DEBUG) << "uint8";
  57. int_v = value->cast<UInt8ImmPtr>()->value();
  58. return std::move(int_v);
  59. case kNumberTypeUInt16:
  60. MS_LOG(DEBUG) << "uint16";
  61. int_v = value->cast<UInt16ImmPtr>()->value();
  62. return std::move(int_v);
  63. case kNumberTypeUInt32:
  64. MS_LOG(DEBUG) << "uint32";
  65. int_v = value->cast<UInt32ImmPtr>()->value();
  66. return std::move(int_v);
  67. case kNumberTypeUInt64:
  68. MS_LOG(DEBUG) << "uint64";
  69. int_v = value->cast<UInt64ImmPtr>()->value();
  70. return std::move(int_v);
  71. case kNumberTypeInt8:
  72. MS_LOG(DEBUG) << "int8";
  73. int_v = value->cast<Int8ImmPtr>()->value();
  74. return std::move(int_v);
  75. case kNumberTypeInt16:
  76. MS_LOG(DEBUG) << "int16";
  77. int_v = value->cast<Int16ImmPtr>()->value();
  78. return std::move(int_v);
  79. case kNumberTypeInt32:
  80. MS_LOG(DEBUG) << "int32";
  81. int_v = value->cast<Int32ImmPtr>()->value();
  82. return std::move(int_v);
  83. case kNumberTypeInt64:
  84. MS_LOG(DEBUG) << "int64";
  85. int_v = value->cast<Int64ImmPtr>()->value();
  86. return std::move(int_v);
  87. case kNumberTypeFloat32:
  88. MS_LOG(DEBUG) << "float";
  89. float_v = value->cast<FP32ImmPtr>()->value();
  90. return std::move(float_v);
  91. case kNumberTypeFloat64:
  92. MS_LOG(DEBUG) << "double";
  93. float_v = value->cast<FP64ImmPtr>()->value();
  94. return std::move(float_v);
  95. case kNumberTypeBool:
  96. MS_LOG(DEBUG) << "bool";
  97. bool_v = value->cast<BoolImmPtr>()->value();
  98. return std::move(bool_v);
  99. default:
  100. MS_EXCEPTION(TypeError) << "Unsupported scalar converted to py data: " << value->ToString();
  101. }
  102. }
  103. using ConverterFunction = std::function<py::object(const ValuePtr &value)>;
  104. using ValueNameToConverterVector = std::vector<std::pair<const char *, ConverterFunction>>;
  105. // (Value Type Name) -> (Converter Function)
  106. // The converter function is used to convert Value object to Python data object.
  107. static ValueNameToConverterVector value_name_to_converter = {
  108. // Scalar
  109. {typeid(Scalar).name(),
  110. [](const ValuePtr &value) -> py::object { return ScalarPtrToPyData(value->cast<ScalarPtr>()); }},
  111. // Tensor
  112. {typeid(tensor::Tensor).name(),
  113. [](const ValuePtr &value) -> py::object {
  114. auto tensor_ptr = value->cast<tensor::TensorPtr>();
  115. return TensorToPyData(tensor_ptr);
  116. }},
  117. // MetaTenser
  118. {typeid(tensor::MetaTensor).name(),
  119. [](const ValuePtr &value) -> py::object {
  120. py::tuple tuple_container(1);
  121. tuple_container[0] = value->cast<tensor::MetaTensorPtr>();
  122. return tuple_container[0];
  123. }},
  124. // RefKey
  125. {typeid(RefKey).name(),
  126. [](const ValuePtr &value) -> py::object {
  127. py::tuple tuple_container(1);
  128. tuple_container[0] = value->cast<RefKeyPtr>();
  129. return tuple_container[0];
  130. }},
  131. // Type
  132. {typeid(Type).name(),
  133. [](const ValuePtr &value) -> py::object {
  134. py::tuple tuple_container(1);
  135. tuple_container[0] = value->cast<TypePtr>();
  136. return tuple_container[0];
  137. }},
  138. // StringImm
  139. {typeid(StringImm).name(),
  140. [](const ValuePtr &value) -> py::object {
  141. py::str res = value->cast<StringImmPtr>()->value();
  142. return res;
  143. }},
  144. // ValueSequeue
  145. {typeid(ValueSequeue).name(),
  146. [](const ValuePtr &value) -> py::object {
  147. auto value_sequeue = value->cast<ValueSequeuePtr>()->value();
  148. py::tuple res_sequeue(value_sequeue.size());
  149. for (size_t i = 0; i < value_sequeue.size(); i++) {
  150. res_sequeue[i] = ValueToPyData(value_sequeue[i]);
  151. }
  152. if (value->isa<ValueTuple>()) {
  153. return res_sequeue;
  154. }
  155. return res_sequeue.cast<py::list>();
  156. }},
  157. // ValueDictionary
  158. {typeid(ValueDictionary).name(),
  159. [](const ValuePtr &value) -> py::object {
  160. auto value_list = value->cast<ValueDictionaryPtr>()->value();
  161. py::dict res_dict;
  162. for (const auto &value : value_list) {
  163. res_dict[py::str(value.first)] = ValueToPyData(value.second);
  164. }
  165. return res_dict;
  166. }},
  167. // ValueSlice
  168. {typeid(ValueSlice).name(),
  169. [](const ValuePtr &value) -> py::object {
  170. auto slice = value->cast<ValueSlicePtr>();
  171. auto start = ValueToPyData(slice->start());
  172. auto end = ValueToPyData(slice->stop());
  173. auto step = ValueToPyData(slice->step());
  174. return parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
  175. step);
  176. }},
  177. // KeywordArg
  178. {typeid(KeywordArg).name(),
  179. [](const ValuePtr &value) -> py::object {
  180. auto abs_keyword_arg = value->ToAbstract()->cast<abstract::AbstractKeywordArgPtr>();
  181. auto key = abs_keyword_arg->get_key();
  182. auto val = abs_keyword_arg->get_arg()->BuildValue();
  183. auto py_value = ValueToPyData(val);
  184. auto kwargs = py::kwargs();
  185. kwargs[key.c_str()] = py_value;
  186. return kwargs;
  187. }},
  188. // parse::NameSpace
  189. {typeid(parse::NameSpace).name(),
  190. [](const ValuePtr &value) -> py::object {
  191. auto ns = value->cast<parse::NameSpacePtr>();
  192. return ns->module_obj();
  193. }},
  194. // parse::ClassType
  195. {typeid(parse::ClassType).name(),
  196. [](const ValuePtr &value) -> py::object {
  197. auto class_type = value->cast<parse::ClassTypePtr>();
  198. return class_type->obj();
  199. }},
  200. // None
  201. {typeid(None).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
  202. // AnyValue
  203. {typeid(AnyValue).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
  204. // FuncGraph
  205. {typeid(FuncGraph).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
  206. // Monad
  207. {typeid(Monad).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
  208. // Ellipsis
  209. {typeid(Ellipsis).name(), [](const ValuePtr &value) -> py::object { return py::ellipsis(); }}};
  210. py::object ValueToPyData(const ValuePtr &value) {
  211. if (value == nullptr) {
  212. MS_LOG(EXCEPTION) << "The `value` should not be null";
  213. }
  214. for (auto &iter : value_name_to_converter) {
  215. if (value->IsFromTypeId(Base::GetTypeId(iter.first))) {
  216. return iter.second(value);
  217. }
  218. }
  219. MS_LOG(EXCEPTION) << "Unsupported to convert " << value->ToString() << "[" << value->type_name() << "] to a PyData";
  220. }
  221. py::object AnyToPyData(const Any &value) {
  222. py::object ret;
  223. MS_LOG(DEBUG) << "AnyToPyData " << value.GetString();
  224. if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) {
  225. ret = BuiltinsToPyData(value);
  226. } else if (value.is<ValuePtr>()) {
  227. MS_LOG(DEBUG) << "ValuePtr";
  228. ValuePtr v = value.cast<ValuePtr>();
  229. ret = ValueToPyData(v);
  230. } else if (value.is<tensor::TensorPtr>()) {
  231. MS_LOG(DEBUG) << "tensor";
  232. auto tensor_ptr = value.cast<tensor::TensorPtr>();
  233. ret = TensorToPyData(tensor_ptr);
  234. } else if (value.is<py::object>()) {
  235. MS_LOG(DEBUG) << "py obj";
  236. ret = value.cast<py::object>();
  237. } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) {
  238. ret = VectorToPyData(value);
  239. } else if (value.is<std::list<Any>>()) {
  240. MS_LOG(DEBUG) << "list_any";
  241. auto value_list = value.cast<std::list<Any>>();
  242. py::list rets = py::list();
  243. for (auto &v : value_list) {
  244. rets.append(AnyToPyData(v));
  245. }
  246. ret = rets;
  247. } else if (value.is<std::vector<Any>>()) {
  248. auto value_list = value.cast<std::vector<Any>>();
  249. py::tuple rets(value_list.size());
  250. for (size_t i = 0; i < value_list.size(); i++) {
  251. rets[i] = AnyToPyData(value_list[i]);
  252. }
  253. ret = rets;
  254. } else if (value.is<TypePtr>()) {
  255. py::tuple v(1);
  256. v[0] = value.cast<TypePtr>();
  257. ret = v[0];
  258. } else {
  259. MS_LOG(EXCEPTION) << "value is not support type";
  260. }
  261. return ret;
  262. }
  263. py::object BaseRefToPyData(const BaseRef &value) {
  264. py::object ret;
  265. MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString();
  266. if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
  267. ret = BuiltinsToPyData(value);
  268. } else if (utils::isa<ValuePtr>(value)) {
  269. MS_LOG(DEBUG) << "ValuePtr";
  270. ValuePtr v = utils::cast<ValuePtr>(value);
  271. ret = ValueToPyData(v);
  272. } else if (utils::isa<tensor::TensorPtr>(value)) {
  273. MS_LOG(DEBUG) << "tensor";
  274. auto tensor_ptr = utils::cast<tensor::TensorPtr>(value);
  275. ret = TensorToPyData(tensor_ptr);
  276. } else if (utils::isa<PyObjectRef>(value)) {
  277. MS_LOG(DEBUG) << "py obj";
  278. PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
  279. ret = py_ref.object_;
  280. } else if (utils::isa<VectorRef>(value)) {
  281. auto vec_ref = utils::cast<VectorRef>(value);
  282. ret = VectorRefToPyData(vec_ref);
  283. } else if (utils::isa<TypePtr>(value)) {
  284. py::tuple v(1);
  285. v[0] = utils::cast<TypePtr>(value);
  286. ret = v[0];
  287. } else {
  288. MS_LOG(EXCEPTION) << "value is not support type";
  289. }
  290. return ret;
  291. }
  292. py::object BuiltinsToPyData(const Any &value) {
  293. if (value.is<int>()) {
  294. MS_LOG(DEBUG) << "int";
  295. py::int_ ret = value.cast<int>();
  296. return std::move(ret);
  297. } else if (value.is<float>()) {
  298. MS_LOG(DEBUG) << "float";
  299. py::float_ ret = value.cast<float>();
  300. return std::move(ret);
  301. } else if (value.is<double>()) {
  302. MS_LOG(DEBUG) << "double";
  303. py::float_ ret = value.cast<double>();
  304. return std::move(ret);
  305. } else {
  306. MS_LOG(DEBUG) << "bool";
  307. py::bool_ ret = value.cast<bool>();
  308. return std::move(ret);
  309. }
  310. }
  311. py::object BuiltinsToPyData(const BaseRef &value) {
  312. if (utils::isa<int>(value)) {
  313. MS_LOG(DEBUG) << "int";
  314. py::int_ ret = utils::cast<int>(value);
  315. return std::move(ret);
  316. } else if (utils::isa<float>(value)) {
  317. MS_LOG(DEBUG) << "float";
  318. py::float_ ret = utils::cast<float>(value);
  319. return std::move(ret);
  320. } else if (utils::isa<double>(value)) {
  321. MS_LOG(DEBUG) << "double";
  322. py::float_ ret = utils::cast<double>(value);
  323. return std::move(ret);
  324. } else {
  325. MS_LOG(DEBUG) << "bool";
  326. py::bool_ ret = utils::cast<bool>(value);
  327. return std::move(ret);
  328. }
  329. }
  330. py::object VectorToPyData(const Any &value) {
  331. py::object ret;
  332. if (value.is<std::vector<tensor::TensorPtr>>()) {
  333. MS_LOG(DEBUG) << "vector_tensor";
  334. std::vector<tensor::TensorPtr> outputs;
  335. outputs = value.cast<std::vector<tensor::TensorPtr>>();
  336. py::tuple tensor_tuple(outputs.size());
  337. for (std::size_t i = 0; i < outputs.size(); ++i) {
  338. tensor_tuple[i] = *outputs[i];
  339. }
  340. ret = tensor_tuple;
  341. } else {
  342. MS_LOG(DEBUG) << "vector_any";
  343. auto value_list = value.cast<std::vector<Any>>();
  344. py::tuple any_tuple = py::tuple(value_list.size());
  345. size_t i = 0;
  346. for (auto &v : value_list) {
  347. any_tuple[i] = AnyToPyData(v);
  348. i++;
  349. }
  350. ret = any_tuple;
  351. }
  352. return ret;
  353. }
  354. py::object VectorRefToPyData(const VectorRef &value_list) {
  355. py::object ret;
  356. MS_LOG(DEBUG) << "vector_ref";
  357. size_t value_size = value_list.size();
  358. auto ref_tuple = py::tuple(value_size);
  359. for (size_t i = 0; i < value_size; i++) {
  360. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  361. }
  362. ret = ref_tuple;
  363. return ret;
  364. }
  365. void SetValueRange(const AbstractBasePtr &tensor, const py::object &output) {
  366. if (output.is_none()) {
  367. return;
  368. }
  369. py::object obj_min =
  370. output.contains(py::str(ATTR_MIN_VALUE)) ? (py::object)output[ATTR_MIN_VALUE] : (py::object)py::none();
  371. py::object obj_max =
  372. output.contains(py::str(ATTR_MAX_VALUE)) ? (py::object)output[ATTR_MAX_VALUE] : (py::object)py::none();
  373. if (!obj_min.is_none() && !obj_max.is_none()) {
  374. bool converted = true;
  375. ValuePtr min_value = nullptr;
  376. ValuePtr max_value = nullptr;
  377. converted = parse::ConvertData(obj_min, &min_value);
  378. if (!converted) {
  379. MS_LOG(EXCEPTION) << "Convert shape min value data failed";
  380. }
  381. converted = parse::ConvertData(obj_max, &max_value);
  382. if (!converted) {
  383. MS_LOG(EXCEPTION) << "Convert shape max value data failed";
  384. }
  385. auto abs_tensor = dyn_cast<abstract::AbstractTensor>(tensor);
  386. abs_tensor->set_value_range(min_value, max_value);
  387. }
  388. }
  389. AbstractBasePtr MakePyInferRes2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
  390. const py::object &output) {
  391. auto ret_vec = shape_obj.cast<ShapeVector>();
  392. auto ret_dtype = type_obj.cast<TypePtr>();
  393. ShapeVector min_shape_vec;
  394. ShapeVector max_shape_vec;
  395. if (!output.is_none()) {
  396. py::object min_shape =
  397. output.contains(py::str(ATTR_MIN_SHAPE)) ? (py::object)output[ATTR_MIN_SHAPE] : (py::object)py::none();
  398. py::object max_shape =
  399. output.contains(py::str(ATTR_MAX_SHAPE)) ? (py::object)output[ATTR_MAX_SHAPE] : (py::object)py::none();
  400. if (!min_shape.is_none()) {
  401. min_shape_vec = min_shape.cast<ShapeVector>();
  402. }
  403. if (!max_shape.is_none()) {
  404. max_shape_vec = max_shape.cast<ShapeVector>();
  405. }
  406. }
  407. auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
  408. AbstractBasePtr tensor = MakeAbstractTensor(ret_shape, ret_dtype);
  409. SetValueRange(tensor, output);
  410. return tensor;
  411. }
  412. static bool IsMonadType(const py::object &type_obj) {
  413. if (py::isinstance<Type>(type_obj)) {
  414. auto type = type_obj.cast<Type *>();
  415. return type->isa<MonadType>();
  416. }
  417. return false;
  418. }
  419. static AbstractBasePtr ToMonadAbstract(const py::object &type_obj) {
  420. if (py::isinstance<Type>(type_obj)) {
  421. auto type = type_obj.cast<Type *>();
  422. if (!type->isa<MonadType>()) {
  423. MS_LOG(EXCEPTION) << "Not a monad type object: " << py::str(type_obj);
  424. }
  425. return abstract::MakeMonadAbstract(type->cast<MonadTypePtr>());
  426. }
  427. MS_LOG(EXCEPTION) << "Not a type object: " << py::str(type_obj);
  428. }
  429. AbstractBasePtr MakePyInferRes2Abstract(const py::object &shape_obj, const py::object &type_obj,
  430. const py::object &output) {
  431. if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) {
  432. auto ret_vec = shape_obj.cast<ShapeVector>();
  433. auto ret_dtype = type_obj.cast<TypePtr>();
  434. MS_EXCEPTION_IF_NULL(ret_dtype);
  435. // if the size of shape list is empty, return an scalar abstract
  436. if (ret_vec.empty() && (!ret_dtype->isa<TensorType>())) {
  437. abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
  438. return abs_scalar;
  439. }
  440. return MakePyInferRes2AbstractTensor(shape_obj, type_obj, output);
  441. } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
  442. auto shape_tuple = shape_obj.cast<py::tuple>();
  443. auto typeid_tuple = type_obj.cast<py::tuple>();
  444. AbstractBasePtrList ptr_list;
  445. for (size_t it = 0; it < shape_tuple.size(); ++it) {
  446. auto tensor_it = MakePyInferRes2Abstract(shape_tuple[it], typeid_tuple[it]);
  447. ptr_list.push_back(tensor_it);
  448. }
  449. auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
  450. return tuple;
  451. } else if (py::isinstance<py::list>(shape_obj) && py::isinstance<py::list>(type_obj)) {
  452. auto shape_list = shape_obj.cast<py::list>();
  453. auto typeid_list = type_obj.cast<py::list>();
  454. AbstractBasePtrList ptr_list;
  455. for (size_t it = 0; it < shape_list.size(); ++it) {
  456. auto tensor_it = MakePyInferRes2Abstract(shape_list[it], typeid_list[it]);
  457. ptr_list.push_back(tensor_it);
  458. }
  459. auto list = std::make_shared<abstract::AbstractList>(ptr_list);
  460. return list;
  461. } else if (shape_obj.is_none() && type_obj.is_none()) {
  462. // AbstractNone indicates there is no output for this CNode node.
  463. auto abstract_none = std::make_shared<abstract::AbstractNone>();
  464. return abstract_none;
  465. } else if (IsMonadType(type_obj)) {
  466. // Return monad abstract if it is monad type.
  467. return ToMonadAbstract(type_obj);
  468. } else {
  469. // When sparse enabled, the undetermined might be raised and eliminated in opt passes
  470. auto context = MsContext::GetInstance();
  471. MS_EXCEPTION_IF_NULL(context);
  472. bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
  473. if (enable_sparse) {
  474. return std::make_shared<abstract::AbstractUndetermined>();
  475. }
  476. MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
  477. }
  478. }
  479. bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
  480. const std::shared_ptr<py::object> &ret_val) {
  481. if (output->isa<ValueNode>()) {
  482. MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
  483. ValuePtr value = GetValueNode(output);
  484. *ret_val = ValueToPyData(value);
  485. return true;
  486. }
  487. // Adapter will transform values in __init__() and construct() to parameters, this could cause
  488. // inputs (a.k.a args in current function) size less than parameters'.
  489. if (output->isa<Parameter>()) {
  490. MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
  491. // Find the right parameter as ret_val.
  492. auto func_graph = output->func_graph();
  493. MS_EXCEPTION_IF_NULL(func_graph);
  494. auto params = func_graph->parameters();
  495. if ((args.size() + func_graph->hyper_param_count()) != params.size()) {
  496. MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count()
  497. << " not equal to graph input size " << params.size() << ", let graph to be executed.";
  498. }
  499. auto it = std::find(params.begin(), params.end(), output);
  500. if (it == params.end()) {
  501. MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
  502. }
  503. size_t index = it - params.cbegin();
  504. if (index >= args.size() + func_graph->hyper_param_count()) {
  505. MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
  506. << " add Parameter count " << func_graph->hyper_param_count() << ".";
  507. }
  508. if (index < args.size()) {
  509. *ret_val = args[index];
  510. } else {
  511. auto param = dyn_cast<Parameter>(params[index]);
  512. MS_EXCEPTION_IF_NULL(param);
  513. if (!param->has_default()) {
  514. MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
  515. }
  516. auto tensor = param->default_param();
  517. *ret_val = py::cast(tensor);
  518. }
  519. return true;
  520. }
  521. return false;
  522. }
  523. } // namespace mindspore