| @@ -184,7 +184,7 @@ class Callback: | |||||
| """ | """ | ||||
| pass | pass | ||||
| def on_before_optimizer_step(self, trainer, optimizers): | |||||
| def on_before_optimizers_step(self, trainer, optimizers): | |||||
| """ | """ | ||||
| 在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | 在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | ||||
| @@ -194,6 +194,16 @@ class Callback: | |||||
| """ | """ | ||||
| pass | pass | ||||
| def on_after_optimizers_step(self, trainer, optimizers): | |||||
| """ | |||||
| 在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
| :param trainer: | |||||
| :param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
| :return: | |||||
| """ | |||||
| pass | |||||
| def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
| """ | """ | ||||
| 在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | 在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | ||||
| @@ -204,6 +214,16 @@ class Callback: | |||||
| """ | """ | ||||
| pass | pass | ||||
| def on_after_zero_grad(self, trainer, optimizers): | |||||
| """ | |||||
| 在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
| :param trainer: | |||||
| :param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
| :return: | |||||
| """ | |||||
| pass | |||||
| def on_validate_begin(self, trainer): | def on_validate_begin(self, trainer): | ||||
| """ | """ | ||||
| 在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 | 在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 | ||||
| @@ -92,8 +92,10 @@ class Events(EventEnum): | |||||
| ON_LOAD_CHECKPOINT = "on_load_checkpoint" | ON_LOAD_CHECKPOINT = "on_load_checkpoint" | ||||
| ON_BEFORE_BACKWARD = "on_before_backward" | ON_BEFORE_BACKWARD = "on_before_backward" | ||||
| ON_AFTER_BACKWARD = "on_after_backward" | ON_AFTER_BACKWARD = "on_after_backward" | ||||
| ON_BEFORE_OPTIMIZER_STEP = "on_before_optimizer_step" | |||||
| ON_BEFORE_OPTIMIZERS_STEP = "on_before_optimizers_step" | |||||
| ON_AFTER_OPTIMIZERS_STEP = "on_after_optimizers_step" | |||||
| ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" | ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" | ||||
| ON_AFTER_ZERO_GRAD = "on_after_zero_grad" | |||||
| ON_VALIDATE_BEGIN = "on_validate_begin" | ON_VALIDATE_BEGIN = "on_validate_begin" | ||||
| ON_VALIDATE_END = "on_validate_end" | ON_VALIDATE_END = "on_validate_end" | ||||
| @@ -278,13 +278,21 @@ class CallbackManager: | |||||
| pass | pass | ||||
| @_transfer | @_transfer | ||||
| def on_before_optimizer_step(self, trainer, optimizers): | |||||
| def on_before_optimizers_step(self, trainer, optimizers): | |||||
| pass | |||||
| @_transfer | |||||
| def on_after_optimizers_step(self, trainer, optimizers): | |||||
| pass | pass | ||||
| @_transfer | @_transfer | ||||
| def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
| pass | pass | ||||
| @_transfer | |||||
| def on_after_zero_grad(self, trainer, optimizers): | |||||
| pass | |||||
| @_transfer | @_transfer | ||||
| def on_validate_begin(self, trainer): | def on_validate_begin(self, trainer): | ||||
| pass | pass | ||||
| @@ -137,6 +137,7 @@ class Trainer(TrainerEventTrigger): | |||||
| else: | else: | ||||
| self.driver_name = driver.__class__.__name__ | self.driver_name = driver.__class__.__name__ | ||||
| self.device = device | self.device = device | ||||
| self.optimizers = optimizers | |||||
| self.fp16 = fp16 | self.fp16 = fp16 | ||||
| self.input_mapping = input_mapping | self.input_mapping = input_mapping | ||||
| self.output_mapping = output_mapping | self.output_mapping = output_mapping | ||||
| @@ -440,9 +441,11 @@ class Trainer(TrainerEventTrigger): | |||||
| 2. 函数作用 | 2. 函数作用 | ||||
| 这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际 | 这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际 | ||||
| 定制了 ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") / | |||||
| 定制了 ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", | |||||
| "on_after_zero_grad") / | |||||
| ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | ||||
| "on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
| "on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", | |||||
| "on_after_zero_grad") | |||||
| 这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中 | 这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中 | ||||
| 上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为; | 上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为; | ||||
| @@ -452,10 +455,12 @@ class Trainer(TrainerEventTrigger): | |||||
| 'batch_step_fn',为 False 时表示检测 'TrainBatchLoop'; | 'batch_step_fn',为 False 时表示检测 'TrainBatchLoop'; | ||||
| """ | """ | ||||
| if check_mode: | if check_mode: | ||||
| callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
| callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", | |||||
| "on_before_zero_grad", "on_after_zero_grad") | |||||
| else: | else: | ||||
| callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | ||||
| "on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
| "on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", | |||||
| "on_before_zero_grad", "on_after_zero_grad") | |||||
| _not_called_callback_fns = [] | _not_called_callback_fns = [] | ||||
| for each_callback_fn in callbacks: | for each_callback_fn in callbacks: | ||||
| if each_callback_fn in self.callback_manager.callback_fns: | if each_callback_fn in self.callback_manager.callback_fns: | ||||
| @@ -699,13 +704,15 @@ class Trainer(TrainerEventTrigger): | |||||
| def zero_grad(self): | def zero_grad(self): | ||||
| if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | ||||
| self.on_before_zero_grad(self.driver.optimizers) | |||||
| self.on_before_zero_grad(self.optimizers) | |||||
| self.driver.zero_grad(self.set_grad_to_none) | self.driver.zero_grad(self.set_grad_to_none) | ||||
| self.on_after_zero_grad(self.optimizers) | |||||
| def step(self): | def step(self): | ||||
| if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | ||||
| self.on_before_optimizer_step(self.driver.optimizers) | |||||
| self.on_before_optimizers_step(self.optimizers) | |||||
| self.driver.step() | self.driver.step() | ||||
| self.on_after_optimizers_step(self.optimizers) | |||||
| def move_data_to_device(self, batch): | def move_data_to_device(self, batch): | ||||
| return self.driver.move_data_to_device(batch) | return self.driver.move_data_to_device(batch) | ||||
| @@ -817,3 +824,5 @@ class Trainer(TrainerEventTrigger): | |||||
| @@ -68,12 +68,18 @@ class TrainerEventTrigger: | |||||
| def on_after_backward(self): | def on_after_backward(self): | ||||
| self.callback_manager.on_after_backward(self) | self.callback_manager.on_after_backward(self) | ||||
| def on_before_optimizer_step(self, optimizers): | |||||
| self.callback_manager.on_before_optimizer_step(self, optimizers) | |||||
| def on_before_optimizers_step(self, optimizers): | |||||
| self.callback_manager.on_before_optimizers_step(self, optimizers) | |||||
| def on_after_optimizers_step(self, optimizers): | |||||
| self.callback_manager.on_after_optimizers_step(self, optimizers) | |||||
| def on_before_zero_grad(self, optimizers): | def on_before_zero_grad(self, optimizers): | ||||
| self.callback_manager.on_before_zero_grad(self, optimizers) | self.callback_manager.on_before_zero_grad(self, optimizers) | ||||
| def on_after_zero_grad(self, optimizers): | |||||
| self.callback_manager.on_after_zero_grad(self, optimizers) | |||||
| def on_validate_begin(self): | def on_validate_begin(self): | ||||
| self.callback_manager.on_validate_begin(self) | self.callback_manager.on_validate_begin(self) | ||||
| @@ -530,14 +530,6 @@ class TorchDDPDriver(TorchDriver): | |||||
| else: | else: | ||||
| raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | ||||
| def backward(self, loss): | |||||
| self.grad_scaler.scale(loss).backward() | |||||
| def step(self): | |||||
| for optimizer in self.optimizers: | |||||
| self.grad_scaler.step(optimizer) | |||||
| self.grad_scaler.update() | |||||
| def is_global_zero(self): | def is_global_zero(self): | ||||
| return self.global_rank == 0 | return self.global_rank == 0 | ||||
| @@ -107,14 +107,6 @@ class TorchSingleDriver(TorchDriver): | |||||
| else: | else: | ||||
| return self._train_step(batch) | return self._train_step(batch) | ||||
| def backward(self, loss): | |||||
| self.grad_scaler.scale(loss).backward() | |||||
| def step(self): | |||||
| for optimizer in self.optimizers: | |||||
| self.grad_scaler.step(optimizer) | |||||
| self.grad_scaler.update() | |||||
| def validate_step(self, batch) -> Dict: | def validate_step(self, batch) -> Dict: | ||||
| # 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | # 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | ||||
| # 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | # 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | ||||
| @@ -72,6 +72,14 @@ class TorchDriver(Driver): | |||||
| p.grad.requires_grad_(False) | p.grad.requires_grad_(False) | ||||
| p.grad.zero_() | p.grad.zero_() | ||||
| def backward(self, loss): | |||||
| self.grad_scaler.scale(loss).backward() | |||||
| def step(self): | |||||
| for optimizer in self.optimizers: | |||||
| self.grad_scaler.step(optimizer) | |||||
| self.grad_scaler.update() | |||||
| @staticmethod | @staticmethod | ||||
| def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | ||||
| if is_train: | if is_train: | ||||
| @@ -101,12 +101,18 @@ class RecordTrainerEventTriggerCallback(Callback): | |||||
| def on_after_backward(self, trainer): | def on_after_backward(self, trainer): | ||||
| print("on_after_backward") | print("on_after_backward") | ||||
| def on_before_optimizer_step(self, trainer, optimizers): | |||||
| print("on_before_optimizer_step") | |||||
| def on_before_optimizers_step(self, trainer, optimizers): | |||||
| print("on_before_optimizers_step") | |||||
| def on_after_optimizers_step(self, trainer, optimizers): | |||||
| print("on_after_optimizers_step") | |||||
| def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
| print("on_before_zero_grad") | print("on_before_zero_grad") | ||||
| def on_after_zero_grad(self, trainer, optimizers): | |||||
| print("on_after_zero_grad") | |||||
| def on_validate_begin(self, trainer): | def on_validate_begin(self, trainer): | ||||
| print("on_validate_begin") | print("on_validate_begin") | ||||