You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parameter_manager.cc 50 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/parallel/parameter_manager.h"
  17. #include <inttypes.h>
  18. #include <sys/time.h>
  19. #include <algorithm>
  20. #include <map>
  21. #include <memory>
  22. #include <set>
  23. #include <string>
  24. #include <utility>
  25. #include <cmath>
  26. #include "utils/hash_map.h"
  27. #include "base/core_ops.h"
  28. #include "frontend/operator/ops.h"
  29. #include "frontend/optimizer/optimizer.h"
  30. #include "include/common/utils/parallel_context.h"
  31. #include "frontend/parallel/device_manager.h"
  32. #include "frontend/parallel/graph_util/generate_graph.h"
  33. #include "frontend/parallel/graph_util/graph_info.h"
  34. #include "frontend/parallel/graph_util/node_info.h"
  35. #include "frontend/parallel/graph_util/pipeline_split_utils.h"
  36. #include "frontend/parallel/node_check.h"
  37. #include "ir/param_info.h"
  38. #include "ir/tensor.h"
  39. #include "utils/trace_base.h"
  40. #include "include/common/utils/comm_manager.h"
  41. #include "utils/ms_context.h"
  42. #include "utils/symbolic.h"
  43. #include "mindspore/core/utils/parallel_node_check.h"
  44. #include "frontend/parallel/step_parallel_utils.h"
  45. namespace mindspore {
  46. namespace parallel {
  47. static ParameterUsersInfo FindRefKeyNodeUsers(const RefKeyPair &ref_key_pair, bool (*IsCareNode)(const CNodePtr &)) {
  48. // Dealing with the RefKey case
  49. ParameterUsersInfo parameter_user_info;
  50. auto refkeys = ref_key_pair.second;
  51. auto cnode = ref_key_pair.first;
  52. auto cnode_ptr = cnode->cast<CNodePtr>();
  53. if ((cnode_ptr == nullptr) || !IsValueNode<Primitive>(cnode_ptr->input(0)) || !IsCareNode(cnode_ptr)) {
  54. return parameter_user_info;
  55. }
  56. if (refkeys.size() > 1) {
  57. MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << "'s inputs have more than 1 RefKeys";
  58. }
  59. MS_EXCEPTION_IF_NULL(cnode->func_graph());
  60. auto cnode_func_graph = cnode->func_graph();
  61. MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
  62. // Find the RefKey being used
  63. auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]];
  64. for (auto &candidate : candidate_set_by_refkey) {
  65. auto candidate_node = candidate.first;
  66. auto c = candidate_node->cast<CNodePtr>();
  67. if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
  68. continue;
  69. }
  70. parameter_user_info.second.second.insert(candidate);
  71. }
  72. // Find the corresponding Parameter being used
  73. std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph);
  74. if (parameters.size() != 1) {
  75. MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
  76. }
  77. parameter_user_info.first = parameters[0]->cast<ParameterPtr>()->name();
  78. parameter_user_info.second.first = parameters[0];
  79. auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]];
  80. for (auto &candidate : candidate_set_by_para) {
  81. auto candidate_node = candidate.first;
  82. auto c = candidate_node->cast<CNodePtr>();
  83. if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
  84. continue;
  85. }
  86. parameter_user_info.second.second.insert(candidate);
  87. }
  88. return parameter_user_info;
  89. }
  90. static ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node) {
  91. // In this case, node is a Parameter
  92. ParameterUsersInfo parameter_user_info;
  93. MS_EXCEPTION_IF_NULL(node->func_graph());
  94. MS_EXCEPTION_IF_NULL(node->func_graph()->manager());
  95. auto candidate_set = node->func_graph()->manager()->node_users()[node];
  96. for (auto &candidate : candidate_set) {
  97. auto candidate_node = candidate.first;
  98. if (IsPrimitiveCNode(candidate_node, prim::kPrimLoad)) {
  99. if (candidate.second != 1) {
  100. continue;
  101. }
  102. auto load_node_users = node->func_graph()->manager()->node_users()[candidate_node];
  103. for (auto &node_user : load_node_users) {
  104. auto cnode = node_user.first->cast<CNodePtr>();
  105. if (IsSomePrimitive(cnode, DEPEND)) {
  106. auto depend_node_users = node->func_graph()->manager()->node_users()[node_user.first];
  107. for (auto depend_user : depend_node_users) {
  108. if (IsPrimitiveCNode(depend_user.first, prim::kPrimLoad)) {
  109. auto local_load_node_users = node->func_graph()->manager()->node_users()[depend_user.first];
  110. for (auto local_load_user : local_load_node_users) {
  111. auto local_cnode = local_load_user.first->cast<CNodePtr>();
  112. if (local_cnode == nullptr || !local_cnode->has_user_data<OperatorInfo>() ||
  113. IsSomePrimitive(local_cnode, RECEIVE)) {
  114. continue;
  115. }
  116. parameter_user_info.second.second.insert(local_load_user);
  117. }
  118. }
  119. }
  120. }
  121. if (cnode == nullptr || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE)) {
  122. continue;
  123. }
  124. parameter_user_info.second.second.insert(node_user);
  125. }
  126. } else {
  127. auto c = candidate_node->cast<CNodePtr>();
  128. if (c == nullptr || !c->has_user_data<OperatorInfo>() || IsSomePrimitive(c, RECEIVE)) {
  129. continue;
  130. }
  131. parameter_user_info.second.second.insert(candidate);
  132. }
  133. }
  134. parameter_user_info.first = node->cast<ParameterPtr>()->name();
  135. parameter_user_info.second.first = node;
  136. return parameter_user_info;
  137. }
  138. static RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode) {
  139. MS_EXCEPTION_IF_NULL(cnode);
  140. std::vector<AnfNodePtr> refkeys;
  141. if (cnode->isa<CNode>()) {
  142. auto cnode_ptr = cnode->cast<CNodePtr>();
  143. auto inputs = cnode_ptr->inputs();
  144. for (auto &one_input : inputs) {
  145. if (IsValueNode<RefKey>(one_input)) {
  146. refkeys.push_back(one_input);
  147. }
  148. }
  149. if (refkeys.size() >= 1) {
  150. return std::make_pair(cnode, refkeys);
  151. }
  152. }
  153. return {nullptr, refkeys};
  154. }
  155. ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)) {
  156. ParameterUsersInfo parameter_users_info;
  157. auto cnode_with_refkeys = CNodeWithRefKeys(node);
  158. if (cnode_with_refkeys.first != nullptr) {
  159. // the node is a ref key node
  160. return FindRefKeyNodeUsers(cnode_with_refkeys, IsCareNode);
  161. } else if (node->isa<Parameter>()) {
  162. // the node is a parameter node
  163. return FindParameterNodeUsers(node);
  164. }
  165. return parameter_users_info;
  166. }
  167. static bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr &parameter, size_t max_depth) {
  168. if (max_depth > MAX_RECURSIVE_DEPTH) {
  169. MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
  170. }
  171. MS_EXCEPTION_IF_NULL(graph);
  172. MS_EXCEPTION_IF_NULL(parameter);
  173. auto manager = graph->manager();
  174. auto node_users = manager->node_users()[parameter];
  175. if (node_users.empty()) {
  176. return false;
  177. }
  178. for (auto node_user : node_users) {
  179. auto use_node = node_user.first->cast<CNodePtr>();
  180. if (IsValueNode<FuncGraph>(use_node->input(0))) {
  181. auto graph_sub = GetValueNode<FuncGraphPtr>(use_node->input(0));
  182. auto parameters = graph_sub->parameters();
  183. auto parameter_sub = parameters[IntToSize(node_user.second - 1)];
  184. return IsUsedParameter(graph_sub, parameter_sub, max_depth + 1);
  185. }
  186. if (use_node->input(0)->isa<CNode>()) {
  187. auto cnode = use_node->input(0)->cast<CNodePtr>();
  188. if (!IsSomePrimitive(cnode, J) || !IsValueNode<FuncGraph>(cnode->input(1))) {
  189. return true;
  190. }
  191. auto graph_sub = GetValueNode<FuncGraphPtr>(cnode->input(1));
  192. auto parameters = graph_sub->parameters();
  193. auto parameter_sub = parameters[IntToSize(node_user.second - 1)];
  194. return IsUsedParameter(graph_sub, parameter_sub, max_depth + 1);
  195. }
  196. return true;
  197. }
  198. return true;
  199. }
  200. void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
  201. for (auto &node : all_nodes) {
  202. ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode);
  203. auto &users_set = parameter_users_info.second.second;
  204. if (users_set.size() <= 1) {
  205. continue;
  206. }
  207. auto parameter_name = parameter_users_info.first;
  208. MS_LOG(INFO) << "The parameter: " << parameter_name << " has " << users_set.size() << " users";
  209. auto &first_user = users_set.front();
  210. auto parameter_tensor_info = GetInputsTensorInfo(first_user);
  211. for (auto iter = users_set.begin() + 1; iter != users_set.end(); ++iter) {
  212. auto &user = *iter;
  213. auto user_tensor_info = GetInputsTensorInfo(user);
  214. if (parameter_tensor_info == user_tensor_info) {
  215. continue;
  216. } else {
  217. MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
  218. << " has multiple users, but the TensorInfo are different";
  219. }
  220. }
  221. }
  222. }
  223. namespace {
  224. void RevertSymbolicKeyInstance(const FuncGraphPtr &root, const AnfNodePtr &node) {
  225. MS_EXCEPTION_IF_NULL(root);
  226. MS_EXCEPTION_IF_NULL(node);
  227. auto symbolic_key = GetValueNode<SymbolicKeyInstancePtr>(node);
  228. MS_EXCEPTION_IF_NULL(symbolic_key);
  229. auto all_upstream_node = root->manager()->node_users()[node];
  230. for (auto &upstream_node : all_upstream_node) {
  231. FuncGraphPtr fg = upstream_node.first->func_graph();
  232. if (symbolic_key->node()->isa<Parameter>()) {
  233. for (auto &param : root->parameters()) {
  234. if (*param == *symbolic_key->node()) {
  235. AnfNodePtr reverted_node = root->NewCNode({NewValueNode(prim::kPrimEmbed), param});
  236. MS_EXCEPTION_IF_NULL(reverted_node);
  237. MS_LOG(DEBUG) << "before replace " << node->ToString() << " to node " << reverted_node->DebugString();
  238. (void)fg->manager()->Replace(node, reverted_node);
  239. MS_LOG(DEBUG) << "revert node " << node->ToString() << " to node " << reverted_node->DebugString();
  240. }
  241. }
  242. }
  243. }
  244. }
  245. } // namespace
  246. void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
  247. MS_EXCEPTION_IF_NULL(root);
  248. for (auto &node : all_nodes) {
  249. // revert back SymbolicKeyInstance to embed() primitive
  250. if (IsValueNode<SymbolicKeyInstance>(node)) {
  251. RevertSymbolicKeyInstance(root, node);
  252. continue;
  253. }
  254. }
  255. }
  256. bool ParameterIsCloned(const AnfNodePtr &parameter_node) {
  257. MS_EXCEPTION_IF_NULL(parameter_node);
  258. auto cloned_parameter = parameter_node->cast<ParameterPtr>();
  259. MS_EXCEPTION_IF_NULL(cloned_parameter);
  260. // find the clone parameter
  261. if (!cloned_parameter->has_default()) {
  262. return false;
  263. }
  264. auto param_value = cloned_parameter->param_info();
  265. if (param_value == nullptr) {
  266. return false;
  267. }
  268. bool cloned = param_value->cloned();
  269. if (!cloned) {
  270. return false;
  271. }
  272. MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
  273. return true;
  274. }
  275. void HandleNoUsedParameter(const FuncGraphPtr &root) {
  276. MS_EXCEPTION_IF_NULL(root);
  277. bool full_batch = ParallelContext::GetInstance()->full_batch();
  278. if (full_batch) {
  279. return;
  280. }
  281. // in grad accumulation mode, if use dynamic lr, it has some parameters in optimizer which no used for first graph,
  282. // but used for second graph(such as global_step), so can not change their shapes
  283. int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
  284. if (grad_accumulation_step > 1) {
  285. MS_LOG(INFO) << "In grad accumulation mode, do not handle no used parameters";
  286. return;
  287. }
  288. auto dev_num = g_device_manager->stage_device_num();
  289. auto parameters = root->parameters();
  290. for (auto &parameter : parameters) {
  291. if (IsUsedParameter(root, parameter, 0)) {
  292. continue;
  293. }
  294. auto parameter_shape = GetNodeShape(parameter);
  295. if (parameter_shape.empty()) {
  296. continue;
  297. }
  298. Shape slice_shape = parameter_shape[0];
  299. if (slice_shape.empty() || slice_shape[0] < dev_num) {
  300. continue;
  301. }
  302. slice_shape[0] = slice_shape[0] / dev_num;
  303. auto slice_shape_ptr = std::make_shared<abstract::Shape>(slice_shape);
  304. auto abstract = parameter->abstract();
  305. MS_EXCEPTION_IF_NULL(abstract);
  306. auto abstract_cloned = abstract->Clone();
  307. MS_EXCEPTION_IF_NULL(abstract_cloned);
  308. abstract_cloned->set_shape(slice_shape_ptr);
  309. parameter->set_abstract(abstract_cloned);
  310. }
  311. }
  312. bool IsFullySplitParameter(const ParameterPtr &param_ptr, size_t allow_repeat_num) {
  313. auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
  314. if (tensor_layout == nullptr) {
  315. return false;
  316. }
  317. auto dev_mat_shape = tensor_layout->device_arrangement().array();
  318. auto tensor_map = tensor_layout->tensor_map().array();
  319. int64_t rank = g_device_manager->global_rank();
  320. RankList rank_list = g_device_manager->GetDeviceListInThisStage();
  321. DeviceMatrix dev_matrix(rank, rank_list, dev_mat_shape);
  322. RankList group_devices;
  323. if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
  324. MS_LOG(WARNING) << "Get devices by tensor map failed, invalid tensor layout";
  325. return false;
  326. }
  327. if (group_devices.size() <= allow_repeat_num) {
  328. MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split";
  329. return true;
  330. }
  331. return false;
  332. }
  333. static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node_user,
  334. const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) {
  335. auto cnode = node_user.first->cast<CNodePtr>();
  336. auto prim = GetCNodePrimitive(cnode);
  337. if (prim == nullptr) {
  338. MS_LOG(WARNING) << cnode->DebugString() << " can not insert fully split param grad accumulation node";
  339. return;
  340. }
  341. OperatorAttrs attrs;
  342. auto py_instance = CreateOpInstance(attrs, "_VirtualAdd", "grad_accu");
  343. auto value_node = NewValueNode(py_instance);
  344. std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(IntToSize(node_user.second)), accu_parameter};
  345. auto graph = cnode->func_graph();
  346. auto virtual_node = graph->NewCNode(virtual_node_input);
  347. manager->SetEdge(cnode, node_user.second, virtual_node);
  348. }
  349. void HandleFullySplitParameters(const FuncGraphPtr &root) {
  350. int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
  351. if ((grad_accumulation_step <= 1) || root->has_flag(kAccumulation)) {
  352. return;
  353. }
  354. auto parameters = root->parameters();
  355. auto node_users_map = root->manager()->node_users();
  356. for (auto &parameter : parameters) {
  357. auto param_ptr = parameter->cast<ParameterPtr>();
  358. MS_EXCEPTION_IF_NULL(param_ptr);
  359. if (!IsFullySplitParameter(param_ptr)) {
  360. continue;
  361. }
  362. auto accu_parameter = FindGradAccuParameter(parameters, param_ptr->name());
  363. if (!accu_parameter) {
  364. continue; // some parameters no need to handle, such as itself or lr
  365. }
  366. auto node_users = node_users_map[parameter];
  367. for (auto &user : node_users) {
  368. auto node = user.first;
  369. auto cnode = node->cast<CNodePtr>();
  370. MS_EXCEPTION_IF_NULL(cnode);
  371. if (!cnode->in_forward_flag()) {
  372. continue;
  373. }
  374. InsertFullySplitParamGradAccu(user, root->manager(), accu_parameter);
  375. MS_LOG(INFO) << "Insert full split assign add node for " << param_ptr->name();
  376. break; // only need to insert once, if the parameter has many users
  377. }
  378. }
  379. }
  380. void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
  381. MS_EXCEPTION_IF_NULL(root);
  382. auto grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
  383. for (auto &cloned_parameter_node : root->parameters()) {
  384. MS_EXCEPTION_IF_NULL(cloned_parameter_node);
  385. auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
  386. MS_EXCEPTION_IF_NULL(cloned_parameter);
  387. if (!ParameterIsCloned(cloned_parameter_node)) {
  388. continue;
  389. }
  390. auto param_value = cloned_parameter->param_info();
  391. if (param_value == nullptr) {
  392. continue;
  393. }
  394. // get the cloned index
  395. int64_t cloned_index = param_value->cloned_index();
  396. // find the be cloned parameter
  397. bool found_be_cloned_parameter = false;
  398. ParameterPtr cloned_from_parameter = nullptr;
  399. AnfNodePtr cloned_from_node = nullptr;
  400. for (auto &be_cloned_parameter_node : root->parameters()) {
  401. MS_EXCEPTION_IF_NULL(be_cloned_parameter_node);
  402. auto be_cloned_parameter = be_cloned_parameter_node->cast<ParameterPtr>();
  403. MS_EXCEPTION_IF_NULL(be_cloned_parameter);
  404. if (!be_cloned_parameter->has_default()) {
  405. continue;
  406. }
  407. auto param_value_in = be_cloned_parameter->param_info();
  408. if (param_value_in == nullptr) {
  409. continue;
  410. }
  411. if (!param_value_in->be_cloned()) {
  412. continue;
  413. }
  414. // get the be cloned index
  415. auto &be_cloned_index = param_value_in->be_cloned_index();
  416. if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) {
  417. found_be_cloned_parameter = true;
  418. cloned_from_parameter = be_cloned_parameter;
  419. cloned_from_node = be_cloned_parameter_node;
  420. }
  421. }
  422. if (found_be_cloned_parameter) {
  423. // set the shape and tensor layout for cloned parameter
  424. std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
  425. if (cloned_from_parameter->user_data<TensorLayout>() == nullptr) {
  426. MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it";
  427. continue;
  428. }
  429. auto tensor_layout = cloned_from_parameter->user_data<TensorLayout>();
  430. MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
  431. MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
  432. auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
  433. MS_EXCEPTION_IF_NULL(cloned_abstract);
  434. // from pipeline or grad accumulation
  435. if (param_name.find(ACCU_GRADS) != std::string::npos) {
  436. auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array();
  437. auto opt_shard_group = tensor_layout->opt_shard_group();
  438. auto opt_shard_shape = cloned_from_parameter->user_data<TensorLayout>()->opt_shard_slice_shape();
  439. std::shared_ptr<abstract::BaseShape> parallel_shape = nullptr;
  440. // set opt shard shape if the pipeline sharding is set
  441. if (grad_accumulation_shard && !opt_shard_group.empty()) {
  442. parallel_shape = std::make_shared<abstract::Shape>(opt_shard_shape);
  443. } else {
  444. parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
  445. }
  446. MS_EXCEPTION_IF_NULL(parallel_shape);
  447. cloned_abstract->set_shape(parallel_shape);
  448. // in opt shard, accu_grad's shape is different from the original param's shape
  449. // if the grad_accumulation_shard is enabled, the accu_grads will be a opt-sharded shape
  450. if (!grad_accumulation_shard && ParallelContext::GetInstance()->enable_parallel_optimizer()) {
  451. TensorLayout new_layout = *tensor_layout;
  452. new_layout.set_opt_shard_group("");
  453. tensor_layout = std::make_shared<TensorLayout>(new_layout);
  454. }
  455. } else {
  456. cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack());
  457. }
  458. cloned_parameter->set_user_data<TensorLayout>(tensor_layout);
  459. cloned_parameter_node->set_abstract(cloned_abstract);
  460. // copy the fusion tag
  461. auto cloned_param_info = cloned_parameter->param_info();
  462. MS_EXCEPTION_IF_NULL(cloned_param_info);
  463. auto cloned_from_param_info = cloned_from_parameter->param_info();
  464. MS_EXCEPTION_IF_NULL(cloned_from_param_info);
  465. cloned_param_info->set_comm_fusion(cloned_from_param_info->comm_fusion());
  466. MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
  467. << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()
  468. << ", clone index is: " << cloned_index;
  469. } else {
  470. MS_LOG(EXCEPTION) << "The parameter: " << cloned_parameter->name() << " is cloned, cloned index is "
  471. << cloned_index << ", but not found the be cloned parameter";
  472. }
  473. }
  474. }
  475. // For adafactor optimizer, the relationship between parameter and state's shape as follows:
  476. // 1) parameter: [A, B, C, D] (shape_size > 2), exp_avg_sq_row: [A, B, C], exp_avg_sq_col: [A, B, D], exp_avg_sq: [1]
  477. // If the parameter is opt shard, the exp_avg_sq_row and exp_avg_sq_col need to be shard accordingly.
  478. // 2) parameter: [A, B] (shape_size = 2), exp_avg_sq_row: [A], exp_avg_sq_col: [B], exp_avg_sq: [1]
  479. // If the parameter is opt shard, the exp_avg_sq_row needs to be shard accordingly.
  480. // 3) parameter: [A] (shape_size = 1), exp_avg_sq_row: [1], exp_avg_sq_col: [1], exp_avg_sq: [A]
  481. // If the parameter is opt shard, the exp_avg_sq needs to be shard accordingly.
  482. static bool AdafactorStateIsOptShard(const std::string &opt_shard_group, size_t shape_size,
  483. const std::string &param_name, const std::string &state_name) {
  484. if (opt_shard_group.empty()) {
  485. return false;
  486. }
  487. std::string exp_row_name = EXP_AVG_SQ_ROW + param_name;
  488. std::string exp_col_name = EXP_AVG_SQ_COL + param_name;
  489. std::string exp_avg_name = EXP_AVG_SQ + param_name;
  490. if (shape_size > 2 && state_name == exp_avg_name) {
  491. return false;
  492. }
  493. if (shape_size == 2 && (state_name == exp_col_name || state_name == exp_avg_name)) {
  494. return false;
  495. }
  496. if (shape_size == 1 && (state_name == exp_row_name || state_name == exp_col_name)) {
  497. return false;
  498. }
  499. MS_LOG(INFO) << "The parameter " << param_name << " is opt shard";
  500. return true;
  501. }
  502. static bool IsOriginWeight(const ParameterPtr &param) {
  503. std::string param_name = param->name();
  504. if (param_name.find(EXP_AVG) != std::string::npos) {
  505. return false;
  506. }
  507. auto tensor_layout = param->user_data<TensorLayout>();
  508. if (tensor_layout == nullptr) {
  509. return false;
  510. }
  511. return true;
  512. }
  513. static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
  514. const std::string &name = ALL_REDUCE) {
  515. if (IsValueNode<RefKey>(node)) {
  516. std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
  517. if (param_v.size() != 1) {
  518. MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is "
  519. << param_v.size();
  520. }
  521. auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
  522. if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty() &&
  523. name == ALL_REDUCE) {
  524. return std::make_pair(nullptr, true);
  525. }
  526. return std::make_pair(node, true);
  527. }
  528. return std::make_pair(nullptr, false);
  529. }
  530. static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node,
  531. const std::string &name = ALL_REDUCE) {
  532. auto param_ptr = node->user_data<parallel::TensorLayout>();
  533. if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty() &&
  534. name == ALL_REDUCE) {
  535. return std::make_pair(nullptr, false);
  536. }
  537. return std::make_pair(node, false);
  538. }
  539. // Only used for InsertMirrorOps
  540. std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
  541. if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
  542. return std::make_pair(nullptr, false);
  543. }
  544. if (node->isa<Parameter>()) {
  545. return FindParameterByParameter(node);
  546. }
  547. if (node->isa<ValueNode>()) {
  548. return FindParameterByValueNode(node, func_graph);
  549. }
  550. CNodePtr cnode = node->cast<CNodePtr>();
  551. MS_EXCEPTION_IF_NULL(cnode);
  552. if (!IsValueNode<Primitive>(cnode->input(0))) {
  553. for (size_t index = 0; index < cnode->inputs().size(); ++index) {
  554. auto res = FindParameter(cnode->input(index), func_graph);
  555. if (!res.first) {
  556. continue;
  557. }
  558. return res;
  559. }
  560. }
  561. // When not fully use opt shard, allgather and mirror would be both inserted.
  562. // Skip allgather here and find parameter recursively.
  563. if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) {
  564. return std::make_pair(nullptr, false);
  565. }
  566. ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
  567. MS_EXCEPTION_IF_NULL(prim_anf_node);
  568. for (size_t index = 0; index < cnode->inputs().size(); ++index) {
  569. PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
  570. MS_EXCEPTION_IF_NULL(prim);
  571. if ((prim->name() == DEPEND || prim->name() == LOAD || IsInAllGatherNodeList(cnode)) && index != 1) {
  572. continue;
  573. }
  574. auto res = FindParameter(cnode->input(index), func_graph);
  575. if (!res.first) {
  576. continue;
  577. }
  578. return res;
  579. }
  580. return std::make_pair(nullptr, false);
  581. }
  582. // Used for allgather and reducescatter
  583. std::pair<AnfNodePtr, bool> FindParameterWithAllgather(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
  584. const std::string &name) {
  585. if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
  586. return std::make_pair(nullptr, false);
  587. }
  588. if (node->isa<Parameter>()) {
  589. return FindParameterByParameter(node, name);
  590. }
  591. if (node->isa<ValueNode>()) {
  592. return FindParameterByValueNode(node, func_graph, name);
  593. }
  594. CNodePtr cnode = node->cast<CNodePtr>();
  595. MS_EXCEPTION_IF_NULL(cnode);
  596. for (size_t index = 0; index < cnode->inputs().size(); ++index) {
  597. if (index != 1) continue;
  598. auto res = FindParameterWithAllgather(cnode->input(index), func_graph, name);
  599. if (!res.first) {
  600. continue;
  601. }
  602. return res;
  603. }
  604. return std::make_pair(nullptr, false);
  605. }
  606. std::unordered_map<std::string, std::shared_ptr<TensorLayout>> AdaSumParamTensorLayout(const FuncGraphPtr &root) {
  607. MS_EXCEPTION_IF_NULL(root);
  608. std::unordered_map<std::string, std::shared_ptr<TensorLayout>> adasum_param_map;
  609. for (auto &parameter_node : root->parameters()) {
  610. MS_EXCEPTION_IF_NULL(parameter_node);
  611. auto cloned_parameter = parameter_node->cast<ParameterPtr>();
  612. MS_EXCEPTION_IF_NULL(cloned_parameter);
  613. if (!ParameterIsCloned(parameter_node)) {
  614. auto parameter_tensor_layout = cloned_parameter->user_data<TensorLayout>();
  615. adasum_param_map["adasum_delta_weight." + cloned_parameter->name()] = parameter_tensor_layout;
  616. }
  617. }
  618. return adasum_param_map;
  619. }
  620. Shape ValueSequeueScaleToShape(const ValuePtr &value_seq, const Shape &scale, size_t expand_ratio = 1) {
  621. if (!value_seq->isa<ValueSequeue>()) {
  622. MS_LOG(EXCEPTION) << "The input is not a value_sequeue";
  623. }
  624. std::vector<int64_t> origin_value_vector;
  625. if (TransValueSequeueToVector(value_seq, &origin_value_vector) != SUCCESS) {
  626. MS_LOG(EXCEPTION) << "Transform value_seq to vector failed";
  627. }
  628. if (origin_value_vector.size() != scale.size()) {
  629. MS_LOG(EXCEPTION) << "Shape not equal, cannot scale, value_seq size is: " << origin_value_vector.size()
  630. << " scale size is: " << scale.size();
  631. }
  632. for (size_t i = 0; i < scale.size(); ++i) {
  633. origin_value_vector[i] = origin_value_vector[i] / scale[i];
  634. if (i == 0) {
  635. origin_value_vector[i] = origin_value_vector[i] * expand_ratio;
  636. }
  637. }
  638. return origin_value_vector;
  639. }
  640. ValuePtr ValueSequeueScale(const ValuePtr &value_seq, const Shape &scale, size_t expand_ratio = 1) {
  641. Shape origin_value_vector = ValueSequeueScaleToShape(value_seq, scale, expand_ratio);
  642. if (value_seq->isa<ValueTuple>()) {
  643. return TransVectorToValueSequeue<ValueTuple>(origin_value_vector);
  644. }
  645. return TransVectorToValueSequeue<ValueList>(origin_value_vector);
  646. }
  647. void ReplaceAdaSumStridedSliceValue(const CNodePtr &stridedslice_cnode1,
  648. const std::shared_ptr<TensorLayout> &target_param_layout,
  649. size_t slice_expand_ratio) {
  650. auto target_param_info = std::make_shared<TensorInfo>(target_param_layout->SqueezeShape());
  651. Dimensions param_strategy = target_param_info->InferStrategy();
  652. auto new_begin1_value =
  653. ValueSequeueScale(GetValueNode(stridedslice_cnode1->input(2)), param_strategy, slice_expand_ratio);
  654. auto new_end1_value =
  655. ValueSequeueScale(GetValueNode(stridedslice_cnode1->input(3)), param_strategy, slice_expand_ratio);
  656. ValueNodePtr new_begin_value_node = std::make_shared<ValueNode>(new_begin1_value);
  657. ValueNodePtr new_end_value_node = std::make_shared<ValueNode>(new_end1_value);
  658. stridedslice_cnode1->set_input(2, new_begin_value_node);
  659. stridedslice_cnode1->set_input(3, new_end_value_node);
  660. }
  661. RankList GetRankListByLayout(const std::shared_ptr<TensorLayout> &target_param_layout) {
  662. int64_t rank = g_device_manager->global_rank();
  663. auto dev_shape = target_param_layout->device_arrangement().array();
  664. auto stage_device_list = g_device_manager->GetDeviceListInThisStage();
  665. DeviceMatrix dev_matrix(rank, stage_device_list, dev_shape);
  666. RankList group_devices;
  667. if (dev_matrix.GetDevicesByTensorMap(target_param_layout->tensor_map().array(), &group_devices) != SUCCESS) {
  668. MS_LOG(EXCEPTION) << "Get adasum parameter origin mirror group by tensor layout failed.";
  669. }
  670. return group_devices;
  671. }
  672. std::vector<bool> IsBorderAdaSumSendReceive(const AnfNodePtr &node, const RankList &group_devices) {
  673. bool is_send = IsPrimitiveCNode(node, prim::kPrimSend);
  674. PrimitivePtr send_rec_prim = GetCNodePrimitive(node);
  675. int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr(OPPOSITE_RANK));
  676. int64_t rank = g_device_manager->global_rank();
  677. int64_t adasum_rank_distance = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
  678. if (adasum_rank_distance < ADASUM_MIN_DIS) {
  679. adasum_rank_distance = ADASUM_MIN_DIS;
  680. }
  681. size_t border_step = size_t(log2(adasum_rank_distance / ADASUM_MIN_DIS));
  682. int64_t fusion_id = GetValue<int64_t>(send_rec_prim->GetAttr("origin_fusion"));
  683. // when cuting nodes, the fusion id should change.
  684. int64_t new_fusion_id = fusion_id + g_device_manager->DeviceNum() * (border_step + 1);
  685. send_rec_prim->set_attr(FUSION, MakeValue(new_fusion_id));
  686. std::vector<int64_t> group_list;
  687. int64_t new_dest_src_rank;
  688. if (rank > origin_dest_rank) {
  689. group_list = {origin_dest_rank, rank};
  690. new_dest_src_rank = 0;
  691. } else {
  692. group_list = {rank, origin_dest_rank};
  693. new_dest_src_rank = 1;
  694. }
  695. Group adasum_send_rec_group = g_device_manager->CreateGroup(group_list);
  696. send_rec_prim->set_attr(GROUP, MakeValue(adasum_send_rec_group.name()));
  697. if (is_send) {
  698. send_rec_prim->set_attr(DEST_RANK, MakeValue(new_dest_src_rank));
  699. } else {
  700. send_rec_prim->set_attr(SRC_RANK, MakeValue(new_dest_src_rank));
  701. }
  702. int64_t rank_dis = abs(origin_dest_rank - rank);
  703. if (adasum_rank_distance == ADASUM_MIN_DIS) {
  704. return {false, false, false, false};
  705. }
  706. bool is_origin_first_node_if_forward = false;
  707. bool is_new_first_node_if_forward = false;
  708. bool is_origin_last_node_if_rollback = false;
  709. bool is_new_last_node_if_rollback = false;
  710. if (rank_dis == ADASUM_MIN_DIS) {
  711. is_origin_first_node_if_forward = true;
  712. is_origin_last_node_if_rollback = true;
  713. }
  714. if (rank_dis == adasum_rank_distance) {
  715. is_new_first_node_if_forward = true;
  716. }
  717. if (rank_dis == adasum_rank_distance / 2) {
  718. is_new_last_node_if_rollback = true;
  719. }
  720. return {is_origin_first_node_if_forward, is_new_first_node_if_forward, is_origin_last_node_if_rollback,
  721. is_new_last_node_if_rollback};
  722. }
  723. void HandleAdaSumReshape(const CNodePtr &reshape_cnode, const std::shared_ptr<TensorLayout> &target_param_layout) {
  724. auto slice_shape = target_param_layout->slice_shape().array();
  725. auto slice_shape_value = TransVectorToValueSequeue<ValueTuple>(slice_shape);
  726. ValueNodePtr new_slice_shape_value_node = std::make_shared<ValueNode>(slice_shape_value);
  727. reshape_cnode->set_input(2, new_slice_shape_value_node);
  728. }
  729. void RemoveAdasumRedundantNodes(const FuncGraphManagerPtr &manager,
  730. std::unordered_map<std::string, CNodePtr> *forward_origin_first_node_map,
  731. std::unordered_map<std::string, CNodePtr> *forward_new_first_node_map,
  732. std::unordered_map<std::string, CNodePtr> *rollback_origin_last_node_map,
  733. std::unordered_map<std::string, CNodePtr> *rollback_new_last_node_map) {
  734. // connect forward last node and rollback first node
  735. if (forward_origin_first_node_map->size() != forward_new_first_node_map->size() ||
  736. rollback_origin_last_node_map->size() != rollback_new_last_node_map->size()) {
  737. MS_LOG(EXCEPTION) << "The over border node is not equal in adasum forward process and rollback process.";
  738. }
  739. for (auto node : *forward_origin_first_node_map) {
  740. std::string target_param = node.first;
  741. CNodePtr forward_origin_first_node = node.second;
  742. CNodePtr forward_new_first_node = (*forward_new_first_node_map)[target_param];
  743. manager->SetEdge(forward_new_first_node, 1, forward_origin_first_node->input(1));
  744. }
  745. for (auto node : *rollback_origin_last_node_map) {
  746. std::string target_param = node.first;
  747. CNodePtr rollback_origin_last_node = node.second;
  748. CNodePtr rollback_new_last_node = (*rollback_new_last_node_map)[target_param];
  749. manager->Replace(rollback_origin_last_node, rollback_new_last_node);
  750. }
  751. }
  752. void HandleAdasumAllReduce(const PrimitivePtr &prim, const RankList &group_devices) {
  753. size_t step = size_t(GetValue<int64_t>(prim->GetAttr("step")));
  754. std::vector<int64_t> neighbor_ids;
  755. int64_t adasum_rank_distance = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
  756. if (adasum_rank_distance < ADASUM_MIN_DIS) {
  757. adasum_rank_distance = ADASUM_MIN_DIS;
  758. }
  759. size_t border_step = size_t(log2(adasum_rank_distance / ADASUM_MIN_DIS));
  760. MS_LOG(INFO) << "current border step is: " << border_step;
  761. if (step < border_step) {
  762. return;
  763. }
  764. int64_t rank = g_device_manager->global_rank();
  765. size_t double_d = size_t(2 << step);
  766. for (size_t index = 0; index < double_d; ++index) {
  767. int64_t node_rank = rank / ADASUM_MIN_DIS;
  768. int64_t neighbor_id = (node_rank / double_d * double_d + index) * ADASUM_MIN_DIS + rank % ADASUM_MIN_DIS;
  769. neighbor_ids.push_back(neighbor_id);
  770. }
  771. Group adasum_allreduce_group = g_device_manager->CreateGroup(neighbor_ids);
  772. auto new_group_name = MakeValue(adasum_allreduce_group.name());
  773. int64_t fusion_id = GetValue<int64_t>(prim->GetAttr("origin_fusion"));
  774. int64_t new_fusion_id = fusion_id + g_device_manager->DeviceNum() * (border_step + 1);
  775. prim->set_attr(GROUP, new_group_name);
  776. prim->set_attr(FUSION, MakeValue(new_fusion_id));
  777. }
  778. void HandleAdasumSlice(const AnfNodePtr &stridedslice_node1, const std::shared_ptr<TensorLayout> &target_param_layout,
  779. const std::string &target_param, size_t slice_expand_ratio) {
  780. auto stridedslice_cnode1 = stridedslice_node1->cast<CNodePtr>();
  781. ReplaceAdaSumStridedSliceValue(stridedslice_cnode1, target_param_layout, slice_expand_ratio);
  782. auto squeeze_node = RealInputNode(stridedslice_cnode1, 1);
  783. if (!IsPrimitiveCNode(squeeze_node, prim::kPrimSqueeze)) {
  784. MS_LOG(EXCEPTION) << "The stridedslice input node should be squeeze in adasum";
  785. }
  786. auto squeeze_cnode = squeeze_node->cast<CNodePtr>();
  787. FuncGraphManagerPtr manager = squeeze_node->func_graph()->manager();
  788. MS_EXCEPTION_IF_NULL(manager);
  789. AnfNodeIndexSet node_set = manager->node_users()[squeeze_cnode];
  790. for (auto &node_pair : node_set) {
  791. if (IsPrimitiveCNode(node_pair.first, prim::kPrimStridedSlice) && node_pair.first != stridedslice_node1) {
  792. CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
  793. ReplaceAdaSumStridedSliceValue(use_apply, target_param_layout, slice_expand_ratio);
  794. }
  795. }
  796. }
  797. void HandleAdaSumConcat(const AnfNodePtr &concat_node, const std::vector<bool> &border_info,
  798. const std::string &target_param,
  799. std::unordered_map<std::string, CNodePtr> *rollback_new_last_node_map,
  800. std::unordered_map<std::string, CNodePtr> *rollback_origin_last_node_map) {
  801. if (border_info[3]) {
  802. (*rollback_new_last_node_map)[target_param] = concat_node->cast<CNodePtr>();
  803. }
  804. if (border_info[2]) {
  805. auto manager = concat_node->func_graph()->manager();
  806. AnfNodeIndexSet concat_node_user_set = manager->node_users()[concat_node];
  807. for (auto &node_pair : concat_node_user_set) {
  808. if (IsPrimitiveCNode(node_pair.first, prim::kPrimMakeTuple)) {
  809. AnfNodeIndexSet make_tuple_node_user_set = manager->node_users()[node_pair.first];
  810. for (auto &tuple_user : make_tuple_node_user_set) {
  811. if (IsPrimitiveCNode(tuple_user.first, prim::kPrimConcat)) {
  812. (*rollback_origin_last_node_map)[target_param] = tuple_user.first->cast<CNodePtr>();
  813. return;
  814. }
  815. }
  816. return;
  817. }
  818. }
  819. }
  820. }
  821. void HandleAdaSumSqueeze(const AnfNodePtr &stridedslice_node1, const std::vector<bool> &border_info,
  822. const std::string &target_param,
  823. std::unordered_map<std::string, CNodePtr> *forward_origin_first_node_map,
  824. std::unordered_map<std::string, CNodePtr> *forward_new_first_node_map) {
  825. auto squeeze_node = RealInputNode(stridedslice_node1->cast<CNodePtr>(), 1);
  826. if (border_info[0]) {
  827. (*forward_origin_first_node_map)[target_param] = squeeze_node->cast<CNodePtr>();
  828. }
  829. if (border_info[1]) {
  830. (*forward_new_first_node_map)[target_param] = squeeze_node->cast<CNodePtr>();
  831. }
  832. }
  833. void HandleAdaSumPureModelParallel(const AnfNodePtr &node) {
  834. if (!IsPrimitiveCNode(node, prim::kPrimSend) && !IsPrimitiveCNode(node, prim::kPrimReceive)) {
  835. return;
  836. }
  837. PrimitivePtr send_rec_prim = GetCNodePrimitive(node);
  838. int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr(OPPOSITE_RANK));
  839. int64_t rank = g_device_manager->global_rank();
  840. CNodePtr cnode = node->cast<CNodePtr>();
  841. auto pre_cnode = RealInputNode(cnode, 1);
  842. int64_t rank_dis = abs(origin_dest_rank - rank);
  843. if (rank_dis == ADASUM_MIN_DIS && IsPrimitiveCNode(pre_cnode, prim::kPrimStridedSlice)) {
  844. auto squeeze_node = pre_cnode->cast<CNodePtr>()->input(1);
  845. if (!IsPrimitiveCNode(squeeze_node, prim::kPrimSqueeze)) {
  846. return;
  847. }
  848. auto squeeze_input = squeeze_node->cast<CNodePtr>()->input(1);
  849. auto manager = squeeze_node->func_graph()->manager();
  850. AnfNodeIndexSet squeeze_input_node_user_set = manager->node_users()[squeeze_input];
  851. for (auto &squeeze_input_user : squeeze_input_node_user_set) {
  852. if (IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimSqueeze) ||
  853. IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimUpdateState) ||
  854. IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimMakeTuple)) {
  855. continue;
  856. }
  857. manager->Replace(squeeze_input_user.first, squeeze_input);
  858. }
  859. }
  860. }
  861. bool HandleAdaSum(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
  862. std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map) {
  863. std::unordered_map<std::string, CNodePtr> forward_origin_first_node_map;
  864. std::unordered_map<std::string, CNodePtr> forward_new_first_node_map;
  865. std::unordered_map<std::string, CNodePtr> rollback_origin_last_node_map;
  866. std::unordered_map<std::string, CNodePtr> rollback_new_last_node_map;
  867. bool is_adasum = false;
  868. for (auto &node : all_nodes) {
  869. bool is_allreduce = IsPrimitiveCNode(node, prim::kPrimAllReduce);
  870. bool is_reshape = IsPrimitiveCNode(node, prim::kPrimReshape);
  871. bool is_send = IsPrimitiveCNode(node, prim::kPrimSend);
  872. bool is_receive = IsPrimitiveCNode(node, prim::kPrimReceive);
  873. if (!is_allreduce && !is_reshape && !is_send && !is_receive) {
  874. continue;
  875. }
  876. std::string target_param;
  877. CNodePtr cnode = node->cast<CNodePtr>();
  878. PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)->cast<ValueNodePtr>());
  879. if (!prim->HasAttr(TARGET_PARAM)) {
  880. continue;
  881. }
  882. target_param = GetValue<std::string>(prim->GetAttr(TARGET_PARAM));
  883. auto target_param_layout = (*adasum_param_tensor_layout_map)[target_param];
  884. RankList group_devices = GetRankListByLayout(target_param_layout);
  885. // only model parallel
  886. if (group_devices.size() == 1) {
  887. HandleAdaSumPureModelParallel(node);
  888. continue;
  889. }
  890. int64_t adasum_rank_distance = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
  891. // when the repeat dim is right, the parameter do not enable adasum.
  892. if (adasum_rank_distance == 1 && group_devices.size() < size_t(g_device_manager->stage_device_num())) {
  893. continue;
  894. }
  895. MS_LOG(INFO) << "Apply adasum in auto parallel, current dealing node is: " << node->fullname_with_scope();
  896. is_adasum = true;
  897. size_t slice_expand_ratio = adasum_rank_distance / ADASUM_MIN_DIS > 0 ? adasum_rank_distance / ADASUM_MIN_DIS : 1;
  898. if (is_reshape) {
  899. HandleAdaSumReshape(cnode, (*adasum_param_tensor_layout_map)[target_param]);
  900. }
  901. if (is_allreduce && prim->HasAttr("step")) {
  902. HandleAdasumAllReduce(prim, group_devices);
  903. }
  904. if (is_send || is_receive) {
  905. std::vector<bool> border_info = IsBorderAdaSumSendReceive(node, group_devices);
  906. if (is_receive) {
  907. auto target_param_info = std::make_shared<TensorInfo>(*target_param_layout);
  908. Dimensions param_strategy = target_param_info->InferStrategy();
  909. Shape new_rec_shape = ValueSequeueScaleToShape(prim->GetAttr(SHAPE), param_strategy, slice_expand_ratio);
  910. auto new_rec_shape_value = TransVectorToValueSequeue<ValueList>(new_rec_shape);
  911. prim->set_attr(SHAPE, new_rec_shape_value);
  912. continue;
  913. }
  914. auto stridedslice_node1 = RealInputNode(cnode, 1);
  915. if (IsPrimitiveCNode(stridedslice_node1, prim::kPrimConcat)) {
  916. HandleAdaSumConcat(stridedslice_node1, border_info, target_param, &rollback_new_last_node_map,
  917. &rollback_origin_last_node_map);
  918. continue;
  919. }
  920. if (!IsPrimitiveCNode(stridedslice_node1, prim::kPrimStridedSlice)) {
  921. continue;
  922. }
  923. HandleAdasumSlice(stridedslice_node1, target_param_layout, target_param, slice_expand_ratio);
  924. HandleAdaSumSqueeze(stridedslice_node1, border_info, target_param, &forward_origin_first_node_map,
  925. &forward_new_first_node_map);
  926. }
  927. }
  928. RemoveAdasumRedundantNodes(root->manager(), &forward_origin_first_node_map, &forward_new_first_node_map,
  929. &rollback_origin_last_node_map, &rollback_new_last_node_map);
  930. return is_adasum;
  931. }
  932. void ResetMirrorAttr(const PrimitivePtr &prim, const RankList &new_group) {
  933. if (new_group.size() == 1) {
  934. prim->set_attr(DEV_NUM, MakeValue<int64_t>(new_group.size()));
  935. prim->set_attr(GROUP, MakeValue("one_rank_group"));
  936. prim->set_attr(GROUP_RANKS, MakeValue(std::to_string(new_group[0])));
  937. return;
  938. }
  939. Group adasum_mirror_group = g_device_manager->CreateGroup(new_group);
  940. auto new_group_name = MakeValue(adasum_mirror_group.name());
  941. prim->set_attr(GROUP, new_group_name);
  942. prim->set_attr(DEV_NUM, MakeValue<int64_t>(new_group.size()));
  943. std::string rank_list_name = g_device_manager->FindRankListNameByHashName(adasum_mirror_group.name());
  944. prim->set_attr(GROUP_RANKS, MakeValue(rank_list_name));
  945. }
  946. void HandleMirrorInAdaSum(
  947. const FuncGraphPtr &root,
  948. std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map) {
  949. std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(root->get_return());
  950. for (auto &node : all_nodes) {
  951. if (!IsPrimitiveCNode(node, prim::kPrimMirror)) {
  952. continue;
  953. }
  954. CNodePtr mirror_cnode = node->cast<CNodePtr>();
  955. auto param_node_pair = FindParameter(mirror_cnode->input(1), node->func_graph());
  956. if (!param_node_pair.first) {
  957. MS_LOG(EXCEPTION) << "Mirror input is not a param";
  958. }
  959. auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
  960. std::string param_name = param_ptr->name();
  961. MS_LOG(INFO) << "Mirror param name is: " << param_name;
  962. std::string target_param = "adasum_delta_weight." + param_name;
  963. auto target_param_layout = (*adasum_param_tensor_layout_map)[target_param];
  964. // Change mirror group
  965. RankList group_devices = GetRankListByLayout(target_param_layout);
  966. int64_t rank = g_device_manager->global_rank();
  967. size_t group_dis = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
  968. auto prim = GetCNodePrimitive(node);
  969. if (group_dis < ADASUM_MIN_DIS) {
  970. size_t new_group_size = size_t(ADASUM_MIN_DIS) / group_dis;
  971. // compute new group range
  972. size_t group_begin = 0;
  973. for (size_t group_end = new_group_size; group_end < group_devices.size() + new_group_size;
  974. group_end += new_group_size) {
  975. int64_t max_group_value =
  976. group_end >= group_devices.size() ? (group_devices.back() + 1) : group_devices[group_end];
  977. if (group_devices[group_begin] <= rank && rank < max_group_value) {
  978. std::vector<int64_t> new_group(group_devices.begin() + group_begin, group_devices.begin() + group_end);
  979. MS_LOG(INFO) << "Find new mirror group in adasum: " << new_group << " target_param:" << target_param;
  980. ResetMirrorAttr(prim, new_group);
  981. break;
  982. }
  983. group_begin = group_end;
  984. }
  985. continue;
  986. }
  987. ResetMirrorAttr(prim, {rank});
  988. }
  989. }
  990. void HandleAdaFactorOpt(const FuncGraphPtr &root) {
  991. MS_EXCEPTION_IF_NULL(root);
  992. for (auto &param_node : root->parameters()) {
  993. MS_EXCEPTION_IF_NULL(param_node);
  994. auto param = param_node->cast<ParameterPtr>();
  995. MS_EXCEPTION_IF_NULL(param);
  996. if (!IsOriginWeight(param)) {
  997. continue;
  998. }
  999. int64_t row_col_count = 0;
  1000. int64_t exp_avg_sq_count = 0;
  1001. for (auto &row_col_node : root->parameters()) {
  1002. if (row_col_count == 2 && exp_avg_sq_count == 1) {
  1003. break;
  1004. }
  1005. MS_EXCEPTION_IF_NULL(row_col_node);
  1006. auto row_col_param = row_col_node->cast<ParameterPtr>();
  1007. MS_EXCEPTION_IF_NULL(row_col_param);
  1008. std::string row_col_param_name = row_col_param->name();
  1009. std::string param_name = param->name();
  1010. std::string exp_row_name = EXP_AVG_SQ_ROW + param_name;
  1011. std::string exp_col_name = EXP_AVG_SQ_COL + param_name;
  1012. std::string exp_avg_name = EXP_AVG_SQ + param_name;
  1013. if ((row_col_param_name != exp_row_name) && (row_col_param_name != exp_col_name) &&
  1014. (row_col_param_name != exp_avg_name)) {
  1015. continue;
  1016. }
  1017. auto tensor_layout = param->user_data<TensorLayout>();
  1018. MS_EXCEPTION_IF_NULL(tensor_layout);
  1019. auto slice_shape = tensor_layout->slice_shape().array();
  1020. Shape opt_shard_slice_shape = slice_shape;
  1021. if (!tensor_layout->opt_shard_group().empty()) {
  1022. opt_shard_slice_shape = tensor_layout->opt_shard_slice_shape();
  1023. }
  1024. auto shape_size = slice_shape.size();
  1025. bool is_row_or_col_param = (row_col_param_name == exp_row_name) || (row_col_param_name == exp_col_name);
  1026. if (is_row_or_col_param && shape_size <= 1) {
  1027. row_col_count++;
  1028. continue;
  1029. }
  1030. if (row_col_param_name == exp_avg_name && shape_size != 1) {
  1031. exp_avg_sq_count++;
  1032. continue;
  1033. }
  1034. auto origin_shape = tensor_layout->tensor_shape().array();
  1035. auto dev_mat = tensor_layout->device_arrangement().array();
  1036. auto tensor_map = tensor_layout->tensor_map().array();
  1037. if (row_col_param_name == exp_row_name) {
  1038. opt_shard_slice_shape.pop_back();
  1039. origin_shape.pop_back();
  1040. tensor_map.pop_back();
  1041. row_col_count++;
  1042. } else if (row_col_param_name == exp_col_name) {
  1043. (void)opt_shard_slice_shape.erase(opt_shard_slice_shape.begin() +
  1044. static_cast<different_type>(SECOND_FROM_END(shape_size)));
  1045. (void)origin_shape.erase(origin_shape.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
  1046. (void)tensor_map.erase(tensor_map.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
  1047. row_col_count++;
  1048. } else {
  1049. exp_avg_sq_count++;
  1050. }
  1051. TensorLayout new_tensor_layout;
  1052. if (new_tensor_layout.InitFromVector(dev_mat, tensor_map, origin_shape) != SUCCESS) {
  1053. MS_LOG(EXCEPTION) << "Init tensor layout failed";
  1054. }
  1055. if (AdafactorStateIsOptShard(tensor_layout->opt_shard_group(), shape_size, param_name, row_col_param_name)) {
  1056. new_tensor_layout.set_opt_shard_group(tensor_layout->opt_shard_group());
  1057. }
  1058. auto cloned_abstract = row_col_node->abstract()->Clone();
  1059. MS_EXCEPTION_IF_NULL(cloned_abstract);
  1060. std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(opt_shard_slice_shape);
  1061. MS_EXCEPTION_IF_NULL(parallel_shape);
  1062. cloned_abstract->set_shape(parallel_shape);
  1063. row_col_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(new_tensor_layout));
  1064. row_col_node->set_abstract(cloned_abstract);
  1065. MS_LOG(INFO) << "Set the slice shape for " << row_col_param_name << ", origin shape is " << origin_shape
  1066. << ", new slice shape is " << opt_shard_slice_shape;
  1067. }
  1068. }
  1069. }
  1070. } // namespace parallel
  1071. } // namespace mindspore