| @@ -239,30 +239,32 @@ usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_n | |||||
| [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] | [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] | ||||
| [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] | [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] | ||||
| [--accumulation_steps N] | [--accumulation_steps N] | ||||
| [--allreduce_post_accumulation ALLREDUCE_POST_ACCUMULATION] | |||||
| [--save_checkpoint_path SAVE_CHECKPOINT_PATH] | [--save_checkpoint_path SAVE_CHECKPOINT_PATH] | ||||
| [--load_checkpoint_path LOAD_CHECKPOINT_PATH] | [--load_checkpoint_path LOAD_CHECKPOINT_PATH] | ||||
| [--save_checkpoint_steps N] [--save_checkpoint_num N] | [--save_checkpoint_steps N] [--save_checkpoint_num N] | ||||
| [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] [train_steps N] | [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] [train_steps N] | ||||
| options: | options: | ||||
| --device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend" | |||||
| --distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false" | |||||
| --epoch_size epoch size: N, default is 1 | |||||
| --device_num number of used devices: N, default is 1 | |||||
| --device_id device id: N, default is 0 | |||||
| --enable_save_ckpt enable save checkpoint: "true" | "false", default is "true" | |||||
| --enable_lossscale enable lossscale: "true" | "false", default is "true" | |||||
| --do_shuffle enable shuffle: "true" | "false", default is "true" | |||||
| --enable_data_sink enable data sink: "true" | "false", default is "true" | |||||
| --data_sink_steps set data sink steps: N, default is 1 | |||||
| --accumulation_steps accumulate gradients N times before weight update: N, default is 1 | |||||
| --save_checkpoint_path path to save checkpoint files: PATH, default is "" | |||||
| --load_checkpoint_path path to load checkpoint files: PATH, default is "" | |||||
| --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 | |||||
| --save_checkpoint_num number for saving checkpoint files: N, default is 1 | |||||
| --train_steps Training Steps: N, default is -1 | |||||
| --data_dir path to dataset directory: PATH, default is "" | |||||
| --schema_dir path to schema.json file, PATH, default is "" | |||||
| --device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend" | |||||
| --distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false" | |||||
| --epoch_size epoch size: N, default is 1 | |||||
| --device_num number of used devices: N, default is 1 | |||||
| --device_id device id: N, default is 0 | |||||
| --enable_save_ckpt enable save checkpoint: "true" | "false", default is "true" | |||||
| --enable_lossscale enable lossscale: "true" | "false", default is "true" | |||||
| --do_shuffle enable shuffle: "true" | "false", default is "true" | |||||
| --enable_data_sink enable data sink: "true" | "false", default is "true" | |||||
| --data_sink_steps set data sink steps: N, default is 1 | |||||
| --accumulation_steps accumulate gradients N times before weight update: N, default is 1 | |||||
| --allreduce_post_accumulation allreduce after accumulation of N steps or after each step: "true" | "false", default is "true" | |||||
| --save_checkpoint_path path to save checkpoint files: PATH, default is "" | |||||
| --load_checkpoint_path path to load checkpoint files: PATH, default is "" | |||||
| --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 | |||||
| --save_checkpoint_num number for saving checkpoint files: N, default is 1 | |||||
| --train_steps Training Steps: N, default is -1 | |||||
| --data_dir path to dataset directory: PATH, default is "" | |||||
| --schema_dir path to schema.json file, PATH, default is "" | |||||
| ``` | ``` | ||||
| ### Fine-Tuning and Evaluation | ### Fine-Tuning and Evaluation | ||||
| @@ -32,7 +32,9 @@ from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ | from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ | ||||
| BertTrainAccumulateStepsWithLossScaleCell, BertTrainOneStepWithLossScaleCellForAdam, \ | |||||
| BertTrainAccumulationAllReduceEachWithLossScaleCell, \ | |||||
| BertTrainAccumulationAllReducePostWithLossScaleCell, \ | |||||
| BertTrainOneStepWithLossScaleCellForAdam, \ | |||||
| AdamWeightDecayForBert | AdamWeightDecayForBert | ||||
| from src.dataset import create_bert_dataset | from src.dataset import create_bert_dataset | ||||
| from src.config import cfg, bert_net_cfg | from src.config import cfg, bert_net_cfg | ||||
| @@ -122,6 +124,8 @@ def run_pretrain(): | |||||
| parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") | parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") | ||||
| parser.add_argument("--accumulation_steps", type=int, default="1", | parser.add_argument("--accumulation_steps", type=int, default="1", | ||||
| help="Accumulating gradients N times before weight update, default is 1.") | help="Accumulating gradients N times before weight update, default is 1.") | ||||
| parser.add_argument("--allreduce_post_accumulation", type=str, default="true", choices=["true", "false"], | |||||
| help="Whether to allreduce after accumulation of N steps or after each step, default is true.") | |||||
| parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") | parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") | ||||
| parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") | parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") | ||||
| parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " | parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " | ||||
| @@ -207,8 +211,9 @@ def run_pretrain(): | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, | update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, | ||||
| scale_factor=cfg.scale_factor, | scale_factor=cfg.scale_factor, | ||||
| scale_window=cfg.scale_window) | scale_window=cfg.scale_window) | ||||
| if args_opt.accumulation_steps <= 1: | |||||
| accumulation_steps = args_opt.accumulation_steps | |||||
| enable_global_norm = cfg.enable_global_norm | |||||
| if accumulation_steps <= 1: | |||||
| if cfg.optimizer == 'AdamWeightDecay': | if cfg.optimizer == 'AdamWeightDecay': | ||||
| net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer, | net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer, | ||||
| scale_update_cell=update_cell) | scale_update_cell=update_cell) | ||||
| @@ -216,11 +221,13 @@ def run_pretrain(): | |||||
| net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, | net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, | ||||
| scale_update_cell=update_cell) | scale_update_cell=update_cell) | ||||
| else: | else: | ||||
| accumulation_steps = args_opt.accumulation_steps | |||||
| net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer, | |||||
| scale_update_cell=update_cell, | |||||
| accumulation_steps=accumulation_steps, | |||||
| enable_global_norm=cfg.enable_global_norm) | |||||
| allreduce_post = args_opt.distribute == "false" or args_opt.allreduce_post_accumulation == "true" | |||||
| net_with_accumulation = (BertTrainAccumulationAllReducePostWithLossScaleCell if allreduce_post else | |||||
| BertTrainAccumulationAllReduceEachWithLossScaleCell) | |||||
| net_with_grads = net_with_accumulation(net_with_loss, optimizer=optimizer, | |||||
| scale_update_cell=update_cell, | |||||
| accumulation_steps=accumulation_steps, | |||||
| enable_global_norm=enable_global_norm) | |||||
| else: | else: | ||||
| net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) | net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) | ||||
| @@ -7,6 +7,7 @@ do_shuffle=true | |||||
| enable_data_sink=true | enable_data_sink=true | ||||
| data_sink_steps=100 | data_sink_steps=100 | ||||
| accumulation_steps=1 | accumulation_steps=1 | ||||
| allreduce_post_accumulation=true | |||||
| save_checkpoint_path=./ | save_checkpoint_path=./ | ||||
| save_checkpoint_steps=10000 | save_checkpoint_steps=10000 | ||||
| save_checkpoint_num=1 | save_checkpoint_num=1 | ||||
| @@ -16,7 +16,8 @@ | |||||
| from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \ | from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \ | ||||
| BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \ | BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \ | ||||
| BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ | BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ | ||||
| BertTrainAccumulateStepsWithLossScaleCell | |||||
| BertTrainAccumulationAllReduceEachWithLossScaleCell, \ | |||||
| BertTrainAccumulationAllReducePostWithLossScaleCell | |||||
| from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ | from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ | ||||
| BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ | BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ | ||||
| EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ | EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ | ||||
| @@ -25,7 +26,8 @@ from .adam import AdamWeightDecayForBert | |||||
| __all__ = [ | __all__ = [ | ||||
| "BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss", | "BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss", | ||||
| "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", | "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", | ||||
| "BertTrainOneStepWithLossScaleCell", "BertTrainAccumulateStepsWithLossScaleCell", | |||||
| "BertTrainOneStepWithLossScaleCell", "BertTrainAccumulationAllReduceEachWithLossScaleCell", | |||||
| "BertTrainAccumulationAllReducePostWithLossScaleCell", | |||||
| "BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", | "BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", | ||||
| "BertSelfAttention", "BertTransformer", "EmbeddingLookup", | "BertSelfAttention", "BertTransformer", "EmbeddingLookup", | ||||
| "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert", | "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert", | ||||
| @@ -556,11 +556,24 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): | |||||
| return F.depend(ret, succ) | return F.depend(ret, succ) | ||||
| cast = P.Cast() | cast = P.Cast() | ||||
| update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") | |||||
| add_grads = C.MultitypeFuncGraph("add_grads") | |||||
| @add_grads.register("Tensor", "Tensor") | |||||
| def _add_grads(accu_grad, grad): | |||||
| return accu_grad + cast(grad, mstype.float32) | |||||
| update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") | |||||
| @update_accu_grads.register("Tensor", "Tensor") | @update_accu_grads.register("Tensor", "Tensor") | ||||
| def _update_accu_grads(accu_grad, grad): | def _update_accu_grads(accu_grad, grad): | ||||
| succ = True | |||||
| return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32))) | |||||
| accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads") | |||||
| @accumulate_accu_grads.register("Tensor", "Tensor") | |||||
| def _accumulate_accu_grads(accu_grad, grad): | |||||
| succ = True | succ = True | ||||
| return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32))) | return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32))) | ||||
| @@ -575,13 +588,17 @@ def _reset_accu_grads(accu_grad): | |||||
| return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad))) | return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad))) | ||||
| class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): | |||||
| class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell): | |||||
| """ | """ | ||||
| Encapsulation class of bert network training. | Encapsulation class of bert network training. | ||||
| Append an optimizer to the training network after that the construct | Append an optimizer to the training network after that the construct | ||||
| function can be called to create the backward graph. To mimic higher batch size, gradients are | |||||
| accumulated N times before weight update. | |||||
| function can be called to create the backward graph. | |||||
| To mimic higher batch size, gradients are accumulated N times before weight update. | |||||
| For distribution mode, allreduce will only be implemented in the weight updated step, | |||||
| i.e. the sub-step after gradients accumulated N times. | |||||
| Args: | Args: | ||||
| network (Cell): The training network. Note that loss function should have been added. | network (Cell): The training network. Note that loss function should have been added. | ||||
| @@ -591,7 +608,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): | |||||
| batch_size * accumulation_steps. Default: 1. | batch_size * accumulation_steps. Default: 1. | ||||
| """ | """ | ||||
| def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): | def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): | ||||
| super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) | |||||
| super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | self.network = network | ||||
| self.network.set_grad() | self.network.set_grad() | ||||
| self.weights = optimizer.parameters | self.weights = optimizer.parameters | ||||
| @@ -680,7 +697,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): | |||||
| self.cast(scaling_sens, | self.cast(scaling_sens, | ||||
| mstype.float32)) | mstype.float32)) | ||||
| accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) | |||||
| accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads) | |||||
| mean_loss = F.depend(mean_loss, accu_succ) | mean_loss = F.depend(mean_loss, accu_succ) | ||||
| self.get_status(init) | self.get_status(init) | ||||
| @@ -716,3 +733,151 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): | |||||
| ret = (mean_loss, overflow, scaling_sens) | ret = (mean_loss, overflow, scaling_sens) | ||||
| return F.depend(ret, succ) | return F.depend(ret, succ) | ||||
| class BertTrainAccumulationAllReduceEachWithLossScaleCell(nn.Cell): | |||||
| """ | |||||
| Encapsulation class of bert network training. | |||||
| Append an optimizer to the training network after that the construct | |||||
| function can be called to create the backward graph. | |||||
| To mimic higher batch size, gradients are accumulated N times before weight update. | |||||
| For distribution mode, allreduce will be implemented after each sub-step and the trailing time | |||||
| will be overided by backend optimization pass. | |||||
| Args: | |||||
| network (Cell): The training network. Note that loss function should have been added. | |||||
| optimizer (Optimizer): Optimizer for updating the weights. | |||||
| scale_update_cell (Cell): Cell to do the loss scale. Default: None. | |||||
| accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = | |||||
| batch_size * accumulation_steps. Default: 1. | |||||
| """ | |||||
| def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): | |||||
| super(BertTrainAccumulationAllReduceEachWithLossScaleCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.network.set_grad() | |||||
| self.weights = optimizer.parameters | |||||
| self.optimizer = optimizer | |||||
| self.accumulation_steps = accumulation_steps | |||||
| self.enable_global_norm = enable_global_norm | |||||
| self.one = Tensor(np.array([1]).astype(np.int32)) | |||||
| self.zero = Tensor(np.array([0]).astype(np.int32)) | |||||
| self.local_step = Parameter(initializer(0, [1], mstype.int32)) | |||||
| self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') | |||||
| self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) | |||||
| self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) | |||||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||||
| self.reducer_flag = False | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
| self.reducer_flag = True | |||||
| self.grad_reducer = F.identity | |||||
| self.degree = 1 | |||||
| if self.reducer_flag: | |||||
| self.degree = get_group_size() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) | |||||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||||
| self.overflow_reducer = F.identity | |||||
| if self.is_distributed: | |||||
| self.overflow_reducer = P.AllReduce() | |||||
| self.cast = P.Cast() | |||||
| self.alloc_status = P.NPUAllocFloatStatus() | |||||
| self.get_status = P.NPUGetFloatStatus() | |||||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||||
| self.base = Tensor(1, mstype.float32) | |||||
| self.less_equal = P.LessEqual() | |||||
| self.logical_or = P.LogicalOr() | |||||
| self.not_equal = P.NotEqual() | |||||
| self.select = P.Select() | |||||
| self.reshape = P.Reshape() | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.loss_scale = None | |||||
| self.loss_scaling_manager = scale_update_cell | |||||
| if scale_update_cell: | |||||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) | |||||
| @C.add_flags(has_effect=True) | |||||
| def construct(self, | |||||
| input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| next_sentence_labels, | |||||
| masked_lm_positions, | |||||
| masked_lm_ids, | |||||
| masked_lm_weights, | |||||
| sens=None): | |||||
| """Defines the computation performed.""" | |||||
| weights = self.weights | |||||
| loss = self.network(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| next_sentence_labels, | |||||
| masked_lm_positions, | |||||
| masked_lm_ids, | |||||
| masked_lm_weights) | |||||
| if sens is None: | |||||
| scaling_sens = self.loss_scale | |||||
| else: | |||||
| scaling_sens = sens | |||||
| # update accumulation parameters | |||||
| is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) | |||||
| self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) | |||||
| self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) | |||||
| mean_loss = self.accu_loss / self.local_step | |||||
| is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) | |||||
| # alloc status and clear should be right before gradoperation | |||||
| init = self.alloc_status() | |||||
| self.clear_before_grad(init) | |||||
| grads = self.grad(self.network, weights)(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| next_sentence_labels, | |||||
| masked_lm_positions, | |||||
| masked_lm_ids, | |||||
| masked_lm_weights, | |||||
| self.cast(scaling_sens, | |||||
| mstype.float32)) | |||||
| accu_grads = self.hyper_map(add_grads, self.accu_grads, grads) | |||||
| scaling = scaling_sens * self.degree * self.accumulation_steps | |||||
| grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads) | |||||
| grads = self.grad_reducer(grads) | |||||
| self.get_status(init) | |||||
| flag_sum = self.reduce_sum(init, (0,)) | |||||
| flag_reduce = self.overflow_reducer(flag_sum) | |||||
| overflow = self.less_equal(self.base, flag_reduce) | |||||
| overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) | |||||
| accu_overflow = self.select(overflow, self.one, self.zero) | |||||
| self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) | |||||
| overflow = self.reshape(overflow, (())) | |||||
| if is_accu_step: | |||||
| succ = False | |||||
| accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads) | |||||
| succ = F.depend(succ, accu_succ) | |||||
| else: | |||||
| if sens is None: | |||||
| overflow = self.loss_scaling_manager(self.loss_scale, overflow) | |||||
| if overflow: | |||||
| succ = False | |||||
| else: | |||||
| if self.enable_global_norm: | |||||
| grads = C.clip_by_global_norm(grads, 1.0, None) | |||||
| else: | |||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||||
| succ = self.optimizer(grads) | |||||
| accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) | |||||
| succ = F.depend(succ, accu_succ) | |||||
| ret = (mean_loss, overflow, scaling_sens) | |||||
| return F.depend(ret, succ) | |||||