You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils_py.cc 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/convert_utils_py.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <algorithm>
  21. #include <list>
  22. #include <utility>
  23. #include <cfloat>
  24. #include "abstract/abstract_value.h"
  25. #include "pipeline/jit/parse/parse.h"
  26. #include "pipeline/jit/parse/parse_base.h"
  27. #include "ir/value.h"
  28. #include "ir/tensor.h"
  29. #include "ir/param_info.h"
  30. #include "pybind_api/ir/base_ref_py.h"
  31. #include "utils/ms_context.h"
  32. namespace mindspore {
  33. py::object BuiltinsToPyData(const Any &value);
  34. py::object BuiltinsToPyData(const BaseRef &value);
  35. py::object VectorToPyData(const Any &value);
  36. py::object VectorRefToPyData(const VectorRef &value);
  37. py::object TensorToPyData(const tensor::TensorPtr &tensor) {
  38. MS_EXCEPTION_IF_NULL(tensor);
  39. if (tensor->NeedWait()) {
  40. py::gil_scoped_release release;
  41. tensor->Wait();
  42. }
  43. py::tuple v(1);
  44. v[0] = tensor;
  45. return v[0];
  46. }
  47. py::object ValuePtrToPyData(const ValuePtr &value) {
  48. if (value == nullptr) {
  49. MS_LOG(EXCEPTION) << "value is null";
  50. }
  51. py::object ret;
  52. if (value->isa<Int8Imm>()) {
  53. MS_LOG(DEBUG) << "int8";
  54. py::int_ v = value->cast<Int8ImmPtr>()->value();
  55. ret = v;
  56. } else if (value->isa<Int16Imm>()) {
  57. MS_LOG(DEBUG) << "int16";
  58. py::int_ v = value->cast<Int16ImmPtr>()->value();
  59. ret = v;
  60. } else if (value->isa<Int32Imm>()) {
  61. MS_LOG(DEBUG) << "int32";
  62. py::int_ v = value->cast<Int32ImmPtr>()->value();
  63. ret = v;
  64. } else if (value->isa<Int64Imm>()) {
  65. MS_LOG(DEBUG) << "int64";
  66. py::int_ v = value->cast<Int64ImmPtr>()->value();
  67. ret = v;
  68. } else if (value->isa<UInt8Imm>()) {
  69. MS_LOG(DEBUG) << "uint8";
  70. py::int_ v = value->cast<UInt8ImmPtr>()->value();
  71. ret = v;
  72. } else if (value->isa<UInt16Imm>()) {
  73. MS_LOG(DEBUG) << "uint16";
  74. py::int_ v = value->cast<UInt16ImmPtr>()->value();
  75. ret = v;
  76. } else if (value->isa<UInt32Imm>()) {
  77. MS_LOG(DEBUG) << "uint32";
  78. py::int_ v = value->cast<UInt32ImmPtr>()->value();
  79. ret = v;
  80. } else if (value->isa<UInt64Imm>()) {
  81. MS_LOG(DEBUG) << "uint64";
  82. py::int_ v = value->cast<UInt64ImmPtr>()->value();
  83. ret = v;
  84. } else if (value->isa<BoolImm>()) {
  85. MS_LOG(DEBUG) << "bool";
  86. py::bool_ v = value->cast<BoolImmPtr>()->value();
  87. ret = v;
  88. } else if (value->isa<FP64Imm>()) {
  89. MS_LOG(DEBUG) << "double";
  90. py::float_ v = value->cast<FP64ImmPtr>()->value();
  91. ret = v;
  92. } else if (value->isa<FP32Imm>()) {
  93. MS_LOG(DEBUG) << "float";
  94. py::float_ v = value->cast<FP32ImmPtr>()->value();
  95. ret = v;
  96. } else if (value->isa<StringImm>()) {
  97. MS_LOG(DEBUG) << "String";
  98. py::str v = value->cast<StringImmPtr>()->value();
  99. ret = v;
  100. } else if (value->isa<tensor::Tensor>()) {
  101. MS_LOG(DEBUG) << "tensor";
  102. auto tensor_ptr = value->cast<tensor::TensorPtr>();
  103. ret = TensorToPyData(tensor_ptr);
  104. } else if (value->isa<tensor::MetaTensor>()) {
  105. MS_LOG(DEBUG) << "MetaTensor";
  106. py::tuple v(1);
  107. v[0] = value->cast<tensor::MetaTensorPtr>();
  108. ret = v[0];
  109. } else if (value->isa<RefKey>()) {
  110. MS_LOG(DEBUG) << "RefKey";
  111. py::tuple v(1);
  112. v[0] = value->cast<RefKeyPtr>();
  113. ret = v[0];
  114. } else if (value->isa<ValueTuple>()) {
  115. MS_LOG(DEBUG) << "tuple";
  116. auto value_tuple = value->cast<ValueTuplePtr>()->value();
  117. py::tuple rets(value_tuple.size());
  118. size_t i = 0;
  119. for (auto &v : value_tuple) {
  120. rets[i] = ValuePtrToPyData(v);
  121. i++;
  122. }
  123. ret = rets;
  124. } else if (value->isa<ValueList>()) {
  125. MS_LOG(DEBUG) << "list";
  126. auto value_list = value->cast<ValueListPtr>()->value();
  127. py::list rets(value_list.size());
  128. size_t i = 0;
  129. for (auto &v : value_list) {
  130. rets[i] = ValuePtrToPyData(v);
  131. i++;
  132. }
  133. ret = rets;
  134. } else if (value->isa<Ellipsis>()) {
  135. ret = py::ellipsis();
  136. } else if (value->isa<ValueSlice>()) {
  137. auto slice = value->cast<ValueSlicePtr>();
  138. auto start = ValuePtrToPyData(slice->start());
  139. auto end = ValuePtrToPyData(slice->stop());
  140. auto step = ValuePtrToPyData(slice->step());
  141. ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
  142. step);
  143. } else if (value->isa<Type>()) {
  144. py::tuple v(1);
  145. v[0] = value->cast<TypePtr>();
  146. ret = v[0];
  147. } else if (value->isa<AnyValue>()) {
  148. ret = py::none();
  149. } else if (value->isa<None>()) {
  150. ret = py::none();
  151. } else if (value->isa<FuncGraph>()) {
  152. // FuncGraph is not used in the backend, return None
  153. ret = py::none();
  154. } else {
  155. MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
  156. }
  157. return ret;
  158. }
  159. py::object AnyToPyData(const Any &value) {
  160. py::object ret;
  161. MS_LOG(DEBUG) << "AnyToPyData " << value.GetString();
  162. if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) {
  163. ret = BuiltinsToPyData(value);
  164. } else if (value.is<ValuePtr>()) {
  165. MS_LOG(DEBUG) << "ValuePtr";
  166. ValuePtr v = value.cast<ValuePtr>();
  167. ret = ValuePtrToPyData(v);
  168. } else if (value.is<tensor::TensorPtr>()) {
  169. MS_LOG(DEBUG) << "tensor";
  170. auto tensor_ptr = value.cast<tensor::TensorPtr>();
  171. ret = TensorToPyData(tensor_ptr);
  172. } else if (value.is<py::object>()) {
  173. MS_LOG(DEBUG) << "py obj";
  174. ret = value.cast<py::object>();
  175. } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) {
  176. ret = VectorToPyData(value);
  177. } else if (value.is<std::list<Any>>()) {
  178. MS_LOG(DEBUG) << "list_any";
  179. auto value_list = value.cast<std::list<Any>>();
  180. py::list rets = py::list();
  181. for (auto &v : value_list) {
  182. rets.append(AnyToPyData(v));
  183. }
  184. ret = rets;
  185. } else if (value.is<std::vector<Any>>()) {
  186. auto value_list = value.cast<std::vector<Any>>();
  187. py::tuple rets(value_list.size());
  188. for (size_t i = 0; i < value_list.size(); i++) {
  189. rets[i] = AnyToPyData(value_list[i]);
  190. }
  191. ret = rets;
  192. } else if (value.is<TypePtr>()) {
  193. py::tuple v(1);
  194. v[0] = value.cast<TypePtr>();
  195. ret = v[0];
  196. } else {
  197. MS_LOG(EXCEPTION) << "value is not support type";
  198. }
  199. return ret;
  200. }
  201. py::object BaseRefToPyData(const BaseRef &value) {
  202. py::object ret;
  203. MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString();
  204. if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
  205. ret = BuiltinsToPyData(value);
  206. } else if (utils::isa<ValuePtr>(value)) {
  207. MS_LOG(DEBUG) << "ValuePtr";
  208. ValuePtr v = utils::cast<ValuePtr>(value);
  209. ret = ValuePtrToPyData(v);
  210. } else if (utils::isa<tensor::TensorPtr>(value)) {
  211. MS_LOG(DEBUG) << "tensor";
  212. auto tensor_ptr = utils::cast<tensor::TensorPtr>(value);
  213. ret = TensorToPyData(tensor_ptr);
  214. } else if (utils::isa<PyObjectRef>(value)) {
  215. MS_LOG(DEBUG) << "py obj";
  216. PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
  217. ret = py_ref.object_;
  218. } else if (utils::isa<VectorRef>(value)) {
  219. auto vec_ref = utils::cast<VectorRef>(value);
  220. ret = VectorRefToPyData(vec_ref);
  221. } else if (utils::isa<TypePtr>(value)) {
  222. py::tuple v(1);
  223. v[0] = utils::cast<TypePtr>(value);
  224. ret = v[0];
  225. } else {
  226. MS_LOG(EXCEPTION) << "value is not support type";
  227. }
  228. return ret;
  229. }
  230. py::object BuiltinsToPyData(const Any &value) {
  231. if (value.is<int>()) {
  232. MS_LOG(DEBUG) << "int";
  233. py::int_ ret = value.cast<int>();
  234. return std::move(ret);
  235. } else if (value.is<float>()) {
  236. MS_LOG(DEBUG) << "float";
  237. py::float_ ret = value.cast<float>();
  238. return std::move(ret);
  239. } else if (value.is<double>()) {
  240. MS_LOG(DEBUG) << "double";
  241. py::float_ ret = value.cast<double>();
  242. return std::move(ret);
  243. } else {
  244. MS_LOG(DEBUG) << "bool";
  245. py::bool_ ret = value.cast<bool>();
  246. return std::move(ret);
  247. }
  248. }
  249. py::object BuiltinsToPyData(const BaseRef &value) {
  250. if (utils::isa<int>(value)) {
  251. MS_LOG(DEBUG) << "int";
  252. py::int_ ret = utils::cast<int>(value);
  253. return std::move(ret);
  254. } else if (utils::isa<float>(value)) {
  255. MS_LOG(DEBUG) << "float";
  256. py::float_ ret = utils::cast<float>(value);
  257. return std::move(ret);
  258. } else if (utils::isa<double>(value)) {
  259. MS_LOG(DEBUG) << "double";
  260. py::float_ ret = utils::cast<double>(value);
  261. return std::move(ret);
  262. } else {
  263. MS_LOG(DEBUG) << "bool";
  264. py::bool_ ret = utils::cast<bool>(value);
  265. return std::move(ret);
  266. }
  267. }
  268. py::object VectorToPyData(const Any &value) {
  269. py::object ret;
  270. if (value.is<std::vector<tensor::TensorPtr>>()) {
  271. MS_LOG(DEBUG) << "vector_tensor";
  272. std::vector<tensor::TensorPtr> outputs;
  273. outputs = value.cast<std::vector<tensor::TensorPtr>>();
  274. py::tuple tensor_tuple(outputs.size());
  275. for (std::size_t i = 0; i < outputs.size(); ++i) {
  276. tensor_tuple[i] = *outputs[i];
  277. }
  278. ret = tensor_tuple;
  279. } else {
  280. MS_LOG(DEBUG) << "vector_any";
  281. auto value_list = value.cast<std::vector<Any>>();
  282. py::tuple any_tuple = py::tuple(value_list.size());
  283. size_t i = 0;
  284. for (auto &v : value_list) {
  285. any_tuple[i] = AnyToPyData(v);
  286. i++;
  287. }
  288. ret = any_tuple;
  289. }
  290. return ret;
  291. }
  292. py::object VectorRefToPyData(const VectorRef &value_list) {
  293. py::object ret;
  294. MS_LOG(DEBUG) << "vector_ref";
  295. size_t value_size = value_list.size();
  296. auto ref_tuple = py::tuple(value_size);
  297. for (size_t i = 0; i < value_size; i++) {
  298. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  299. }
  300. ret = ref_tuple;
  301. return ret;
  302. }
  303. AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
  304. const py::object &min_shape, const py::object &max_shape) {
  305. if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) {
  306. auto ret_vec = shape_obj.cast<ShapeVector>();
  307. auto ret_dtype = type_obj.cast<TypePtr>();
  308. MS_EXCEPTION_IF_NULL(ret_dtype);
  309. // if the size of shape list is empty, return an scalar abstract
  310. if (ret_vec.empty() && (!ret_dtype->isa<TensorType>())) {
  311. abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
  312. return abs_scalar;
  313. }
  314. AbstractBasePtr tensor = nullptr;
  315. ShapeVector min_shape_vec;
  316. ShapeVector max_shape_vec;
  317. if (!min_shape.is_none()) {
  318. min_shape_vec = min_shape.cast<ShapeVector>();
  319. }
  320. if (!max_shape.is_none()) {
  321. max_shape_vec = max_shape.cast<ShapeVector>();
  322. }
  323. auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
  324. if (ret_dtype->isa<TensorType>()) {
  325. auto tensor_type = type_obj.cast<TensorTypePtr>();
  326. MS_EXCEPTION_IF_NULL(tensor_type);
  327. auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element());
  328. tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
  329. } else {
  330. auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
  331. tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
  332. }
  333. return tensor;
  334. } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
  335. py::tuple shape_tuple = shape_obj.cast<py::tuple>();
  336. py::tuple typeid_tuple = type_obj.cast<py::tuple>();
  337. AbstractBasePtrList ptr_list;
  338. for (size_t it = 0; it < shape_tuple.size(); ++it) {
  339. auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]);
  340. ptr_list.push_back(tensor_it);
  341. }
  342. auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
  343. return tuple;
  344. } else if (py::isinstance<py::list>(shape_obj) && py::isinstance<py::list>(type_obj)) {
  345. py::list shape_list = shape_obj.cast<py::list>();
  346. py::list typeid_list = type_obj.cast<py::list>();
  347. AbstractBasePtrList ptr_list;
  348. for (size_t it = 0; it < shape_list.size(); ++it) {
  349. auto tensor_it = PyListDtype2AbstractTensor(shape_list[it], typeid_list[it]);
  350. ptr_list.push_back(tensor_it);
  351. }
  352. auto list = std::make_shared<abstract::AbstractList>(ptr_list);
  353. return list;
  354. } else if (shape_obj.is_none() && type_obj.is_none()) {
  355. // AbstractNone indicates there is no output for this CNode node.
  356. auto abstract_none = std::make_shared<abstract::AbstractNone>();
  357. return abstract_none;
  358. } else {
  359. // When sparse enabled, the undetermined might be raised and eliminated in opt passes
  360. auto context = MsContext::GetInstance();
  361. MS_EXCEPTION_IF_NULL(context);
  362. bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
  363. if (enable_sparse) {
  364. return std::make_shared<abstract::AbstractUndetermined>();
  365. }
  366. MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
  367. }
  368. }
  369. bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
  370. const std::shared_ptr<py::object> &ret_val) {
  371. if (output->isa<ValueNode>()) {
  372. MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
  373. ValuePtr value = GetValueNode(output);
  374. *ret_val = ValuePtrToPyData(value);
  375. return true;
  376. }
  377. // Adapter will transform values in __init__() and construct() to parameters, this could cause
  378. // inputs (a.k.a args in current function) size less than parameters'.
  379. if (output->isa<Parameter>()) {
  380. MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
  381. // Find the right parameter as ret_val.
  382. auto func_graph = output->func_graph();
  383. MS_EXCEPTION_IF_NULL(func_graph);
  384. auto params = func_graph->parameters();
  385. if ((args.size() + func_graph->hyper_param_count()) != params.size()) {
  386. MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count()
  387. << " not equal to graph input size " << params.size() << ", let graph to be executed.";
  388. }
  389. auto it = std::find(params.begin(), params.end(), output);
  390. if (it == params.end()) {
  391. MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
  392. }
  393. size_t index = it - params.cbegin();
  394. if (index >= args.size() + func_graph->hyper_param_count()) {
  395. MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
  396. << " add Parameter count " << func_graph->hyper_param_count() << ".";
  397. }
  398. if (index < args.size()) {
  399. *ret_val = args[index];
  400. } else {
  401. auto param = dyn_cast<Parameter>(params[index]);
  402. MS_EXCEPTION_IF_NULL(param);
  403. if (!param->has_default()) {
  404. MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
  405. }
  406. auto tensor = param->default_param();
  407. *ret_val = py::cast(tensor);
  408. }
  409. return true;
  410. }
  411. return false;
  412. }
  413. } // namespace mindspore