| @@ -390,4 +390,23 @@ class HasMonitorCallback(Callback): | |||
| if (self.larger_better and monitor_value1 > monitor_value2) or \ | |||
| (not self.larger_better and monitor_value1 < monitor_value2): | |||
| better = True | |||
| return better | |||
| return better | |||
| @property | |||
| def monitor_name(self): | |||
| """ | |||
| 返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。 | |||
| :return: | |||
| """ | |||
| if callable(self.monitor): | |||
| try: | |||
| monitor_name = self.monitor.__qualname__ | |||
| except: | |||
| monitor_name = self.monitor.__name__ | |||
| elif self.monitor is None: | |||
| return None | |||
| else: | |||
| # 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了 | |||
| monitor_name = str(self.monitor) | |||
| return monitor_name | |||
| @@ -19,11 +19,11 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||
| class CheckpointCallback(HasMonitorCallback): | |||
| def __init__( | |||
| self, | |||
| monitor, | |||
| monitor:Optional[Union[str, Callable]]=None, | |||
| save_folder: Optional[Union[str, Path]] = None, | |||
| save_every_n_epochs: Optional[int] = None, | |||
| save_every_n_batches: Optional[int] = None, | |||
| save_last: bool = True, | |||
| save_last: bool = False, | |||
| save_topk: Optional[int] = None, | |||
| save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, | |||
| larger_better: bool = True, | |||
| @@ -31,12 +31,32 @@ class CheckpointCallback(HasMonitorCallback): | |||
| model_save_fn: Optional[Callable] = None, | |||
| **kwargs, | |||
| ): | |||
| """ | |||
| 请使用 ModelCheckpointCallback 与 TrainerCheckpointCallback 。 | |||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
| 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
| :param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
| :param save_every_n_batches: 多少个 batch 保存一次。 | |||
| :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | |||
| :param save_topk: 保存 monitor 结果 topK 个。 | |||
| :param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||
| :param larger_better: monitor 的值是否时越大越好。 | |||
| :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | |||
| :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
| 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
| :param kwargs: | |||
| """ | |||
| super().__init__(monitor=monitor, larger_better=larger_better, | |||
| must_have_monitor=save_topk is not None) | |||
| if save_folder is None: | |||
| logger.warning( | |||
| "Parameter `path` is None, and we will use the current work directory to find and load your model.") | |||
| save_folder = Path.cwd() | |||
| save_folder = Path(save_folder) | |||
| if not save_folder.exists(): | |||
| raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!") | |||
| elif save_folder.is_file(): | |||
| @@ -71,7 +91,7 @@ class CheckpointCallback(HasMonitorCallback): | |||
| else: | |||
| save_on_exception = [] | |||
| self.save_folder = Path(save_folder) | |||
| self.save_folder = save_folder | |||
| self.save_every_n_epochs = save_every_n_epochs | |||
| self.save_every_n_batches = save_every_n_batches | |||
| self.save_last = save_last | |||
| @@ -88,8 +108,7 @@ class CheckpointCallback(HasMonitorCallback): | |||
| # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | |||
| # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | |||
| self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||
| # 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; | |||
| synchronize_mkdir(self.timestamp_path) | |||
| # 该 folder 只在保存真的要发生的时候再创建。 | |||
| def on_after_trainer_initialized(self, trainer, driver): | |||
| if self.save_topk is not None: | |||
| @@ -98,8 +117,6 @@ class CheckpointCallback(HasMonitorCallback): | |||
| logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.") | |||
| def on_validate_end(self, trainer, results): | |||
| if len(results) == 0: | |||
| return | |||
| self._save_topk(trainer, results) | |||
| def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | |||
| @@ -136,16 +153,17 @@ class CheckpointCallback(HasMonitorCallback): | |||
| states['timestamp_path'] = str(self.timestamp_path.absolute()) | |||
| states['_topk_model'] = deepcopy(self._topk_model) | |||
| states['save_topk'] = 0 if self.save_topk is None else self.save_topk | |||
| states['_real_monitor'] = self._real_monitor | |||
| if isinstance(self._real_monitor, str): | |||
| states['_real_monitor'] = self._real_monitor | |||
| return states | |||
| def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||
| timestamp_path = states['timestamp_path'] | |||
| if not os.path.exists(timestamp_path): | |||
| logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to " | |||
| logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, will checkpoint save to " | |||
| f" {self.timestamp_path.absolute()}.") | |||
| else: | |||
| logger.info(f"Resume to save in path: {timestamp_path}.") | |||
| logger.info(f"Resume to checkpoint in path: {timestamp_path}.") | |||
| self.timestamp_path = Path(timestamp_path) | |||
| _topk_model = states['_topk_model'] | |||
| save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) | |||
| @@ -153,7 +171,8 @@ class CheckpointCallback(HasMonitorCallback): | |||
| assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ | |||
| f"as {save_topk}." | |||
| self._topk_model.update(self._topk_model) | |||
| self._real_monitor = states["real_monitor"] | |||
| self._real_monitor = states["_real_monitor"] | |||
| def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): | |||
| """ | |||
| @@ -231,9 +250,9 @@ class ModelCheckpointCallback(CheckpointCallback): | |||
| model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | |||
| 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | |||
| :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
| 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||
| 返回一个 float 值作为 monitor 的结果。 | |||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
| 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
| :param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
| @@ -249,6 +268,11 @@ class ModelCheckpointCallback(CheckpointCallback): | |||
| """ | |||
| @property | |||
| def save_fn_name(self): | |||
| """ | |||
| 调用 Trainer 中的哪个函数。 | |||
| :return: | |||
| """ | |||
| return 'save_model' | |||
| @property | |||
| @@ -257,7 +281,7 @@ class ModelCheckpointCallback(CheckpointCallback): | |||
| 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | |||
| :return: | |||
| """ | |||
| return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||
| return f"model_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||
| @property | |||
| def folder_prefix(self): | |||
| @@ -279,9 +303,9 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||
| model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 | |||
| 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | |||
| :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
| 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), | |||
| 返回一个 float 值作为 monitor 的结果。 | |||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
| 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | |||
| :param save_every_n_epochs: 多少个 epoch 保存一次。 | |||
| @@ -297,6 +321,11 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||
| """ | |||
| @property | |||
| def save_fn_name(self): | |||
| """ | |||
| 调用 Trainer 中的哪个函数。 | |||
| :return: | |||
| """ | |||
| return 'save' | |||
| @property | |||
| @@ -305,7 +334,8 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||
| 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | |||
| :return: | |||
| """ | |||
| return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||
| return f"trainer_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" | |||
| @property | |||
| def folder_prefix(self): | |||
| @@ -12,8 +12,9 @@ class EarlyStopCallback(HasMonitorCallback): | |||
| def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | |||
| """ | |||
| :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||
| evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param larger_better: monitor 的值是否是越大越好。 | |||
| :param patience: 多少次 validate 不没有提升就停止。 | |||
| """ | |||
| @@ -46,17 +47,20 @@ class EarlyStopCallback(HasMonitorCallback): | |||
| states = { | |||
| 'patience': self.patience, | |||
| 'wait': self.wait, | |||
| 'monitor': self.monitor, | |||
| 'monitor_value': self.monitor_value | |||
| } | |||
| if not callable(self._real_monitor): | |||
| states['_real_monitor'] = self._real_monitor | |||
| return states | |||
| def on_load_checkpoint(self, trainer, states): | |||
| self.patience = states['patience'] | |||
| self.wait = states['wait'] | |||
| self.monitor = states['monitor'] | |||
| self.monitor_value = float(states['monitor_value']) | |||
| if '_real_monitor' in states: | |||
| self._real_monitor = states['_real_monitor'] | |||
| @property | |||
| def callback_name(self): | |||
| return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}' | |||
| return f'EarlyStopCallback#monitor-{self.monitor_name}#patience-{self.patience}' | |||
| @@ -21,8 +21,9 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
| """ | |||
| 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | |||
| :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 | |||
| evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param larger_better: 该 metric 值是否是越大越好。 | |||
| :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | |||
| 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | |||
| @@ -44,10 +44,11 @@ class RichCallback(ProgressCallback): | |||
| :param print_every: 多少个 batch 更新一次显示。 | |||
| :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
| :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 | |||
| 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param larger_better: 是否是monitor的结果越大越好。 | |||
| :param format_json: 是否format json再打印 | |||
| :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | |||
| 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||
| 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param larger_better: 是否是 monitor 的结果越大越好。 | |||
| :param format_json: 是否格式化 json 再打印 | |||
| """ | |||
| super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | |||
| self.print_every = print_every | |||
| @@ -136,8 +137,9 @@ class RawTextCallback(ProgressCallback): | |||
| :param print_every: 多少个 batch 更新一次显示。 | |||
| :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
| :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果( | |||
| 字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | |||
| 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||
| 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||
| :param larger_better: 是否是monitor的结果越大越好。 | |||
| :param format_json: 是否format json再打印 | |||
| """ | |||
| @@ -36,7 +36,7 @@ class Evaluator: | |||
| model, | |||
| dataloaders, | |||
| metrics: Optional[Union[Dict, Metric]] = None, | |||
| driver: Union[str, Driver] = 'single', | |||
| driver: Union[str, Driver] = 'torch', | |||
| device: Optional[Union[int, List[int], str]] = None, | |||
| batch_step_fn: Optional[callable] = None, | |||
| evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable | |||
| @@ -49,8 +49,8 @@ class Evaluator: | |||
| ): | |||
| """ | |||
| :param dataloaders: | |||
| :param model: | |||
| :param dataloaders: | |||
| :param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 | |||
| metric ,torchmetrics,allennlpmetrics等。 | |||
| :param driver: 使用 driver 。 | |||
| @@ -120,7 +120,8 @@ class Evaluator: | |||
| if evaluate_fn is not None and not isinstance(evaluate_fn, str): | |||
| raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | |||
| self._evaluate_step, self._evaluate_step_signature_fn = self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) | |||
| self._evaluate_step, self._evaluate_step_signature_fn = \ | |||
| self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) | |||
| self.evaluate_fn = evaluate_fn | |||
| self.dataloaders = {} | |||
| @@ -134,8 +135,6 @@ class Evaluator: | |||
| self.driver.barrier() | |||
| self.driver.check_dataloader_legality(self.dataloaders, "dataloaders", is_train=False) | |||
| def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | |||
| """ | |||
| 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | |||
| @@ -20,7 +20,7 @@ class TrainBatchLoop(Loop): | |||
| else lambda *args, **kwargs: None | |||
| dataloader = iter(dataloader) | |||
| indices = None | |||
| while True: | |||
| while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch: | |||
| try: | |||
| trainer.on_fetch_data_begin() | |||
| batch = next(dataloader) | |||
| @@ -30,10 +30,8 @@ class TrainBatchLoop(Loop): | |||
| batch = trainer.move_data_to_device(batch) | |||
| except StopIteration: | |||
| break | |||
| except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception | |||
| break | |||
| except BaseException as e: | |||
| if indices: | |||
| if indices and not isinstance(e, EarlyStopException): | |||
| logger.debug(f"The following exception happens when running on samples: {indices}") | |||
| raise e | |||
| @@ -264,7 +264,6 @@ class Trainer(TrainerEventTrigger): | |||
| self.on_after_trainer_initialized(self.driver) | |||
| self.driver.barrier() | |||
| self.driver.check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True) | |||
| def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | |||
| num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, | |||
| @@ -310,7 +309,7 @@ class Trainer(TrainerEventTrigger): | |||
| self.num_batches_per_epoch = len(self.dataloader) | |||
| self.total_batches = self.num_batches_per_epoch * self.n_epochs | |||
| self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch | |||
| self.on_train_begin() | |||
| self.driver.barrier() | |||
| self.driver.zero_grad(self.set_grad_to_none) | |||
| @@ -637,6 +636,8 @@ class Trainer(TrainerEventTrigger): | |||
| :param folder: 保存断点重训 states 的文件地址; | |||
| :param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们 | |||
| 只会加载 model 和 optimizers 的状态;而其余的对象的值则根据用户的 Trainer 的初始化直接重置; | |||
| :param only_state_dict: 保存的 model 是否只包含了权重。 | |||
| :param model_load_fn: 使用的模型加载函数,参数应为一个 文件夹,不返回任何内容。 | |||
| """ | |||
| self.driver.barrier() | |||
| if isinstance(folder, str): | |||
| @@ -675,8 +676,6 @@ class Trainer(TrainerEventTrigger): | |||
| # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | |||
| # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | |||
| self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | |||
| self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \ | |||
| self.batch_idx_in_epoch | |||
| # 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||
| self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||
| @@ -65,10 +65,10 @@ class TrainerState: | |||
| """ | |||
| n_epochs: Optional[int] = None # 无论如何重新算 | |||
| cur_epoch_idx: Optional[int] = None # 断点重训; 仅当 resume=False 时为0; | |||
| global_forward_batches: Optional[int] = None # 断点重训 | |||
| cur_epoch_idx: Optional[int] = 0 # 断点重训; 仅当 resume=False 时为0; | |||
| global_forward_batches: Optional[int] = 0 # 断点重训 | |||
| batch_idx_in_epoch: Optional[int] = None # 断点重训 | |||
| batch_idx_in_epoch: Optional[int] = 0 # 断点重训 | |||
| num_batches_per_epoch: Optional[int] = None # 无论如何重新算 | |||
| @@ -86,7 +86,7 @@ class Driver(ABC): | |||
| 函数; | |||
| :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||
| :param fn: 由 Trainer 传入的用于网络前向传播一次的函数; | |||
| :param fn: 调用该函数进行一次计算。 | |||
| :param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call | |||
| 函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | |||
| :return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||
| @@ -126,17 +126,6 @@ class Driver(ABC): | |||
| def model(self, model): | |||
| self._model = model | |||
| @staticmethod | |||
| def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
| r""" | |||
| 该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性,因为不同的深度学习的框架需要的 dataloader 的 | |||
| 行为是不相同的; | |||
| :param dataloader: 需要检测的输入的 `dataloader`; | |||
| :param dataloader_name: | |||
| """ | |||
| raise NotImplementedError("Each specific driver should implemented its own `check_dataloader_legality` function.") | |||
| @property | |||
| def optimizers(self) -> List: | |||
| r""" | |||
| @@ -34,10 +34,10 @@ if _NEED_IMPORT_PADDLE: | |||
| from paddle.optimizer import Optimizer | |||
| _reduces = { | |||
| 'max': paddle.max, | |||
| 'min': paddle.min, | |||
| 'mean': paddle.mean, | |||
| 'sum': paddle.sum | |||
| "max": paddle.max, | |||
| "min": paddle.min, | |||
| "mean": paddle.mean, | |||
| "sum": paddle.sum | |||
| } | |||
| class PaddleDriver(Driver): | |||
| @@ -254,24 +254,24 @@ class PaddleDriver(Driver): | |||
| else: | |||
| raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||
| num_consumed_batches = states.pop('num_consumed_batches') | |||
| if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||
| num_consumed_batches = states.pop("num_consumed_batches") | |||
| if hasattr(sampler, "state_dict") and callable(sampler.state_dict): | |||
| sampler_states = sampler.state_dict() | |||
| # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
| # 会造成多余实际消耗的问题。 | |||
| num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
| # 会造成多余实际消耗的问题。 | |||
| num_consumed_samples_array = sampler_states.pop("num_consumed_samples_array", None) | |||
| if num_consumed_samples_array is not None: | |||
| if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||
| try: | |||
| num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
| except: # 有可能 batch_size 为 None,就只有损失精度了 | |||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
| assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
| sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] | |||
| else: | |||
| try: | |||
| sampler_states["num_consumed_samples"] = num_consumed_batches * dataloader_args.batch_size | |||
| except: # 有可能 batch_size 为 None,就只有损失精度了 | |||
| pass | |||
| assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report." | |||
| else: | |||
| raise RuntimeError( | |||
| 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||
| "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") | |||
| states["sampler_states"] = sampler_states | |||
| # 2. 保存模型的状态; | |||
| if should_save_model: | |||
| @@ -326,7 +326,7 @@ class PaddleDriver(Driver): | |||
| batch_size=dataloader_args.batch_size, | |||
| drop_last=dataloader_args.drop_last | |||
| ) | |||
| sampler.load_state_dict(states['sampler_states']) | |||
| sampler.load_state_dict(states["sampler_states"]) | |||
| states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||
| # 4. 修改 trainer_state.batch_idx_in_epoch | |||
| @@ -355,7 +355,7 @@ class PaddleDriver(Driver): | |||
| return paddle.no_grad | |||
| @staticmethod | |||
| def move_model_to_device(model: 'paddle.nn.Layer', device: Union[str, int, 'paddle.CUDAPlace', 'paddle.CPUPlace']): | |||
| def move_model_to_device(model: "paddle.nn.Layer", device: Union[str, int, "paddle.CUDAPlace", "paddle.CPUPlace"]): | |||
| r""" | |||
| 用来将模型转移到指定的 device 上; | |||
| 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | |||
| @@ -363,7 +363,7 @@ class PaddleDriver(Driver): | |||
| if device is not None: | |||
| model.to(device) | |||
| def move_data_to_device(self, batch: 'paddle.Tensor'): | |||
| def move_data_to_device(self, batch: "paddle.Tensor"): | |||
| r""" | |||
| 将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | |||
| 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | |||
| @@ -404,7 +404,7 @@ class PaddleDriver(Driver): | |||
| if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: | |||
| dataloader.worker_init_fn = partial(self.worker_init_function, rank=self.global_rank) | |||
| def set_sampler_epoch(self, dataloader: 'DataLoader', cur_epoch_idx): | |||
| def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx): | |||
| r""" | |||
| 对于分布式的 sampler,dataloader 需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; | |||
| @@ -406,7 +406,7 @@ class TorchDDPDriver(TorchDriver): | |||
| if hasattr(model, fn): | |||
| fn = getattr(model, fn) | |||
| if not callable(fn): | |||
| raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") | |||
| raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.") | |||
| return fn, None | |||
| elif fn in {"train_step", "evaluate_step"}: | |||
| return model, model.forward | |||
| @@ -199,6 +199,7 @@ class TorchDriver(Driver): | |||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
| assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
| states['sampler_states'] = sampler_states | |||
| else: | |||
| raise RuntimeError( | |||
| 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||
| @@ -1,19 +1,20 @@ | |||
| import pytest | |||
| import os | |||
| os.environ["FASTNLP_BACKEND"] = "paddle" | |||
| from typing import Any | |||
| from dataclasses import dataclass | |||
| from paddle.optimizer import Adam | |||
| from paddle.io import DataLoader | |||
| from fastNLP.core.controllers.trainer import Trainer | |||
| from fastNLP.core.metrics.accuracy import Accuracy | |||
| from fastNLP.core.callbacks.progress_callback import RichCallback | |||
| from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||
| from paddle.optimizer import Adam | |||
| from paddle.io import DataLoader | |||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||
| from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST | |||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
| from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||
| from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | |||
| from tests.helpers.utils import magic_argv_env_context | |||
| @@ -48,64 +49,31 @@ class TrainerParameters: | |||
| output_mapping: Any = None | |||
| metrics: Any = None | |||
| # @pytest.fixture(params=[0], autouse=True) | |||
| # def model_and_optimizers(request): | |||
| # """ | |||
| # 初始化单卡模式的模型和优化器 | |||
| # """ | |||
| # trainer_params = TrainerParameters() | |||
| # print(paddle.device.get_device()) | |||
| # if request.param == 0: | |||
| # trainer_params.model = PaddleNormalModel_Classification( | |||
| # num_labels=MNISTTrainPaddleConfig.num_labels, | |||
| # feature_dimension=MNISTTrainPaddleConfig.feature_dimension | |||
| # ) | |||
| # trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) | |||
| # train_dataloader = DataLoader( | |||
| # dataset=PaddleDataset_MNIST("train"), | |||
| # batch_size=MNISTTrainPaddleConfig.batch_size, | |||
| # shuffle=True | |||
| # ) | |||
| # val_dataloader = DataLoader( | |||
| # dataset=PaddleDataset_MNIST(evaluate_fn="test"), | |||
| # batch_size=MNISTTrainPaddleConfig.batch_size, | |||
| # shuffle=True | |||
| # ) | |||
| # trainer_params.train_dataloader = train_dataloader | |||
| # trainer_params.evaluate_dataloaders = val_dataloader | |||
| # trainer_params.evaluate_every = MNISTTrainPaddleConfig.evaluate_every | |||
| # trainer_params.metrics = {"acc": Accuracy()} | |||
| # return trainer_params | |||
| @pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)]) | |||
| @pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)]) | |||
| # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | |||
| @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True), | |||
| RichCallback(5), RecordLossCallback(loss_threshold=0.3)]]) | |||
| @magic_argv_env_context | |||
| def test_trainer_paddle( | |||
| # model_and_optimizers: TrainerParameters, | |||
| driver, | |||
| device, | |||
| callbacks, | |||
| n_epochs=15, | |||
| n_epochs=2, | |||
| ): | |||
| trainer_params = TrainerParameters() | |||
| trainer_params.model = PaddleNormalModel_Classification( | |||
| trainer_params.model = PaddleNormalModel_Classification_1( | |||
| num_labels=MNISTTrainPaddleConfig.num_labels, | |||
| feature_dimension=MNISTTrainPaddleConfig.feature_dimension | |||
| ) | |||
| trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) | |||
| train_dataloader = DataLoader( | |||
| dataset=PaddleDataset_MNIST("train"), | |||
| dataset=PaddleRandomMaxDataset(6400, 10), | |||
| batch_size=MNISTTrainPaddleConfig.batch_size, | |||
| shuffle=True | |||
| ) | |||
| val_dataloader = DataLoader( | |||
| dataset=PaddleDataset_MNIST(mode="test"), | |||
| dataset=PaddleRandomMaxDataset(1000, 10), | |||
| batch_size=MNISTTrainPaddleConfig.batch_size, | |||
| shuffle=True | |||
| ) | |||
| @@ -113,39 +81,19 @@ def test_trainer_paddle( | |||
| trainer_params.validate_dataloaders = val_dataloader | |||
| trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every | |||
| trainer_params.metrics = {"acc": Accuracy(backend="paddle")} | |||
| if not isinstance(device, (int, str)) and len(device) > 1 and FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
| with pytest.raises(SystemExit) as exc: | |||
| trainer = Trainer( | |||
| model=trainer_params.model, | |||
| driver=driver, | |||
| device=device, | |||
| optimizers=trainer_params.optimizers, | |||
| train_dataloader=trainer_params.train_dataloader, | |||
| evaluate_dataloaders=trainer_params.validate_dataloaders, | |||
| evaluate_every=trainer_params.validate_every, | |||
| input_mapping=trainer_params.input_mapping, | |||
| output_mapping=trainer_params.output_mapping, | |||
| metrics=trainer_params.metrics, | |||
| n_epochs=n_epochs, | |||
| callbacks=callbacks, | |||
| ) | |||
| assert exc.value.code == 0 | |||
| return | |||
| else: | |||
| trainer = Trainer( | |||
| model=trainer_params.model, | |||
| driver=driver, | |||
| device=device, | |||
| optimizers=trainer_params.optimizers, | |||
| train_dataloader=trainer_params.train_dataloader, | |||
| evaluate_dataloaders=trainer_params.validate_dataloaders, | |||
| evaluate_every=trainer_params.validate_every, | |||
| input_mapping=trainer_params.input_mapping, | |||
| output_mapping=trainer_params.output_mapping, | |||
| metrics=trainer_params.metrics, | |||
| n_epochs=n_epochs, | |||
| callbacks=callbacks, | |||
| ) | |||
| trainer.run() | |||
| trainer = Trainer( | |||
| model=trainer_params.model, | |||
| driver=driver, | |||
| device=device, | |||
| optimizers=trainer_params.optimizers, | |||
| train_dataloader=trainer_params.train_dataloader, | |||
| validate_dataloaders=trainer_params.validate_dataloaders, | |||
| validate_every=trainer_params.validate_every, | |||
| input_mapping=trainer_params.input_mapping, | |||
| output_mapping=trainer_params.output_mapping, | |||
| metrics=trainer_params.metrics, | |||
| n_epochs=n_epochs, | |||
| callbacks=callbacks, | |||
| ) | |||
| trainer.run() | |||
| @@ -224,7 +224,6 @@ class TestSetDistReproDataloder: | |||
| """ | |||
| def setup_method(self): | |||
| self.dataset = PaddleNormalDataset(20) | |||
| self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||
| model = PaddleNormalModel_Classification_1(10, 32) | |||
| self.driver = PaddleSingleDriver(model, device="cpu") | |||
| @@ -233,55 +232,59 @@ class TestSetDistReproDataloder: | |||
| 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 | |||
| 当dist为字符串时,此时应该返回原来的 dataloader | |||
| """ | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False) | |||
| dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||
| assert replaced_loader is self.dataloader | |||
| assert replaced_loader is dataloader | |||
| def test_set_dist_repro_dataloader_with_reproducible_true(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||
| 当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler | |||
| """ | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True) | |||
| dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | |||
| assert not (replaced_loader is self.dataloader) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||
| assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | |||
| assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size | |||
| assert replaced_loader.drop_last == self.dataloader.drop_last | |||
| assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||
| assert replaced_loader.drop_last == dataloader.drop_last | |||
| # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||
| # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||
| def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler | |||
| 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | |||
| """ | |||
| dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||
| dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||
| assert not (replaced_loader is self.dataloader) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||
| assert replaced_loader.batch_sampler is dist | |||
| # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||
| self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||
| def test_set_dist_repro_dataloader_with_dist_sampler(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 | |||
| 应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler | |||
| """ | |||
| dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||
| dist = RandomSampler(self.dataset, shuffle=True) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||
| assert not (replaced_loader is self.dataloader) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
| assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert replaced_loader.batch_sampler.sampler is dist | |||
| assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size | |||
| assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||
| # self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) | |||
| self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||
| def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): | |||
| """ | |||
| @@ -295,11 +298,12 @@ class TestSetDistReproDataloder: | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||
| assert replaced_loader.drop_last == dataloader.drop_last | |||
| # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||
| self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||
| def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): | |||
| """ | |||
| @@ -316,34 +320,52 @@ class TestSetDistReproDataloder: | |||
| assert not (replaced_loader is dataloader) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
| assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||
| assert replaced_loader.batch_sampler.batch_size == 2 | |||
| assert replaced_loader.batch_sampler.sampler.shuffle == True | |||
| # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||
| self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||
| def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): | |||
| """ | |||
| 测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | |||
| """ | |||
| # 迭代两个 batch | |||
| # 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。 | |||
| num_consumed_batches = 2 | |||
| already_seen_idx = set() | |||
| for idx, batch in replaced_loader: | |||
| already_seen_idx.update(batch) | |||
| if idx >= 1: | |||
| for idx, batch in enumerate(replaced_loader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_idx.update(batch) | |||
| if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||
| sampler_states = replaced_loader.batch_sampler.state_dict() | |||
| else: | |||
| sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | |||
| print(sampler_states["data_idx"]) | |||
| # 加载 num_consumed_samples_array,设置正确取出的 batch 数目 | |||
| num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
| import time | |||
| time.sleep(5) | |||
| # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | |||
| left_idxes = set() | |||
| if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||
| batch_size = replaced_loader.batch_sampler.batch_size | |||
| if num_consumed_samples_array is not None: | |||
| sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] | |||
| else: | |||
| sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||
| replaced_loader.batch_sampler.load_state_dict(sampler_states) | |||
| else: | |||
| batch_size = replaced_loader.batch_sampler.batch_size | |||
| if num_consumed_samples_array is not None: | |||
| sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] | |||
| else: | |||
| sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||
| replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||
| replaced_loader.batch_sampler.sampler.set_epoch(0) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| left_idxes.update(batch) | |||
| @@ -181,6 +181,7 @@ class TestCheckNumberOfParameters: | |||
| def test_get_fun_msg(): | |||
| # 测试运行 | |||
| def demo(x): | |||
| pass | |||
| @@ -1,3 +1,6 @@ | |||
| import numpy as np | |||
| class NormalIterator: | |||
| def __init__(self, num_of_data=1000): | |||
| self._num_of_data = num_of_data | |||
| @@ -15,4 +18,15 @@ class NormalIterator: | |||
| return self._data | |||
| def __len__(self): | |||
| return self._num_of_data | |||
| return self._num_of_data | |||
| class RandomDataset: | |||
| def __init__(self, num_data=10): | |||
| self.data = np.random.rand(num_data) | |||
| def __len__(self): | |||
| return len(self.data) | |||
| def __getitem__(self, item): | |||
| return self.data[item] | |||