You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/convert_utils.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <algorithm>
  21. #include <list>
  22. #include <utility>
  23. #include <cfloat>
  24. #include "pybind11/pybind11.h"
  25. #include "ir/meta_tensor.h"
  26. #include "pipeline/parse/parse.h"
  27. #include "pipeline/parse/parse_base.h"
  28. #include "ir/value.h"
  29. namespace mindspore {
  30. py::object BuiltinsToPyData(const Any &value);
  31. py::object BuiltinsToPyData(const BaseRef &value);
  32. py::object VectorToPyData(const Any &value);
  33. py::object VectorRefToPyData(const VectorRef &value);
  34. py::object ValuePtrToPyData(const ValuePtr &value) {
  35. if (value == nullptr) {
  36. MS_LOG(EXCEPTION) << "value is null";
  37. }
  38. py::object ret;
  39. if (value->isa<Int32Imm>()) {
  40. MS_LOG(DEBUG) << "int";
  41. py::int_ v = value->cast<Int32ImmPtr>()->value();
  42. ret = v;
  43. } else if (value->isa<UInt64Imm>()) {
  44. MS_LOG(DEBUG) << "uint64";
  45. py::int_ v = value->cast<UInt64ImmPtr>()->value();
  46. ret = v;
  47. } else if (value->isa<BoolImm>()) {
  48. MS_LOG(DEBUG) << "bool";
  49. py::bool_ v = value->cast<BoolImmPtr>()->value();
  50. ret = v;
  51. } else if (value->isa<FP64Imm>()) {
  52. MS_LOG(DEBUG) << "double";
  53. py::float_ v = value->cast<FP64ImmPtr>()->value();
  54. ret = v;
  55. } else if (value->isa<FP32Imm>()) {
  56. MS_LOG(DEBUG) << "float";
  57. py::float_ v = value->cast<FP32ImmPtr>()->value();
  58. ret = v;
  59. } else if (value->isa<StringImm>()) {
  60. MS_LOG(DEBUG) << "String";
  61. py::str v = value->cast<StringImmPtr>()->value();
  62. ret = v;
  63. } else if (value->isa<tensor::Tensor>()) {
  64. MS_LOG(DEBUG) << "tensor";
  65. py::tuple v(1);
  66. v[0] = value->cast<tensor::TensorPtr>();
  67. ret = v[0];
  68. } else if (value->isa<RefKey>()) {
  69. MS_LOG(DEBUG) << "RefKey";
  70. py::tuple v(1);
  71. v[0] = value->cast<RefKeyPtr>();
  72. ret = v[0];
  73. } else if (value->isa<ValueTuple>()) {
  74. MS_LOG(DEBUG) << "tuple";
  75. auto value_tuple = value->cast<ValueTuplePtr>()->value();
  76. py::tuple rets(value_tuple.size());
  77. size_t i = 0;
  78. for (auto &v : value_tuple) {
  79. rets[i] = ValuePtrToPyData(v);
  80. i++;
  81. }
  82. ret = rets;
  83. } else if (value->isa<ValueList>()) {
  84. MS_LOG(DEBUG) << "list";
  85. auto value_list = value->cast<ValueListPtr>()->value();
  86. py::list rets(value_list.size());
  87. size_t i = 0;
  88. for (auto &v : value_list) {
  89. rets[i] = ValuePtrToPyData(v);
  90. i++;
  91. }
  92. ret = rets;
  93. } else if (value->isa<EllipsisObj>()) {
  94. ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_ELLIPSIS);
  95. } else if (value->isa<ValueSlice>()) {
  96. auto slice = value->cast<ValueSlicePtr>();
  97. auto start = ValuePtrToPyData(slice->start());
  98. auto end = ValuePtrToPyData(slice->stop());
  99. auto step = ValuePtrToPyData(slice->step());
  100. ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
  101. step);
  102. } else if (value->isa<Type>()) {
  103. py::tuple v(1);
  104. v[0] = value->cast<TypePtr>();
  105. ret = v[0];
  106. } else if (value->isa<AnyValue>()) {
  107. ret = py::none();
  108. } else if (value->isa<None>()) {
  109. ret = py::none();
  110. } else {
  111. MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
  112. }
  113. return ret;
  114. }
  115. py::object AnyToPyData(const Any &value) {
  116. py::object ret;
  117. MS_LOG(DEBUG) << "AnyToPyData " << value.GetString();
  118. if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) {
  119. ret = BuiltinsToPyData(value);
  120. } else if (value.is<ValuePtr>()) {
  121. MS_LOG(DEBUG) << "ValuePtr";
  122. ValuePtr v = value.cast<ValuePtr>();
  123. ret = ValuePtrToPyData(v);
  124. } else if (value.is<tensor::TensorPtr>()) {
  125. MS_LOG(DEBUG) << "tensor";
  126. py::tuple v(1);
  127. v[0] = value.cast<tensor::TensorPtr>();
  128. ret = v[0];
  129. } else if (value.is<py::object>()) {
  130. MS_LOG(DEBUG) << "py obj";
  131. ret = value.cast<py::object>();
  132. } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) {
  133. ret = VectorToPyData(value);
  134. } else if (value.is<std::list<Any>>()) {
  135. MS_LOG(DEBUG) << "list_any";
  136. auto value_list = value.cast<std::list<Any>>();
  137. py::list rets = py::list();
  138. for (auto &v : value_list) {
  139. rets.append(AnyToPyData(v));
  140. }
  141. ret = rets;
  142. } else if (value.is<std::vector<Any>>()) {
  143. auto value_list = value.cast<std::vector<Any>>();
  144. py::tuple rets(value_list.size());
  145. for (size_t i = 0; i < value_list.size(); i++) {
  146. rets[i] = AnyToPyData(value_list[i]);
  147. }
  148. ret = rets;
  149. } else if (value.is<TypePtr>()) {
  150. py::tuple v(1);
  151. v[0] = value.cast<TypePtr>();
  152. ret = v[0];
  153. } else {
  154. MS_LOG(EXCEPTION) << "value is not support type";
  155. }
  156. return ret;
  157. }
  158. py::object BaseRefToPyData(const BaseRef &value) {
  159. py::object ret;
  160. MS_LOG(DEBUG) << "BaseRefToPyData " << common::SafeCStr(value.ToString());
  161. if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
  162. ret = BuiltinsToPyData(value);
  163. } else if (utils::isa<ValuePtr>(value)) {
  164. MS_LOG(DEBUG) << "ValuePtr";
  165. ValuePtr v = utils::cast<ValuePtr>(value);
  166. ret = ValuePtrToPyData(v);
  167. } else if (utils::isa<tensor::TensorPtr>(value)) {
  168. MS_LOG(DEBUG) << "tensor";
  169. py::tuple v(1);
  170. v[0] = utils::cast<tensor::TensorPtr>(value);
  171. ret = v[0];
  172. } else if (utils::isa<PyObjectRef>(value)) {
  173. MS_LOG(DEBUG) << "py obj";
  174. PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
  175. ret = py_ref.object_;
  176. } else if (utils::isa<VectorRef>(value)) {
  177. auto vec_ref = utils::cast<VectorRef>(value);
  178. ret = VectorRefToPyData(vec_ref);
  179. } else if (utils::isa<TypePtr>(value)) {
  180. py::tuple v(1);
  181. v[0] = utils::cast<TypePtr>(value);
  182. ret = v[0];
  183. } else {
  184. MS_LOG(EXCEPTION) << "value is not support type";
  185. }
  186. return ret;
  187. }
  188. bool ValueToBool(const ValuePtr &v, bool *value) {
  189. MS_EXCEPTION_IF_NULL(v);
  190. if (v->isa<BoolImm>()) {
  191. *value = v->cast<BoolImmPtr>()->value();
  192. } else if (v->isa<Int32Imm>()) {
  193. *value = v->cast<Int32ImmPtr>()->value() == 0 ? false : true;
  194. } else if (v->isa<UInt32Imm>()) {
  195. *value = v->cast<UInt32ImmPtr>()->value() == 0 ? false : true;
  196. } else if (v->isa<FP32Imm>()) {
  197. *value = v->cast<FP32ImmPtr>()->value() == 0 ? false : true;
  198. } else if (v->isa<FP64Imm>()) {
  199. *value = v->cast<FP64ImmPtr>()->value() == 0 ? false : true;
  200. } else if (v->isa<tensor::Tensor>()) {
  201. auto tensor = v->cast<tensor::TensorPtr>();
  202. MS_EXCEPTION_IF_NULL(tensor);
  203. bool *tensor_data = static_cast<bool *>(tensor->data_c());
  204. // maybe need to support if tensor is a bool array
  205. auto vb = tensor_data[0];
  206. *value = vb;
  207. } else {
  208. MS_LOG(WARNING) << "value is not supported to cast to be bool";
  209. return false;
  210. }
  211. return true;
  212. }
  213. bool BaseRefToBool(const BaseRef &v, bool *value) {
  214. if (utils::isa<ValuePtr>(v)) {
  215. return ValueToBool(utils::cast<ValuePtr>(v), value);
  216. } else if (utils::isa<bool>(v)) {
  217. auto vb = utils::cast<bool>(v);
  218. if (vb == true) {
  219. *value = true;
  220. } else {
  221. *value = false;
  222. }
  223. } else if (utils::isa<int>(v)) {
  224. auto vb = utils::cast<int>(v);
  225. if (vb == 0) {
  226. *value = false;
  227. } else {
  228. *value = true;
  229. }
  230. } else if (utils::isa<unsigned int>(v)) {
  231. auto vb = utils::cast<unsigned int>(v);
  232. if (vb == 0) {
  233. *value = false;
  234. } else {
  235. *value = true;
  236. }
  237. } else if (utils::isa<float>(v)) {
  238. auto vb = utils::cast<float>(v);
  239. if (vb >= -FLT_EPSILON && vb <= FLT_EPSILON) {
  240. *value = false;
  241. } else {
  242. *value = true;
  243. }
  244. } else if (utils::isa<double>(v)) {
  245. auto vb = utils::cast<double>(v);
  246. if (vb >= -DBL_EPSILON && vb <= DBL_EPSILON) {
  247. *value = false;
  248. } else {
  249. *value = true;
  250. }
  251. } else {
  252. MS_LOG(DEBUG) << "value is not supported to cast to be bool";
  253. return false;
  254. }
  255. return true;
  256. }
  257. py::object BuiltinsToPyData(const Any &value) {
  258. if (value.is<int>()) {
  259. MS_LOG(DEBUG) << "int";
  260. py::int_ ret = value.cast<int>();
  261. return std::move(ret);
  262. } else if (value.is<float>()) {
  263. MS_LOG(DEBUG) << "float";
  264. py::float_ ret = value.cast<float>();
  265. return std::move(ret);
  266. } else if (value.is<double>()) {
  267. MS_LOG(DEBUG) << "double";
  268. py::float_ ret = value.cast<double>();
  269. return std::move(ret);
  270. } else {
  271. MS_LOG(DEBUG) << "bool";
  272. py::bool_ ret = value.cast<bool>();
  273. return std::move(ret);
  274. }
  275. }
  276. py::object BuiltinsToPyData(const BaseRef &value) {
  277. if (utils::isa<int>(value)) {
  278. MS_LOG(DEBUG) << "int";
  279. py::int_ ret = utils::cast<int>(value);
  280. return std::move(ret);
  281. } else if (utils::isa<float>(value)) {
  282. MS_LOG(DEBUG) << "float";
  283. py::float_ ret = utils::cast<float>(value);
  284. return std::move(ret);
  285. } else if (utils::isa<double>(value)) {
  286. MS_LOG(DEBUG) << "double";
  287. py::float_ ret = utils::cast<double>(value);
  288. return std::move(ret);
  289. } else {
  290. MS_LOG(DEBUG) << "bool";
  291. py::bool_ ret = utils::cast<bool>(value);
  292. return std::move(ret);
  293. }
  294. }
  295. py::object VectorToPyData(const Any &value) {
  296. py::object ret;
  297. if (value.is<std::vector<tensor::TensorPtr>>()) {
  298. MS_LOG(DEBUG) << "vector_tensor";
  299. std::vector<tensor::TensorPtr> outputs;
  300. outputs = value.cast<std::vector<tensor::TensorPtr>>();
  301. py::tuple tensor_tuple(outputs.size());
  302. for (std::size_t i = 0; i < outputs.size(); ++i) {
  303. tensor_tuple[i] = *outputs[i];
  304. }
  305. ret = tensor_tuple;
  306. } else {
  307. MS_LOG(DEBUG) << "vector_any";
  308. auto value_list = value.cast<std::vector<Any>>();
  309. py::tuple any_tuple = py::tuple(value_list.size());
  310. size_t i = 0;
  311. for (auto &v : value_list) {
  312. any_tuple[i] = AnyToPyData(v);
  313. i++;
  314. }
  315. ret = any_tuple;
  316. }
  317. return ret;
  318. }
  319. py::object VectorRefToPyData(const VectorRef &value_list) {
  320. py::object ret;
  321. MS_LOG(DEBUG) << "vector_ref";
  322. size_t value_size = value_list.size();
  323. auto ref_tuple = py::tuple(value_size);
  324. for (size_t i = 0; i < value_size; i++) {
  325. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  326. }
  327. ret = ref_tuple;
  328. return ret;
  329. }
  330. AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj) {
  331. if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) &&
  332. py::hasattr(type_obj, PYTHON_DTYPE_FLAG)) {
  333. auto ret_vec = shape_obj.cast<std::vector<int>>();
  334. auto ret_dtype = type_obj.cast<TypePtr>();
  335. MS_EXCEPTION_IF_NULL(ret_dtype);
  336. // if the size of shape list is empty, return an scalar abstract
  337. if (ret_vec.empty() && (!ret_dtype->isa<TensorType>())) {
  338. abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
  339. return abs_scalar;
  340. }
  341. AbstractBasePtr tensor = nullptr;
  342. if (ret_dtype->isa<TensorType>()) {
  343. auto tensor_type = type_obj.cast<TensorTypePtr>();
  344. MS_EXCEPTION_IF_NULL(tensor_type);
  345. tensor = std::make_shared<abstract::AbstractTensor>(tensor_type->element(), ret_vec);
  346. } else {
  347. tensor = std::make_shared<abstract::AbstractTensor>(ret_dtype, ret_vec);
  348. }
  349. return tensor;
  350. } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
  351. py::tuple shape_tuple = shape_obj.cast<py::tuple>();
  352. py::tuple typeid_tuple = type_obj.cast<py::tuple>();
  353. AbstractBasePtrList ptr_list;
  354. for (size_t it = 0; it < shape_tuple.size(); ++it) {
  355. auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]);
  356. ptr_list.push_back(tensor_it);
  357. }
  358. auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
  359. return tuple;
  360. } else if (shape_obj.is_none() && type_obj.is_none()) {
  361. // AbstractNone indicates there is no output for this CNode node.
  362. auto abstract_none = std::make_shared<abstract::AbstractNone>();
  363. return abstract_none;
  364. } else {
  365. MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
  366. }
  367. }
  368. bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
  369. const std::shared_ptr<py::object> &ret_val) {
  370. if (output->isa<ValueNode>()) {
  371. MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
  372. ValuePtr value = GetValueNode(output);
  373. *ret_val = ValuePtrToPyData(value);
  374. return true;
  375. }
  376. // Adapter will transform values in __init__() and construct() to parameters, this could cause
  377. // inputs (a.k.a args in current function) size less than parameters'.
  378. if (output->isa<Parameter>()) {
  379. MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
  380. if (args.empty()) {
  381. MS_LOG(EXCEPTION) << "Inputs size is 0, let graph to be executed.";
  382. }
  383. // Find the right parameter as ret_val.
  384. auto func_graph = output->func_graph();
  385. MS_EXCEPTION_IF_NULL(func_graph);
  386. auto params = func_graph->parameters();
  387. if (params.empty()) {
  388. MS_EXCEPTION(UnknownError) << "Graph's parameters size is 0";
  389. }
  390. if (args.size() != params.size()) {
  391. MS_LOG(EXCEPTION) << "Input size " << args.size() << " not equal to params size " << params.size()
  392. << ", let graph to be executed.";
  393. }
  394. auto it = std::find(params.begin(), params.end(), output);
  395. if (it == params.end()) {
  396. MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
  397. }
  398. size_t index = it - params.cbegin();
  399. if (index >= args.size()) {
  400. MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size() << ".";
  401. }
  402. *ret_val = args[index];
  403. return true;
  404. }
  405. return false;
  406. }
  407. } // namespace mindspore