f_rich_progress在没有bar的时候暂时关闭live;2.修改sampler获取dataset长度的方式以适配jittortags/v1.0.0alpha
| @@ -85,7 +85,8 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
| if watch_monitor is None and evaluate_every is None: | if watch_monitor is None and evaluate_every is None: | ||||
| raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") | raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") | ||||
| if watch_monitor is not None and evaluate_every is not None: | if watch_monitor is not None and evaluate_every is not None: | ||||
| raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.") | |||||
| raise RuntimeError(f"`evaluate_every`({evaluate_every}) and `watch_monitor`({watch_monitor}) " | |||||
| f"cannot be set at the same time.") | |||||
| if topk_monitor is not None and topk == 0: | if topk_monitor is not None and topk == 0: | ||||
| raise RuntimeError("`topk_monitor` is set, but `topk` is 0.") | raise RuntimeError("`topk_monitor` is set, but `topk` is 0.") | ||||
| @@ -36,7 +36,8 @@ class Saver: | |||||
| model_save_fn:Callable=None, **kwargs): | model_save_fn:Callable=None, **kwargs): | ||||
| if folder is None: | if folder is None: | ||||
| folder = Path.cwd().absolute() | folder = Path.cwd().absolute() | ||||
| logger.info(f"Parameter `folder` is None, and fastNLP will use {folder} to save and load your model.") | |||||
| if save_object is not None: | |||||
| logger.info(f"Parameter `folder` is None, and fastNLP will use {folder} to save and load your model.") | |||||
| folder = Path(folder) | folder = Path(folder) | ||||
| if not folder.exists(): | if not folder.exists(): | ||||
| folder.mkdir(parents=True, exist_ok=True) | folder.mkdir(parents=True, exist_ok=True) | ||||
| @@ -208,7 +209,7 @@ class TopkSaver(ResultsMonitor, Saver): | |||||
| if topk is None: | if topk is None: | ||||
| topk = 0 | topk = 0 | ||||
| ResultsMonitor.__init__(self, monitor, larger_better) | ResultsMonitor.__init__(self, monitor, larger_better) | ||||
| Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs) | |||||
| Saver.__init__(self, folder, save_object if topk!=0 else None, only_state_dict, model_save_fn, **kwargs) | |||||
| if monitor is not None and topk == 0: | if monitor is not None and topk == 0: | ||||
| raise RuntimeError("`monitor` is set, but `topk` is 0.") | raise RuntimeError("`monitor` is set, but `topk` is 0.") | ||||
| @@ -210,7 +210,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
| self.during_iter = True | self.during_iter = True | ||||
| indices = list(range(len(self.dataset))) | |||||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
| if self.shuffle: | if self.shuffle: | ||||
| if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | ||||
| @@ -237,7 +237,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| if len(indices)%self.batch_size!=0: | if len(indices)%self.batch_size!=0: | ||||
| batches.append(indices[_num_batches*self.batch_size:]) | batches.append(indices[_num_batches*self.batch_size:]) | ||||
| need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas | |||||
| need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas | |||||
| if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | ||||
| if len(batches) > 0: | if len(batches) > 0: | ||||
| if len(batches[-1])<self.batch_size: | if len(batches[-1])<self.batch_size: | ||||
| @@ -290,9 +290,9 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| @property | @property | ||||
| def batch_idx_in_epoch(self): | def batch_idx_in_epoch(self): | ||||
| if self.drop_last: | if self.drop_last: | ||||
| return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
| return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
| else: | else: | ||||
| return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
| return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
| (self.num_left_samples + self.batch_size - 1) // self.batch_size | (self.num_left_samples + self.batch_size - 1) // self.batch_size | ||||
| @property | @property | ||||
| @@ -313,8 +313,8 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| num_consumed_samples = self.num_consumed_samples | num_consumed_samples = self.num_consumed_samples | ||||
| return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||||
| self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||||
| return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||||
| self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||||
| def __len__(self)->int: | def __len__(self)->int: | ||||
| """ | """ | ||||
| @@ -332,7 +332,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | ||||
| " consumed. ") | " consumed. ") | ||||
| states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | ||||
| 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||||
| 'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, | |||||
| 'batch_size': self.batch_size, | 'batch_size': self.batch_size, | ||||
| 'num_replicas': self.num_replicas} | 'num_replicas': self.num_replicas} | ||||
| @@ -347,7 +347,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
| length = states['length'] | length = states['length'] | ||||
| assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||||
| assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ | |||||
| "and current dataset." | "and current dataset." | ||||
| self.seed = states['seed'] | self.seed = states['seed'] | ||||
| self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
| @@ -464,8 +464,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| num_consumed_samples = self.num_consumed_samples | num_consumed_samples = self.num_consumed_samples | ||||
| return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||||
| self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||||
| return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||||
| self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||||
| def __len__(self)->int: | def __len__(self)->int: | ||||
| """ | """ | ||||
| @@ -515,7 +515,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| if len(sorted_indices)%self.batch_size!=0: | if len(sorted_indices)%self.batch_size!=0: | ||||
| batches.append(sorted_indices[_num_batches*self.batch_size:]) | batches.append(sorted_indices[_num_batches*self.batch_size:]) | ||||
| need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas | |||||
| need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas | |||||
| if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | ||||
| if len(batches) > 0: | if len(batches) > 0: | ||||
| if len(batches[-1])<self.batch_size: | if len(batches[-1])<self.batch_size: | ||||
| @@ -593,7 +593,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | ||||
| " consumed. ") | " consumed. ") | ||||
| states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | ||||
| 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||||
| 'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, | |||||
| 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | ||||
| 'num_replicas': self.num_replicas | 'num_replicas': self.num_replicas | ||||
| } | } | ||||
| @@ -609,7 +609,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
| length = states['length'] | length = states['length'] | ||||
| assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||||
| assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ | |||||
| "and current dataset." | "and current dataset." | ||||
| self.seed = states['seed'] | self.seed = states['seed'] | ||||
| self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
| @@ -630,7 +630,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| @property | @property | ||||
| def batch_idx_in_epoch(self): | def batch_idx_in_epoch(self): | ||||
| if self.drop_last: | if self.drop_last: | ||||
| return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
| return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
| else: | else: | ||||
| return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
| return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
| (self.num_left_samples + self.batch_size - 1) // self.batch_size | (self.num_left_samples + self.batch_size - 1) // self.batch_size | ||||
| @@ -131,14 +131,14 @@ class RandomSampler(ReproducibleSampler): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| if self.shuffle: | if self.shuffle: | ||||
| indices = list(range(len(self.dataset))) | |||||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
| seed = self.seed + self.epoch | seed = self.seed + self.epoch | ||||
| rng = np.random.default_rng(abs(seed)) | rng = np.random.default_rng(abs(seed)) | ||||
| rng.shuffle(indices) | rng.shuffle(indices) | ||||
| if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | ||||
| self.epoch -= 1 | self.epoch -= 1 | ||||
| else: | else: | ||||
| indices = list(range(len(self.dataset))) | |||||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
| return indices | return indices | ||||
| def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
| @@ -155,8 +155,8 @@ class RandomSampler(ReproducibleSampler): | |||||
| f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
| length = states['length'] | length = states['length'] | ||||
| assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
| f"and current dataset({len(self.dataset)})." | |||||
| assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
| f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||||
| self.seed = states['seed'] | self.seed = states['seed'] | ||||
| self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
| self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
| @@ -208,8 +208,8 @@ class RandomSampler(ReproducibleSampler): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| num_consumed_samples = self.num_consumed_samples | num_consumed_samples = self.num_consumed_samples | ||||
| return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||||
| self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||||
| return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||||
| self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||||
| class SequentialSampler(RandomSampler): | class SequentialSampler(RandomSampler): | ||||
| @@ -258,11 +258,11 @@ class SequentialSampler(RandomSampler): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| return list(range(len(self.dataset))) | |||||
| return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
| def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
| states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | ||||
| 'length': len(self.dataset) | |||||
| 'length': getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
| } | } | ||||
| return states | return states | ||||
| @@ -275,8 +275,8 @@ class SequentialSampler(RandomSampler): | |||||
| f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
| length = states['length'] | length = states['length'] | ||||
| assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
| f"and current dataset({len(self.dataset)})." | |||||
| assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
| f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||||
| self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
| if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | ||||
| self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
| @@ -314,9 +314,9 @@ class SortedSampler(SequentialSampler): | |||||
| except BaseException as e: | except BaseException as e: | ||||
| logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") | logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") | ||||
| assert len(length) == len(dataset), f"The length of `dataset`({len(dataset)}) and " \ | |||||
| f"`length`({len(length)}) should be equal." | |||||
| assert len(self.sorted_indices) == len(dataset), "The indices and dataset should have equal length." | |||||
| assert len(length) == getattr(self.dataset, 'total_len', len(self.dataset)), f"The length of `dataset`({len(dataset)}) and " \ | |||||
| f"`length`({getattr(self.dataset, 'total_len', len(self.dataset))}) should be equal." | |||||
| assert len(self.sorted_indices) == getattr(self.dataset, 'total_len', len(self.dataset)), "The indices and dataset should have equal length." | |||||
| self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | ||||
| self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | ||||
| @@ -42,8 +42,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
| 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| num_common = len(self.dataset)//self.num_replicas | |||||
| num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) | |||||
| num_common = getattr(self.dataset, 'total_len', len(self.dataset))//self.num_replicas | |||||
| num_samples = num_common + int(self.rank < (getattr(self.dataset, 'total_len', len(self.dataset))-num_common*self.num_replicas)) | |||||
| return num_samples | return num_samples | ||||
| def __iter__(self): | def __iter__(self): | ||||
| @@ -63,14 +63,14 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| if self.shuffle: | if self.shuffle: | ||||
| indices = list(range(len(self.dataset))) | |||||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
| seed = self.seed + self.epoch | seed = self.seed + self.epoch | ||||
| rng = np.random.default_rng(abs(seed)) | rng = np.random.default_rng(abs(seed)) | ||||
| rng.shuffle(indices) | rng.shuffle(indices) | ||||
| if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | ||||
| self.epoch -= 1 | self.epoch -= 1 | ||||
| else: | else: | ||||
| indices = list(range(len(self.dataset))) | |||||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
| return indices | return indices | ||||
| def set_epoch(self, epoch: int) -> None: | def set_epoch(self, epoch: int) -> None: | ||||
| @@ -84,8 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
| :param rank: | :param rank: | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| assert num_replicas<=len(self.dataset), f"The number of replicas({num_replicas}) should be lesser than the " \ | |||||
| f"number of samples({len(self.dataset)})." | |||||
| assert num_replicas<=getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of replicas({num_replicas}) should be lesser than the " \ | |||||
| f"number of samples({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||||
| assert num_replicas>0 and isinstance(num_replicas, int) | assert num_replicas>0 and isinstance(num_replicas, int) | ||||
| assert isinstance(rank, int) and 0<=rank<num_replicas | assert isinstance(rank, int) and 0<=rank<num_replicas | ||||
| # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | ||||
| @@ -147,5 +147,5 @@ class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | |||||
| yield index | yield index | ||||
| def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
| return list(range(len(self.dataset))) | |||||
| return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
| @@ -149,9 +149,12 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
| super().stop_task(task_id) | super().stop_task(task_id) | ||||
| super().remove_task(task_id) | super().remove_task(task_id) | ||||
| self.refresh() # 使得bar不残留 | self.refresh() # 使得bar不残留 | ||||
| # 这里需要注释掉的原因是由于,在dataset多次apply的过程中会出现自动换行的问题。以前保留这个的原因应该是由于evaluate结束bar不消失。 | |||||
| # if len(self._tasks) == 0: | |||||
| # self.live.stop() | |||||
| if len(self._tasks) == 0: | |||||
| # 这里将这个line函数给hack一下防止stop的时候打印出空行 | |||||
| old_line = getattr(self.live.console, 'line') | |||||
| setattr(self.live.console, 'line', lambda *args,**kwargs:...) | |||||
| self.live.stop() | |||||
| setattr(self.live.console, 'line', old_line) | |||||
| def start(self) -> None: | def start(self) -> None: | ||||
| super().start() | super().start() | ||||