You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils.cc 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/convert_utils.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <algorithm>
  21. #include <list>
  22. #include <utility>
  23. #include <cfloat>
  24. #include "pybind11/pybind11.h"
  25. #include "pipeline/static_analysis/abstract_value.h"
  26. #include "pipeline/parse/parse.h"
  27. #include "pipeline/parse/parse_base.h"
  28. #include "ir/value.h"
  29. #include "ir/tensor.h"
  30. #include "ir/param_value_py.h"
  31. #include "utils/base_ref_extends.h"
  32. namespace mindspore {
  33. py::object BuiltinsToPyData(const Any &value);
  34. py::object BuiltinsToPyData(const BaseRef &value);
  35. py::object VectorToPyData(const Any &value);
  36. py::object VectorRefToPyData(const VectorRef &value);
  37. py::object ValuePtrToPyData(const ValuePtr &value) {
  38. if (value == nullptr) {
  39. MS_LOG(EXCEPTION) << "value is null";
  40. }
  41. py::object ret;
  42. if (value->isa<Int32Imm>()) {
  43. MS_LOG(DEBUG) << "int";
  44. py::int_ v = value->cast<Int32ImmPtr>()->value();
  45. ret = v;
  46. } else if (value->isa<UInt64Imm>()) {
  47. MS_LOG(DEBUG) << "uint64";
  48. py::int_ v = value->cast<UInt64ImmPtr>()->value();
  49. ret = v;
  50. } else if (value->isa<BoolImm>()) {
  51. MS_LOG(DEBUG) << "bool";
  52. py::bool_ v = value->cast<BoolImmPtr>()->value();
  53. ret = v;
  54. } else if (value->isa<FP64Imm>()) {
  55. MS_LOG(DEBUG) << "double";
  56. py::float_ v = value->cast<FP64ImmPtr>()->value();
  57. ret = v;
  58. } else if (value->isa<FP32Imm>()) {
  59. MS_LOG(DEBUG) << "float";
  60. py::float_ v = value->cast<FP32ImmPtr>()->value();
  61. ret = v;
  62. } else if (value->isa<StringImm>()) {
  63. MS_LOG(DEBUG) << "String";
  64. py::str v = value->cast<StringImmPtr>()->value();
  65. ret = v;
  66. } else if (value->isa<tensor::Tensor>()) {
  67. MS_LOG(DEBUG) << "tensor";
  68. py::tuple v(1);
  69. v[0] = value->cast<tensor::TensorPtr>();
  70. ret = v[0];
  71. } else if (value->isa<tensor::MetaTensor>()) {
  72. MS_LOG(DEBUG) << "MetaTensor";
  73. py::tuple v(1);
  74. v[0] = value->cast<tensor::MetaTensorPtr>();
  75. ret = v[0];
  76. } else if (value->isa<RefKey>()) {
  77. MS_LOG(DEBUG) << "RefKey";
  78. py::tuple v(1);
  79. v[0] = value->cast<RefKeyPtr>();
  80. ret = v[0];
  81. } else if (value->isa<ValueTuple>()) {
  82. MS_LOG(DEBUG) << "tuple";
  83. auto value_tuple = value->cast<ValueTuplePtr>()->value();
  84. py::tuple rets(value_tuple.size());
  85. size_t i = 0;
  86. for (auto &v : value_tuple) {
  87. rets[i] = ValuePtrToPyData(v);
  88. i++;
  89. }
  90. ret = rets;
  91. } else if (value->isa<ValueList>()) {
  92. MS_LOG(DEBUG) << "list";
  93. auto value_list = value->cast<ValueListPtr>()->value();
  94. py::list rets(value_list.size());
  95. size_t i = 0;
  96. for (auto &v : value_list) {
  97. rets[i] = ValuePtrToPyData(v);
  98. i++;
  99. }
  100. ret = rets;
  101. } else if (value->isa<Ellipsis>()) {
  102. ret = py::ellipsis();
  103. } else if (value->isa<ValueSlice>()) {
  104. auto slice = value->cast<ValueSlicePtr>();
  105. auto start = ValuePtrToPyData(slice->start());
  106. auto end = ValuePtrToPyData(slice->stop());
  107. auto step = ValuePtrToPyData(slice->step());
  108. ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
  109. step);
  110. } else if (value->isa<Type>()) {
  111. py::tuple v(1);
  112. v[0] = value->cast<TypePtr>();
  113. ret = v[0];
  114. } else if (value->isa<AnyValue>()) {
  115. ret = py::none();
  116. } else if (value->isa<None>()) {
  117. ret = py::none();
  118. } else {
  119. MS_LOG(INFO) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
  120. }
  121. return ret;
  122. }
  123. py::object AnyToPyData(const Any &value) {
  124. py::object ret;
  125. MS_LOG(DEBUG) << "AnyToPyData " << value.GetString();
  126. if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) {
  127. ret = BuiltinsToPyData(value);
  128. } else if (value.is<ValuePtr>()) {
  129. MS_LOG(DEBUG) << "ValuePtr";
  130. ValuePtr v = value.cast<ValuePtr>();
  131. ret = ValuePtrToPyData(v);
  132. } else if (value.is<tensor::TensorPtr>()) {
  133. MS_LOG(DEBUG) << "tensor";
  134. py::tuple v(1);
  135. v[0] = value.cast<tensor::TensorPtr>();
  136. ret = v[0];
  137. } else if (value.is<py::object>()) {
  138. MS_LOG(DEBUG) << "py obj";
  139. ret = value.cast<py::object>();
  140. } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) {
  141. ret = VectorToPyData(value);
  142. } else if (value.is<std::list<Any>>()) {
  143. MS_LOG(DEBUG) << "list_any";
  144. auto value_list = value.cast<std::list<Any>>();
  145. py::list rets = py::list();
  146. for (auto &v : value_list) {
  147. rets.append(AnyToPyData(v));
  148. }
  149. ret = rets;
  150. } else if (value.is<std::vector<Any>>()) {
  151. auto value_list = value.cast<std::vector<Any>>();
  152. py::tuple rets(value_list.size());
  153. for (size_t i = 0; i < value_list.size(); i++) {
  154. rets[i] = AnyToPyData(value_list[i]);
  155. }
  156. ret = rets;
  157. } else if (value.is<TypePtr>()) {
  158. py::tuple v(1);
  159. v[0] = value.cast<TypePtr>();
  160. ret = v[0];
  161. } else {
  162. MS_LOG(EXCEPTION) << "value is not support type";
  163. }
  164. return ret;
  165. }
  166. py::object BaseRefToPyData(const BaseRef &value) {
  167. py::object ret;
  168. MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString();
  169. if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
  170. ret = BuiltinsToPyData(value);
  171. } else if (utils::isa<ValuePtr>(value)) {
  172. MS_LOG(DEBUG) << "ValuePtr";
  173. ValuePtr v = utils::cast<ValuePtr>(value);
  174. ret = ValuePtrToPyData(v);
  175. } else if (utils::isa<tensor::TensorPtr>(value)) {
  176. MS_LOG(DEBUG) << "tensor";
  177. py::tuple v(1);
  178. v[0] = utils::cast<tensor::TensorPtr>(value);
  179. ret = v[0];
  180. } else if (utils::isa<PyObjectRef>(value)) {
  181. MS_LOG(DEBUG) << "py obj";
  182. PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
  183. ret = py_ref.object_;
  184. } else if (utils::isa<VectorRef>(value)) {
  185. auto vec_ref = utils::cast<VectorRef>(value);
  186. ret = VectorRefToPyData(vec_ref);
  187. } else if (utils::isa<TypePtr>(value)) {
  188. py::tuple v(1);
  189. v[0] = utils::cast<TypePtr>(value);
  190. ret = v[0];
  191. } else {
  192. MS_LOG(EXCEPTION) << "value is not support type";
  193. }
  194. return ret;
  195. }
  196. bool ValueToBool(const ValuePtr &v, bool *value) {
  197. MS_EXCEPTION_IF_NULL(v);
  198. if (v->isa<BoolImm>()) {
  199. *value = v->cast<BoolImmPtr>()->value();
  200. } else if (v->isa<Int32Imm>()) {
  201. *value = v->cast<Int32ImmPtr>()->value() == 0 ? false : true;
  202. } else if (v->isa<UInt32Imm>()) {
  203. *value = v->cast<UInt32ImmPtr>()->value() == 0 ? false : true;
  204. } else if (v->isa<FP32Imm>()) {
  205. *value = v->cast<FP32ImmPtr>()->value() == 0 ? false : true;
  206. } else if (v->isa<FP64Imm>()) {
  207. *value = v->cast<FP64ImmPtr>()->value() == 0 ? false : true;
  208. } else if (v->isa<tensor::Tensor>()) {
  209. auto tensor = v->cast<tensor::TensorPtr>();
  210. MS_EXCEPTION_IF_NULL(tensor);
  211. (void)tensor->data_sync();
  212. bool *tensor_data = static_cast<bool *>(tensor->data_c());
  213. // maybe need to support if tensor is a bool array
  214. auto vb = tensor_data[0];
  215. *value = vb;
  216. } else {
  217. MS_LOG(WARNING) << "value is not supported to cast to be bool";
  218. return false;
  219. }
  220. return true;
  221. }
  222. bool BaseRefToBool(const BaseRef &v, bool *value) {
  223. if (utils::isa<ValuePtr>(v)) {
  224. return ValueToBool(utils::cast<ValuePtr>(v), value);
  225. } else if (utils::isa<bool>(v)) {
  226. auto vb = utils::cast<bool>(v);
  227. if (vb == true) {
  228. *value = true;
  229. } else {
  230. *value = false;
  231. }
  232. } else if (utils::isa<int>(v)) {
  233. auto vb = utils::cast<int>(v);
  234. if (vb == 0) {
  235. *value = false;
  236. } else {
  237. *value = true;
  238. }
  239. } else if (utils::isa<unsigned int>(v)) {
  240. auto vb = utils::cast<unsigned int>(v);
  241. if (vb == 0) {
  242. *value = false;
  243. } else {
  244. *value = true;
  245. }
  246. } else if (utils::isa<float>(v)) {
  247. auto vb = utils::cast<float>(v);
  248. if (vb >= -FLT_EPSILON && vb <= FLT_EPSILON) {
  249. *value = false;
  250. } else {
  251. *value = true;
  252. }
  253. } else if (utils::isa<double>(v)) {
  254. auto vb = utils::cast<double>(v);
  255. if (vb >= -DBL_EPSILON && vb <= DBL_EPSILON) {
  256. *value = false;
  257. } else {
  258. *value = true;
  259. }
  260. } else {
  261. MS_LOG(DEBUG) << "value is not supported to cast to be bool";
  262. return false;
  263. }
  264. return true;
  265. }
  266. py::object BuiltinsToPyData(const Any &value) {
  267. if (value.is<int>()) {
  268. MS_LOG(DEBUG) << "int";
  269. py::int_ ret = value.cast<int>();
  270. return std::move(ret);
  271. } else if (value.is<float>()) {
  272. MS_LOG(DEBUG) << "float";
  273. py::float_ ret = value.cast<float>();
  274. return std::move(ret);
  275. } else if (value.is<double>()) {
  276. MS_LOG(DEBUG) << "double";
  277. py::float_ ret = value.cast<double>();
  278. return std::move(ret);
  279. } else {
  280. MS_LOG(DEBUG) << "bool";
  281. py::bool_ ret = value.cast<bool>();
  282. return std::move(ret);
  283. }
  284. }
  285. py::object BuiltinsToPyData(const BaseRef &value) {
  286. if (utils::isa<int>(value)) {
  287. MS_LOG(DEBUG) << "int";
  288. py::int_ ret = utils::cast<int>(value);
  289. return std::move(ret);
  290. } else if (utils::isa<float>(value)) {
  291. MS_LOG(DEBUG) << "float";
  292. py::float_ ret = utils::cast<float>(value);
  293. return std::move(ret);
  294. } else if (utils::isa<double>(value)) {
  295. MS_LOG(DEBUG) << "double";
  296. py::float_ ret = utils::cast<double>(value);
  297. return std::move(ret);
  298. } else {
  299. MS_LOG(DEBUG) << "bool";
  300. py::bool_ ret = utils::cast<bool>(value);
  301. return std::move(ret);
  302. }
  303. }
  304. py::object VectorToPyData(const Any &value) {
  305. py::object ret;
  306. if (value.is<std::vector<tensor::TensorPtr>>()) {
  307. MS_LOG(DEBUG) << "vector_tensor";
  308. std::vector<tensor::TensorPtr> outputs;
  309. outputs = value.cast<std::vector<tensor::TensorPtr>>();
  310. py::tuple tensor_tuple(outputs.size());
  311. for (std::size_t i = 0; i < outputs.size(); ++i) {
  312. tensor_tuple[i] = *outputs[i];
  313. }
  314. ret = tensor_tuple;
  315. } else {
  316. MS_LOG(DEBUG) << "vector_any";
  317. auto value_list = value.cast<std::vector<Any>>();
  318. py::tuple any_tuple = py::tuple(value_list.size());
  319. size_t i = 0;
  320. for (auto &v : value_list) {
  321. any_tuple[i] = AnyToPyData(v);
  322. i++;
  323. }
  324. ret = any_tuple;
  325. }
  326. return ret;
  327. }
  328. py::object VectorRefToPyData(const VectorRef &value_list) {
  329. py::object ret;
  330. MS_LOG(DEBUG) << "vector_ref";
  331. size_t value_size = value_list.size();
  332. auto ref_tuple = py::tuple(value_size);
  333. for (size_t i = 0; i < value_size; i++) {
  334. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  335. }
  336. ret = ref_tuple;
  337. return ret;
  338. }
  339. AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj) {
  340. if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) &&
  341. py::hasattr(type_obj, PYTHON_DTYPE_FLAG)) {
  342. auto ret_vec = shape_obj.cast<std::vector<int>>();
  343. auto ret_dtype = type_obj.cast<TypePtr>();
  344. MS_EXCEPTION_IF_NULL(ret_dtype);
  345. // if the size of shape list is empty, return an scalar abstract
  346. if (ret_vec.empty() && (!ret_dtype->isa<TensorType>())) {
  347. abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
  348. return abs_scalar;
  349. }
  350. AbstractBasePtr tensor = nullptr;
  351. if (ret_dtype->isa<TensorType>()) {
  352. auto tensor_type = type_obj.cast<TensorTypePtr>();
  353. MS_EXCEPTION_IF_NULL(tensor_type);
  354. tensor = std::make_shared<abstract::AbstractTensor>(tensor_type->element(), ret_vec);
  355. } else {
  356. tensor = std::make_shared<abstract::AbstractTensor>(ret_dtype, ret_vec);
  357. }
  358. return tensor;
  359. } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
  360. py::tuple shape_tuple = shape_obj.cast<py::tuple>();
  361. py::tuple typeid_tuple = type_obj.cast<py::tuple>();
  362. AbstractBasePtrList ptr_list;
  363. for (size_t it = 0; it < shape_tuple.size(); ++it) {
  364. auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]);
  365. ptr_list.push_back(tensor_it);
  366. }
  367. auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
  368. return tuple;
  369. } else if (shape_obj.is_none() && type_obj.is_none()) {
  370. // AbstractNone indicates there is no output for this CNode node.
  371. auto abstract_none = std::make_shared<abstract::AbstractNone>();
  372. return abstract_none;
  373. } else {
  374. MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
  375. }
  376. }
  377. bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
  378. const std::shared_ptr<py::object> &ret_val) {
  379. if (output->isa<ValueNode>()) {
  380. MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
  381. ValuePtr value = GetValueNode(output);
  382. *ret_val = ValuePtrToPyData(value);
  383. return true;
  384. }
  385. // Adapter will transform values in __init__() and construct() to parameters, this could cause
  386. // inputs (a.k.a args in current function) size less than parameters'.
  387. if (output->isa<Parameter>()) {
  388. MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
  389. if (args.empty()) {
  390. MS_LOG(EXCEPTION) << "Inputs size is 0, let graph to be executed.";
  391. }
  392. // Find the right parameter as ret_val.
  393. auto func_graph = output->func_graph();
  394. MS_EXCEPTION_IF_NULL(func_graph);
  395. auto params = func_graph->parameters();
  396. if (params.empty()) {
  397. MS_EXCEPTION(UnknownError) << "Graph's parameters size is 0";
  398. }
  399. if ((args.size() + func_graph->hyper_param_count()) != params.size()) {
  400. MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count()
  401. << " not equal to graph input size " << params.size() << ", let graph to be executed.";
  402. }
  403. auto it = std::find(params.begin(), params.end(), output);
  404. if (it == params.end()) {
  405. MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
  406. }
  407. size_t index = it - params.cbegin();
  408. if (index >= args.size() + func_graph->hyper_param_count()) {
  409. MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
  410. << " add Parameter count " << func_graph->hyper_param_count() << ".";
  411. }
  412. if (index < args.size()) {
  413. *ret_val = args[index];
  414. } else {
  415. auto param = dyn_cast<Parameter>(params[index]);
  416. MS_EXCEPTION_IF_NULL(param);
  417. if (!param->has_default()) {
  418. MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
  419. }
  420. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param->default_param());
  421. *ret_val = param_value->value().attr("data");
  422. }
  423. return true;
  424. }
  425. return false;
  426. }
  427. // Isomorphism
  428. static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
  429. NodeMapEquiv *const equiv_node) {
  430. if (equiv_node == nullptr) {
  431. MS_LOG(ERROR) << "Invalid equiv_node";
  432. return false;
  433. }
  434. if (equiv_node->count(node1) > 0 && (*equiv_node)[node1] == node2) {
  435. return true;
  436. }
  437. if (IsValueNode<FuncGraph>(node1) && IsValueNode<FuncGraph>(node2)) {
  438. return Isomorphic(GetValueNode<FuncGraphPtr>(node1), GetValueNode<FuncGraphPtr>(node2), equiv_func_graph,
  439. equiv_node);
  440. }
  441. if (node1->isa<ValueNode>() && node2->isa<ValueNode>()) {
  442. auto a1 = GetValueNode(node1);
  443. auto a2 = GetValueNode(node2);
  444. if (a1->isa<Primitive>() && a2->isa<Primitive>()) {
  445. return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name();
  446. } else if (a1->isa<tensor::Tensor>() && a2->isa<tensor::Tensor>()) {
  447. return a1->cast<tensor::TensorPtr>()->ValueEqual(*(a2->cast<tensor::TensorPtr>()));
  448. } else {
  449. return *a1 == *a2;
  450. }
  451. }
  452. if (node1->isa<Parameter>() && node2->isa<Parameter>()) {
  453. auto para1 = node1->cast<ParameterPtr>();
  454. auto para2 = node2->cast<ParameterPtr>();
  455. if (para1->name() == para2->name()) {
  456. return true;
  457. }
  458. MS_LOG(DEBUG) << "two parameters are not equal.";
  459. return false;
  460. }
  461. MS_LOG(ERROR) << "type error";
  462. return false;
  463. }
  464. static bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
  465. NodeMapEquiv *const equiv_node) {
  466. MS_EXCEPTION_IF_NULL(node1);
  467. MS_EXCEPTION_IF_NULL(node2);
  468. if (node1->isa<CNode>() && node2->isa<CNode>()) {
  469. auto &inputs1 = node1->cast<CNodePtr>()->inputs();
  470. auto &inputs2 = node2->cast<CNodePtr>()->inputs();
  471. for (std::size_t i = 0; i < inputs1.size(); ++i) {
  472. if (!SameNodeShallow(inputs1[i], inputs2[i], equiv_func_graph, equiv_node)) {
  473. return false;
  474. }
  475. }
  476. return true;
  477. }
  478. return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node);
  479. }
  480. static bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equiv_func_graph,
  481. NodeMapEquiv *const equiv_node) {
  482. std::unordered_set<AnfNodePtr> done;
  483. std::stack<std::pair<AnfNodePtr, AnfNodePtr>> todo;
  484. todo.push(std::make_pair(root1, root2));
  485. while (todo.size() > 0) {
  486. AnfNodePtr node1 = todo.top().first;
  487. if (done.count(node1) > 0) {
  488. todo.pop();
  489. continue;
  490. }
  491. AnfNodePtr node2 = todo.top().second;
  492. bool condition = false;
  493. std::vector<AnfNodePtr> s1 = SuccIncoming(node1);
  494. std::vector<AnfNodePtr> s2 = SuccIncoming(node2);
  495. if (s1.size() != s2.size()) {
  496. return false;
  497. }
  498. for (std::size_t i = 0; i < s1.size(); ++i) {
  499. if (done.count(s1[i]) == 0) {
  500. todo.push(std::make_pair(s1[i], s2[i]));
  501. condition = true;
  502. }
  503. }
  504. if (condition) {
  505. continue;
  506. }
  507. (void)done.insert(node1);
  508. auto res = SameNode(node1, node2, equiv_func_graph, equiv_node);
  509. if (res) {
  510. (*equiv_node)[node1] = node2;
  511. } else {
  512. return false;
  513. }
  514. todo.pop();
  515. }
  516. return true;
  517. }
  518. bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv *equiv_func_graph,
  519. NodeMapEquiv *const equiv_node) {
  520. auto fg1_fg2 = std::make_pair(fg1, fg2);
  521. if (equiv_func_graph == nullptr) {
  522. MS_LOG(ERROR) << "equiv_func_graph not init";
  523. return false;
  524. }
  525. if (equiv_func_graph->find(fg1_fg2) != equiv_func_graph->end()) {
  526. return (*equiv_func_graph)[fg1_fg2] != kNotEquiv;
  527. }
  528. if (fg1 == nullptr || fg2 == nullptr) {
  529. MS_LOG(ERROR) << "Invalid function graph";
  530. return false;
  531. }
  532. if (fg1->parameters().size() != fg2->parameters().size()) {
  533. MS_LOG(DEBUG) << "parameters size not match";
  534. return false;
  535. }
  536. if (equiv_node != nullptr) {
  537. for (std::size_t i = 0; i < fg1->parameters().size(); ++i) {
  538. (*equiv_node)[fg1->parameters()[i]] = fg2->parameters()[i];
  539. }
  540. (*equiv_func_graph)[fg1_fg2] = kPending;
  541. auto result = SameSubgraph(fg1->get_return(), fg2->get_return(), equiv_func_graph, equiv_node);
  542. (*equiv_func_graph)[fg1_fg2] = EquivState(result);
  543. return result;
  544. }
  545. MS_LOG(ERROR) << "equiv_node not init";
  546. return false;
  547. }
  548. tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) {
  549. if (scalar == nullptr) {
  550. MS_EXCEPTION(ArgumentError) << "Nullptr Error!";
  551. }
  552. tensor::TensorPtr tensor = nullptr;
  553. if (scalar->isa<FloatImm>()) {
  554. tensor = std::make_shared<tensor::Tensor>(static_cast<double>(GetValue<float>(scalar)), kFloat32);
  555. } else if (scalar->isa<IntergerImm>()) {
  556. tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int>(scalar)), kInt32);
  557. } else if (scalar->isa<BoolImm>()) {
  558. const int64_t bool_value = GetValue<bool>(scalar) ? 1 : 0;
  559. tensor = std::make_shared<tensor::Tensor>(bool_value, kBool);
  560. } else {
  561. auto type = scalar->type();
  562. auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
  563. MS_LOG(EXCEPTION) << "Invalid scalar type: " << type_str;
  564. }
  565. MS_EXCEPTION_IF_NULL(tensor);
  566. return tensor;
  567. }
  568. } // namespace mindspore