| @@ -244,17 +244,18 @@ class PaddleDriver(Driver): | |||||
| if hasattr(sampler, "state_dict") and callable(sampler.state_dict): | if hasattr(sampler, "state_dict") and callable(sampler.state_dict): | ||||
| sampler_states = sampler.state_dict() | sampler_states = sampler.state_dict() | ||||
| # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | ||||
| # 会造成多余实际消耗的问题。 | |||||
| num_consumed_samples_array = sampler_states.pop("num_consumed_samples_array", None) | |||||
| # 会造成多余实际消耗的问题。 | |||||
| num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
| if num_consumed_samples_array is not None: | if num_consumed_samples_array is not None: | ||||
| sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] | |||||
| else: | |||||
| try: | |||||
| sampler_states["num_consumed_samples"] = num_consumed_batches * dataloader_args.batch_size | |||||
| except: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
| pass | |||||
| assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report." | |||||
| states["sampler_states"] = sampler_states | |||||
| if isinstance(sampler, ReproducibleSampler): | |||||
| # 如果是 sampler 的话,需要计算出实际的 sample 数目 | |||||
| try: | |||||
| num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
| except: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
| assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
| states['sampler_states'] = sampler_states | |||||
| else: | else: | ||||
| raise RuntimeError( | raise RuntimeError( | ||||
| "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") | "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") | ||||