You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/convert_utils.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <algorithm>
  21. #include <list>
  22. #include <utility>
  23. #include <cfloat>
  24. #include "pybind11/pybind11.h"
  25. #include "ir/meta_tensor.h"
  26. #include "pipeline/parse/parse.h"
  27. #include "pipeline/parse/parse_base.h"
  28. #include "ir/value.h"
  29. namespace mindspore {
  30. py::object BuiltinsToPyData(const Any &value);
  31. py::object BuiltinsToPyData(const BaseRef &value);
  32. py::object VectorToPyData(const Any &value);
  33. py::object VectorRefToPyData(const VectorRef &value);
  34. py::object ValuePtrToPyData(const ValuePtr &value) {
  35. if (value == nullptr) {
  36. MS_LOG(EXCEPTION) << "value is null";
  37. }
  38. py::object ret;
  39. if (value->isa<Int32Imm>()) {
  40. MS_LOG(DEBUG) << "int";
  41. py::int_ v = value->cast<Int32ImmPtr>()->value();
  42. ret = v;
  43. } else if (value->isa<UInt64Imm>()) {
  44. MS_LOG(DEBUG) << "uint64";
  45. py::int_ v = value->cast<UInt64ImmPtr>()->value();
  46. ret = v;
  47. } else if (value->isa<BoolImm>()) {
  48. MS_LOG(DEBUG) << "bool";
  49. py::bool_ v = value->cast<BoolImmPtr>()->value();
  50. ret = v;
  51. } else if (value->isa<FP64Imm>()) {
  52. MS_LOG(DEBUG) << "double";
  53. py::float_ v = value->cast<FP64ImmPtr>()->value();
  54. ret = v;
  55. } else if (value->isa<FP32Imm>()) {
  56. MS_LOG(DEBUG) << "float";
  57. py::float_ v = value->cast<FP32ImmPtr>()->value();
  58. ret = v;
  59. } else if (value->isa<StringImm>()) {
  60. MS_LOG(DEBUG) << "String";
  61. py::str v = value->cast<StringImmPtr>()->value();
  62. ret = v;
  63. } else if (value->isa<tensor::Tensor>()) {
  64. MS_LOG(DEBUG) << "tensor";
  65. py::tuple v(1);
  66. v[0] = value->cast<tensor::TensorPtr>();
  67. ret = v[0];
  68. } else if (value->isa<RefKey>()) {
  69. MS_LOG(DEBUG) << "RefKey";
  70. py::tuple v(1);
  71. v[0] = value->cast<RefKeyPtr>();
  72. ret = v[0];
  73. } else if (value->isa<ValueTuple>()) {
  74. MS_LOG(DEBUG) << "tuple";
  75. auto value_tuple = value->cast<ValueTuplePtr>()->value();
  76. py::tuple rets(value_tuple.size());
  77. size_t i = 0;
  78. for (auto &v : value_tuple) {
  79. rets[i] = ValuePtrToPyData(v);
  80. i++;
  81. }
  82. ret = rets;
  83. } else if (value->isa<ValueList>()) {
  84. MS_LOG(DEBUG) << "list";
  85. auto value_list = value->cast<ValueListPtr>()->value();
  86. py::list rets(value_list.size());
  87. size_t i = 0;
  88. for (auto &v : value_list) {
  89. rets[i] = ValuePtrToPyData(v);
  90. i++;
  91. }
  92. ret = rets;
  93. } else if (value->isa<ValueSlice>()) {
  94. auto slice = value->cast<ValueSlicePtr>();
  95. auto start = ValuePtrToPyData(slice->start());
  96. auto end = ValuePtrToPyData(slice->stop());
  97. auto step = ValuePtrToPyData(slice->step());
  98. ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
  99. step);
  100. } else if (value->isa<Type>()) {
  101. py::tuple v(1);
  102. v[0] = value->cast<TypePtr>();
  103. ret = v[0];
  104. } else if (value->isa<AnyValue>()) {
  105. ret = py::none();
  106. } else if (value->isa<None>()) {
  107. ret = py::none();
  108. } else {
  109. MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
  110. }
  111. return ret;
  112. }
  113. py::object AnyToPyData(const Any &value) {
  114. py::object ret;
  115. MS_LOG(DEBUG) << "AnyToPyData " << value.GetString();
  116. if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) {
  117. ret = BuiltinsToPyData(value);
  118. } else if (value.is<ValuePtr>()) {
  119. MS_LOG(DEBUG) << "ValuePtr";
  120. ValuePtr v = value.cast<ValuePtr>();
  121. ret = ValuePtrToPyData(v);
  122. } else if (value.is<tensor::TensorPtr>()) {
  123. MS_LOG(DEBUG) << "tensor";
  124. py::tuple v(1);
  125. v[0] = value.cast<tensor::TensorPtr>();
  126. ret = v[0];
  127. } else if (value.is<py::object>()) {
  128. MS_LOG(DEBUG) << "py obj";
  129. ret = value.cast<py::object>();
  130. } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) {
  131. ret = VectorToPyData(value);
  132. } else if (value.is<std::list<Any>>()) {
  133. MS_LOG(DEBUG) << "list_any";
  134. auto value_list = value.cast<std::list<Any>>();
  135. py::list rets = py::list();
  136. for (auto &v : value_list) {
  137. rets.append(AnyToPyData(v));
  138. }
  139. ret = rets;
  140. } else if (value.is<std::vector<Any>>()) {
  141. auto value_list = value.cast<std::vector<Any>>();
  142. py::tuple rets(value_list.size());
  143. for (size_t i = 0; i < value_list.size(); i++) {
  144. rets[i] = AnyToPyData(value_list[i]);
  145. }
  146. ret = rets;
  147. } else if (value.is<TypePtr>()) {
  148. py::tuple v(1);
  149. v[0] = value.cast<TypePtr>();
  150. ret = v[0];
  151. } else {
  152. MS_LOG(EXCEPTION) << "value is not support type";
  153. }
  154. return ret;
  155. }
  156. py::object BaseRefToPyData(const BaseRef &value) {
  157. py::object ret;
  158. MS_LOG(DEBUG) << "BaseRefToPyData " << common::SafeCStr(value.ToString());
  159. if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
  160. ret = BuiltinsToPyData(value);
  161. } else if (utils::isa<ValuePtr>(value)) {
  162. MS_LOG(DEBUG) << "ValuePtr";
  163. ValuePtr v = utils::cast<ValuePtr>(value);
  164. ret = ValuePtrToPyData(v);
  165. } else if (utils::isa<tensor::TensorPtr>(value)) {
  166. MS_LOG(DEBUG) << "tensor";
  167. py::tuple v(1);
  168. v[0] = utils::cast<tensor::TensorPtr>(value);
  169. ret = v[0];
  170. } else if (utils::isa<PyObjectRef>(value)) {
  171. MS_LOG(DEBUG) << "py obj";
  172. PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
  173. ret = py_ref.object_;
  174. } else if (utils::isa<VectorRef>(value)) {
  175. auto vec_ref = utils::cast<VectorRef>(value);
  176. ret = VectorRefToPyData(vec_ref);
  177. } else if (utils::isa<TypePtr>(value)) {
  178. py::tuple v(1);
  179. v[0] = utils::cast<TypePtr>(value);
  180. ret = v[0];
  181. } else {
  182. MS_LOG(EXCEPTION) << "value is not support type";
  183. }
  184. return ret;
  185. }
  186. bool ValueToBool(const ValuePtr &v, bool *value) {
  187. MS_EXCEPTION_IF_NULL(v);
  188. if (v->isa<BoolImm>()) {
  189. *value = v->cast<BoolImmPtr>()->value();
  190. } else if (v->isa<Int32Imm>()) {
  191. *value = v->cast<Int32ImmPtr>()->value() == 0 ? false : true;
  192. } else if (v->isa<UInt32Imm>()) {
  193. *value = v->cast<UInt32ImmPtr>()->value() == 0 ? false : true;
  194. } else if (v->isa<FP32Imm>()) {
  195. *value = v->cast<FP32ImmPtr>()->value() == 0 ? false : true;
  196. } else if (v->isa<FP64Imm>()) {
  197. *value = v->cast<FP64ImmPtr>()->value() == 0 ? false : true;
  198. } else if (v->isa<tensor::Tensor>()) {
  199. auto tensor = v->cast<tensor::TensorPtr>();
  200. MS_EXCEPTION_IF_NULL(tensor);
  201. bool *tensor_data = static_cast<bool *>(tensor->data_c());
  202. // maybe need to support if tensor is a bool array
  203. auto vb = tensor_data[0];
  204. *value = vb;
  205. } else {
  206. MS_LOG(WARNING) << "value is not supported to cast to be bool";
  207. return false;
  208. }
  209. return true;
  210. }
  211. bool BaseRefToBool(const BaseRef &v, bool *value) {
  212. if (utils::isa<ValuePtr>(v)) {
  213. return ValueToBool(utils::cast<ValuePtr>(v), value);
  214. } else if (utils::isa<bool>(v)) {
  215. auto vb = utils::cast<bool>(v);
  216. if (vb == true) {
  217. *value = true;
  218. } else {
  219. *value = false;
  220. }
  221. } else if (utils::isa<int>(v)) {
  222. auto vb = utils::cast<int>(v);
  223. if (vb == 0) {
  224. *value = false;
  225. } else {
  226. *value = true;
  227. }
  228. } else if (utils::isa<unsigned int>(v)) {
  229. auto vb = utils::cast<unsigned int>(v);
  230. if (vb == 0) {
  231. *value = false;
  232. } else {
  233. *value = true;
  234. }
  235. } else if (utils::isa<float>(v)) {
  236. auto vb = utils::cast<float>(v);
  237. if (vb >= -FLT_EPSILON && vb <= FLT_EPSILON) {
  238. *value = false;
  239. } else {
  240. *value = true;
  241. }
  242. } else if (utils::isa<double>(v)) {
  243. auto vb = utils::cast<double>(v);
  244. if (vb >= -DBL_EPSILON && vb <= DBL_EPSILON) {
  245. *value = false;
  246. } else {
  247. *value = true;
  248. }
  249. } else {
  250. MS_LOG(DEBUG) << "value is not supported to cast to be bool";
  251. return false;
  252. }
  253. return true;
  254. }
  255. py::object BuiltinsToPyData(const Any &value) {
  256. if (value.is<int>()) {
  257. MS_LOG(DEBUG) << "int";
  258. py::int_ ret = value.cast<int>();
  259. return std::move(ret);
  260. } else if (value.is<float>()) {
  261. MS_LOG(DEBUG) << "float";
  262. py::float_ ret = value.cast<float>();
  263. return std::move(ret);
  264. } else if (value.is<double>()) {
  265. MS_LOG(DEBUG) << "double";
  266. py::float_ ret = value.cast<double>();
  267. return std::move(ret);
  268. } else {
  269. MS_LOG(DEBUG) << "bool";
  270. py::bool_ ret = value.cast<bool>();
  271. return std::move(ret);
  272. }
  273. }
  274. py::object BuiltinsToPyData(const BaseRef &value) {
  275. if (utils::isa<int>(value)) {
  276. MS_LOG(DEBUG) << "int";
  277. py::int_ ret = utils::cast<int>(value);
  278. return std::move(ret);
  279. } else if (utils::isa<float>(value)) {
  280. MS_LOG(DEBUG) << "float";
  281. py::float_ ret = utils::cast<float>(value);
  282. return std::move(ret);
  283. } else if (utils::isa<double>(value)) {
  284. MS_LOG(DEBUG) << "double";
  285. py::float_ ret = utils::cast<double>(value);
  286. return std::move(ret);
  287. } else {
  288. MS_LOG(DEBUG) << "bool";
  289. py::bool_ ret = utils::cast<bool>(value);
  290. return std::move(ret);
  291. }
  292. }
  293. py::object VectorToPyData(const Any &value) {
  294. py::object ret;
  295. if (value.is<std::vector<tensor::TensorPtr>>()) {
  296. MS_LOG(DEBUG) << "vector_tensor";
  297. std::vector<tensor::TensorPtr> outputs;
  298. outputs = value.cast<std::vector<tensor::TensorPtr>>();
  299. py::tuple tensor_tuple(outputs.size());
  300. for (std::size_t i = 0; i < outputs.size(); ++i) {
  301. tensor_tuple[i] = *outputs[i];
  302. }
  303. ret = tensor_tuple;
  304. } else {
  305. MS_LOG(DEBUG) << "vector_any";
  306. auto value_list = value.cast<std::vector<Any>>();
  307. py::tuple any_tuple = py::tuple(value_list.size());
  308. size_t i = 0;
  309. for (auto &v : value_list) {
  310. any_tuple[i] = AnyToPyData(v);
  311. i++;
  312. }
  313. ret = any_tuple;
  314. }
  315. return ret;
  316. }
  317. py::object VectorRefToPyData(const VectorRef &value_list) {
  318. py::object ret;
  319. MS_LOG(DEBUG) << "vector_ref";
  320. size_t value_size = value_list.size();
  321. auto ref_tuple = py::tuple(value_size);
  322. for (size_t i = 0; i < value_size; i++) {
  323. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  324. }
  325. ret = ref_tuple;
  326. return ret;
  327. }
  328. AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj) {
  329. if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) &&
  330. py::hasattr(type_obj, PYTHON_DTYPE_FLAG)) {
  331. auto ret_vec = shape_obj.cast<std::vector<int>>();
  332. auto ret_dtype = type_obj.cast<TypePtr>();
  333. MS_EXCEPTION_IF_NULL(ret_dtype);
  334. // if the size of shape list is empty, return an scalar abstract
  335. if (ret_vec.empty() && (!ret_dtype->isa<TensorType>())) {
  336. abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
  337. return abs_scalar;
  338. }
  339. AbstractBasePtr tensor = nullptr;
  340. if (ret_dtype->isa<TensorType>()) {
  341. auto tensor_type = type_obj.cast<TensorTypePtr>();
  342. MS_EXCEPTION_IF_NULL(tensor_type);
  343. tensor = std::make_shared<abstract::AbstractTensor>(tensor_type->element(), ret_vec);
  344. } else {
  345. tensor = std::make_shared<abstract::AbstractTensor>(ret_dtype, ret_vec);
  346. }
  347. return tensor;
  348. } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
  349. py::tuple shape_tuple = shape_obj.cast<py::tuple>();
  350. py::tuple typeid_tuple = type_obj.cast<py::tuple>();
  351. AbstractBasePtrList ptr_list;
  352. for (size_t it = 0; it < shape_tuple.size(); ++it) {
  353. auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]);
  354. ptr_list.push_back(tensor_it);
  355. }
  356. auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
  357. return tuple;
  358. } else if (shape_obj.is_none() && type_obj.is_none()) {
  359. // AbstractNone indicates there is no output for this CNode node.
  360. auto abstract_none = std::make_shared<abstract::AbstractNone>();
  361. return abstract_none;
  362. } else {
  363. MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
  364. }
  365. }
  366. bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
  367. const std::shared_ptr<py::object> &ret_val) {
  368. if (output->isa<ValueNode>()) {
  369. MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
  370. ValuePtr value = GetValueNode(output);
  371. *ret_val = ValuePtrToPyData(value);
  372. return true;
  373. }
  374. // Adapter will transform values in __init__() and construct() to parameters, this could cause
  375. // inputs (a.k.a args in current function) size less than parameters'.
  376. if (output->isa<Parameter>()) {
  377. MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
  378. if (args.empty()) {
  379. MS_LOG(EXCEPTION) << "Inputs size is 0, let graph to be executed.";
  380. }
  381. // Find the right parameter as ret_val.
  382. auto func_graph = output->func_graph();
  383. MS_EXCEPTION_IF_NULL(func_graph);
  384. auto params = func_graph->parameters();
  385. if (params.empty()) {
  386. MS_EXCEPTION(UnknownError) << "Graph's parameters size is 0";
  387. }
  388. if (args.size() != params.size()) {
  389. MS_LOG(EXCEPTION) << "Input size " << args.size() << " not equal to params size " << params.size()
  390. << ", let graph to be executed.";
  391. }
  392. auto it = std::find(params.begin(), params.end(), output);
  393. if (it == params.end()) {
  394. MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
  395. }
  396. size_t index = it - params.cbegin();
  397. if (index >= args.size()) {
  398. MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size() << ".";
  399. }
  400. *ret_val = args[index];
  401. return true;
  402. }
  403. return false;
  404. }
  405. } // namespace mindspore