You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/convert_utils.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <algorithm>
  21. #include <list>
  22. #include <utility>
  23. #include <cfloat>
  24. #include "pybind11/pybind11.h"
  25. #include "ir/meta_tensor.h"
  26. #include "pipeline/parse/parse.h"
  27. #include "ir/value.h"
  28. namespace mindspore {
  29. py::object BuiltinsToPyData(const Any &value);
  30. py::object BuiltinsToPyData(const BaseRef &value);
  31. py::object VectorToPyData(const Any &value);
  32. py::object VectorRefToPyData(const VectorRef &value);
  33. py::object ValuePtrToPyData(const ValuePtr &value) {
  34. if (value == nullptr) {
  35. MS_LOG(EXCEPTION) << "value is null";
  36. }
  37. py::object ret;
  38. if (value->isa<Int32Imm>()) {
  39. MS_LOG(DEBUG) << "int";
  40. py::int_ v = value->cast<Int32ImmPtr>()->value();
  41. ret = v;
  42. } else if (value->isa<UInt64Imm>()) {
  43. MS_LOG(DEBUG) << "uint64";
  44. py::int_ v = value->cast<UInt64ImmPtr>()->value();
  45. ret = v;
  46. } else if (value->isa<BoolImm>()) {
  47. MS_LOG(DEBUG) << "bool";
  48. py::bool_ v = value->cast<BoolImmPtr>()->value();
  49. ret = v;
  50. } else if (value->isa<FP64Imm>()) {
  51. MS_LOG(DEBUG) << "double";
  52. py::float_ v = value->cast<FP64ImmPtr>()->value();
  53. ret = v;
  54. } else if (value->isa<FP32Imm>()) {
  55. MS_LOG(DEBUG) << "float";
  56. py::float_ v = value->cast<FP32ImmPtr>()->value();
  57. ret = v;
  58. } else if (value->isa<StringImm>()) {
  59. MS_LOG(DEBUG) << "String";
  60. py::str v = value->cast<StringImmPtr>()->value();
  61. ret = v;
  62. } else if (value->isa<tensor::Tensor>()) {
  63. MS_LOG(DEBUG) << "tensor";
  64. py::tuple v(1);
  65. v[0] = value->cast<tensor::TensorPtr>();
  66. ret = v[0];
  67. } else if (value->isa<RefKey>()) {
  68. MS_LOG(DEBUG) << "RefKey";
  69. py::tuple v(1);
  70. v[0] = value->cast<RefKeyPtr>();
  71. ret = v[0];
  72. } else if (value->isa<ValueTuple>()) {
  73. MS_LOG(DEBUG) << "tuple";
  74. auto value_tuple = value->cast<ValueTuplePtr>()->value();
  75. py::tuple rets(value_tuple.size());
  76. size_t i = 0;
  77. for (auto &v : value_tuple) {
  78. rets[i] = ValuePtrToPyData(v);
  79. i++;
  80. }
  81. ret = rets;
  82. } else if (value->isa<ValueList>()) {
  83. MS_LOG(DEBUG) << "list";
  84. auto value_list = value->cast<ValueListPtr>()->value();
  85. py::list rets(value_list.size());
  86. size_t i = 0;
  87. for (auto &v : value_list) {
  88. rets[i] = ValuePtrToPyData(v);
  89. i++;
  90. }
  91. ret = rets;
  92. } else if (value->isa<Type>()) {
  93. py::tuple v(1);
  94. v[0] = value->cast<TypePtr>();
  95. ret = v[0];
  96. } else if (value->isa<AnyValue>()) {
  97. ret = py::none();
  98. } else if (value->isa<None>()) {
  99. ret = py::none();
  100. } else {
  101. MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
  102. }
  103. return ret;
  104. }
  105. py::object AnyToPyData(const Any &value) {
  106. py::object ret;
  107. MS_LOG(DEBUG) << "AnyToPyData " << value.GetString();
  108. if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) {
  109. ret = BuiltinsToPyData(value);
  110. } else if (value.is<ValuePtr>()) {
  111. MS_LOG(DEBUG) << "ValuePtr";
  112. ValuePtr v = value.cast<ValuePtr>();
  113. ret = ValuePtrToPyData(v);
  114. } else if (value.is<tensor::TensorPtr>()) {
  115. MS_LOG(DEBUG) << "tensor";
  116. py::tuple v(1);
  117. v[0] = value.cast<tensor::TensorPtr>();
  118. ret = v[0];
  119. } else if (value.is<py::object>()) {
  120. MS_LOG(DEBUG) << "py obj";
  121. ret = value.cast<py::object>();
  122. } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) {
  123. ret = VectorToPyData(value);
  124. } else if (value.is<std::list<Any>>()) {
  125. MS_LOG(DEBUG) << "list_any";
  126. auto value_list = value.cast<std::list<Any>>();
  127. py::list rets = py::list();
  128. for (auto &v : value_list) {
  129. rets.append(AnyToPyData(v));
  130. }
  131. ret = rets;
  132. } else if (value.is<std::vector<Any>>()) {
  133. auto value_list = value.cast<std::vector<Any>>();
  134. py::tuple rets(value_list.size());
  135. for (size_t i = 0; i < value_list.size(); i++) {
  136. rets[i] = AnyToPyData(value_list[i]);
  137. }
  138. ret = rets;
  139. } else if (value.is<TypePtr>()) {
  140. py::tuple v(1);
  141. v[0] = value.cast<TypePtr>();
  142. ret = v[0];
  143. } else {
  144. MS_LOG(EXCEPTION) << "value is not support type";
  145. }
  146. return ret;
  147. }
  148. py::object BaseRefToPyData(const BaseRef &value) {
  149. py::object ret;
  150. MS_LOG(DEBUG) << "BaseRefToPyData " << common::SafeCStr(value.ToString());
  151. if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
  152. ret = BuiltinsToPyData(value);
  153. } else if (utils::isa<ValuePtr>(value)) {
  154. MS_LOG(DEBUG) << "ValuePtr";
  155. ValuePtr v = utils::cast<ValuePtr>(value);
  156. ret = ValuePtrToPyData(v);
  157. } else if (utils::isa<tensor::TensorPtr>(value)) {
  158. MS_LOG(DEBUG) << "tensor";
  159. py::tuple v(1);
  160. v[0] = utils::cast<tensor::TensorPtr>(value);
  161. ret = v[0];
  162. } else if (utils::isa<PyObjectRef>(value)) {
  163. MS_LOG(DEBUG) << "py obj";
  164. PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
  165. ret = py_ref.object_;
  166. } else if (utils::isa<VectorRef>(value)) {
  167. auto vec_ref = utils::cast<VectorRef>(value);
  168. ret = VectorRefToPyData(vec_ref);
  169. } else if (utils::isa<TypePtr>(value)) {
  170. py::tuple v(1);
  171. v[0] = utils::cast<TypePtr>(value);
  172. ret = v[0];
  173. } else {
  174. MS_LOG(EXCEPTION) << "value is not support type";
  175. }
  176. return ret;
  177. }
  178. bool ValueToBool(const ValuePtr &v, bool *value) {
  179. MS_EXCEPTION_IF_NULL(v);
  180. if (v->isa<BoolImm>()) {
  181. *value = v->cast<BoolImmPtr>()->value();
  182. } else if (v->isa<Int32Imm>()) {
  183. *value = v->cast<Int32ImmPtr>()->value() == 0 ? false : true;
  184. } else if (v->isa<UInt32Imm>()) {
  185. *value = v->cast<UInt32ImmPtr>()->value() == 0 ? false : true;
  186. } else if (v->isa<FP32Imm>()) {
  187. *value = v->cast<FP32ImmPtr>()->value() == 0 ? false : true;
  188. } else if (v->isa<FP64Imm>()) {
  189. *value = v->cast<FP64ImmPtr>()->value() == 0 ? false : true;
  190. } else if (v->isa<tensor::Tensor>()) {
  191. auto tensor = v->cast<tensor::TensorPtr>();
  192. MS_EXCEPTION_IF_NULL(tensor);
  193. bool *tensor_data = static_cast<bool *>(tensor->data_c());
  194. // maybe need to support if tensor is a bool array
  195. auto vb = tensor_data[0];
  196. *value = vb;
  197. } else {
  198. MS_LOG(WARNING) << "value is not supported to cast to be bool";
  199. return false;
  200. }
  201. return true;
  202. }
  203. bool BaseRefToBool(const BaseRef &v, bool *value) {
  204. if (utils::isa<ValuePtr>(v)) {
  205. return ValueToBool(utils::cast<ValuePtr>(v), value);
  206. } else if (utils::isa<bool>(v)) {
  207. auto vb = utils::cast<bool>(v);
  208. if (vb == true) {
  209. *value = true;
  210. } else {
  211. *value = false;
  212. }
  213. } else if (utils::isa<int>(v)) {
  214. auto vb = utils::cast<int>(v);
  215. if (vb == 0) {
  216. *value = false;
  217. } else {
  218. *value = true;
  219. }
  220. } else if (utils::isa<unsigned int>(v)) {
  221. auto vb = utils::cast<unsigned int>(v);
  222. if (vb == 0) {
  223. *value = false;
  224. } else {
  225. *value = true;
  226. }
  227. } else if (utils::isa<float>(v)) {
  228. auto vb = utils::cast<float>(v);
  229. if (vb >= -FLT_EPSILON && vb <= FLT_EPSILON) {
  230. *value = false;
  231. } else {
  232. *value = true;
  233. }
  234. } else if (utils::isa<double>(v)) {
  235. auto vb = utils::cast<double>(v);
  236. if (vb >= -DBL_EPSILON && vb <= DBL_EPSILON) {
  237. *value = false;
  238. } else {
  239. *value = true;
  240. }
  241. } else {
  242. MS_LOG(DEBUG) << "value is not supported to cast to be bool";
  243. return false;
  244. }
  245. return true;
  246. }
  247. py::object BuiltinsToPyData(const Any &value) {
  248. if (value.is<int>()) {
  249. MS_LOG(DEBUG) << "int";
  250. py::int_ ret = value.cast<int>();
  251. return std::move(ret);
  252. } else if (value.is<float>()) {
  253. MS_LOG(DEBUG) << "float";
  254. py::float_ ret = value.cast<float>();
  255. return std::move(ret);
  256. } else if (value.is<double>()) {
  257. MS_LOG(DEBUG) << "double";
  258. py::float_ ret = value.cast<double>();
  259. return std::move(ret);
  260. } else {
  261. MS_LOG(DEBUG) << "bool";
  262. py::bool_ ret = value.cast<bool>();
  263. return std::move(ret);
  264. }
  265. }
  266. py::object BuiltinsToPyData(const BaseRef &value) {
  267. if (utils::isa<int>(value)) {
  268. MS_LOG(DEBUG) << "int";
  269. py::int_ ret = utils::cast<int>(value);
  270. return std::move(ret);
  271. } else if (utils::isa<float>(value)) {
  272. MS_LOG(DEBUG) << "float";
  273. py::float_ ret = utils::cast<float>(value);
  274. return std::move(ret);
  275. } else if (utils::isa<double>(value)) {
  276. MS_LOG(DEBUG) << "double";
  277. py::float_ ret = utils::cast<double>(value);
  278. return std::move(ret);
  279. } else {
  280. MS_LOG(DEBUG) << "bool";
  281. py::bool_ ret = utils::cast<bool>(value);
  282. return std::move(ret);
  283. }
  284. }
  285. py::object VectorToPyData(const Any &value) {
  286. py::object ret;
  287. if (value.is<std::vector<tensor::TensorPtr>>()) {
  288. MS_LOG(DEBUG) << "vector_tensor";
  289. std::vector<tensor::TensorPtr> outputs;
  290. outputs = value.cast<std::vector<tensor::TensorPtr>>();
  291. py::tuple tensor_tuple(outputs.size());
  292. for (std::size_t i = 0; i < outputs.size(); ++i) {
  293. tensor_tuple[i] = *outputs[i];
  294. }
  295. ret = tensor_tuple;
  296. } else {
  297. MS_LOG(DEBUG) << "vector_any";
  298. auto value_list = value.cast<std::vector<Any>>();
  299. py::tuple any_tuple = py::tuple(value_list.size());
  300. size_t i = 0;
  301. for (auto &v : value_list) {
  302. any_tuple[i] = AnyToPyData(v);
  303. i++;
  304. }
  305. ret = any_tuple;
  306. }
  307. return ret;
  308. }
  309. py::object VectorRefToPyData(const VectorRef &value_list) {
  310. py::object ret;
  311. MS_LOG(DEBUG) << "vector_ref";
  312. size_t value_size = value_list.size();
  313. py::tuple ref_tuple = py::tuple(value_size);
  314. for (size_t i = 0; i < value_size; i++) {
  315. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  316. }
  317. ret = ref_tuple;
  318. return ret;
  319. }
  320. AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj) {
  321. if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) &&
  322. py::hasattr(type_obj, PYTHON_DTYPE_FLAG)) {
  323. auto ret_vec = shape_obj.cast<std::vector<int>>();
  324. auto ret_dtype = type_obj.cast<TypePtr>();
  325. MS_EXCEPTION_IF_NULL(ret_dtype);
  326. // if the size of shape list is empty, return an scalar abstract
  327. if (ret_vec.empty() && (!ret_dtype->isa<TensorType>())) {
  328. abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
  329. return abs_scalar;
  330. }
  331. AbstractBasePtr tensor = nullptr;
  332. if (ret_dtype->isa<TensorType>()) {
  333. auto tensor_type = type_obj.cast<TensorTypePtr>();
  334. MS_EXCEPTION_IF_NULL(tensor_type);
  335. tensor = std::make_shared<abstract::AbstractTensor>(tensor_type->element(), ret_vec);
  336. } else {
  337. tensor = std::make_shared<abstract::AbstractTensor>(ret_dtype, ret_vec);
  338. }
  339. return tensor;
  340. } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
  341. py::tuple shape_tuple = shape_obj.cast<py::tuple>();
  342. py::tuple typeid_tuple = type_obj.cast<py::tuple>();
  343. AbstractBasePtrList ptr_list;
  344. for (size_t it = 0; it < shape_tuple.size(); ++it) {
  345. auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]);
  346. ptr_list.push_back(tensor_it);
  347. }
  348. auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
  349. return tuple;
  350. } else if (shape_obj.is_none() && type_obj.is_none()) {
  351. // AbstractNone indicates there is no output for this CNode node.
  352. auto abstract_none = std::make_shared<abstract::AbstractNone>();
  353. return abstract_none;
  354. } else {
  355. MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
  356. }
  357. }
  358. bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
  359. const std::shared_ptr<py::object> &ret_val) {
  360. if (output->isa<ValueNode>()) {
  361. MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
  362. ValuePtr value = GetValueNode(output);
  363. *ret_val = ValuePtrToPyData(value);
  364. return true;
  365. }
  366. // Adapter will transform values in __init__() and construct() to parameters, this could cause
  367. // inputs (a.k.a args in current function) size less than parameters'.
  368. if (output->isa<Parameter>()) {
  369. MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
  370. if (args.empty()) {
  371. MS_LOG(EXCEPTION) << "Inputs size is 0, let graph to be executed.";
  372. }
  373. // Find the right parameter as ret_val.
  374. auto func_graph = output->func_graph();
  375. MS_EXCEPTION_IF_NULL(func_graph);
  376. auto params = func_graph->parameters();
  377. if (params.empty()) {
  378. MS_EXCEPTION(UnknownError) << "Graph's parameters size is 0";
  379. }
  380. if (args.size() != params.size()) {
  381. MS_LOG(EXCEPTION) << "Input size " << args.size() << " not equal to params size " << params.size()
  382. << ", let graph to be executed.";
  383. }
  384. auto it = std::find(params.begin(), params.end(), output);
  385. if (it == params.end()) {
  386. MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
  387. }
  388. size_t index = it - params.cbegin();
  389. if (index >= args.size()) {
  390. MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size() << ".";
  391. }
  392. *ret_val = args[index];
  393. return true;
  394. }
  395. return false;
  396. }
  397. } // namespace mindspore