You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils.cc 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/convert_utils.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <algorithm>
  21. #include <list>
  22. #include <utility>
  23. #include <cfloat>
  24. #include "pybind11/pybind11.h"
  25. #include "pipeline/static_analysis/abstract_value.h"
  26. #include "pipeline/parse/parse.h"
  27. #include "pipeline/parse/parse_base.h"
  28. #include "ir/value.h"
  29. #include "ir/tensor.h"
  30. #include "ir/param_value.h"
  31. #include "utils/base_ref_extends.h"
  32. namespace mindspore {
  33. py::object BuiltinsToPyData(const Any &value);
  34. py::object BuiltinsToPyData(const BaseRef &value);
  35. py::object VectorToPyData(const Any &value);
  36. py::object VectorRefToPyData(const VectorRef &value);
  37. py::object ValuePtrToPyData(const ValuePtr &value) {
  38. if (value == nullptr) {
  39. MS_LOG(EXCEPTION) << "value is null";
  40. }
  41. py::object ret;
  42. if (value->isa<Int32Imm>()) {
  43. MS_LOG(DEBUG) << "int";
  44. py::int_ v = value->cast<Int32ImmPtr>()->value();
  45. ret = v;
  46. } else if (value->isa<UInt64Imm>()) {
  47. MS_LOG(DEBUG) << "uint64";
  48. py::int_ v = value->cast<UInt64ImmPtr>()->value();
  49. ret = v;
  50. } else if (value->isa<BoolImm>()) {
  51. MS_LOG(DEBUG) << "bool";
  52. py::bool_ v = value->cast<BoolImmPtr>()->value();
  53. ret = v;
  54. } else if (value->isa<FP64Imm>()) {
  55. MS_LOG(DEBUG) << "double";
  56. py::float_ v = value->cast<FP64ImmPtr>()->value();
  57. ret = v;
  58. } else if (value->isa<FP32Imm>()) {
  59. MS_LOG(DEBUG) << "float";
  60. py::float_ v = value->cast<FP32ImmPtr>()->value();
  61. ret = v;
  62. } else if (value->isa<StringImm>()) {
  63. MS_LOG(DEBUG) << "String";
  64. py::str v = value->cast<StringImmPtr>()->value();
  65. ret = v;
  66. } else if (value->isa<tensor::Tensor>()) {
  67. MS_LOG(DEBUG) << "tensor";
  68. py::tuple v(1);
  69. v[0] = value->cast<tensor::TensorPtr>();
  70. ret = v[0];
  71. } else if (value->isa<tensor::MetaTensor>()) {
  72. MS_LOG(DEBUG) << "MetaTensor";
  73. py::tuple v(1);
  74. v[0] = value->cast<tensor::MetaTensorPtr>();
  75. ret = v[0];
  76. } else if (value->isa<RefKey>()) {
  77. MS_LOG(DEBUG) << "RefKey";
  78. py::tuple v(1);
  79. v[0] = value->cast<RefKeyPtr>();
  80. ret = v[0];
  81. } else if (value->isa<ValueTuple>()) {
  82. MS_LOG(DEBUG) << "tuple";
  83. auto value_tuple = value->cast<ValueTuplePtr>()->value();
  84. py::tuple rets(value_tuple.size());
  85. size_t i = 0;
  86. for (auto &v : value_tuple) {
  87. rets[i] = ValuePtrToPyData(v);
  88. i++;
  89. }
  90. ret = rets;
  91. } else if (value->isa<ValueList>()) {
  92. MS_LOG(DEBUG) << "list";
  93. auto value_list = value->cast<ValueListPtr>()->value();
  94. py::list rets(value_list.size());
  95. size_t i = 0;
  96. for (auto &v : value_list) {
  97. rets[i] = ValuePtrToPyData(v);
  98. i++;
  99. }
  100. ret = rets;
  101. } else if (value->isa<Ellipsis>()) {
  102. ret = py::ellipsis();
  103. } else if (value->isa<ValueSlice>()) {
  104. auto slice = value->cast<ValueSlicePtr>();
  105. auto start = ValuePtrToPyData(slice->start());
  106. auto end = ValuePtrToPyData(slice->stop());
  107. auto step = ValuePtrToPyData(slice->step());
  108. ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
  109. step);
  110. } else if (value->isa<Type>()) {
  111. py::tuple v(1);
  112. v[0] = value->cast<TypePtr>();
  113. ret = v[0];
  114. } else if (value->isa<AnyValue>()) {
  115. ret = py::none();
  116. } else if (value->isa<None>()) {
  117. ret = py::none();
  118. } else {
  119. MS_LOG(INFO) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
  120. }
  121. return ret;
  122. }
  123. py::object AnyToPyData(const Any &value) {
  124. py::object ret;
  125. MS_LOG(DEBUG) << "AnyToPyData " << value.GetString();
  126. if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) {
  127. ret = BuiltinsToPyData(value);
  128. } else if (value.is<ValuePtr>()) {
  129. MS_LOG(DEBUG) << "ValuePtr";
  130. ValuePtr v = value.cast<ValuePtr>();
  131. ret = ValuePtrToPyData(v);
  132. } else if (value.is<tensor::TensorPtr>()) {
  133. MS_LOG(DEBUG) << "tensor";
  134. py::tuple v(1);
  135. v[0] = value.cast<tensor::TensorPtr>();
  136. ret = v[0];
  137. } else if (value.is<py::object>()) {
  138. MS_LOG(DEBUG) << "py obj";
  139. ret = value.cast<py::object>();
  140. } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) {
  141. ret = VectorToPyData(value);
  142. } else if (value.is<std::list<Any>>()) {
  143. MS_LOG(DEBUG) << "list_any";
  144. auto value_list = value.cast<std::list<Any>>();
  145. py::list rets = py::list();
  146. for (auto &v : value_list) {
  147. rets.append(AnyToPyData(v));
  148. }
  149. ret = rets;
  150. } else if (value.is<std::vector<Any>>()) {
  151. auto value_list = value.cast<std::vector<Any>>();
  152. py::tuple rets(value_list.size());
  153. for (size_t i = 0; i < value_list.size(); i++) {
  154. rets[i] = AnyToPyData(value_list[i]);
  155. }
  156. ret = rets;
  157. } else if (value.is<TypePtr>()) {
  158. py::tuple v(1);
  159. v[0] = value.cast<TypePtr>();
  160. ret = v[0];
  161. } else {
  162. MS_LOG(EXCEPTION) << "value is not support type";
  163. }
  164. return ret;
  165. }
  166. py::object BaseRefToPyData(const BaseRef &value) {
  167. py::object ret;
  168. MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString();
  169. if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
  170. ret = BuiltinsToPyData(value);
  171. } else if (utils::isa<ValuePtr>(value)) {
  172. MS_LOG(DEBUG) << "ValuePtr";
  173. ValuePtr v = utils::cast<ValuePtr>(value);
  174. ret = ValuePtrToPyData(v);
  175. } else if (utils::isa<tensor::TensorPtr>(value)) {
  176. MS_LOG(DEBUG) << "tensor";
  177. py::tuple v(1);
  178. v[0] = utils::cast<tensor::TensorPtr>(value);
  179. ret = v[0];
  180. } else if (utils::isa<PyObjectRef>(value)) {
  181. MS_LOG(DEBUG) << "py obj";
  182. PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
  183. ret = py_ref.object_;
  184. } else if (utils::isa<VectorRef>(value)) {
  185. auto vec_ref = utils::cast<VectorRef>(value);
  186. ret = VectorRefToPyData(vec_ref);
  187. } else if (utils::isa<TypePtr>(value)) {
  188. py::tuple v(1);
  189. v[0] = utils::cast<TypePtr>(value);
  190. ret = v[0];
  191. } else {
  192. MS_LOG(EXCEPTION) << "value is not support type";
  193. }
  194. return ret;
  195. }
  196. bool ValueToBool(const ValuePtr &v, bool *value) {
  197. MS_EXCEPTION_IF_NULL(v);
  198. if (v->isa<BoolImm>()) {
  199. *value = v->cast<BoolImmPtr>()->value();
  200. } else if (v->isa<Int32Imm>()) {
  201. *value = v->cast<Int32ImmPtr>()->value() == 0 ? false : true;
  202. } else if (v->isa<UInt32Imm>()) {
  203. *value = v->cast<UInt32ImmPtr>()->value() == 0 ? false : true;
  204. } else if (v->isa<FP32Imm>()) {
  205. *value = v->cast<FP32ImmPtr>()->value() == 0 ? false : true;
  206. } else if (v->isa<FP64Imm>()) {
  207. *value = v->cast<FP64ImmPtr>()->value() == 0 ? false : true;
  208. } else if (v->isa<tensor::Tensor>()) {
  209. auto tensor = v->cast<tensor::TensorPtr>();
  210. MS_EXCEPTION_IF_NULL(tensor);
  211. (void)tensor->data_sync();
  212. bool *tensor_data = static_cast<bool *>(tensor->data_c());
  213. // maybe need to support if tensor is a bool array
  214. auto vb = tensor_data[0];
  215. *value = vb;
  216. } else {
  217. MS_LOG(WARNING) << "value is not supported to cast to be bool";
  218. return false;
  219. }
  220. return true;
  221. }
  222. bool BaseRefToInt(const ValuePtr &v, int *value) {
  223. MS_EXCEPTION_IF_NULL(v);
  224. if (v->isa<tensor::Tensor>()) {
  225. auto tensor = v->cast<tensor::TensorPtr>();
  226. (void)tensor->data_sync();
  227. int *tensor_data = static_cast<int *>(tensor->data_c());
  228. auto vb = tensor_data[0];
  229. *value = vb;
  230. return true;
  231. }
  232. MS_LOG(ERROR) << "Index must be tensor type.";
  233. return false;
  234. }
  235. bool BaseRefToBool(const BaseRef &v, bool *value) {
  236. if (utils::isa<ValuePtr>(v)) {
  237. return ValueToBool(utils::cast<ValuePtr>(v), value);
  238. } else if (utils::isa<bool>(v)) {
  239. auto vb = utils::cast<bool>(v);
  240. if (vb == true) {
  241. *value = true;
  242. } else {
  243. *value = false;
  244. }
  245. } else if (utils::isa<int>(v)) {
  246. auto vb = utils::cast<int>(v);
  247. if (vb == 0) {
  248. *value = false;
  249. } else {
  250. *value = true;
  251. }
  252. } else if (utils::isa<unsigned int>(v)) {
  253. auto vb = utils::cast<unsigned int>(v);
  254. if (vb == 0) {
  255. *value = false;
  256. } else {
  257. *value = true;
  258. }
  259. } else if (utils::isa<float>(v)) {
  260. auto vb = utils::cast<float>(v);
  261. if (vb >= -FLT_EPSILON && vb <= FLT_EPSILON) {
  262. *value = false;
  263. } else {
  264. *value = true;
  265. }
  266. } else if (utils::isa<double>(v)) {
  267. auto vb = utils::cast<double>(v);
  268. if (vb >= -DBL_EPSILON && vb <= DBL_EPSILON) {
  269. *value = false;
  270. } else {
  271. *value = true;
  272. }
  273. } else {
  274. MS_LOG(DEBUG) << "value is not supported to cast to be bool";
  275. return false;
  276. }
  277. return true;
  278. }
  279. py::object BuiltinsToPyData(const Any &value) {
  280. if (value.is<int>()) {
  281. MS_LOG(DEBUG) << "int";
  282. py::int_ ret = value.cast<int>();
  283. return std::move(ret);
  284. } else if (value.is<float>()) {
  285. MS_LOG(DEBUG) << "float";
  286. py::float_ ret = value.cast<float>();
  287. return std::move(ret);
  288. } else if (value.is<double>()) {
  289. MS_LOG(DEBUG) << "double";
  290. py::float_ ret = value.cast<double>();
  291. return std::move(ret);
  292. } else {
  293. MS_LOG(DEBUG) << "bool";
  294. py::bool_ ret = value.cast<bool>();
  295. return std::move(ret);
  296. }
  297. }
  298. py::object BuiltinsToPyData(const BaseRef &value) {
  299. if (utils::isa<int>(value)) {
  300. MS_LOG(DEBUG) << "int";
  301. py::int_ ret = utils::cast<int>(value);
  302. return std::move(ret);
  303. } else if (utils::isa<float>(value)) {
  304. MS_LOG(DEBUG) << "float";
  305. py::float_ ret = utils::cast<float>(value);
  306. return std::move(ret);
  307. } else if (utils::isa<double>(value)) {
  308. MS_LOG(DEBUG) << "double";
  309. py::float_ ret = utils::cast<double>(value);
  310. return std::move(ret);
  311. } else {
  312. MS_LOG(DEBUG) << "bool";
  313. py::bool_ ret = utils::cast<bool>(value);
  314. return std::move(ret);
  315. }
  316. }
  317. py::object VectorToPyData(const Any &value) {
  318. py::object ret;
  319. if (value.is<std::vector<tensor::TensorPtr>>()) {
  320. MS_LOG(DEBUG) << "vector_tensor";
  321. std::vector<tensor::TensorPtr> outputs;
  322. outputs = value.cast<std::vector<tensor::TensorPtr>>();
  323. py::tuple tensor_tuple(outputs.size());
  324. for (std::size_t i = 0; i < outputs.size(); ++i) {
  325. tensor_tuple[i] = *outputs[i];
  326. }
  327. ret = tensor_tuple;
  328. } else {
  329. MS_LOG(DEBUG) << "vector_any";
  330. auto value_list = value.cast<std::vector<Any>>();
  331. py::tuple any_tuple = py::tuple(value_list.size());
  332. size_t i = 0;
  333. for (auto &v : value_list) {
  334. any_tuple[i] = AnyToPyData(v);
  335. i++;
  336. }
  337. ret = any_tuple;
  338. }
  339. return ret;
  340. }
  341. py::object VectorRefToPyData(const VectorRef &value_list) {
  342. py::object ret;
  343. MS_LOG(DEBUG) << "vector_ref";
  344. size_t value_size = value_list.size();
  345. auto ref_tuple = py::tuple(value_size);
  346. for (size_t i = 0; i < value_size; i++) {
  347. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  348. }
  349. ret = ref_tuple;
  350. return ret;
  351. }
  352. AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj) {
  353. if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) &&
  354. py::hasattr(type_obj, PYTHON_DTYPE_FLAG)) {
  355. auto ret_vec = shape_obj.cast<std::vector<int>>();
  356. auto ret_dtype = type_obj.cast<TypePtr>();
  357. MS_EXCEPTION_IF_NULL(ret_dtype);
  358. // if the size of shape list is empty, return an scalar abstract
  359. if (ret_vec.empty() && (!ret_dtype->isa<TensorType>())) {
  360. abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
  361. return abs_scalar;
  362. }
  363. AbstractBasePtr tensor = nullptr;
  364. if (ret_dtype->isa<TensorType>()) {
  365. auto tensor_type = type_obj.cast<TensorTypePtr>();
  366. MS_EXCEPTION_IF_NULL(tensor_type);
  367. tensor = std::make_shared<abstract::AbstractTensor>(tensor_type->element(), ret_vec);
  368. } else {
  369. tensor = std::make_shared<abstract::AbstractTensor>(ret_dtype, ret_vec);
  370. }
  371. return tensor;
  372. } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
  373. py::tuple shape_tuple = shape_obj.cast<py::tuple>();
  374. py::tuple typeid_tuple = type_obj.cast<py::tuple>();
  375. AbstractBasePtrList ptr_list;
  376. for (size_t it = 0; it < shape_tuple.size(); ++it) {
  377. auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]);
  378. ptr_list.push_back(tensor_it);
  379. }
  380. auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
  381. return tuple;
  382. } else if (shape_obj.is_none() && type_obj.is_none()) {
  383. // AbstractNone indicates there is no output for this CNode node.
  384. auto abstract_none = std::make_shared<abstract::AbstractNone>();
  385. return abstract_none;
  386. } else {
  387. MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
  388. }
  389. }
  390. bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
  391. const std::shared_ptr<py::object> &ret_val) {
  392. if (output->isa<ValueNode>()) {
  393. MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
  394. ValuePtr value = GetValueNode(output);
  395. *ret_val = ValuePtrToPyData(value);
  396. return true;
  397. }
  398. // Adapter will transform values in __init__() and construct() to parameters, this could cause
  399. // inputs (a.k.a args in current function) size less than parameters'.
  400. if (output->isa<Parameter>()) {
  401. MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
  402. if (args.empty()) {
  403. MS_LOG(EXCEPTION) << "Inputs size is 0, let graph to be executed.";
  404. }
  405. // Find the right parameter as ret_val.
  406. auto func_graph = output->func_graph();
  407. MS_EXCEPTION_IF_NULL(func_graph);
  408. auto params = func_graph->parameters();
  409. if (params.empty()) {
  410. MS_EXCEPTION(UnknownError) << "Graph's parameters size is 0";
  411. }
  412. if ((args.size() + func_graph->hyper_param_count()) != params.size()) {
  413. MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count()
  414. << " not equal to graph input size " << params.size() << ", let graph to be executed.";
  415. }
  416. auto it = std::find(params.begin(), params.end(), output);
  417. if (it == params.end()) {
  418. MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
  419. }
  420. size_t index = it - params.cbegin();
  421. if (index >= args.size() + func_graph->hyper_param_count()) {
  422. MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
  423. << " add Parameter count " << func_graph->hyper_param_count() << ".";
  424. }
  425. if (index < args.size()) {
  426. *ret_val = args[index];
  427. } else {
  428. auto param = dyn_cast<Parameter>(params[index]);
  429. MS_EXCEPTION_IF_NULL(param);
  430. if (!param->has_default()) {
  431. MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
  432. }
  433. auto tensor = param->default_param()->value();
  434. *ret_val = py::cast(tensor);
  435. }
  436. return true;
  437. }
  438. return false;
  439. }
  440. // Isomorphism
  441. static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
  442. NodeMapEquiv *const equiv_node) {
  443. if (equiv_node == nullptr) {
  444. MS_LOG(ERROR) << "Invalid equiv_node";
  445. return false;
  446. }
  447. if (equiv_node->count(node1) > 0 && (*equiv_node)[node1] == node2) {
  448. return true;
  449. }
  450. if (IsValueNode<FuncGraph>(node1) && IsValueNode<FuncGraph>(node2)) {
  451. return Isomorphic(GetValueNode<FuncGraphPtr>(node1), GetValueNode<FuncGraphPtr>(node2), equiv_func_graph,
  452. equiv_node);
  453. }
  454. if (node1->isa<ValueNode>() && node2->isa<ValueNode>()) {
  455. auto a1 = GetValueNode(node1);
  456. auto a2 = GetValueNode(node2);
  457. if (a1->isa<Primitive>() && a2->isa<Primitive>()) {
  458. return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name();
  459. } else if (a1->isa<tensor::Tensor>() && a2->isa<tensor::Tensor>()) {
  460. return a1->cast<tensor::TensorPtr>()->ValueEqual(*(a2->cast<tensor::TensorPtr>()));
  461. } else {
  462. return *a1 == *a2;
  463. }
  464. }
  465. if (node1->isa<Parameter>() && node2->isa<Parameter>()) {
  466. auto para1 = node1->cast<ParameterPtr>();
  467. auto para2 = node2->cast<ParameterPtr>();
  468. if (para1->name() == para2->name()) {
  469. return true;
  470. }
  471. MS_LOG(DEBUG) << "two parameters are not equal.";
  472. return false;
  473. }
  474. MS_LOG(ERROR) << "type error";
  475. return false;
  476. }
  477. static bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
  478. NodeMapEquiv *const equiv_node) {
  479. MS_EXCEPTION_IF_NULL(node1);
  480. MS_EXCEPTION_IF_NULL(node2);
  481. if (node1->isa<CNode>() && node2->isa<CNode>()) {
  482. auto &inputs1 = node1->cast<CNodePtr>()->inputs();
  483. auto &inputs2 = node2->cast<CNodePtr>()->inputs();
  484. for (std::size_t i = 0; i < inputs1.size(); ++i) {
  485. if (!SameNodeShallow(inputs1[i], inputs2[i], equiv_func_graph, equiv_node)) {
  486. return false;
  487. }
  488. }
  489. return true;
  490. }
  491. return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node);
  492. }
  493. static bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equiv_func_graph,
  494. NodeMapEquiv *const equiv_node) {
  495. std::unordered_set<AnfNodePtr> done;
  496. std::stack<std::pair<AnfNodePtr, AnfNodePtr>> todo;
  497. todo.push(std::make_pair(root1, root2));
  498. while (todo.size() > 0) {
  499. AnfNodePtr node1 = todo.top().first;
  500. if (done.count(node1) > 0) {
  501. todo.pop();
  502. continue;
  503. }
  504. AnfNodePtr node2 = todo.top().second;
  505. bool condition = false;
  506. std::vector<AnfNodePtr> s1 = SuccIncoming(node1);
  507. std::vector<AnfNodePtr> s2 = SuccIncoming(node2);
  508. if (s1.size() != s2.size()) {
  509. return false;
  510. }
  511. for (std::size_t i = 0; i < s1.size(); ++i) {
  512. if (done.count(s1[i]) == 0) {
  513. todo.push(std::make_pair(s1[i], s2[i]));
  514. condition = true;
  515. }
  516. }
  517. if (condition) {
  518. continue;
  519. }
  520. (void)done.insert(node1);
  521. auto res = SameNode(node1, node2, equiv_func_graph, equiv_node);
  522. if (res) {
  523. (*equiv_node)[node1] = node2;
  524. } else {
  525. return false;
  526. }
  527. todo.pop();
  528. }
  529. return true;
  530. }
  531. bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv *equiv_func_graph,
  532. NodeMapEquiv *const equiv_node) {
  533. auto fg1_fg2 = std::make_pair(fg1, fg2);
  534. if (equiv_func_graph == nullptr) {
  535. MS_LOG(ERROR) << "equiv_func_graph not init";
  536. return false;
  537. }
  538. if (equiv_func_graph->find(fg1_fg2) != equiv_func_graph->end()) {
  539. return (*equiv_func_graph)[fg1_fg2] != kNotEquiv;
  540. }
  541. if (fg1 == nullptr || fg2 == nullptr) {
  542. MS_LOG(ERROR) << "Invalid function graph";
  543. return false;
  544. }
  545. if (fg1->parameters().size() != fg2->parameters().size()) {
  546. MS_LOG(DEBUG) << "parameters size not match";
  547. return false;
  548. }
  549. if (equiv_node != nullptr) {
  550. for (std::size_t i = 0; i < fg1->parameters().size(); ++i) {
  551. (*equiv_node)[fg1->parameters()[i]] = fg2->parameters()[i];
  552. }
  553. (*equiv_func_graph)[fg1_fg2] = kPending;
  554. auto result = SameSubgraph(fg1->get_return(), fg2->get_return(), equiv_func_graph, equiv_node);
  555. (*equiv_func_graph)[fg1_fg2] = EquivState(result);
  556. return result;
  557. }
  558. MS_LOG(ERROR) << "equiv_node not init";
  559. return false;
  560. }
  561. tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) {
  562. if (scalar == nullptr) {
  563. MS_EXCEPTION(ArgumentError) << "Nullptr Error!";
  564. }
  565. tensor::TensorPtr tensor = nullptr;
  566. if (scalar->isa<FloatImm>()) {
  567. tensor = std::make_shared<tensor::Tensor>(static_cast<double>(GetValue<float>(scalar)), kFloat32);
  568. } else if (scalar->isa<IntergerImm>()) {
  569. tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int>(scalar)), kInt32);
  570. } else if (scalar->isa<BoolImm>()) {
  571. const int64_t bool_value = GetValue<bool>(scalar) ? 1 : 0;
  572. tensor = std::make_shared<tensor::Tensor>(bool_value, kBool);
  573. } else {
  574. auto type = scalar->type();
  575. auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
  576. MS_LOG(EXCEPTION) << "Invalid scalar type: " << type_str;
  577. }
  578. MS_EXCEPTION_IF_NULL(tensor);
  579. return tensor;
  580. }
  581. } // namespace mindspore