You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

arithmetic_simplify.cc 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <algorithm>
  17. #include <memory>
  18. #include <vector>
  19. #include <functional>
  20. #include "optimizer/irpass/arithmetic_simplify.h"
  21. #include "ir/optimizer_caller.h"
  22. #include "ir/visitor.h"
  23. #include "operator/ops.h"
  24. #include "optimizer/irpass.h"
  25. #include "optimizer/irpass/prim_eliminate.h"
  26. #include "optimizer/optimizer.h"
  27. namespace mindspore {
  28. namespace opt {
  29. namespace irpass {
  30. // {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0}
  31. // {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1}
  32. AnfNodePtr MultiplyByZeroOrOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
  33. Reset();
  34. AnfVisitor::Match(prim::kPrimScalarMul)(node);
  35. if (is_zero_) {
  36. return NewValueNode(zero_);
  37. }
  38. if (is_one_) {
  39. return x_;
  40. }
  41. return nullptr;
  42. }
  43. void MultiplyByZeroOrOne::Visit(const AnfNodePtr &node) {
  44. if (is_one_ || node->isa<CNode>()) {
  45. x_ = node;
  46. return;
  47. }
  48. AnfVisitor::Visit(node);
  49. if (!is_one_) {
  50. x_ = node;
  51. }
  52. }
  53. void MultiplyByZeroOrOne::Visit(const ValueNodePtr &vnode) {
  54. auto value = vnode->value();
  55. if (*value == *zero_) {
  56. is_zero_ = true;
  57. } else if (*value == *one_) {
  58. is_one_ = true;
  59. }
  60. }
  61. void MultiplyByZeroOrOne::Reset() {
  62. x_ = nullptr;
  63. is_one_ = false;
  64. is_zero_ = false;
  65. }
  66. // Support class used for checking if all values of a Tensor are equal `check_value_`
  67. // Supported data types: double, float/float32, int/int32
  68. bool CheckTensorConstant::IsTensorConstant(const ValuePtr &value) {
  69. if (!value->isa<tensor::Tensor>()) {
  70. return false;
  71. }
  72. auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
  73. TypeId tensor_type = tensor_ptr->Dtype()->type_id();
  74. if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
  75. float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c());
  76. for (int i = 0; i < tensor_ptr->DataSize(); i++) {
  77. if (fabs(data2[i] - check_value_) > FLT_EPSILON) {
  78. return false;
  79. }
  80. }
  81. return true;
  82. } else if (tensor_type == TypeId::kNumberTypeFloat64) {
  83. double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c());
  84. for (int i = 0; i < tensor_ptr->DataSize(); i++) {
  85. if (fabs(data2[i] - check_value_) > DBL_EPSILON) {
  86. return false;
  87. }
  88. }
  89. return true;
  90. } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
  91. int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c());
  92. for (int i = 0; i < tensor_ptr->DataSize(); i++) {
  93. if (data2[i] != check_value_) {
  94. return false;
  95. }
  96. }
  97. return true;
  98. }
  99. // input Data Types is not supported
  100. return false;
  101. }
  102. bool CheckTensorConstant::IsTensorScalarConstant(const ValuePtr &value) {
  103. if (!value->isa<tensor::Tensor>()) {
  104. return false;
  105. }
  106. auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
  107. if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) {
  108. return false;
  109. }
  110. return IsTensorConstant(value);
  111. }
  112. void *TensorMultiplyBase::GetPointerToTensorData(const AnfNodePtr &node, bool writable) {
  113. if (!node->isa<ValueNode>()) {
  114. return nullptr;
  115. }
  116. auto value = node->cast<ValueNodePtr>()->value();
  117. if (!value->isa<tensor::Tensor>()) {
  118. return nullptr;
  119. }
  120. tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
  121. return tensor_ptr->data_c();
  122. }
  123. // Make a new tensor (when possible) with the same shape as of `node`
  124. // If x is nullptr then fill new tensor will "0"
  125. // If x is a tensor with empty shape then fill new tensor with the single value of x
  126. // If x is a tensor with same shape as `node` then return x as result
  127. AnfNodePtr TensorMultiplyBase::NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x) {
  128. if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) {
  129. return nullptr;
  130. }
  131. auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
  132. TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
  133. std::vector<int> tensor_shape = tensor_abstract->shape()->shape();
  134. auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
  135. size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
  136. char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
  137. if (x == nullptr) {
  138. std::memset(data, 0, mem_size);
  139. auto new_vnode = NewValueNode(new_tensor_ptr);
  140. new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
  141. return new_vnode;
  142. }
  143. // x is not nullptr
  144. if (x->isa<CNode>()) {
  145. if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) {
  146. return nullptr;
  147. }
  148. auto x_abstract = x->abstract()->cast<abstract::AbstractTensorPtr>();
  149. std::vector<int> x_shape = x_abstract->shape()->shape();
  150. if (x_shape != tensor_shape) {
  151. return nullptr;
  152. }
  153. return x;
  154. }
  155. if (!x->isa<ValueNode>()) {
  156. return nullptr;
  157. }
  158. auto x_value = x->cast<ValueNodePtr>()->value();
  159. if (!x_value->isa<tensor::Tensor>()) {
  160. return nullptr;
  161. }
  162. auto x_tensor_ptr = dyn_cast<tensor::Tensor>(x_value);
  163. if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) {
  164. return nullptr;
  165. }
  166. char *source_data = reinterpret_cast<char *>(GetPointerToTensorData(x));
  167. if (x_tensor_ptr->DataSize() == 1) {
  168. for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) {
  169. memcpy(data + i * GetTypeByte(tensor_type_ptr), source_data, GetTypeByte(tensor_type_ptr));
  170. }
  171. } else {
  172. memcpy(data, source_data, mem_size);
  173. }
  174. auto new_vnode = NewValueNode(new_tensor_ptr);
  175. new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
  176. return new_vnode;
  177. }
  178. // {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
  179. AnfNodePtr TensorMultiplyByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
  180. Reset();
  181. AnfVisitor::Match(prim::kPrimMul)(node);
  182. if (is_zero_) {
  183. if (x_->func_graph() != node->func_graph()) {
  184. return nullptr;
  185. }
  186. return NewTensorFilledWithData(node);
  187. }
  188. return nullptr;
  189. }
  190. void TensorMultiplyByZero::Visit(const AnfNodePtr &node) {
  191. if (is_zero_) {
  192. x_ = node;
  193. return;
  194. }
  195. if (IsParam(node)) {
  196. x_ = node;
  197. return;
  198. }
  199. if (IsCNode(node)) {
  200. CNodePtr cnode = node->cast<CNodePtr>();
  201. if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) {
  202. is_zero_ = true;
  203. return;
  204. }
  205. x_ = node;
  206. return;
  207. }
  208. auto value = node->cast<ValueNodePtr>()->value();
  209. if (CheckTensorConstant(0).IsTensorConstant(value)) {
  210. is_zero_ = true;
  211. return;
  212. }
  213. x_ = node;
  214. }
  215. void TensorMultiplyByZero::Visit(const ValueNodePtr &vnode) {
  216. auto value = vnode->value();
  217. if (CheckTensorConstant(0).IsTensorConstant(value)) {
  218. is_zero_ = true;
  219. return;
  220. }
  221. x_ = vnode;
  222. }
  223. void TensorMultiplyByZero::Reset() {
  224. x_ = nullptr;
  225. is_zero_ = false;
  226. }
  227. // {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
  228. AnfNodePtr TensorMultiplyByOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
  229. Reset();
  230. AnfVisitor::Match(prim::kPrimMul)(node);
  231. if (is_one_) {
  232. return NewTensorFilledWithData(node, x_);
  233. }
  234. return nullptr;
  235. }
  236. void TensorMultiplyByOne::Visit(const AnfNodePtr &node) {
  237. if (is_one_) {
  238. x_ = node;
  239. return;
  240. }
  241. if (IsParam(node) || IsCNode(node)) {
  242. x_ = node;
  243. return;
  244. }
  245. auto value = node->cast<ValueNodePtr>()->value();
  246. if (CheckTensorConstant(1).IsTensorConstant(value)) {
  247. is_one_ = true;
  248. return;
  249. }
  250. x_ = node;
  251. }
  252. void TensorMultiplyByOne::Visit(const ValueNodePtr &vnode) {
  253. auto value = vnode->value();
  254. if (CheckTensorConstant(1).IsTensorConstant(value)) {
  255. is_one_ = true;
  256. return;
  257. }
  258. x_ = vnode;
  259. }
  260. void TensorMultiplyByOne::Reset() {
  261. x_ = nullptr;
  262. is_one_ = false;
  263. }
  264. // {prim::kPrimScalarAdd, X, 0}
  265. // {prim::kPrimScalarAdd, 0, X}
  266. AnfNodePtr AddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
  267. Reset();
  268. AnfVisitor::Match(prim::kPrimScalarAdd)(node);
  269. if (is_zero_) {
  270. return x_;
  271. }
  272. return nullptr;
  273. }
  274. void AddByZero::Visit(const AnfNodePtr &node) {
  275. if (node->isa<ValueNode>() &&
  276. ((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) {
  277. is_zero_ = true;
  278. return;
  279. }
  280. x_ = node;
  281. }
  282. void AddByZero::Reset() {
  283. x_ = nullptr;
  284. is_zero_ = false;
  285. }
  286. // {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X},
  287. // {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}}
  288. AnfNodePtr TensorAddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
  289. Reset();
  290. AnfVisitor::Match(prim::kPrimTensorAdd)(node);
  291. if (is_zero_) {
  292. return x_;
  293. }
  294. return nullptr;
  295. }
  296. void TensorAddByZero::Visit(const AnfNodePtr &node) {
  297. if (node->isa<ValueNode>() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) {
  298. is_zero_ = true;
  299. return;
  300. }
  301. x_ = node;
  302. }
  303. void TensorAddByZero::Visit(const ValueNodePtr &vnode) {
  304. auto value = vnode->value();
  305. if (CheckTensorConstant(0).IsTensorConstant(value)) {
  306. is_zero_ = true;
  307. return;
  308. }
  309. }
  310. void TensorAddByZero::Reset() {
  311. x_ = nullptr;
  312. is_zero_ = false;
  313. }
  314. // {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
  315. AnfNodePtr OptUpdateZeroTensor::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
  316. if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) {
  317. return nullptr;
  318. }
  319. // {PrimMomentum, {...}, Y, Z, Xs}
  320. auto &inputs = node->cast<CNodePtr>()->inputs();
  321. if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) {
  322. return nullptr;
  323. }
  324. auto y = inputs[2];
  325. auto z = inputs[3];
  326. // {kPrimZerosLike, X}
  327. if (inputs[1]->cast<CNodePtr>()->size() != 2) {
  328. return nullptr;
  329. }
  330. // {prim::kPrimMakeTuple, Z, Y}
  331. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y});
  332. }
  333. // {prim::kPrimMul, Tensor1, {prim::kPrimMul, Tensor2, {...}}} ->
  334. // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
  335. // Support function to multiply two constant tensors: partially support broadcasting shapes
  336. template <typename T>
  337. void ConstantDuplicateMul::Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size,
  338. void **out_data, int out_data_size) {
  339. T *data_1 = reinterpret_cast<T *>(in_data_1);
  340. T *data_2 = reinterpret_cast<T *>(in_data_2);
  341. T *data_out = new T[out_data_size];
  342. if (in_data_1_size == 1) {
  343. for (int i = 0; i < out_data_size; i++) {
  344. data_out[i] = data_1[0];
  345. }
  346. } else {
  347. for (int i = 0; i < out_data_size; i++) {
  348. data_out[i] = data_1[i];
  349. }
  350. }
  351. if (in_data_2_size == 1) {
  352. for (int i = 0; i < out_data_size; i++) {
  353. data_out[i] *= data_2[0];
  354. }
  355. } else {
  356. for (int i = 0; i < out_data_size; i++) {
  357. data_out[i] *= data_2[i];
  358. }
  359. }
  360. *out_data = reinterpret_cast<void *>(data_out);
  361. return;
  362. }
  363. AnfNodePtr ConstantDuplicateMul::MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2,
  364. const AnfNodePtr &node_3) {
  365. if (!vnode_1->isa<ValueNode>() || !vnode_2->isa<ValueNode>() || (vnode_1->abstract() == nullptr) ||
  366. (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) {
  367. return nullptr;
  368. }
  369. auto value_1 = GetValueNode(vnode_1);
  370. auto value_2 = GetValueNode(vnode_2);
  371. if (!value_1->isa<tensor::Tensor>() || !value_2->isa<tensor::Tensor>()) {
  372. return nullptr;
  373. }
  374. auto tensor_ptr_1 = dyn_cast<tensor::Tensor>(value_1);
  375. auto tensor_ptr_2 = dyn_cast<tensor::Tensor>(value_2);
  376. auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
  377. auto tensor_2_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
  378. auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>();
  379. TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType();
  380. TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType();
  381. TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType();
  382. if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) ||
  383. (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) {
  384. return nullptr;
  385. }
  386. std::vector<int> tensor_out_shape = tensor_3_abstract->shape()->shape();
  387. int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int>());
  388. if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) {
  389. return nullptr;
  390. }
  391. if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) {
  392. return nullptr;
  393. }
  394. void *data_out;
  395. if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) ||
  396. (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) {
  397. Multiply<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), tensor_ptr_2->DataSize(),
  398. &data_out, data_out_size);
  399. } else {
  400. if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) {
  401. Multiply<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
  402. tensor_ptr_2->DataSize(), &data_out, data_out_size);
  403. } else {
  404. if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) ||
  405. (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) {
  406. Multiply<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
  407. tensor_ptr_2->DataSize(), &data_out, data_out_size);
  408. } else {
  409. // Un-support data types
  410. return nullptr;
  411. }
  412. }
  413. }
  414. auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape);
  415. size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
  416. char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
  417. memcpy(data, data_out, mem_size);
  418. auto new_vnode = NewValueNode(new_tensor_ptr);
  419. new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
  420. return new_vnode;
  421. }
  422. AnfNodePtr ConstantDuplicateMul::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
  423. Reset();
  424. // {prim::kPrimMul, Tensor1, {...}}
  425. AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node);
  426. if (vnode_ == nullptr || c_p_node_ == nullptr) {
  427. return nullptr;
  428. }
  429. if (!IsCNode(c_p_node_)) {
  430. return nullptr;
  431. }
  432. auto tensor1 = vnode_;
  433. auto mul = c_p_node_->cast<CNodePtr>();
  434. Reset();
  435. // {prim::kPrimMul, Tensor2, {...}}
  436. AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul);
  437. if (vnode_ == nullptr || c_p_node_ == nullptr) {
  438. return nullptr;
  439. }
  440. auto tensor2 = vnode_;
  441. auto c_p_node = c_p_node_;
  442. auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0));
  443. auto fg = node->func_graph();
  444. auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node);
  445. if (new_mul_tensor == nullptr) {
  446. auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg);
  447. return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg);
  448. }
  449. return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg);
  450. }
  451. void ConstantDuplicateMul::Visit(const AnfNodePtr &node) {
  452. if (IsValueNode<tensor::Tensor>(node)) {
  453. vnode_ = node;
  454. }
  455. if (IsCNode(node) || IsParam(node)) {
  456. c_p_node_ = node;
  457. }
  458. }
  459. void ConstantDuplicateMul::Reset() {
  460. vnode_ = nullptr;
  461. c_p_node_ = nullptr;
  462. }
  463. AnfNodePtr PowerOneEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
  464. if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) {
  465. return nullptr;
  466. }
  467. auto &inputs = node->cast<CNodePtr>()->inputs();
  468. if (!IsValueNode<Scalar>(inputs[2])) {
  469. return nullptr;
  470. }
  471. auto scalar = GetValueNode<ScalarPtr>(inputs[2]);
  472. if (scalar->isa<FloatImm>() && GetValue<float>(scalar) == 1.0) {
  473. return inputs[1];
  474. } else if (scalar->isa<IntergerImm>() && GetValue<int>(scalar) == 1) {
  475. return inputs[1];
  476. }
  477. return nullptr;
  478. }
  479. // grad = AllReduce(grad) / worker_number
  480. // grad = grad + weight * decy
  481. // ->
  482. // grad = grad + weight * decy
  483. // grad = AllReduce(grad) / worker_number
  484. // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
  485. // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
  486. AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
  487. Reset();
  488. // {prim::kPrimAddN, Zs}
  489. if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
  490. return nullptr;
  491. }
  492. auto addn = node->cast<CNodePtr>();
  493. if (addn->size() != 2) {
  494. return nullptr;
  495. }
  496. AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
  497. if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) {
  498. return nullptr;
  499. }
  500. auto addn_maketuple = addn->input(1);
  501. auto fg = all_reduce_fg_;
  502. // addn inputs cross the graph, make the inputs same as allreduce node.
  503. if (z_->isa<CNode>() && fg != z_->func_graph()) {
  504. auto cnode_z = z_->cast<CNodePtr>();
  505. z_ = NewCNode(cnode_z->inputs(), fg);
  506. }
  507. auto addn_op_node = addn->input(0);
  508. auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0);
  509. AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
  510. AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
  511. AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg);
  512. AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg);
  513. ProcessDependEdge(fg, addn_maketuple, all_reduce);
  514. return mul;
  515. }
  516. void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple,
  517. const AnfNodePtr &new_node) {
  518. // If has dynamic loss scale.
  519. auto &users_map = fg->manager()->node_users();
  520. auto it = users_map.find(mul_cnode_);
  521. if (it != users_map.end()) {
  522. auto users = it->second;
  523. for (auto &user_pair : users) {
  524. auto node = user_pair.first;
  525. if (node != addn_maketuple) {
  526. if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
  527. fg->manager()->SetEdge(node, user_pair.second, new_node);
  528. }
  529. }
  530. }
  531. }
  532. }
  533. void AdjustAllReduceMulAdd::Visit(const AnfNodePtr &node) {
  534. if (level_ == 0) {
  535. level_ = 1;
  536. is_reduce_match_ = false;
  537. // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
  538. AnfVisitor::Match(prim::kPrimMul)(node);
  539. level_ = 0;
  540. if (is_reduce_match_) {
  541. mul_ = node->cast<CNodePtr>()->input(0);
  542. mul_cnode_ = node->cast<CNodePtr>();
  543. y_ = tmp_;
  544. } else {
  545. z_ = node;
  546. }
  547. }
  548. if (level_ == 1) {
  549. // {prim::kPrimAllReduce, X}
  550. if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
  551. auto cnode = node->cast<CNodePtr>();
  552. if (cnode->size() > 1) {
  553. all_reduce_ = cnode->input(0);
  554. x_ = cnode->input(1);
  555. is_reduce_match_ = true;
  556. all_reduce_fg_ = cnode->func_graph();
  557. }
  558. } else {
  559. tmp_ = node;
  560. }
  561. }
  562. }
  563. void AdjustAllReduceMulAdd::Reset() {
  564. level_ = 0;
  565. is_reduce_match_ = false;
  566. x_ = nullptr;
  567. y_ = nullptr;
  568. z_ = nullptr;
  569. tmp_ = nullptr;
  570. all_reduce_fg_ = nullptr;
  571. }
  572. AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
  573. AnfNodePtr new_node;
  574. for (auto &eliminater : eliminaters_) {
  575. new_node = (*eliminater)(optimizer, node);
  576. if (new_node != nullptr) {
  577. return new_node;
  578. }
  579. }
  580. return nullptr;
  581. }
  582. AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
  583. AnfNodePtr new_node;
  584. for (auto &eliminater : eliminaters_) {
  585. new_node = (*eliminater)(optimizer, node);
  586. if (new_node != nullptr) {
  587. return new_node;
  588. }
  589. }
  590. return nullptr;
  591. }
  592. } // namespace irpass
  593. } // namespace opt
  594. } // namespace mindspore