diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc index a49824f18b..b85e63c11e 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc @@ -203,11 +203,11 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { MS_LOG(ERROR) << name_ << " : The dev num is 0."; return FAILED; } - if (out_shapes[i][0] % dev_num_ != 0) { - MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num."; - return FAILED; - } if (!full_batch) { + if (out_shapes[i][0] % dev_num_ != 0) { + MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num."; + return FAILED; + } out_shapes[i][0] = out_shapes[i][0] / dev_num_; } } diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index 60c72cc6b7..ddd8fd294e 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -238,7 +238,7 @@ class Lamb(Optimizer): Examples: >>> net = Net() >>> #1) All parameters use the same learning rate and weight decay - >>> optim = nn.Lamb(params=net.trainable_params()) + >>> optim = nn.Lamb(params=net.trainable_params(learning_rate=0.1)) >>> >>> #2) Use parameter groups and set different values >>> poly_decay_lr = learning_rate_schedule.PolynomialDecayLR() diff --git a/mindspore/nn/wrap/grad_reducer.py b/mindspore/nn/wrap/grad_reducer.py index ea47b6ca73..b6ec7191b2 100644 --- a/mindspore/nn/wrap/grad_reducer.py +++ b/mindspore/nn/wrap/grad_reducer.py @@ -254,6 +254,8 @@ class DistributedGradReducer(Cell): >>> from mindspore.context import ParallelMode >>> from mindspore import nn >>> from mindspore import ParameterTuple + >>> from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, + >>> _get_parallel_mode) >>> >>> device_id = int(os.environ["DEVICE_ID"]) >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, @@ -279,11 +281,8 @@ class DistributedGradReducer(Cell): >>> ParallelMode.HYBRID_PARALLEL]: >>> self.reducer_flag = True >>> if self.reducer_flag: - >>> mean = context.get_auto_parallel_context("gradients_mean") - >>> if mean.get_device_num_is_set(): - >>> degree = context.get_auto_parallel_context("device_num") - >>> else: - >>> degree = get_group_size() + >>> mean = _get_gradients_mean() + >>> degree = _get_device_num() >>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) >>> >>> def construct(self, *args):