Browse Source

!23517 fix parallel opt recompute pass warning

Merge pull request !23517 from yao_yf/parallel_opt_recompute_pass_warning_fix
tags/v1.5.0-rc1
i-robot Gitee 4 years ago
parent
commit
be10bb7af6
2 changed files with 4 additions and 2 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/optimizer/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.cc
  2. +3
    -1
      mindspore/parallel/_utils.py

+ 1
- 1
mindspore/ccsrc/backend/optimizer/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.cc View File

@@ -30,7 +30,7 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::Run(const FuncGraphPtr
std::vector<int64_t> parallel_optimizer_recompute_allgather_fusion_ids;
std::vector<AnfNodePtr> parallel_optimizer_recompute_allgathers;
std::vector<AnfNodePtr> parallel_optimizer_recompute_first_fusion_allgathers;
int64_t unrecompute_max_fusion_id = 0;
int64_t unrecompute_max_fusion_id = -1;
int64_t recompute_min_fusion_id = 0;
for (auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);


+ 3
- 1
mindspore/parallel/_utils.py View File

@@ -274,7 +274,9 @@ def _reset_op_id():
def _parallel_predict_check():
"""validate parallel model prediction"""
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
if not context.get_auto_parallel_context("full_batch"):
dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
is_shard_dataset_mp = (dataset_strategy and dataset_strategy not in ("data_parallel", "full_batch"))
if not context.get_auto_parallel_context("full_batch") and not is_shard_dataset_mp:
raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.')




Loading…
Cancel
Save