You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils_py.cc 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/convert_utils_py.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <algorithm>
  21. #include <list>
  22. #include <utility>
  23. #include <cfloat>
  24. #include "abstract/abstract_value.h"
  25. #include "abstract/utils.h"
  26. #include "pipeline/jit/parse/parse.h"
  27. #include "pipeline/jit/parse/parse_base.h"
  28. #include "pipeline/jit/parse/resolve.h"
  29. #include "ir/value.h"
  30. #include "ir/tensor.h"
  31. #include "ir/param_info.h"
  32. #include "pybind_api/ir/base_ref_py.h"
  33. #include "utils/ms_context.h"
  34. namespace mindspore {
  35. py::object BuiltinsToPyData(const Any &value);
  36. py::object BuiltinsToPyData(const BaseRef &value);
  37. py::object VectorToPyData(const Any &value);
  38. py::object VectorRefToPyData(const VectorRef &value);
  39. py::object TensorToPyData(const tensor::TensorPtr &tensor) {
  40. MS_EXCEPTION_IF_NULL(tensor);
  41. if (tensor->NeedWait()) {
  42. py::gil_scoped_release release;
  43. tensor->Wait();
  44. }
  45. py::tuple v(1);
  46. v[0] = tensor;
  47. return v[0];
  48. }
  49. py::object ScalarPtrToPyData(const ScalarPtr &value) {
  50. py::int_ int_v;
  51. py::float_ float_v;
  52. py::bool_ bool_v;
  53. TypeId scalar_type = value->type()->type_id();
  54. switch (scalar_type) {
  55. case kNumberTypeUInt8:
  56. MS_LOG(DEBUG) << "uint8";
  57. int_v = value->cast<UInt8ImmPtr>()->value();
  58. return std::move(int_v);
  59. case kNumberTypeUInt16:
  60. MS_LOG(DEBUG) << "uint16";
  61. int_v = value->cast<UInt16ImmPtr>()->value();
  62. return std::move(int_v);
  63. case kNumberTypeUInt32:
  64. MS_LOG(DEBUG) << "uint32";
  65. int_v = value->cast<UInt32ImmPtr>()->value();
  66. return std::move(int_v);
  67. case kNumberTypeUInt64:
  68. MS_LOG(DEBUG) << "uint64";
  69. int_v = value->cast<UInt64ImmPtr>()->value();
  70. return std::move(int_v);
  71. case kNumberTypeInt8:
  72. MS_LOG(DEBUG) << "int8";
  73. int_v = value->cast<Int8ImmPtr>()->value();
  74. return std::move(int_v);
  75. case kNumberTypeInt16:
  76. MS_LOG(DEBUG) << "int16";
  77. int_v = value->cast<Int16ImmPtr>()->value();
  78. return std::move(int_v);
  79. case kNumberTypeInt32:
  80. MS_LOG(DEBUG) << "int32";
  81. int_v = value->cast<Int32ImmPtr>()->value();
  82. return std::move(int_v);
  83. case kNumberTypeInt64:
  84. MS_LOG(DEBUG) << "int64";
  85. int_v = value->cast<Int64ImmPtr>()->value();
  86. return std::move(int_v);
  87. case kNumberTypeFloat32:
  88. MS_LOG(DEBUG) << "float";
  89. float_v = value->cast<FP32ImmPtr>()->value();
  90. return std::move(float_v);
  91. case kNumberTypeFloat64:
  92. MS_LOG(DEBUG) << "double";
  93. float_v = value->cast<FP64ImmPtr>()->value();
  94. return std::move(float_v);
  95. case kNumberTypeBool:
  96. MS_LOG(DEBUG) << "bool";
  97. bool_v = value->cast<BoolImmPtr>()->value();
  98. return std::move(bool_v);
  99. default:
  100. MS_EXCEPTION(TypeError) << "Unsupported scalar converted to py data: " << value->ToString();
  101. }
  102. }
  103. using ConverterFunction = std::function<py::object(const ValuePtr &value)>;
  104. using ValueNameToConverterVector = std::vector<std::pair<const char *, ConverterFunction>>;
  105. // (Value Type Name) -> (Converter Function)
  106. // The converter function is used to convert Value object to Python data object.
  107. static ValueNameToConverterVector value_name_to_converter = {
  108. // Scalar
  109. {typeid(Scalar).name(),
  110. [](const ValuePtr &value) -> py::object { return ScalarPtrToPyData(value->cast<ScalarPtr>()); }},
  111. // Tensor
  112. {typeid(tensor::Tensor).name(),
  113. [](const ValuePtr &value) -> py::object {
  114. auto tensor_ptr = value->cast<tensor::TensorPtr>();
  115. return TensorToPyData(tensor_ptr);
  116. }},
  117. // MetaTenser
  118. {typeid(tensor::MetaTensor).name(),
  119. [](const ValuePtr &value) -> py::object {
  120. py::tuple tuple_container(1);
  121. tuple_container[0] = value->cast<tensor::MetaTensorPtr>();
  122. return tuple_container[0];
  123. }},
  124. // RefKey
  125. {typeid(RefKey).name(),
  126. [](const ValuePtr &value) -> py::object {
  127. py::tuple tuple_container(1);
  128. tuple_container[0] = value->cast<RefKeyPtr>();
  129. return tuple_container[0];
  130. }},
  131. // Type
  132. {typeid(Type).name(),
  133. [](const ValuePtr &value) -> py::object {
  134. py::tuple tuple_container(1);
  135. tuple_container[0] = value->cast<TypePtr>();
  136. return tuple_container[0];
  137. }},
  138. // StringImm
  139. {typeid(StringImm).name(),
  140. [](const ValuePtr &value) -> py::object {
  141. py::str res = value->cast<StringImmPtr>()->value();
  142. return res;
  143. }},
  144. // ValueSequeue
  145. {typeid(ValueSequeue).name(),
  146. [](const ValuePtr &value) -> py::object {
  147. auto value_sequeue = value->cast<ValueSequeuePtr>()->value();
  148. py::tuple res_sequeue(value_sequeue.size());
  149. for (size_t i = 0; i < value_sequeue.size(); i++) {
  150. res_sequeue[i] = ValueToPyData(value_sequeue[i]);
  151. }
  152. if (value->isa<ValueTuple>()) {
  153. return res_sequeue;
  154. }
  155. return res_sequeue.cast<py::list>();
  156. }},
  157. // ValueDictionary
  158. {typeid(ValueDictionary).name(),
  159. [](const ValuePtr &value) -> py::object {
  160. auto value_list = value->cast<ValueDictionaryPtr>()->value();
  161. py::dict res_dict;
  162. for (const auto &value : value_list) {
  163. res_dict[py::str(value.first)] = ValueToPyData(value.second);
  164. }
  165. return res_dict;
  166. }},
  167. // ValueSlice
  168. {typeid(ValueSlice).name(),
  169. [](const ValuePtr &value) -> py::object {
  170. auto slice = value->cast<ValueSlicePtr>();
  171. auto start = ValueToPyData(slice->start());
  172. auto end = ValueToPyData(slice->stop());
  173. auto step = ValueToPyData(slice->step());
  174. return parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
  175. step);
  176. }},
  177. // KeywordArg
  178. {typeid(KeywordArg).name(),
  179. [](const ValuePtr &value) -> py::object {
  180. auto abs_keyword_arg = value->ToAbstract()->cast<abstract::AbstractKeywordArgPtr>();
  181. auto key = abs_keyword_arg->get_key();
  182. auto val = abs_keyword_arg->get_arg()->BuildValue();
  183. auto py_value = ValueToPyData(val);
  184. auto kwargs = py::kwargs();
  185. kwargs[key.c_str()] = py_value;
  186. return kwargs;
  187. }},
  188. // parse::NameSpace
  189. {typeid(parse::NameSpace).name(),
  190. [](const ValuePtr &value) -> py::object {
  191. auto ns = value->cast<parse::NameSpacePtr>();
  192. return ns->module_obj();
  193. }},
  194. // parse::ClassType
  195. {typeid(parse::ClassType).name(),
  196. [](const ValuePtr &value) -> py::object {
  197. auto class_type = value->cast<parse::ClassTypePtr>();
  198. return class_type->obj();
  199. }},
  200. // parse::InterpretedObject
  201. {typeid(parse::InterpretedObject).name(),
  202. [](const ValuePtr &value) -> py::object {
  203. auto interpreted_object = value->cast<parse::InterpretedObjectPtr>();
  204. return interpreted_object->obj();
  205. }},
  206. // None
  207. {typeid(None).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
  208. // AnyValue
  209. {typeid(AnyValue).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
  210. // FuncGraph
  211. {typeid(FuncGraph).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
  212. // Monad
  213. {typeid(Monad).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
  214. // Ellipsis
  215. {typeid(Ellipsis).name(), [](const ValuePtr &value) -> py::object { return py::ellipsis(); }}};
  216. py::object ValueToPyData(const ValuePtr &value) {
  217. if (value == nullptr) {
  218. MS_LOG(EXCEPTION) << "The `value` should not be null";
  219. }
  220. for (auto &iter : value_name_to_converter) {
  221. if (value->IsFromTypeId(Base::GetTypeId(iter.first))) {
  222. return iter.second(value);
  223. }
  224. }
  225. MS_LOG(EXCEPTION) << "Unsupported to convert " << value->ToString() << "[" << value->type_name() << "] to a PyData";
  226. }
  227. py::object AnyToPyData(const Any &value) {
  228. py::object ret;
  229. MS_LOG(DEBUG) << "AnyToPyData " << value.GetString();
  230. if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) {
  231. ret = BuiltinsToPyData(value);
  232. } else if (value.is<ValuePtr>()) {
  233. MS_LOG(DEBUG) << "ValuePtr";
  234. ValuePtr v = value.cast<ValuePtr>();
  235. ret = ValueToPyData(v);
  236. } else if (value.is<tensor::TensorPtr>()) {
  237. MS_LOG(DEBUG) << "tensor";
  238. auto tensor_ptr = value.cast<tensor::TensorPtr>();
  239. ret = TensorToPyData(tensor_ptr);
  240. } else if (value.is<py::object>()) {
  241. MS_LOG(DEBUG) << "py obj";
  242. ret = value.cast<py::object>();
  243. } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) {
  244. ret = VectorToPyData(value);
  245. } else if (value.is<std::list<Any>>()) {
  246. MS_LOG(DEBUG) << "list_any";
  247. auto value_list = value.cast<std::list<Any>>();
  248. py::list rets = py::list();
  249. for (auto &v : value_list) {
  250. rets.append(AnyToPyData(v));
  251. }
  252. ret = rets;
  253. } else if (value.is<std::vector<Any>>()) {
  254. auto value_list = value.cast<std::vector<Any>>();
  255. py::tuple rets(value_list.size());
  256. for (size_t i = 0; i < value_list.size(); i++) {
  257. rets[i] = AnyToPyData(value_list[i]);
  258. }
  259. ret = rets;
  260. } else if (value.is<TypePtr>()) {
  261. py::tuple v(1);
  262. v[0] = value.cast<TypePtr>();
  263. ret = v[0];
  264. } else {
  265. MS_LOG(EXCEPTION) << "value is not support type";
  266. }
  267. return ret;
  268. }
  269. py::object BaseRefToPyData(const BaseRef &value) {
  270. py::object ret;
  271. MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString();
  272. if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
  273. ret = BuiltinsToPyData(value);
  274. } else if (utils::isa<ValuePtr>(value)) {
  275. MS_LOG(DEBUG) << "ValuePtr";
  276. ValuePtr v = utils::cast<ValuePtr>(value);
  277. ret = ValueToPyData(v);
  278. } else if (utils::isa<tensor::TensorPtr>(value)) {
  279. MS_LOG(DEBUG) << "tensor";
  280. auto tensor_ptr = utils::cast<tensor::TensorPtr>(value);
  281. ret = TensorToPyData(tensor_ptr);
  282. } else if (utils::isa<PyObjectRef>(value)) {
  283. MS_LOG(DEBUG) << "py obj";
  284. PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
  285. ret = py_ref.object_;
  286. } else if (utils::isa<VectorRef>(value)) {
  287. auto vec_ref = utils::cast<VectorRef>(value);
  288. ret = VectorRefToPyData(vec_ref);
  289. } else if (utils::isa<TypePtr>(value)) {
  290. py::tuple v(1);
  291. v[0] = utils::cast<TypePtr>(value);
  292. ret = v[0];
  293. } else {
  294. MS_LOG(EXCEPTION) << "value is not support type";
  295. }
  296. return ret;
  297. }
  298. py::object BuiltinsToPyData(const Any &value) {
  299. if (value.is<int>()) {
  300. MS_LOG(DEBUG) << "int";
  301. py::int_ ret = value.cast<int>();
  302. return std::move(ret);
  303. } else if (value.is<float>()) {
  304. MS_LOG(DEBUG) << "float";
  305. py::float_ ret = value.cast<float>();
  306. return std::move(ret);
  307. } else if (value.is<double>()) {
  308. MS_LOG(DEBUG) << "double";
  309. py::float_ ret = value.cast<double>();
  310. return std::move(ret);
  311. } else {
  312. MS_LOG(DEBUG) << "bool";
  313. py::bool_ ret = value.cast<bool>();
  314. return std::move(ret);
  315. }
  316. }
  317. py::object BuiltinsToPyData(const BaseRef &value) {
  318. if (utils::isa<int>(value)) {
  319. MS_LOG(DEBUG) << "int";
  320. py::int_ ret = utils::cast<int>(value);
  321. return std::move(ret);
  322. } else if (utils::isa<float>(value)) {
  323. MS_LOG(DEBUG) << "float";
  324. py::float_ ret = utils::cast<float>(value);
  325. return std::move(ret);
  326. } else if (utils::isa<double>(value)) {
  327. MS_LOG(DEBUG) << "double";
  328. py::float_ ret = utils::cast<double>(value);
  329. return std::move(ret);
  330. } else {
  331. MS_LOG(DEBUG) << "bool";
  332. py::bool_ ret = utils::cast<bool>(value);
  333. return std::move(ret);
  334. }
  335. }
  336. py::object VectorToPyData(const Any &value) {
  337. py::object ret;
  338. if (value.is<std::vector<tensor::TensorPtr>>()) {
  339. MS_LOG(DEBUG) << "vector_tensor";
  340. std::vector<tensor::TensorPtr> outputs;
  341. outputs = value.cast<std::vector<tensor::TensorPtr>>();
  342. py::tuple tensor_tuple(outputs.size());
  343. for (std::size_t i = 0; i < outputs.size(); ++i) {
  344. tensor_tuple[i] = *outputs[i];
  345. }
  346. ret = tensor_tuple;
  347. } else {
  348. MS_LOG(DEBUG) << "vector_any";
  349. auto value_list = value.cast<std::vector<Any>>();
  350. py::tuple any_tuple = py::tuple(value_list.size());
  351. size_t i = 0;
  352. for (auto &v : value_list) {
  353. any_tuple[i] = AnyToPyData(v);
  354. i++;
  355. }
  356. ret = any_tuple;
  357. }
  358. return ret;
  359. }
  360. py::object VectorRefToPyData(const VectorRef &value_list) {
  361. py::object ret;
  362. MS_LOG(DEBUG) << "vector_ref";
  363. size_t value_size = value_list.size();
  364. auto ref_tuple = py::tuple(value_size);
  365. for (size_t i = 0; i < value_size; i++) {
  366. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  367. }
  368. ret = ref_tuple;
  369. return ret;
  370. }
  371. void SetValueRange(const AbstractBasePtr &tensor, const py::object &output) {
  372. if (output.is_none()) {
  373. return;
  374. }
  375. py::object obj_min =
  376. output.contains(py::str(ATTR_MIN_VALUE)) ? (py::object)output[ATTR_MIN_VALUE] : (py::object)py::none();
  377. py::object obj_max =
  378. output.contains(py::str(ATTR_MAX_VALUE)) ? (py::object)output[ATTR_MAX_VALUE] : (py::object)py::none();
  379. if (!obj_min.is_none() && !obj_max.is_none()) {
  380. bool converted = true;
  381. ValuePtr min_value = nullptr;
  382. ValuePtr max_value = nullptr;
  383. converted = parse::ConvertData(obj_min, &min_value);
  384. if (!converted) {
  385. MS_LOG(EXCEPTION) << "Convert shape min value data failed";
  386. }
  387. converted = parse::ConvertData(obj_max, &max_value);
  388. if (!converted) {
  389. MS_LOG(EXCEPTION) << "Convert shape max value data failed";
  390. }
  391. auto abs_tensor = dyn_cast<abstract::AbstractTensor>(tensor);
  392. abs_tensor->set_value_range(min_value, max_value);
  393. }
  394. }
  395. AbstractBasePtr MakePyInferRes2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
  396. const py::object &output) {
  397. auto ret_vec = shape_obj.cast<ShapeVector>();
  398. auto ret_dtype = type_obj.cast<TypePtr>();
  399. ShapeVector min_shape_vec;
  400. ShapeVector max_shape_vec;
  401. if (!output.is_none()) {
  402. py::object min_shape =
  403. output.contains(py::str(ATTR_MIN_SHAPE)) ? (py::object)output[ATTR_MIN_SHAPE] : (py::object)py::none();
  404. py::object max_shape =
  405. output.contains(py::str(ATTR_MAX_SHAPE)) ? (py::object)output[ATTR_MAX_SHAPE] : (py::object)py::none();
  406. if (!min_shape.is_none()) {
  407. min_shape_vec = min_shape.cast<ShapeVector>();
  408. }
  409. if (!max_shape.is_none()) {
  410. max_shape_vec = max_shape.cast<ShapeVector>();
  411. }
  412. }
  413. auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
  414. AbstractBasePtr tensor = MakeAbstractTensor(ret_shape, ret_dtype);
  415. SetValueRange(tensor, output);
  416. return tensor;
  417. }
  418. static bool IsMonadType(const py::object &type_obj) {
  419. if (py::isinstance<Type>(type_obj)) {
  420. auto type = type_obj.cast<Type *>();
  421. return type->isa<MonadType>();
  422. }
  423. return false;
  424. }
  425. static AbstractBasePtr ToMonadAbstract(const py::object &type_obj) {
  426. if (py::isinstance<Type>(type_obj)) {
  427. auto type = type_obj.cast<Type *>();
  428. if (!type->isa<MonadType>()) {
  429. MS_LOG(EXCEPTION) << "Not a monad type object: " << py::str(type_obj);
  430. }
  431. return abstract::MakeMonadAbstract(type->cast<MonadTypePtr>());
  432. }
  433. MS_LOG(EXCEPTION) << "Not a type object: " << py::str(type_obj);
  434. }
  435. AbstractBasePtr MakePyInferRes2Abstract(const py::object &shape_obj, const py::object &type_obj,
  436. const py::object &output) {
  437. if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) {
  438. auto ret_vec = shape_obj.cast<ShapeVector>();
  439. auto ret_dtype = type_obj.cast<TypePtr>();
  440. MS_EXCEPTION_IF_NULL(ret_dtype);
  441. // if the size of shape list is empty, return an scalar abstract
  442. if (ret_vec.empty() && (!ret_dtype->isa<TensorType>())) {
  443. abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
  444. return abs_scalar;
  445. }
  446. return MakePyInferRes2AbstractTensor(shape_obj, type_obj, output);
  447. } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
  448. auto shape_tuple = shape_obj.cast<py::tuple>();
  449. auto typeid_tuple = type_obj.cast<py::tuple>();
  450. AbstractBasePtrList ptr_list;
  451. for (size_t it = 0; it < shape_tuple.size(); ++it) {
  452. auto tensor_it = MakePyInferRes2Abstract(shape_tuple[it], typeid_tuple[it]);
  453. ptr_list.push_back(tensor_it);
  454. }
  455. auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
  456. return tuple;
  457. } else if (py::isinstance<py::list>(shape_obj) && py::isinstance<py::list>(type_obj)) {
  458. auto shape_list = shape_obj.cast<py::list>();
  459. auto typeid_list = type_obj.cast<py::list>();
  460. AbstractBasePtrList ptr_list;
  461. for (size_t it = 0; it < shape_list.size(); ++it) {
  462. auto tensor_it = MakePyInferRes2Abstract(shape_list[it], typeid_list[it]);
  463. ptr_list.push_back(tensor_it);
  464. }
  465. auto list = std::make_shared<abstract::AbstractList>(ptr_list);
  466. return list;
  467. } else if (shape_obj.is_none() && type_obj.is_none()) {
  468. // AbstractNone indicates there is no output for this CNode node.
  469. auto abstract_none = std::make_shared<abstract::AbstractNone>();
  470. return abstract_none;
  471. } else if (IsMonadType(type_obj)) {
  472. // Return monad abstract if it is monad type.
  473. return ToMonadAbstract(type_obj);
  474. } else {
  475. // When sparse enabled, the undetermined might be raised and eliminated in opt passes
  476. auto context = MsContext::GetInstance();
  477. MS_EXCEPTION_IF_NULL(context);
  478. bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
  479. if (enable_sparse) {
  480. return std::make_shared<abstract::AbstractUndetermined>();
  481. }
  482. MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
  483. }
  484. }
  485. bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
  486. const std::shared_ptr<py::object> &ret_val) {
  487. if (output->isa<ValueNode>()) {
  488. MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
  489. ValuePtr value = GetValueNode(output);
  490. *ret_val = ValueToPyData(value);
  491. return true;
  492. }
  493. // Adapter will transform values in __init__() and construct() to parameters, this could cause
  494. // inputs (a.k.a args in current function) size less than parameters'.
  495. if (output->isa<Parameter>()) {
  496. MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
  497. // Find the right parameter as ret_val.
  498. auto func_graph = output->func_graph();
  499. MS_EXCEPTION_IF_NULL(func_graph);
  500. auto params = func_graph->parameters();
  501. if ((args.size() + func_graph->hyper_param_count()) != params.size()) {
  502. MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count()
  503. << " not equal to graph input size " << params.size() << ", let graph to be executed.";
  504. }
  505. auto it = std::find(params.begin(), params.end(), output);
  506. if (it == params.end()) {
  507. MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
  508. }
  509. size_t index = it - params.cbegin();
  510. if (index >= args.size() + func_graph->hyper_param_count()) {
  511. MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
  512. << " add Parameter count " << func_graph->hyper_param_count() << ".";
  513. }
  514. if (index < args.size()) {
  515. *ret_val = args[index];
  516. } else {
  517. auto param = dyn_cast<Parameter>(params[index]);
  518. MS_EXCEPTION_IF_NULL(param);
  519. if (!param->has_default()) {
  520. MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
  521. }
  522. auto tensor = param->default_param();
  523. *ret_val = py::cast(tensor);
  524. }
  525. return true;
  526. }
  527. return false;
  528. }
  529. } // namespace mindspore