| @@ -19,7 +19,7 @@ class LRSchedCallback(Callback): | |||||
| self.scheduler = scheduler | self.scheduler = scheduler | ||||
| self.step_on = 0 if step_on == 'batch' else 1 | self.step_on = 0 if step_on == 'batch' else 1 | ||||
| def on_before_optimizers_step(self, trainer, optimizers): | |||||
| def on_after_optimizers_step(self, trainer, optimizers): | |||||
| if self.step_on == 0: | if self.step_on == 0: | ||||
| self.scheduler.step() | self.scheduler.step() | ||||
| @@ -178,19 +178,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler | |||||
| # 中寻找;VAR_KEYWORD 代表 **kwargs | # 中寻找;VAR_KEYWORD 代表 **kwargs | ||||
| has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | ||||
| if has_variadic_kwargs: | if has_variadic_kwargs: | ||||
| init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | |||||
| del init_params["self"] | |||||
| # 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | |||||
| # 将同时在实例名和参数名中出现且不是默认值的参数收集起来 | |||||
| non_default_params = {name for name, p in init_params.items() if | |||||
| name in instance_attrs and p.default != instance_attrs[name]} | |||||
| # add `dataset` as it might have been replaced with `*args` | |||||
| non_default_params.add("dataset") | |||||
| # 收集不是默认值的参数和它的值 | |||||
| reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||||
| # persistent_workers 在类中的对应成员带有下划线,因此添加进来 | |||||
| for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): | |||||
| if key not in init_params and key != 'self': | |||||
| init_params[key] = value | |||||
| reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params} | |||||
| reconstruct_args.update({ | reconstruct_args.update({ | ||||
| "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, | "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, | ||||
| "persistent_workers": dataloader._persistent_workers, | "persistent_workers": dataloader._persistent_workers, | ||||
| @@ -189,16 +189,11 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||||
| # 中寻找; | # 中寻找; | ||||
| has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | ||||
| if has_variadic_kwargs: | if has_variadic_kwargs: | ||||
| init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | |||||
| del init_params["self"] | |||||
| for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): | |||||
| if key not in init_params and key != 'self': | |||||
| init_params[key] = value | |||||
| # 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | |||||
| non_default_params = {name for name, p in init_params.items() if | |||||
| name in instance_attrs and p.default != instance_attrs[name]} | |||||
| # add `dataset` as it might have been replaced with `*args` | |||||
| non_default_params.add("dataset") | |||||
| reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||||
| reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params} | |||||
| reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) | reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) | ||||
| required_args = { | required_args = { | ||||