You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils.cc 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/convert_utils.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <algorithm>
  21. #include <list>
  22. #include <utility>
  23. #include <cfloat>
  24. #include "pybind11/pybind11.h"
  25. #include "abstract/abstract_value.h"
  26. #include "pipeline/jit/parse/parse.h"
  27. #include "pipeline/jit/parse/parse_base.h"
  28. #include "ir/value.h"
  29. #include "ir/tensor.h"
  30. #include "ir/param_info.h"
  31. #include "utils/base_ref_extends.h"
  32. #include "utils/ms_context.h"
  33. namespace mindspore {
  34. py::object BuiltinsToPyData(const Any &value);
  35. py::object BuiltinsToPyData(const BaseRef &value);
  36. py::object VectorToPyData(const Any &value);
  37. py::object VectorRefToPyData(const VectorRef &value);
  38. py::object ValuePtrToPyData(const ValuePtr &value) {
  39. if (value == nullptr) {
  40. MS_LOG(EXCEPTION) << "value is null";
  41. }
  42. py::object ret;
  43. if (value->isa<Int8Imm>()) {
  44. MS_LOG(DEBUG) << "int8";
  45. py::int_ v = value->cast<Int8ImmPtr>()->value();
  46. ret = v;
  47. } else if (value->isa<Int16Imm>()) {
  48. MS_LOG(DEBUG) << "int16";
  49. py::int_ v = value->cast<Int16ImmPtr>()->value();
  50. ret = v;
  51. } else if (value->isa<Int32Imm>()) {
  52. MS_LOG(DEBUG) << "int32";
  53. py::int_ v = value->cast<Int32ImmPtr>()->value();
  54. ret = v;
  55. } else if (value->isa<Int64Imm>()) {
  56. MS_LOG(DEBUG) << "int64";
  57. py::int_ v = value->cast<Int64ImmPtr>()->value();
  58. ret = v;
  59. } else if (value->isa<UInt8Imm>()) {
  60. MS_LOG(DEBUG) << "uint8";
  61. py::int_ v = value->cast<UInt8ImmPtr>()->value();
  62. ret = v;
  63. } else if (value->isa<UInt16Imm>()) {
  64. MS_LOG(DEBUG) << "uint16";
  65. py::int_ v = value->cast<UInt16ImmPtr>()->value();
  66. ret = v;
  67. } else if (value->isa<UInt32Imm>()) {
  68. MS_LOG(DEBUG) << "uint32";
  69. py::int_ v = value->cast<UInt32ImmPtr>()->value();
  70. ret = v;
  71. } else if (value->isa<UInt64Imm>()) {
  72. MS_LOG(DEBUG) << "uint64";
  73. py::int_ v = value->cast<UInt64ImmPtr>()->value();
  74. ret = v;
  75. } else if (value->isa<BoolImm>()) {
  76. MS_LOG(DEBUG) << "bool";
  77. py::bool_ v = value->cast<BoolImmPtr>()->value();
  78. ret = v;
  79. } else if (value->isa<FP64Imm>()) {
  80. MS_LOG(DEBUG) << "double";
  81. py::float_ v = value->cast<FP64ImmPtr>()->value();
  82. ret = v;
  83. } else if (value->isa<FP32Imm>()) {
  84. MS_LOG(DEBUG) << "float";
  85. py::float_ v = value->cast<FP32ImmPtr>()->value();
  86. ret = v;
  87. } else if (value->isa<StringImm>()) {
  88. MS_LOG(DEBUG) << "String";
  89. py::str v = value->cast<StringImmPtr>()->value();
  90. ret = v;
  91. } else if (value->isa<tensor::Tensor>()) {
  92. MS_LOG(DEBUG) << "tensor";
  93. py::tuple v(1);
  94. v[0] = value->cast<tensor::TensorPtr>();
  95. ret = v[0];
  96. } else if (value->isa<tensor::MetaTensor>()) {
  97. MS_LOG(DEBUG) << "MetaTensor";
  98. py::tuple v(1);
  99. v[0] = value->cast<tensor::MetaTensorPtr>();
  100. ret = v[0];
  101. } else if (value->isa<RefKey>()) {
  102. MS_LOG(DEBUG) << "RefKey";
  103. py::tuple v(1);
  104. v[0] = value->cast<RefKeyPtr>();
  105. ret = v[0];
  106. } else if (value->isa<ValueTuple>()) {
  107. MS_LOG(DEBUG) << "tuple";
  108. auto value_tuple = value->cast<ValueTuplePtr>()->value();
  109. py::tuple rets(value_tuple.size());
  110. size_t i = 0;
  111. for (auto &v : value_tuple) {
  112. rets[i] = ValuePtrToPyData(v);
  113. i++;
  114. }
  115. ret = rets;
  116. } else if (value->isa<ValueList>()) {
  117. MS_LOG(DEBUG) << "list";
  118. auto value_list = value->cast<ValueListPtr>()->value();
  119. py::list rets(value_list.size());
  120. size_t i = 0;
  121. for (auto &v : value_list) {
  122. rets[i] = ValuePtrToPyData(v);
  123. i++;
  124. }
  125. ret = rets;
  126. } else if (value->isa<Ellipsis>()) {
  127. ret = py::ellipsis();
  128. } else if (value->isa<ValueSlice>()) {
  129. auto slice = value->cast<ValueSlicePtr>();
  130. auto start = ValuePtrToPyData(slice->start());
  131. auto end = ValuePtrToPyData(slice->stop());
  132. auto step = ValuePtrToPyData(slice->step());
  133. ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
  134. step);
  135. } else if (value->isa<Type>()) {
  136. py::tuple v(1);
  137. v[0] = value->cast<TypePtr>();
  138. ret = v[0];
  139. } else if (value->isa<AnyValue>()) {
  140. ret = py::none();
  141. } else if (value->isa<None>()) {
  142. ret = py::none();
  143. } else {
  144. MS_LOG(INFO) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
  145. }
  146. return ret;
  147. }
  148. py::object AnyToPyData(const Any &value) {
  149. py::object ret;
  150. MS_LOG(DEBUG) << "AnyToPyData " << value.GetString();
  151. if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) {
  152. ret = BuiltinsToPyData(value);
  153. } else if (value.is<ValuePtr>()) {
  154. MS_LOG(DEBUG) << "ValuePtr";
  155. ValuePtr v = value.cast<ValuePtr>();
  156. ret = ValuePtrToPyData(v);
  157. } else if (value.is<tensor::TensorPtr>()) {
  158. MS_LOG(DEBUG) << "tensor";
  159. py::tuple v(1);
  160. v[0] = value.cast<tensor::TensorPtr>();
  161. ret = v[0];
  162. } else if (value.is<py::object>()) {
  163. MS_LOG(DEBUG) << "py obj";
  164. ret = value.cast<py::object>();
  165. } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) {
  166. ret = VectorToPyData(value);
  167. } else if (value.is<std::list<Any>>()) {
  168. MS_LOG(DEBUG) << "list_any";
  169. auto value_list = value.cast<std::list<Any>>();
  170. py::list rets = py::list();
  171. for (auto &v : value_list) {
  172. rets.append(AnyToPyData(v));
  173. }
  174. ret = rets;
  175. } else if (value.is<std::vector<Any>>()) {
  176. auto value_list = value.cast<std::vector<Any>>();
  177. py::tuple rets(value_list.size());
  178. for (size_t i = 0; i < value_list.size(); i++) {
  179. rets[i] = AnyToPyData(value_list[i]);
  180. }
  181. ret = rets;
  182. } else if (value.is<TypePtr>()) {
  183. py::tuple v(1);
  184. v[0] = value.cast<TypePtr>();
  185. ret = v[0];
  186. } else {
  187. MS_LOG(EXCEPTION) << "value is not support type";
  188. }
  189. return ret;
  190. }
  191. py::object BaseRefToPyData(const BaseRef &value) {
  192. py::object ret;
  193. MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString();
  194. if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
  195. ret = BuiltinsToPyData(value);
  196. } else if (utils::isa<ValuePtr>(value)) {
  197. MS_LOG(DEBUG) << "ValuePtr";
  198. ValuePtr v = utils::cast<ValuePtr>(value);
  199. ret = ValuePtrToPyData(v);
  200. } else if (utils::isa<tensor::TensorPtr>(value)) {
  201. MS_LOG(DEBUG) << "tensor";
  202. py::tuple v(1);
  203. v[0] = utils::cast<tensor::TensorPtr>(value);
  204. ret = v[0];
  205. } else if (utils::isa<PyObjectRef>(value)) {
  206. MS_LOG(DEBUG) << "py obj";
  207. PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
  208. ret = py_ref.object_;
  209. } else if (utils::isa<VectorRef>(value)) {
  210. auto vec_ref = utils::cast<VectorRef>(value);
  211. ret = VectorRefToPyData(vec_ref);
  212. } else if (utils::isa<TypePtr>(value)) {
  213. py::tuple v(1);
  214. v[0] = utils::cast<TypePtr>(value);
  215. ret = v[0];
  216. } else {
  217. MS_LOG(EXCEPTION) << "value is not support type";
  218. }
  219. return ret;
  220. }
  221. bool ValueToBool(const ValuePtr &v, bool *value) {
  222. MS_EXCEPTION_IF_NULL(v);
  223. if (v->isa<BoolImm>()) {
  224. *value = v->cast<BoolImmPtr>()->value();
  225. } else if (v->isa<Int32Imm>()) {
  226. *value = v->cast<Int32ImmPtr>()->value() == 0 ? false : true;
  227. } else if (v->isa<UInt32Imm>()) {
  228. *value = v->cast<UInt32ImmPtr>()->value() == 0 ? false : true;
  229. } else if (v->isa<FP32Imm>()) {
  230. *value = v->cast<FP32ImmPtr>()->value() == 0 ? false : true;
  231. } else if (v->isa<FP64Imm>()) {
  232. *value = v->cast<FP64ImmPtr>()->value() == 0 ? false : true;
  233. } else if (v->isa<tensor::Tensor>()) {
  234. auto tensor = v->cast<tensor::TensorPtr>();
  235. MS_EXCEPTION_IF_NULL(tensor);
  236. (void)tensor->data_sync();
  237. bool *tensor_data = static_cast<bool *>(tensor->data_c());
  238. // maybe need to support if tensor is a bool array
  239. auto vb = tensor_data[0];
  240. *value = vb;
  241. } else {
  242. MS_LOG(WARNING) << "value is not supported to cast to be bool";
  243. return false;
  244. }
  245. return true;
  246. }
  247. bool BaseRefToInt(const ValuePtr &v, int *value) {
  248. MS_EXCEPTION_IF_NULL(v);
  249. if (v->isa<tensor::Tensor>()) {
  250. auto tensor = v->cast<tensor::TensorPtr>();
  251. (void)tensor->data_sync();
  252. int *tensor_data = static_cast<int *>(tensor->data_c());
  253. auto vb = tensor_data[0];
  254. *value = vb;
  255. return true;
  256. }
  257. MS_LOG(ERROR) << "Index must be tensor type.";
  258. return false;
  259. }
  260. bool BaseRefToBool(const BaseRef &v, bool *value) {
  261. if (utils::isa<ValuePtr>(v)) {
  262. return ValueToBool(utils::cast<ValuePtr>(v), value);
  263. } else if (utils::isa<bool>(v)) {
  264. auto vb = utils::cast<bool>(v);
  265. if (vb == true) {
  266. *value = true;
  267. } else {
  268. *value = false;
  269. }
  270. } else if (utils::isa<int>(v)) {
  271. auto vb = utils::cast<int>(v);
  272. if (vb == 0) {
  273. *value = false;
  274. } else {
  275. *value = true;
  276. }
  277. } else if (utils::isa<unsigned int>(v)) {
  278. auto vb = utils::cast<unsigned int>(v);
  279. if (vb == 0) {
  280. *value = false;
  281. } else {
  282. *value = true;
  283. }
  284. } else if (utils::isa<float>(v)) {
  285. auto vb = utils::cast<float>(v);
  286. if (vb >= -FLT_EPSILON && vb <= FLT_EPSILON) {
  287. *value = false;
  288. } else {
  289. *value = true;
  290. }
  291. } else if (utils::isa<double>(v)) {
  292. auto vb = utils::cast<double>(v);
  293. if (vb >= -DBL_EPSILON && vb <= DBL_EPSILON) {
  294. *value = false;
  295. } else {
  296. *value = true;
  297. }
  298. } else {
  299. MS_LOG(DEBUG) << "value is not supported to cast to be bool";
  300. return false;
  301. }
  302. return true;
  303. }
  304. py::object BuiltinsToPyData(const Any &value) {
  305. if (value.is<int>()) {
  306. MS_LOG(DEBUG) << "int";
  307. py::int_ ret = value.cast<int>();
  308. return std::move(ret);
  309. } else if (value.is<float>()) {
  310. MS_LOG(DEBUG) << "float";
  311. py::float_ ret = value.cast<float>();
  312. return std::move(ret);
  313. } else if (value.is<double>()) {
  314. MS_LOG(DEBUG) << "double";
  315. py::float_ ret = value.cast<double>();
  316. return std::move(ret);
  317. } else {
  318. MS_LOG(DEBUG) << "bool";
  319. py::bool_ ret = value.cast<bool>();
  320. return std::move(ret);
  321. }
  322. }
  323. py::object BuiltinsToPyData(const BaseRef &value) {
  324. if (utils::isa<int>(value)) {
  325. MS_LOG(DEBUG) << "int";
  326. py::int_ ret = utils::cast<int>(value);
  327. return std::move(ret);
  328. } else if (utils::isa<float>(value)) {
  329. MS_LOG(DEBUG) << "float";
  330. py::float_ ret = utils::cast<float>(value);
  331. return std::move(ret);
  332. } else if (utils::isa<double>(value)) {
  333. MS_LOG(DEBUG) << "double";
  334. py::float_ ret = utils::cast<double>(value);
  335. return std::move(ret);
  336. } else {
  337. MS_LOG(DEBUG) << "bool";
  338. py::bool_ ret = utils::cast<bool>(value);
  339. return std::move(ret);
  340. }
  341. }
  342. py::object VectorToPyData(const Any &value) {
  343. py::object ret;
  344. if (value.is<std::vector<tensor::TensorPtr>>()) {
  345. MS_LOG(DEBUG) << "vector_tensor";
  346. std::vector<tensor::TensorPtr> outputs;
  347. outputs = value.cast<std::vector<tensor::TensorPtr>>();
  348. py::tuple tensor_tuple(outputs.size());
  349. for (std::size_t i = 0; i < outputs.size(); ++i) {
  350. tensor_tuple[i] = *outputs[i];
  351. }
  352. ret = tensor_tuple;
  353. } else {
  354. MS_LOG(DEBUG) << "vector_any";
  355. auto value_list = value.cast<std::vector<Any>>();
  356. py::tuple any_tuple = py::tuple(value_list.size());
  357. size_t i = 0;
  358. for (auto &v : value_list) {
  359. any_tuple[i] = AnyToPyData(v);
  360. i++;
  361. }
  362. ret = any_tuple;
  363. }
  364. return ret;
  365. }
  366. py::object VectorRefToPyData(const VectorRef &value_list) {
  367. py::object ret;
  368. MS_LOG(DEBUG) << "vector_ref";
  369. size_t value_size = value_list.size();
  370. auto ref_tuple = py::tuple(value_size);
  371. for (size_t i = 0; i < value_size; i++) {
  372. ref_tuple[i] = BaseRefToPyData(value_list[i]);
  373. }
  374. ret = ref_tuple;
  375. return ret;
  376. }
  377. AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
  378. const py::object &min_shape, const py::object &max_shape) {
  379. if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) {
  380. auto ret_vec = shape_obj.cast<std::vector<int>>();
  381. auto ret_dtype = type_obj.cast<TypePtr>();
  382. MS_EXCEPTION_IF_NULL(ret_dtype);
  383. // if the size of shape list is empty, return an scalar abstract
  384. if (ret_vec.empty() && (!ret_dtype->isa<TensorType>())) {
  385. abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
  386. return abs_scalar;
  387. }
  388. AbstractBasePtr tensor = nullptr;
  389. std::vector<int> min_shape_vec;
  390. std::vector<int> max_shape_vec;
  391. if (!min_shape.is_none()) {
  392. min_shape_vec = min_shape.cast<std::vector<int>>();
  393. }
  394. if (!max_shape.is_none()) {
  395. max_shape_vec = max_shape.cast<std::vector<int>>();
  396. }
  397. auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
  398. if (ret_dtype->isa<TensorType>()) {
  399. auto tensor_type = type_obj.cast<TensorTypePtr>();
  400. MS_EXCEPTION_IF_NULL(tensor_type);
  401. auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element());
  402. tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
  403. } else {
  404. auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
  405. tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
  406. }
  407. return tensor;
  408. } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
  409. py::tuple shape_tuple = shape_obj.cast<py::tuple>();
  410. py::tuple typeid_tuple = type_obj.cast<py::tuple>();
  411. AbstractBasePtrList ptr_list;
  412. for (size_t it = 0; it < shape_tuple.size(); ++it) {
  413. auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]);
  414. ptr_list.push_back(tensor_it);
  415. }
  416. auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
  417. return tuple;
  418. } else if (shape_obj.is_none() && type_obj.is_none()) {
  419. // AbstractNone indicates there is no output for this CNode node.
  420. auto abstract_none = std::make_shared<abstract::AbstractNone>();
  421. return abstract_none;
  422. } else {
  423. // When sparse enabled, the undetermined might be raised and eliminated in opt passes
  424. auto context = MsContext::GetInstance();
  425. MS_EXCEPTION_IF_NULL(context);
  426. bool enable_sparse = context->enable_sparse();
  427. if (enable_sparse) {
  428. return std::make_shared<abstract::AbstractUndetermined>();
  429. }
  430. MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
  431. }
  432. }
  433. bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
  434. const std::shared_ptr<py::object> &ret_val) {
  435. if (output->isa<ValueNode>()) {
  436. MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
  437. ValuePtr value = GetValueNode(output);
  438. *ret_val = ValuePtrToPyData(value);
  439. return true;
  440. }
  441. // Adapter will transform values in __init__() and construct() to parameters, this could cause
  442. // inputs (a.k.a args in current function) size less than parameters'.
  443. if (output->isa<Parameter>()) {
  444. MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
  445. // Find the right parameter as ret_val.
  446. auto func_graph = output->func_graph();
  447. MS_EXCEPTION_IF_NULL(func_graph);
  448. auto params = func_graph->parameters();
  449. if ((args.size() + func_graph->hyper_param_count()) != params.size()) {
  450. MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count()
  451. << " not equal to graph input size " << params.size() << ", let graph to be executed.";
  452. }
  453. auto it = std::find(params.begin(), params.end(), output);
  454. if (it == params.end()) {
  455. MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
  456. }
  457. size_t index = it - params.cbegin();
  458. if (index >= args.size() + func_graph->hyper_param_count()) {
  459. MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
  460. << " add Parameter count " << func_graph->hyper_param_count() << ".";
  461. }
  462. if (index < args.size()) {
  463. *ret_val = args[index];
  464. } else {
  465. auto param = dyn_cast<Parameter>(params[index]);
  466. MS_EXCEPTION_IF_NULL(param);
  467. if (!param->has_default()) {
  468. MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
  469. }
  470. auto tensor = param->default_param();
  471. *ret_val = py::cast(tensor);
  472. }
  473. return true;
  474. }
  475. return false;
  476. }
  477. namespace {
  478. // Isomorphism
  479. bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
  480. NodeMapEquiv *const equiv_node);
  481. bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
  482. NodeMapEquiv *const equiv_node) {
  483. if (equiv_node == nullptr) {
  484. MS_LOG(ERROR) << "Invalid equiv_node";
  485. return false;
  486. }
  487. if (equiv_node->count(node1) > 0 && (*equiv_node)[node1] == node2) {
  488. return true;
  489. }
  490. if (IsValueNode<FuncGraph>(node1) && IsValueNode<FuncGraph>(node2)) {
  491. return Isomorphic(GetValueNode<FuncGraphPtr>(node1), GetValueNode<FuncGraphPtr>(node2), equiv_func_graph,
  492. equiv_node);
  493. }
  494. if (node1->isa<ValueNode>() && node2->isa<ValueNode>()) {
  495. auto a1 = GetValueNode(node1);
  496. auto a2 = GetValueNode(node2);
  497. if (a1->isa<Primitive>() && a2->isa<Primitive>()) {
  498. return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name();
  499. } else if (a1->isa<tensor::Tensor>() && a2->isa<tensor::Tensor>()) {
  500. return a1->cast<tensor::TensorPtr>()->ValueEqual(*(a2->cast<tensor::TensorPtr>()));
  501. } else {
  502. return *a1 == *a2;
  503. }
  504. }
  505. if (node1->isa<Parameter>() && node2->isa<Parameter>()) {
  506. auto para1 = node1->cast<ParameterPtr>();
  507. auto para2 = node2->cast<ParameterPtr>();
  508. if (para1->name() == para2->name()) {
  509. return true;
  510. }
  511. MS_LOG(DEBUG) << "two parameters are not equal.";
  512. return false;
  513. }
  514. if (node1->isa<CNode>() && node2->isa<CNode>()) {
  515. return SameNode(node1, node2, equiv_func_graph, equiv_node);
  516. }
  517. MS_LOG(ERROR) << "type error";
  518. return false;
  519. }
  520. bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
  521. NodeMapEquiv *const equiv_node) {
  522. MS_EXCEPTION_IF_NULL(node1);
  523. MS_EXCEPTION_IF_NULL(node2);
  524. if (node1->isa<CNode>() && node2->isa<CNode>()) {
  525. auto &inputs1 = node1->cast<CNodePtr>()->inputs();
  526. auto &inputs2 = node2->cast<CNodePtr>()->inputs();
  527. for (std::size_t i = 0; i < inputs1.size(); ++i) {
  528. if (!SameNodeShallow(inputs1[i], inputs2[i], equiv_func_graph, equiv_node)) {
  529. return false;
  530. }
  531. }
  532. return true;
  533. }
  534. return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node);
  535. }
  536. bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equiv_func_graph,
  537. NodeMapEquiv *const equiv_node) {
  538. std::unordered_set<AnfNodePtr> done;
  539. std::stack<std::pair<AnfNodePtr, AnfNodePtr>> todo;
  540. todo.push(std::make_pair(root1, root2));
  541. while (todo.size() > 0) {
  542. AnfNodePtr node1 = todo.top().first;
  543. if (done.count(node1) > 0) {
  544. todo.pop();
  545. continue;
  546. }
  547. AnfNodePtr node2 = todo.top().second;
  548. bool condition = false;
  549. std::vector<AnfNodePtr> s1 = SuccIncoming(node1);
  550. std::vector<AnfNodePtr> s2 = SuccIncoming(node2);
  551. if (s1.size() != s2.size()) {
  552. return false;
  553. }
  554. for (std::size_t i = 0; i < s1.size(); ++i) {
  555. if (done.count(s1[i]) == 0) {
  556. todo.push(std::make_pair(s1[i], s2[i]));
  557. condition = true;
  558. }
  559. }
  560. if (condition) {
  561. continue;
  562. }
  563. (void)done.insert(node1);
  564. auto res = SameNode(node1, node2, equiv_func_graph, equiv_node);
  565. if (res) {
  566. (*equiv_node)[node1] = node2;
  567. } else {
  568. return false;
  569. }
  570. todo.pop();
  571. }
  572. return true;
  573. }
  574. } // namespace
  575. bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv *equiv_func_graph,
  576. NodeMapEquiv *const equiv_node) {
  577. auto fg1_fg2 = std::make_pair(fg1, fg2);
  578. if (equiv_func_graph == nullptr) {
  579. MS_LOG(ERROR) << "equiv_func_graph not init";
  580. return false;
  581. }
  582. if (equiv_func_graph->find(fg1_fg2) != equiv_func_graph->end()) {
  583. return (*equiv_func_graph)[fg1_fg2] != kNotEquiv;
  584. }
  585. if (fg1 == nullptr || fg2 == nullptr) {
  586. MS_LOG(ERROR) << "Invalid function graph";
  587. return false;
  588. }
  589. if (fg1->parameters().size() != fg2->parameters().size()) {
  590. MS_LOG(DEBUG) << "parameters size not match";
  591. return false;
  592. }
  593. if (equiv_node != nullptr) {
  594. for (std::size_t i = 0; i < fg1->parameters().size(); ++i) {
  595. (*equiv_node)[fg1->parameters()[i]] = fg2->parameters()[i];
  596. }
  597. (*equiv_func_graph)[fg1_fg2] = kPending;
  598. auto result = SameSubgraph(fg1->get_return(), fg2->get_return(), equiv_func_graph, equiv_node);
  599. (*equiv_func_graph)[fg1_fg2] = EquivState(result);
  600. return result;
  601. }
  602. MS_LOG(ERROR) << "equiv_node not init";
  603. return false;
  604. }
  605. tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) {
  606. if (scalar == nullptr) {
  607. MS_EXCEPTION(ArgumentError) << "Nullptr Error!";
  608. }
  609. tensor::TensorPtr tensor = nullptr;
  610. if (scalar->isa<FloatImm>()) {
  611. tensor = std::make_shared<tensor::Tensor>(static_cast<double>(GetValue<float>(scalar)), kFloat32);
  612. } else if (scalar->isa<IntergerImm>()) {
  613. tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int>(scalar)), kInt32);
  614. } else if (scalar->isa<BoolImm>()) {
  615. const int64_t bool_value = GetValue<bool>(scalar) ? 1 : 0;
  616. tensor = std::make_shared<tensor::Tensor>(bool_value, kBool);
  617. } else {
  618. auto type = scalar->type();
  619. auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
  620. MS_LOG(EXCEPTION) << "Invalid scalar type: " << type_str;
  621. }
  622. MS_EXCEPTION_IF_NULL(tensor);
  623. return tensor;
  624. }
  625. void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors) {
  626. MS_EXCEPTION_IF_NULL(value);
  627. MS_EXCEPTION_IF_NULL(tensors);
  628. if (value->isa<ValueTuple>()) {
  629. auto value_tuple = value->cast<ValueTuplePtr>();
  630. MS_EXCEPTION_IF_NULL(value_tuple);
  631. for (size_t i = 0; i < value_tuple->size(); ++i) {
  632. ValuePtr element = value_tuple->value()[i];
  633. if (element->isa<tensor::Tensor>()) {
  634. auto tensor = element->cast<tensor::TensorPtr>();
  635. MS_EXCEPTION_IF_NULL(tensor);
  636. tensors->push_back(tensor);
  637. }
  638. }
  639. } else if (value->isa<tensor::Tensor>()) {
  640. tensor::TensorPtr tensor = value->cast<tensor::TensorPtr>();
  641. MS_EXCEPTION_IF_NULL(tensor);
  642. tensors->push_back(tensor);
  643. }
  644. }
  645. } // namespace mindspore