You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

arithmetic_simplify.cc 28 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. /**
  2. * Copyright 2020-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/graph_kernel/arithmetic_simplify.h"
  17. #include <algorithm>
  18. #include <list>
  19. #include <string>
  20. #include <functional>
  21. #include <set>
  22. #include <vector>
  23. #include "utils/hash_map.h"
  24. #include "utils/hash_set.h"
  25. #include "common/graph_kernel/graph_kernel_helper.h"
  26. #include "common/graph_kernel/core/graph_builder.h"
  27. #include "common/graph_kernel/core/graph_kernel_utils.h"
  28. #include "backend/common/session/anf_runtime_algorithm.h"
  29. #include "include/common/utils/anfalgo.h"
  30. #include "ir/anf.h"
  31. #include "include/common/utils/context/graph_kernel_flags.h"
  32. namespace mindspore::graphkernel {
  33. // operator which follows commutative rules
  34. static mindspore::HashSet<std::string> commutative_ops{"Add", "Mul"};
  35. class PatternNode;
  36. using PatternNodePtr = std::shared_ptr<PatternNode>;
  37. using PatternNodePtrList = std::vector<PatternNodePtr>;
  38. class PatternNode {
  39. public:
  40. explicit PatternNode(const std::string &op) : op_(op) {}
  41. ~PatternNode() = default;
  42. std::string op() const { return op_; }
  43. std::vector<PatternNodePtr> inputs() const { return inputs_; }
  44. void AddInput(const PatternNodePtr &input) { inputs_.push_back(input); }
  45. private:
  46. std::string op_ = ""; // ex. "Add","const1","A","0.5" (any op, const or parameter)
  47. std::vector<PatternNodePtr> inputs_;
  48. };
  49. using ParaMap = mindspore::HashMap<char, inner::NodePtr>;
  50. using ConstMap = mindspore::HashMap<std::string, inner::NodePtr>;
  51. /* This class works to store a kind of pattern tree; it needs a string expression to construct;
  52. Ex."Pow(Exp(A),B)=Exp(Mul(A,B))"
  53. then the left tree is
  54. A A B
  55. \ \ /
  56. Exp B Mul
  57. \ / \
  58. left tree: Pow right tree: Exp
  59. lhs_root_ is Pow ;lhs_root_ is Exp */
  60. class PatternTree {
  61. public:
  62. // pattern_str->ex."Pow(Exp(A),B)=Exp(Mul(A,B))"
  63. explicit PatternTree(const std::string &pattern_str) { BuildTree(pattern_str); }
  64. virtual ~PatternTree() = default;
  65. PatternNodePtr lhs_root() { return lhs_root_; }
  66. PatternNodePtr rhs_root() { return rhs_root_; }
  67. std::string GetRootOp() const { return lhs_root_ == nullptr ? "" : lhs_root_->op(); }
  68. // build tree with expression string
  69. PatternNodePtr BuildTree(const std::string &pattern_str);
  70. // traverse pattern tree, return order is topological order
  71. void DfsTraverse(const std::shared_ptr<PatternNodePtrList> &res, const PatternNodePtr &cur) const;
  72. // leverage pattern tree node and lite node's mapping relation to build lite node graph from pattern tree's right
  73. // side
  74. inner::NodePtr AlterGraph(const std::shared_ptr<ParaMap> &para_to_ref, const std::shared_ptr<ConstMap> &const_to_ref,
  75. const inner::NodePtr &origin_root);
  76. // invoke DfsMatchGraph
  77. inner::NodePtrList MatchGraph(const inner::NodePtr &root, const std::shared_ptr<ParaMap> &para_to_ref,
  78. const std::shared_ptr<ConstMap> &const_to_ref);
  79. protected:
  80. // set attributes for certain pattern node if needed;
  81. virtual mindspore::HashMap<PatternNodePtr, inner::DAttrs> SetAttributes(const inner::NodePtr &) {
  82. auto right_pattern = std::make_shared<PatternNodePtrList>();
  83. DfsTraverse(right_pattern, rhs_root_);
  84. mindspore::HashMap<PatternNodePtr, inner::DAttrs> attrs_map;
  85. for (auto &i : (*right_pattern)) {
  86. attrs_map[i] = {};
  87. }
  88. return attrs_map;
  89. }
  90. // check attributes meet requirements for certain pattern node if needed;
  91. virtual bool CheckAttributes(const inner::NodePtr &) const { return true; }
  92. private:
  93. PatternNodePtr lhs_root_ = nullptr; // left side's root
  94. PatternNodePtr rhs_root_ = nullptr; // right side's root
  95. };
  96. std::string CutStr(const string &s, size_t start_pos = 0, size_t len = std::string::npos) {
  97. std::string new_str = "";
  98. if (start_pos >= s.length()) {
  99. MS_LOG(EXCEPTION) << "Start index " << start_pos << " is out of range [0, " << s.length() << ") in string: " << s;
  100. }
  101. for (size_t i = 0; i < len; i++) {
  102. if (start_pos + i >= s.length()) break;
  103. new_str += s[start_pos + i];
  104. }
  105. return new_str;
  106. }
  107. bool StartWith(const std::string &s, const std::string &prefix) {
  108. if (s.length() < prefix.length()) return false;
  109. return s.find(prefix) == 0;
  110. }
  111. // build pattern tree ;left side's root is lhs_root_ ; right side's root is rhs_root_
  112. PatternNodePtr PatternTree::BuildTree(const std::string &pattern_str) {
  113. size_t pos = pattern_str.find("=");
  114. if (pos != std::string::npos) {
  115. auto left_expression = CutStr(pattern_str, 0, pos);
  116. lhs_root_ = BuildTree(left_expression);
  117. auto right_expression = CutStr(pattern_str, pos + 1);
  118. rhs_root_ = BuildTree(right_expression);
  119. } else {
  120. size_t p_start = pattern_str.find("(");
  121. if (p_start != std::string::npos) {
  122. size_t p_end = pattern_str.rfind(")");
  123. auto op_name = CutStr(pattern_str, 0, p_start);
  124. auto op_inputs = CutStr(pattern_str, p_start + 1, p_end - p_start - 1);
  125. PatternNodePtr cur_node = std::make_shared<PatternNode>(op_name);
  126. int tmp = 0;
  127. size_t comma = 0;
  128. while (comma < op_inputs.length()) {
  129. if (op_inputs[comma] == '(') {
  130. tmp++;
  131. }
  132. if (op_inputs[comma] == ')') {
  133. tmp--;
  134. }
  135. if (op_inputs[comma] == ',' && tmp == 0) {
  136. auto first_half = CutStr(op_inputs, 0, comma);
  137. cur_node->AddInput(BuildTree(first_half));
  138. auto second_half = CutStr(op_inputs, comma + 1);
  139. op_inputs = second_half;
  140. comma = 0;
  141. } else {
  142. comma++;
  143. }
  144. }
  145. cur_node->AddInput(BuildTree(op_inputs));
  146. return cur_node;
  147. } else {
  148. return std::make_shared<PatternNode>(pattern_str);
  149. }
  150. }
  151. return nullptr;
  152. }
  153. inner::NType PatternNodeType(const std::string &n) {
  154. // return (Primitive, Parameter or Value)
  155. if (n.length() > 0 && '0' <= n[n.length() - 1] && n[n.length() - 1] <= '9') {
  156. return inner::NType::Value;
  157. } else if (n.length() == 1 && 'A' <= n[0] && n[0] <= 'Z') {
  158. return inner::NType::Parameter;
  159. } else {
  160. return inner::NType::Primitive;
  161. }
  162. }
  163. std::string CleanStr(const std::string &s) {
  164. std::string res = "";
  165. std::for_each(s.begin(), s.end(), [&res](const char &c) {
  166. if (c != '[' && c != ']' && c != ' ') {
  167. res += c;
  168. }
  169. });
  170. return res;
  171. }
  172. bool CheckCurNode(const inner::NodePtr &tmp_node, const std::string &tmp_pattern_op,
  173. const std::shared_ptr<ParaMap> &para_to_ref, const std::shared_ptr<ConstMap> &const_to_ref) {
  174. // put lite graph node's mapping to pattern node into "para_to_ref" and "const_to_ref"
  175. switch (PatternNodeType(tmp_pattern_op)) {
  176. case inner::NType::Parameter: {
  177. if (para_to_ref->find(tmp_pattern_op[0]) != para_to_ref->end()) {
  178. if ((*para_to_ref)[tmp_pattern_op[0]] != tmp_node) {
  179. return false;
  180. }
  181. } else {
  182. (*para_to_ref)[tmp_pattern_op[0]] = tmp_node;
  183. }
  184. break;
  185. }
  186. case inner::NType::Value: {
  187. if (tmp_node->NodeType() != inner::NType::Value) {
  188. return false;
  189. }
  190. auto node_value_str = std::static_pointer_cast<inner::ConstTensorNode>(tmp_node)->ToString();
  191. double node_value = std::stod(CleanStr(node_value_str));
  192. if (StartWith(tmp_pattern_op, "const")) {
  193. if (const_to_ref->find(tmp_pattern_op) != const_to_ref->end()) {
  194. auto pattern_value_str =
  195. std::static_pointer_cast<inner::ConstTensorNode>((*const_to_ref)[tmp_pattern_op])->ToString();
  196. double pattern_value = std::stod(CleanStr(pattern_value_str));
  197. if (pattern_value != node_value) return false;
  198. } else {
  199. (*const_to_ref)[tmp_pattern_op] = tmp_node;
  200. }
  201. } else {
  202. double pattern_value = std::stod(tmp_pattern_op);
  203. if (pattern_value != node_value) {
  204. return false;
  205. }
  206. }
  207. break;
  208. }
  209. case inner::NType::Primitive: {
  210. if (tmp_node->NodeType() != inner::NType::Primitive ||
  211. std::static_pointer_cast<inner::PrimOp>(tmp_node)->op() != tmp_pattern_op) {
  212. return false;
  213. }
  214. break;
  215. }
  216. default:
  217. break;
  218. }
  219. return true;
  220. }
  221. // recursion for thr match of lite node graph and pattern tree's left side, store the mapping of pattern tree node to
  222. // lite graph
  223. bool DfsMatchGraph(const inner::NodePtr &tmp_node, const PatternNodePtr &tmp_pattern,
  224. const std::shared_ptr<ParaMap> &para_to_ref, const std::shared_ptr<ConstMap> &const_to_ref,
  225. const std::shared_ptr<inner::NodePtrList> &res) {
  226. std::string tmp_pattern_op = tmp_pattern->op();
  227. if (!CheckCurNode(tmp_node, tmp_pattern_op, para_to_ref, const_to_ref)) {
  228. return false;
  229. }
  230. std::vector<PatternNodePtr> tmp_pattern_inputs = tmp_pattern->inputs();
  231. auto tmp_node_inputs = tmp_node->inputs();
  232. // check if a node meets requiremnet,and DFS check its inputs
  233. if (tmp_pattern_inputs.size() != 0 && tmp_node_inputs.size() != tmp_pattern_inputs.size()) {
  234. return false;
  235. }
  236. if (PatternNodeType(tmp_pattern_op) == inner::NType::Primitive) {
  237. // exchange inputs for the node who meets commutative rules
  238. if (commutative_ops.find(tmp_pattern_op) != commutative_ops.end()) {
  239. ParaMap para_to_ref_copy = *para_to_ref;
  240. ConstMap const_to_ref_copy = *const_to_ref;
  241. bool first_match = DfsMatchGraph(tmp_node_inputs[0], tmp_pattern_inputs[0], para_to_ref, const_to_ref, res) &&
  242. DfsMatchGraph(tmp_node_inputs[1], tmp_pattern_inputs[1], para_to_ref, const_to_ref, res);
  243. if (!first_match) {
  244. res->clear();
  245. para_to_ref->clear();
  246. const_to_ref->clear();
  247. for (auto &i : para_to_ref_copy) {
  248. (*para_to_ref)[i.first] = i.second;
  249. }
  250. for (auto &i : const_to_ref_copy) {
  251. (*const_to_ref)[i.first] = i.second;
  252. }
  253. bool second_match = DfsMatchGraph(tmp_node_inputs[0], tmp_pattern_inputs[1], para_to_ref, const_to_ref, res) &&
  254. DfsMatchGraph(tmp_node_inputs[1], tmp_pattern_inputs[0], para_to_ref, const_to_ref, res);
  255. if (!second_match) {
  256. return false;
  257. }
  258. }
  259. } else {
  260. for (size_t i = 0; i < tmp_pattern_inputs.size(); i++) {
  261. if (!DfsMatchGraph(tmp_node_inputs[i], tmp_pattern_inputs[i], para_to_ref, const_to_ref, res)) {
  262. return false;
  263. }
  264. }
  265. }
  266. res->push_back(tmp_node);
  267. }
  268. return true;
  269. }
  270. // traverse pattern tree and return topological order
  271. void PatternTree::DfsTraverse(const std::shared_ptr<PatternNodePtrList> &res, const PatternNodePtr &cur) const {
  272. if (cur == nullptr) {
  273. return;
  274. }
  275. for (auto &p : cur->inputs()) {
  276. if (PatternNodeType(p->op()) == inner::NType::Primitive) {
  277. DfsTraverse(res, p);
  278. }
  279. }
  280. res->push_back(cur);
  281. }
  282. // invoke DfsMatchGraph
  283. inner::NodePtrList PatternTree::MatchGraph(const inner::NodePtr &root, const std::shared_ptr<ParaMap> &para_to_ref,
  284. const std::shared_ptr<ConstMap> &const_to_ref) {
  285. auto res = std::make_shared<inner::NodePtrList>();
  286. if (!DfsMatchGraph(root, lhs_root_, para_to_ref, const_to_ref, res)) {
  287. return {};
  288. }
  289. if (CheckAttributes(root)) {
  290. return *res;
  291. }
  292. return {};
  293. }
  294. // leverage pattern tree node and lite node's mapping relation to build new lite node graph from pattern tree's right
  295. // side
  296. inner::NodePtr PatternTree::AlterGraph(const std::shared_ptr<ParaMap> &para_to_ref,
  297. const std::shared_ptr<ConstMap> &const_to_ref,
  298. const inner::NodePtr &origin_root) {
  299. auto res = std::make_shared<PatternNodePtrList>();
  300. DfsTraverse(res, rhs_root_);
  301. auto all_attrs = SetAttributes(origin_root);
  302. inner::LiteGraph::GraphBuilder gb("");
  303. mindspore::HashMap<PatternNodePtr, inner::NodePtr> pattern_to_ref;
  304. for (auto &n : (*res)) {
  305. if (PatternNodeType(n->op()) != inner::NType::Primitive) continue;
  306. inner::NodePtrList inputs;
  307. for (auto &i : n->inputs()) {
  308. if (PatternNodeType(i->op()) == inner::NType::Primitive) {
  309. inputs.push_back(pattern_to_ref[i]);
  310. } else if (PatternNodeType(i->op()) == inner::NType::Parameter) {
  311. inputs.push_back((*para_to_ref)[i->op()[0]]);
  312. } else {
  313. if (StartWith(i->op(), "const")) {
  314. inputs.push_back((*const_to_ref)[i->op()]);
  315. } else {
  316. tensor::TensorPtr data = std::make_shared<tensor::Tensor>(static_cast<double>(std::stof(i->op())));
  317. inputs.push_back(gb.Value(data));
  318. }
  319. }
  320. }
  321. auto p = gb.Emit(n->op(), inputs, all_attrs[n]);
  322. pattern_to_ref[n] = p;
  323. }
  324. auto &alter_graph = gb.Get()->ops();
  325. if (alter_graph.empty()) {
  326. if (PatternNodeType(rhs_root_->op()) == inner::NType::Parameter) {
  327. return (*para_to_ref)[rhs_root_->op()[0]];
  328. } else {
  329. if (StartWith(rhs_root_->op(), "const")) {
  330. return (*const_to_ref)[rhs_root_->op()];
  331. } else {
  332. tensor::TensorPtr data = std::make_shared<tensor::Tensor>(static_cast<double>(std::stof(rhs_root_->op())));
  333. return gb.Value(data);
  334. }
  335. }
  336. }
  337. return alter_graph.back();
  338. }
  339. // Reduce(Reduce(A)) = Reduce(A)
  340. class ExtraReduce1PatternTree : public PatternTree {
  341. public:
  342. explicit ExtraReduce1PatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
  343. ~ExtraReduce1PatternTree() = default;
  344. protected:
  345. bool CheckAttributes(const inner::NodePtr &origin_root) const override {
  346. return (GetValue<bool>((origin_root->inputs()[0])->attrs().find("keep_dims")->second) ==
  347. GetValue<bool>(origin_root->attrs().find("keep_dims")->second));
  348. }
  349. mindspore::HashMap<PatternNodePtr, inner::DAttrs> SetAttributes(const inner::NodePtr &origin_root) override {
  350. auto attrs_map = PatternTree::SetAttributes(origin_root);
  351. std::vector<int64_t> axis;
  352. std::set<int64_t> axis_set;
  353. auto first_reduce = origin_root->inputs()[0];
  354. bool keep_dims = GetValue<bool>(origin_root->attrs().find("keep_dims")->second);
  355. if (keep_dims) {
  356. for (auto &i : GetValue<std::vector<int64_t>>(origin_root->attrs().find("axis")->second)) {
  357. axis_set.insert(i);
  358. }
  359. for (auto &i : GetValue<std::vector<int64_t>>(first_reduce->attrs().find("axis")->second)) {
  360. axis_set.insert(i);
  361. }
  362. } else {
  363. auto first_axis = GetValue<std::vector<int64_t>>(first_reduce->attrs().find("axis")->second);
  364. auto second_axis = GetValue<std::vector<int64_t>>(origin_root->attrs().find("axis")->second);
  365. std::set<int64_t> st(first_axis.begin(), first_axis.end());
  366. mindspore::HashMap<int64_t, int64_t> mp;
  367. int64_t shift = 0;
  368. for (int64_t n = 0; n < SizeToLong(first_reduce->inputs()[0]->shape.size()); n++) {
  369. if (st.find(n) != st.end()) {
  370. shift++;
  371. } else {
  372. mp[n - shift] = n;
  373. }
  374. }
  375. std::for_each(first_axis.begin(), first_axis.end(), [&axis_set](auto &i) { axis_set.insert(i); });
  376. std::for_each(second_axis.begin(), second_axis.end(), [&axis_set, &mp](auto &i) { axis_set.insert(mp[i]); });
  377. }
  378. std::copy(axis_set.begin(), axis_set.end(), std::back_inserter(axis));
  379. attrs_map[this->rhs_root()] = {{"keep_dims", MakeValue(keep_dims)}, {"axis", MakeValue(axis)}};
  380. return attrs_map;
  381. }
  382. };
  383. // "ReduceSum(Neg(A))=Neg(ReduceSum(A))"
  384. class ExtraReduce2PatternTree : public PatternTree {
  385. public:
  386. explicit ExtraReduce2PatternTree(const std::string &pattern_str) : PatternTree(pattern_str) {}
  387. ~ExtraReduce2PatternTree() = default;
  388. protected:
  389. mindspore::HashMap<PatternNodePtr, inner::DAttrs> SetAttributes(const inner::NodePtr &origin_root) override {
  390. auto attrs_map = PatternTree::SetAttributes(origin_root);
  391. bool keep_dims = GetValue<bool>(origin_root->attrs().find("keep_dims")->second);
  392. auto axis = GetValue<std::vector<int64_t>>(origin_root->attrs().find("axis")->second);
  393. attrs_map[this->rhs_root()->inputs()[0]] = {{"keep_dims", MakeValue(keep_dims)}, {"axis", MakeValue(axis)}};
  394. return attrs_map;
  395. }
  396. };
  397. /* A
  398. /
  399. Neg
  400. / \
  401. Neg Mul
  402. Here we cannot transform Neg(Neg(A)) to A because Neg(A) is a input of Mul. OutsideRely is responsible for checking
  403. this case.
  404. */
  405. bool OutsideRely(const inner::NodePtrList &nodes, const inner::NodePtr &root) {
  406. mindspore::HashSet<inner::Node *> nodes_can_simplify;
  407. std::for_each(nodes.begin(), nodes.end(), [&nodes_can_simplify](auto n) { nodes_can_simplify.insert(n.get()); });
  408. for (auto &n : nodes) {
  409. if (n == root) {
  410. continue;
  411. }
  412. for (auto &usr : n->users()) {
  413. if (nodes_can_simplify.find(usr.first) == nodes_can_simplify.end()) {
  414. return true;
  415. }
  416. }
  417. }
  418. return false;
  419. }
  420. struct Expression {
  421. size_t id;
  422. std::string math_expr;
  423. std::function<PatternTreePtr(const std::string &)> func;
  424. };
  425. #define EXPR_PATTERN(cls) [](const std::string &expr) -> PatternTreePtr { return std::make_shared<cls>(expr); }
  426. static std::vector<Expression> expressions = {
  427. // add
  428. {1, "Add(A,0)=A", EXPR_PATTERN(PatternTree)},
  429. {2, "Add(Mul(A,C),Mul(A,B))=Mul(A,Add(B,C))", EXPR_PATTERN(PatternTree)},
  430. {3, "Add(Add(A,const1),const2)=Add(A,Add(const1,const2))", EXPR_PATTERN(PatternTree)},
  431. {4, "Add(A,Neg(A))=0", EXPR_PATTERN(PatternTree)},
  432. {5, "Add(Add(A,B),Neg(A))=B", EXPR_PATTERN(PatternTree)},
  433. {6, "Add(Add(A,B),Add(Neg(A),C))=Add(B,C)", EXPR_PATTERN(PatternTree)},
  434. // sub
  435. {7, "Sub(A,0)=A", EXPR_PATTERN(PatternTree)},
  436. {8, "Sub(A,const1)=Add(A,Neg(const1))", EXPR_PATTERN(PatternTree)},
  437. {9, "Sub(Mul(A,C),Mul(A,B))=Mul(A,Sub(B,C))", EXPR_PATTERN(PatternTree)},
  438. {10, "Sub(Mul(A,C),Mul(B,C))=Mul(Sub(A,B),C)", EXPR_PATTERN(PatternTree)},
  439. // log
  440. {11, "Log(Exp(A))=A", EXPR_PATTERN(PatternTree)},
  441. {12, "Log(Pow(A,B))=Mul(B,Log(Abs(A)))", EXPR_PATTERN(PatternTree)},
  442. {13, "Log(Sqrt(A))=Mul(0.5,Log(A))", EXPR_PATTERN(PatternTree)},
  443. {14, "Log(Rsqrt(A))=Mul(-0.5,Log(A))", EXPR_PATTERN(PatternTree)},
  444. // pow
  445. {15, "Pow(A,1)=A", EXPR_PATTERN(PatternTree)},
  446. {16, "Pow(Exp(A),B)=Exp(Mul(A,B))", EXPR_PATTERN(PatternTree)},
  447. {17, "Pow(A,2)=Mul(A,A)", EXPR_PATTERN(PatternTree)},
  448. {18, "Pow(A,-1)=Reciprocal(A)", EXPR_PATTERN(PatternTree)},
  449. // sqrt
  450. {19, "Sqrt(Mul(A,A))=Abs(A)", EXPR_PATTERN(PatternTree)},
  451. {20, "Rsqrt(Pow(A,-2))=Abs(A)", EXPR_PATTERN(PatternTree)},
  452. {21, "Rsqrt(RealDiv(1,A))=Sqrt(A)", EXPR_PATTERN(PatternTree)},
  453. {22, "Rsqrt(Reciprocal(A))=Sqrt(A)", EXPR_PATTERN(PatternTree)},
  454. // select
  455. {23, "Select(A,B,B)=B", EXPR_PATTERN(PatternTree)},
  456. // Neg
  457. {24, "Neg(Neg(A))=A", EXPR_PATTERN(PatternTree)},
  458. // mul
  459. {25, "Mul(Mul(A,const1),Mul(B,const2))=Mul(Mul(A,B),Mul(const1,const2))", EXPR_PATTERN(PatternTree)},
  460. {26, "Mul(Mul(A,const1),const2)=Mul(A,Mul(const1,const2))", EXPR_PATTERN(PatternTree)},
  461. {27, "Mul(Exp(A),Exp(B))=Exp(Add(A,B))", EXPR_PATTERN(PatternTree)},
  462. {28, "Mul(Mul(Exp(A),C),Exp(B))=Mul(Exp(Add(A,B)),C)", EXPR_PATTERN(PatternTree)},
  463. {29, "Mul(Mul(Exp(A),C),Mul(Exp(B),D))=Mul(Exp(Add(A,B)),Mul(C,D))", EXPR_PATTERN(PatternTree)},
  464. {30, "Mul(Sqrt(A),Sqrt(A))=A", EXPR_PATTERN(PatternTree)},
  465. {31, "Mul(Mul(A,Sqrt(B)),Mul(C,Sqrt(B)))=Mul(Mul(A,B),C)", EXPR_PATTERN(PatternTree)},
  466. {32, "Mul(Mul(A,Sqrt(B)),Sqrt(B))=Mul(A,B)", EXPR_PATTERN(PatternTree)},
  467. {33, "Mul(Sqrt(A),Sqrt(B))=Sqrt(Mul(A,B))", EXPR_PATTERN(PatternTree)},
  468. {34, "Mul(Rsqrt(A),Rsqrt(A))=Reciprocal(A)", EXPR_PATTERN(PatternTree)},
  469. {35, "Mul(Mul(A,Rsqrt(B)),Rsqrt(B))=RealDiv(A,B)", EXPR_PATTERN(PatternTree)},
  470. {36, "Mul(Mul(A,Rsqrt(B)),Mul(C,Rsqrt(B)))=RealDiv(Mul(A,C),B)", EXPR_PATTERN(PatternTree)},
  471. {37, "Mul(Rsqrt(A),Rsqrt(B))=Rsqrt(Mul(A,B))", EXPR_PATTERN(PatternTree)},
  472. {38, "Mul(A,Rsqrt(A))=Sqrt(A)", EXPR_PATTERN(PatternTree)},
  473. {39, "Mul(Abs(A),Abs(B))=Abs(Mul(A,B))", EXPR_PATTERN(PatternTree)},
  474. {40, "Mul(Mul(Abs(A),C),Abs(B))=Mul(Abs(Mul(A,B)),C)", EXPR_PATTERN(PatternTree)},
  475. {41, "Mul(Mul(Abs(A),C),Mul(Abs(B),D))=Mul(Abs(Mul(A,B)),Mul(C,D))", EXPR_PATTERN(PatternTree)},
  476. {42, "Mul(Neg(A),const1)=Mul(A,Neg(const1))", EXPR_PATTERN(PatternTree)},
  477. // realdiv
  478. {43, "RealDiv(A,1)=A", EXPR_PATTERN(PatternTree)},
  479. {44, "RealDiv(Exp(A),Exp(B))=Exp(Sub(A,B))", EXPR_PATTERN(PatternTree)},
  480. {45, "RealDiv(A,Exp(B))=Mul(A,Exp(Neg(B)))", EXPR_PATTERN(PatternTree)},
  481. {46, "RealDiv(A,Pow(B,const1))=Mul(A,Pow(B,Neg(const1)))", EXPR_PATTERN(PatternTree)},
  482. {47, "RealDiv(A,Sqrt(A))=Sqrt(A)", EXPR_PATTERN(PatternTree)},
  483. {48, "RealDiv(A,Sqrt(B))=Mul(A,Rsqrt(B))", EXPR_PATTERN(PatternTree)},
  484. {49, "RealDiv(A,Rsqrt(B))=Mul(A,Sqrt(B))", EXPR_PATTERN(PatternTree)},
  485. {50, "RealDiv(A,const1)=Mul(A,Reciprocal(const1))", EXPR_PATTERN(PatternTree)},
  486. {51, "RealDiv(RealDiv(A,B),RealDiv(C,D))=RealDiv(Mul(A,D),Mul(B,C))", EXPR_PATTERN(PatternTree)},
  487. {52, "RealDiv(Neg(A),const1)=RealDiv(A,Neg(const1))", EXPR_PATTERN(PatternTree)},
  488. {53, "RealDiv(RealDiv(A,B),C)=RealDiv(A,Mul(B,C))", EXPR_PATTERN(PatternTree)},
  489. {54, "RealDiv(A,RealDiv(B,C))=RealDiv(Mul(A,C),B)", EXPR_PATTERN(PatternTree)},
  490. // reduce1
  491. {55, "ReduceSum(ReduceSum(A))=ReduceSum(A)", EXPR_PATTERN(ExtraReduce1PatternTree)},
  492. {56, "ReduceMin(ReduceMin(A))=ReduceMin(A)", EXPR_PATTERN(ExtraReduce1PatternTree)},
  493. {57, "ReduceMax(ReduceMax(A))=ReduceMax(A)", EXPR_PATTERN(ExtraReduce1PatternTree)},
  494. // reduce2
  495. {58, "ReduceSum(Neg(A))=Neg(ReduceSum(A))", EXPR_PATTERN(ExtraReduce2PatternTree)},
  496. {59, "ReduceSum(RealDiv(A,const1))=RealDiv(ReduceSum(A),const1)", EXPR_PATTERN(ExtraReduce2PatternTree)},
  497. {60, "ReduceSum(Mul(A,const1))=Mul(ReduceSum(A),const1)", EXPR_PATTERN(ExtraReduce2PatternTree)},
  498. {61, "CReal(Complex(A,B))=A", EXPR_PATTERN(PatternTree)},
  499. {62, "CImag(Complex(A,B))=B", EXPR_PATTERN(PatternTree)},
  500. };
  501. mindspore::HashMap<std::string, std::vector<PatternTreePtr>> GetExpressions() {
  502. const auto &flags = GraphKernelFlags::GetInstance();
  503. mindspore::HashMap<std::string, std::vector<PatternTreePtr>> expression_map;
  504. mindspore::HashSet<std::string> enable_ids{flags.enable_simplify_exprs_only.begin(),
  505. flags.enable_simplify_exprs_only.end()};
  506. mindspore::HashSet<std::string> disable_ids{flags.disable_simplify_exprs.begin(), flags.disable_simplify_exprs.end()};
  507. for (auto &e : expressions) {
  508. if (!enable_ids.empty()) {
  509. if (enable_ids.count(std::to_string(e.id)) == 0) continue;
  510. } else {
  511. if (disable_ids.count(std::to_string(e.id)) > 0) continue;
  512. }
  513. PatternTreePtr pt = e.func(e.math_expr);
  514. expression_map[pt->GetRootOp()].push_back(pt);
  515. }
  516. return expression_map;
  517. }
  518. // arithmetic simplify
  519. bool ArithmeticSimplify::DoArithmeticTrans(const inner::LiteGraphPtr &litegraph) {
  520. auto ops_list = litegraph->ops();
  521. bool changed = false;
  522. inner::NodePtrList matched_nodes;
  523. auto para_to_ref = std::make_shared<ParaMap>(); // A(B,C ...)->Node* mapping
  524. auto const_to_ref = std::make_shared<ConstMap>(); // const->Node* mapping
  525. PatternTreePtr cur_pattern;
  526. auto iter = ops_list.rbegin();
  527. while (iter != ops_list.rend()) {
  528. bool can_simplify = false;
  529. auto this_op = std::static_pointer_cast<inner::PrimOp>(*iter)->op();
  530. if (expressions_map_.find(this_op) != expressions_map_.end()) {
  531. for (auto p : expressions_map_[this_op]) {
  532. cur_pattern = p;
  533. if (!para_to_ref->empty()) {
  534. para_to_ref->clear();
  535. }
  536. if (!const_to_ref->empty()) {
  537. const_to_ref->clear();
  538. }
  539. // match a pattern;if return is empty,then fails to match
  540. matched_nodes = p->MatchGraph(*iter, para_to_ref, const_to_ref);
  541. if (!matched_nodes.empty()) {
  542. auto right_root_type = PatternNodeType(p->rhs_root()->op());
  543. if (right_root_type == inner::NType::Primitive && OutsideRely(matched_nodes, *iter)) {
  544. continue;
  545. }
  546. // if no outside rely,then this is a successful match
  547. can_simplify = true;
  548. // get the new node to replace
  549. inner::NodePtr alter_graph_node = cur_pattern->AlterGraph(para_to_ref, const_to_ref, *iter);
  550. (*iter)->ReplaceWith(alter_graph_node);
  551. ops_list = litegraph->GetOrderedNodes();
  552. iter = ops_list.rbegin();
  553. changed = true;
  554. break;
  555. }
  556. }
  557. }
  558. if (!can_simplify) {
  559. ++iter;
  560. }
  561. }
  562. return changed;
  563. }
  564. // constant fold
  565. bool ArithmeticSimplify::DoConstantFold(const inner::LiteGraphPtr &litegraph) {
  566. auto ops_list = litegraph->GetOrderedNodes();
  567. bool changed = false;
  568. auto iter = ops_list.begin();
  569. while (iter != ops_list.end()) {
  570. auto this_op = std::static_pointer_cast<inner::PrimOp>(*iter);
  571. auto value = this_op->InferValue(this_op->inputs(), this_op->attrs(), this_op->op());
  572. if (value != nullptr) {
  573. (*iter)->ReplaceWith(value);
  574. ops_list = litegraph->GetOrderedNodes();
  575. iter = ops_list.begin();
  576. changed = true;
  577. } else {
  578. ++iter;
  579. }
  580. }
  581. return changed;
  582. }
  583. void ReorganizeEmptyGraph(const inner::LiteGraphPtr &litegraph) {
  584. auto &outputs = litegraph->GetOutputs();
  585. for (size_t i = 0; i < outputs.size(); i++) {
  586. if (outputs[i]->NodeType() == inner::NType::Value) {
  587. inner::LiteGraph::GraphBuilder gb;
  588. std::vector<int64_t> new_shape = {1};
  589. auto op_ptr = gb.Emit("BroadcastTo", {outputs[i]}, {{"shape", MakeValue(new_shape)}});
  590. litegraph->SetOutput(i, op_ptr);
  591. } else if (outputs[i]->NodeType() == inner::NType::Parameter) {
  592. inner::LiteGraph::GraphBuilder gb;
  593. auto op_ptr = gb.Emit("Reshape", {outputs[i]}, {{"shape", MakeValue(outputs[i]->shape)}});
  594. litegraph->SetOutput(i, op_ptr);
  595. }
  596. }
  597. return;
  598. }
  599. bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
  600. auto mng = func_graph->manager();
  601. bool do_simplify = false;
  602. expressions_map_ = GetExpressions();
  603. for (auto node : func_graph->GetOrderedCnodes()) {
  604. if (common::AnfAlgo::IsGraphKernel(node)) {
  605. auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
  606. inner::LiteGraphPtr lg = GkUtils::AnfGraph2LiteGraph(sub_graph);
  607. bool find_pattern = true;
  608. bool change_anf_graph = false;
  609. while (find_pattern) {
  610. find_pattern = false;
  611. find_pattern = DoConstantFold(lg) || find_pattern;
  612. find_pattern = DoArithmeticTrans(lg) || find_pattern;
  613. change_anf_graph = change_anf_graph || find_pattern;
  614. }
  615. if (!change_anf_graph) continue;
  616. ReorganizeEmptyGraph(lg);
  617. auto new_funcgraph = GkUtils::LiteGraph2AnfGraph(lg);
  618. new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
  619. auto cnode = node->cast<CNodePtr>();
  620. AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
  621. auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs);
  622. mng->Replace(node, new_node);
  623. mng->AddFuncGraph(new_funcgraph);
  624. do_simplify = true;
  625. }
  626. }
  627. return do_simplify;
  628. }
  629. } // namespace mindspore::graphkernel