| @@ -24,7 +24,6 @@ from fastNLP.core.dataset import DataSet as FDataSet | |||||
| class _JittorDataset(Dataset): | class _JittorDataset(Dataset): | ||||
| """ | """ | ||||
| 对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset | 对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset | ||||
| """ | """ | ||||
| def __init__(self, dataset) -> None: | def __init__(self, dataset) -> None: | ||||
| @@ -83,7 +82,7 @@ class JittorDataLoader: | |||||
| # TODO 验证支持replacesampler (以后完成) 增加Sampler | # TODO 验证支持replacesampler (以后完成) 增加Sampler | ||||
| # 将内部dataset批次设置为1 | # 将内部dataset批次设置为1 | ||||
| if isinstance(dataset, Dataset): | if isinstance(dataset, Dataset): | ||||
| dataset.set_attrs(batch_size=1) | |||||
| dataset.set_attrs(batch_size=1, shuffle=False, endless=False) | |||||
| # FastNLP Datset, collate_fn not None | # FastNLP Datset, collate_fn not None | ||||
| if isinstance(dataset, FDataSet) and collate_fn is None: | if isinstance(dataset, FDataSet) and collate_fn is None: | ||||
| @@ -115,6 +114,12 @@ class JittorDataLoader: | |||||
| self.cur_batch_indices = None | self.cur_batch_indices = None | ||||
| def __getattr__(self, attr): | |||||
| if attr in ["batch_size", "shuffle", "drop_last", "num_workers", "buffer_size", "stop_grad", | |||||
| "keep_numpy_array", "endless", "sampler"]: | |||||
| return getattr(self.dataset, attr) | |||||
| raise AttributeError(f"{self} has not attribute '{attr}'") | |||||
| def __iter__(self): | def __iter__(self): | ||||
| # TODO 第一次迭代后不能设置collate_fn,设置是无效的 | # TODO 第一次迭代后不能设置collate_fn,设置是无效的 | ||||
| if self.cur_batch_indices is None: | if self.cur_batch_indices is None: | ||||
| @@ -10,7 +10,7 @@ if _NEED_IMPORT_JITTOR: | |||||
| __all__ = [] | __all__ = [] | ||||
| def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: jittor.Module, **kwargs) -> JittorDriver: | |||||
| def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: "jittor.Module", **kwargs) -> JittorDriver: | |||||
| r""" | r""" | ||||
| 用来根据参数 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去。 | 用来根据参数 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去。 | ||||
| @@ -30,7 +30,7 @@ def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], mo | |||||
| raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") | raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") | ||||
| # TODO 实现更详细的判断 | # TODO 实现更详细的判断 | ||||
| if device in ["cpu", "gpu", "cuda", "cuda:0", 0, None]: | |||||
| if device in ["cpu", "gpu", "cuda", None]: | |||||
| return JittorSingleDriver(model, device, **kwargs) | return JittorSingleDriver(model, device, **kwargs) | ||||
| elif type(device) is int: | elif type(device) is int: | ||||
| return JittorMPIDriver(model, device, **kwargs) | return JittorMPIDriver(model, device, **kwargs) | ||||
| @@ -1,23 +1,31 @@ | |||||
| import os | import os | ||||
| import random | |||||
| from pathlib import Path | from pathlib import Path | ||||
| from typing import Union, Optional | |||||
| from functools import partial | |||||
| import numpy as np | |||||
| from typing import Union, Optional, Dict | |||||
| from contextlib import nullcontext | |||||
| from dataclasses import dataclass | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
| from fastNLP.core.drivers.driver import Driver | from fastNLP.core.drivers.driver import Driver | ||||
| from fastNLP.core.dataloaders import JittorDataLoader | from fastNLP.core.dataloaders import JittorDataLoader | ||||
| from fastNLP.core.samplers import ReproducibleSampler, RandomSampler | |||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from fastNLP.core.utils import apply_to_collection | from fastNLP.core.utils import apply_to_collection | ||||
| from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_SEED_WORKERS | |||||
| from fastNLP.envs import ( | |||||
| FASTNLP_MODEL_FILENAME, | |||||
| FASTNLP_CHECKPOINT_FILENAME, | |||||
| ) | |||||
| if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
| import jittor as jt | import jittor as jt | ||||
| from jittor import Module | from jittor import Module | ||||
| from jittor.optim import Optimizer | from jittor.optim import Optimizer | ||||
| from jittor.dataset import Dataset | from jittor.dataset import Dataset | ||||
| from jittor.dataset import ( | |||||
| BatchSampler as JittorBatchSampler, | |||||
| Sampler as JittorSampler, | |||||
| RandomSampler as JittorRandomSampler, | |||||
| SequentialSampler as JittorSequentialSampler | |||||
| ) | |||||
| _reduces = { | _reduces = { | ||||
| 'max': jt.max, | 'max': jt.max, | ||||
| @@ -56,6 +64,7 @@ class JittorDriver(Driver): | |||||
| else: | else: | ||||
| jt.flags.auto_mixed_precision_level = 0 | jt.flags.auto_mixed_precision_level = 0 | ||||
| self.fp16 = fp16 | self.fp16 = fp16 | ||||
| self._auto_cast = nullcontext | |||||
| # 用来设置是否关闭 auto_param_call 中的参数匹配问题; | # 用来设置是否关闭 auto_param_call 中的参数匹配问题; | ||||
| self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | ||||
| @@ -68,7 +77,7 @@ class JittorDriver(Driver): | |||||
| def _check_optimizer_legality(optimizers): | def _check_optimizer_legality(optimizers): | ||||
| for each_optimizer in optimizers: | for each_optimizer in optimizers: | ||||
| if not isinstance(each_optimizer, Optimizer): | if not isinstance(each_optimizer, Optimizer): | ||||
| raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " | |||||
| raise TypeError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " | |||||
| f"not {type(each_optimizer)}.") | f"not {type(each_optimizer)}.") | ||||
| def step(self): | def step(self): | ||||
| @@ -117,30 +126,118 @@ class JittorDriver(Driver): | |||||
| model = self.unwrap_model() | model = self.unwrap_model() | ||||
| model.load(filepath) | model.load(filepath) | ||||
| def save_checkpoint(self): | |||||
| ... | |||||
| def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
| dataloader_args = self.get_dataloader_args(dataloader) | |||||
| if dataloader_args.sampler: | |||||
| sampler = dataloader_args.sampler | |||||
| else: | |||||
| raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||||
| num_consumed_batches = states.pop('num_consumed_batches') | |||||
| if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||||
| sampler_states = sampler.state_dict() | |||||
| # 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
| # 会造成多余实际消耗的问题。因为 | |||||
| num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
| if num_consumed_samples_array is not None: | |||||
| if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
| if dataloader_args.batch_size is not None: | |||||
| num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
| else: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
| "it may cause missing some samples when reload.") | |||||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
| assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
| else: | |||||
| if dataloader_args.batch_size is not None: | |||||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
| * num_consumed_batches | |||||
| else: | |||||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
| "it may cause missing some samples when reload.") | |||||
| states['sampler_states'] = sampler_states | |||||
| else: | |||||
| raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training ' | |||||
| 'state.') | |||||
| # 2. 保存模型的状态; | |||||
| if should_save_model: | |||||
| if not os.path.exists(folder): | |||||
| os.mkdir(folder) | |||||
| model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) | |||||
| self.save_model(model_path, only_state_dict=only_state_dict) | |||||
| # 3. 保存 optimizers 的状态; | |||||
| states["optimizers_state_dict"] = self.get_optimizer_state() | |||||
| # 4. 保存fp16的状态 | |||||
| logger.debug("Save optimizer state dict") | |||||
| jt.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||||
| def get_optimizer_state(self): | def get_optimizer_state(self): | ||||
| # optimizers_state_dict = {} | |||||
| # for i in range(len(self.optimizers)): | |||||
| # optimizer: torch.optim.Optimizer = self.optimizers[i] | |||||
| # optimizer_state = optimizer.state_dict() | |||||
| # optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) | |||||
| # optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||||
| # return optimizers_state_dict | |||||
| ... | |||||
| optimizers_state_dict = {} | |||||
| for i in range(len(self.optimizers)): | |||||
| optimizer: Optimizer = self.optimizers[i] | |||||
| optimizers_state_dict[f"optimizer{i}"] = optimizer.state_dict() # 注意这里没有使用 deepcopy,测试是不需要的; | |||||
| return optimizers_state_dict | |||||
| def load_optimizer_state(self, states): | def load_optimizer_state(self, states): | ||||
| # assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||||
| # f"checkpoint it is:{len(states)}" | |||||
| # for i in range(len(self.optimizers)): | |||||
| # optimizer: torch.optim.Optimizer = self.optimizers[i] | |||||
| # optimizer.load_state_dict(states[f"optimizer{i}"]) | |||||
| # logger.debug("Load optimizer state dict.") | |||||
| ... | |||||
| assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||||
| f"checkpoint it is:{len(states)}" | |||||
| for i in range(len(self.optimizers)): | |||||
| optimizer: Optimizer = self.optimizers[i] | |||||
| optimizer.load_state_dict(states[f"optimizer{i}"]) | |||||
| logger.debug("Load optimizer state dict.") | |||||
| def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
| states = jt.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | |||||
| # 1. 加载 optimizers 的状态; | |||||
| optimizers_state_dict = states.pop("optimizers_state_dict") | |||||
| self.load_optimizer_state(optimizers_state_dict) | |||||
| # 2. 加载模型状态; | |||||
| if should_load_model: | |||||
| self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) | |||||
| # 3. 加载fp16的状态 | |||||
| # 4. 恢复 sampler 的状态; | |||||
| dataloader_args = self.get_dataloader_args(dataloader) | |||||
| if dataloader_args.sampler is None: | |||||
| sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=dataloader_args.shuffle) | |||||
| elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||||
| sampler = dataloader_args.sampler | |||||
| elif isinstance(dataloader_args.sampler, JittorRandomSampler): | |||||
| sampler = RandomSampler(dataloader_args.sampler.dataset) | |||||
| logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | |||||
| elif isinstance(dataloader_args.sampler, JittorSequentialSampler): | |||||
| sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=False) | |||||
| logger.debug("Replace jittor Sampler into fastNLP RandomSampler without shuffle.") | |||||
| elif self.is_distributed(): | |||||
| raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" | |||||
| "`ReproducibleSampler`.") | |||||
| else: | |||||
| raise RuntimeError(f"Jittor sampler {type(dataloader_args.sampler)} is not supported now.") | |||||
| sampler.load_state_dict(states.pop('sampler_states')) | |||||
| states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||||
| # 4. 修改 trainer_state.batch_idx_in_epoch | |||||
| # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||||
| if dataloader_args.drop_last: | |||||
| batch_idx_in_epoch = len( | |||||
| sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||||
| else: | |||||
| batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||||
| (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||||
| def load_checkpoint(self): | |||||
| ... | |||||
| states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||||
| return states | |||||
| def get_evaluate_context(self): | def get_evaluate_context(self): | ||||
| return jt.no_grad | return jt.no_grad | ||||
| @@ -198,26 +295,8 @@ class JittorDriver(Driver): | |||||
| """ | """ | ||||
| return batch | return batch | ||||
| @staticmethod | |||||
| def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | |||||
| global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | |||||
| process_seed = jt.get_seed() | |||||
| # back out the base seed so we can use all the bits | |||||
| base_seed = process_seed - worker_id | |||||
| ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) | |||||
| # use 128 bits (4 x 32-bit words) | |||||
| np.random.seed(ss.generate_state(4)) | |||||
| # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module | |||||
| jittor_ss, stdlib_ss = ss.spawn(2) | |||||
| jt.set_global_seed(jittor_ss.generate_state(1, dtype=np.uint64)[0]) | |||||
| # use 128 bits expressed as an integer | |||||
| stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() | |||||
| random.seed(stdlib_seed) | |||||
| def set_deterministic_dataloader(self, dataloader: Union["JittorDataLoader", "Dataset"]): | def set_deterministic_dataloader(self, dataloader: Union["JittorDataLoader", "Dataset"]): | ||||
| if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: | |||||
| dataloader.worker_init_fn = partial(self.worker_init_function, | |||||
| rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))) | |||||
| ... | |||||
| def set_sampler_epoch(self, dataloader: Union["JittorDataLoader", "Dataset"], cur_epoch_idx: int): | def set_sampler_epoch(self, dataloader: Union["JittorDataLoader", "Dataset"], cur_epoch_idx: int): | ||||
| # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | ||||
| @@ -226,4 +305,45 @@ class JittorDriver(Driver): | |||||
| @staticmethod | @staticmethod | ||||
| def get_dataloader_args(dataloader: Union["JittorDataLoader", "Dataset"]): | def get_dataloader_args(dataloader: Union["JittorDataLoader", "Dataset"]): | ||||
| pass | |||||
| @dataclass | |||||
| class Res: | |||||
| dataset: Optional[Dataset] = None | |||||
| batch_sampler: Optional[JittorBatchSampler] = None | |||||
| sampler: Optional[JittorSampler] = None | |||||
| batch_size: Optional[int] = None | |||||
| shuffle: Optional[bool] = None | |||||
| drop_last: Optional[bool] = None | |||||
| res = Res() | |||||
| from fastNLP.core.dataloaders.jittor_dataloader.fdl import _JittorDataset | |||||
| if isinstance(dataloader, JittorDataLoader): | |||||
| # JittorDataLoader 实际上是迭代 dataset 成员的 | |||||
| dataloader = dataloader.dataset | |||||
| if isinstance(dataloader, _JittorDataset): | |||||
| # 获取最原始的 dataset | |||||
| res.dataset = dataloader.dataset | |||||
| else: | |||||
| res.dataset = dataloader | |||||
| # jittor 现在不支持 batch_sampler,所以除了 shuffle 都可以直接获取 | |||||
| res.batch_size = dataloader.batch_size | |||||
| res.drop_last = dataloader.drop_last | |||||
| if dataloader.sampler is None: | |||||
| # sampler 是 None,那么就从 Dataset 的属性中获取 | |||||
| res.shuffle = dataloader.shuffle | |||||
| elif isinstance(list(dataloader.sampler.__iter__())[0], (list,tuple)): | |||||
| # jittor 目前不支持 batch_sampler | |||||
| raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, " | |||||
| "please check if you have set `Dataset.sampler` as `BatchSampler`") | |||||
| else: | |||||
| # sampler 不为 None | |||||
| res.sampler = dataloader.sampler | |||||
| if hasattr(dataloader.sampler, "shuffle"): | |||||
| # 这种情况一般出现在 fastNLP 的 ReproduceSampler 中 | |||||
| res.shuffle = dataloader.sampler.shuffle | |||||
| elif isinstance(dataloader.sampler, JittorRandomSampler): | |||||
| res.shuffle = True | |||||
| else: | |||||
| res.shuffle = False | |||||
| return res | |||||
| @@ -38,6 +38,7 @@ class JittorMPIDriver(JittorDriver): | |||||
| ): | ): | ||||
| super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs) | super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs) | ||||
| raise NotImplementedError("MPI for Jittor is not supported right now.") | |||||
| self.is_pull_by_jittor_run = is_pull_by_jittor_run | self.is_pull_by_jittor_run = is_pull_by_jittor_run | ||||
| self.parallel_device = parallel_device | self.parallel_device = parallel_device | ||||
| @@ -100,22 +101,6 @@ class JittorMPIDriver(JittorDriver): | |||||
| return self._data_device | return self._data_device | ||||
| return self.parallel_device | return self.parallel_device | ||||
| def step(self): | |||||
| # for optimizer in self.optimizers: | |||||
| # self.grad_scaler.step(optimizer) | |||||
| # self.grad_scaler.update() | |||||
| for optimizer in self.optimizers: | |||||
| optimizer.step() | |||||
| def backward(self, loss): | |||||
| # self.grad_scaler.scale(loss).backward() | |||||
| for optimizer in self.optimizers: | |||||
| optimizer.backward(loss) | |||||
| def zero_grad(self): | |||||
| for optimizer in self.optimizers: | |||||
| optimizer.zero_grad() | |||||
| def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | ||||
| if isinstance(batch, Dict) and not self.wo_auto_param_call: | if isinstance(batch, Dict) and not self.wo_auto_param_call: | ||||
| return auto_param_call(fn, batch, signature_fn=signature_fn) | return auto_param_call(fn, batch, signature_fn=signature_fn) | ||||
| @@ -1,14 +1,21 @@ | |||||
| from typing import Dict, Union, Tuple, Callable, Optional | from typing import Dict, Union, Tuple, Callable, Optional | ||||
| from .jittor_driver import JittorDriver | from .jittor_driver import JittorDriver | ||||
| from .utils import replace_batch_sampler, replace_sampler | |||||
| from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
| from fastNLP.core.utils.utils import _get_fun_msg | from fastNLP.core.utils.utils import _get_fun_msg | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \ | |||||
| ReproduceBatchSampler | |||||
| from fastNLP.core.samplers import RandomSampler | |||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
| import jittor as jt | import jittor as jt | ||||
| from jittor.dataset import ( | |||||
| RandomSampler as JittorRandomSampler, | |||||
| SequentialSampler as JittorSequentialSampler, | |||||
| ) | |||||
| __all__ = [ | __all__ = [ | ||||
| "JittorSingleDriver", | "JittorSingleDriver", | ||||
| @@ -89,31 +96,46 @@ class JittorSingleDriver(JittorDriver): | |||||
| """ | """ | ||||
| return False | return False | ||||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||||
| reproducible: bool = False, sampler_or_batch_sampler=None): | |||||
| # reproducible 的相关功能暂时没有实现 | |||||
| def set_dist_repro_dataloader(self, dataloader, | |||||
| dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, | |||||
| reproducible: bool = False): | |||||
| # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||||
| if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
| raise NotImplementedError | |||||
| dataloader.batch_sampler = dist_sample | |||||
| if isinstance(dist, ReproducibleSampler): | |||||
| raise NotImplementedError | |||||
| dataloader.batch_sampler.sampler = dist | |||||
| return replace_batch_sampler(dataloader, dist) | |||||
| elif isinstance(dist, ReproducibleSampler): | |||||
| return replace_sampler(dataloader, dist) | |||||
| # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||||
| args = self.get_dataloader_args(dataloader) | |||||
| if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
| batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
| return replace_batch_sampler(dataloader, batch_sampler) | |||||
| elif isinstance(args.sampler, ReproducibleSampler): | |||||
| sampler = re_instantiate_sampler(args.sampler) | |||||
| return replace_sampler(dataloader, sampler) | |||||
| if reproducible: | if reproducible: | ||||
| raise NotImplementedError | |||||
| if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||||
| return dataloader | |||||
| elif isinstance(dataloader.batch_sampler, RandomBatchSampler): | |||||
| return dataloader | |||||
| else: | |||||
| # TODO | |||||
| batch_sampler = RandomBatchSampler( | |||||
| batch_sampler=dataloader.batch_sampler, | |||||
| batch_size=dataloader.batch_sampler.batch_size, | |||||
| drop_last=dataloader.drop_last | |||||
| ) | |||||
| dataloader.batch_sampler = batch_sampler | |||||
| return dataloader | |||||
| if args.sampler is None: | |||||
| sampler = RandomSampler(args.dataset, args.shuffle) | |||||
| return replace_sampler(dataloader, sampler) | |||||
| elif isinstance(args.sampler, JittorRandomSampler): | |||||
| if getattr(args.sampler, '_num_samples', None) is None \ | |||||
| and getattr(args.sampler, 'rep', False) is False: | |||||
| # 如果本来就是随机的,并且没有定制,直接替换掉吧。 | |||||
| sampler = RandomSampler(args.sampler.dataset, shuffle=True) | |||||
| logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | |||||
| return replace_sampler(dataloader, sampler) | |||||
| elif isinstance(args.sampler, JittorSequentialSampler): | |||||
| # 需要替换为不要 shuffle 的。 | |||||
| sampler = RandomSampler(args.sampler.dataset, shuffle=False) | |||||
| logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") | |||||
| return replace_sampler(dataloader, sampler) | |||||
| batch_sampler = ReproduceBatchSampler( | |||||
| batch_sampler=args.batch_sampler, | |||||
| batch_size=args.batch_size, | |||||
| drop_last=args.drop_last | |||||
| ) | |||||
| return replace_batch_sampler(dataloader, batch_sampler) | |||||
| else: | else: | ||||
| return dataloader | return dataloader | ||||
| @@ -1,6 +1,29 @@ | |||||
| import inspect | |||||
| from copy import deepcopy | |||||
| from typing import Union | |||||
| from fastNLP.core.dataloaders import JittorDataLoader | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
| if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
| import jittor | |||||
| from jittor.dataset import Dataset | |||||
| __all__ = [] | __all__ = [] | ||||
| def replace_batch_sampler(dataloader, batch_sampler): | |||||
| raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, " | |||||
| "please check if you have set `Dataset.sampler` as `BatchSampler`" | |||||
| "or report this bug to us.") | |||||
| def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler): | |||||
| if isinstance(dataloader, JittorDataLoader): | |||||
| init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||||
| reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()} | |||||
| reconstruct_args["dataset"] = replace_sampler(reconstruct_args["dataset"].dataset, reconstruct_args["dataset"].sampler) | |||||
| new_dataloader = type(dataloader)(**reconstruct_args) | |||||
| new_dataloader.dataset.set_attrs(sampler=sampler) | |||||
| else: | |||||
| new_dataloader = deepcopy(dataloader) | |||||
| new_dataloader.set_attrs(sampler=sampler) | |||||
| return new_dataloader | |||||
| @@ -31,7 +31,6 @@ if _NEED_IMPORT_PADDLE: | |||||
| import paddle | import paddle | ||||
| from paddle.io import ( | from paddle.io import ( | ||||
| DataLoader, | DataLoader, | ||||
| IterableDataset, | |||||
| Dataset, | Dataset, | ||||
| Sampler, | Sampler, | ||||
| BatchSampler, | BatchSampler, | ||||
| @@ -97,6 +96,9 @@ class PaddleDriver(Driver): | |||||
| def check_dataloader_legality(self, dataloader): | def check_dataloader_legality(self, dataloader): | ||||
| if not isinstance(dataloader, DataLoader): | if not isinstance(dataloader, DataLoader): | ||||
| raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | ||||
| if dataloader.batch_size is None and dataloader.batch_sampler is None: | |||||
| raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" | |||||
| "is not None") | |||||
| @staticmethod | @staticmethod | ||||
| def _check_optimizer_legality(optimizers): | def _check_optimizer_legality(optimizers): | ||||
| @@ -107,7 +109,7 @@ class PaddleDriver(Driver): | |||||
| """ | """ | ||||
| for each_optimizer in optimizers: | for each_optimizer in optimizers: | ||||
| if not isinstance(each_optimizer, Optimizer): | if not isinstance(each_optimizer, Optimizer): | ||||
| raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " | |||||
| raise TypeError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " | |||||
| f"not {type(each_optimizer)}.") | f"not {type(each_optimizer)}.") | ||||
| @staticmethod | @staticmethod | ||||
| @@ -263,9 +265,7 @@ class PaddleDriver(Driver): | |||||
| optimizers_state_dict = {} | optimizers_state_dict = {} | ||||
| for i in range(len(self.optimizers)): | for i in range(len(self.optimizers)): | ||||
| optimizer: Optimizer = self.optimizers[i] | optimizer: Optimizer = self.optimizers[i] | ||||
| optimizer_state = optimizer.state_dict() | |||||
| optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") | |||||
| optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||||
| optimizers_state_dict[f"optimizer{i}"] = optimizer_state_to_device(optimizer.state_dict(), "cpu") | |||||
| return optimizers_state_dict | return optimizers_state_dict | ||||
| @@ -399,6 +399,8 @@ class PaddleDriver(Driver): | |||||
| def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx): | def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx): | ||||
| if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | ||||
| dataloader.batch_sampler.set_epoch(cur_epoch_idx) | dataloader.batch_sampler.set_epoch(cur_epoch_idx) | ||||
| elif callable(getattr(dataloader.batch_sampler.sampler, "set_epoch", None)): | |||||
| dataloader.batch_sampler.sampler.set_epoch(cur_epoch_idx) | |||||
| @staticmethod | @staticmethod | ||||
| def get_dataloader_args(dataloader: "DataLoader"): | def get_dataloader_args(dataloader: "DataLoader"): | ||||
| @@ -99,7 +99,7 @@ class TorchDriver(Driver): | |||||
| def _check_optimizer_legality(optimizers): | def _check_optimizer_legality(optimizers): | ||||
| for each_optimizer in optimizers: | for each_optimizer in optimizers: | ||||
| if not isinstance(each_optimizer, Optimizer): | if not isinstance(each_optimizer, Optimizer): | ||||
| raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " | |||||
| raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " | |||||
| f"not {type(each_optimizer)}.") | f"not {type(each_optimizer)}.") | ||||
| @staticmethod | @staticmethod | ||||
| @@ -210,7 +210,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
| self.during_iter = True | self.during_iter = True | ||||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
| indices = list(range(self.num_samples)) | |||||
| if self.shuffle: | if self.shuffle: | ||||
| if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | ||||
| @@ -237,7 +237,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| if len(indices)%self.batch_size!=0: | if len(indices)%self.batch_size!=0: | ||||
| batches.append(indices[_num_batches*self.batch_size:]) | batches.append(indices[_num_batches*self.batch_size:]) | ||||
| need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas | |||||
| need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas | |||||
| if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | ||||
| if len(batches) > 0: | if len(batches) > 0: | ||||
| if len(batches[-1])<self.batch_size: | if len(batches[-1])<self.batch_size: | ||||
| @@ -290,9 +290,9 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| @property | @property | ||||
| def batch_idx_in_epoch(self): | def batch_idx_in_epoch(self): | ||||
| if self.drop_last: | if self.drop_last: | ||||
| return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
| return self.num_samples // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
| else: | else: | ||||
| return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
| return (self.num_samples // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
| (self.num_left_samples + self.batch_size - 1) // self.batch_size | (self.num_left_samples + self.batch_size - 1) // self.batch_size | ||||
| @property | @property | ||||
| @@ -313,8 +313,12 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| num_consumed_samples = self.num_consumed_samples | num_consumed_samples = self.num_consumed_samples | ||||
| return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||||
| self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||||
| return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \ | |||||
| self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas)) | |||||
| @property | |||||
| def num_samples(self): | |||||
| return getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
| def __len__(self)->int: | def __len__(self)->int: | ||||
| """ | """ | ||||
| @@ -332,7 +336,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | ||||
| " consumed. ") | " consumed. ") | ||||
| states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | ||||
| 'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, | |||||
| 'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle, | |||||
| 'batch_size': self.batch_size, | 'batch_size': self.batch_size, | ||||
| 'num_replicas': self.num_replicas} | 'num_replicas': self.num_replicas} | ||||
| @@ -347,7 +351,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
| length = states['length'] | length = states['length'] | ||||
| assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ | |||||
| assert length == self.num_samples, "The number of samples is different between the checkpoint record " \ | |||||
| "and current dataset." | "and current dataset." | ||||
| self.seed = states['seed'] | self.seed = states['seed'] | ||||
| self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
| @@ -464,8 +468,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| num_consumed_samples = self.num_consumed_samples | num_consumed_samples = self.num_consumed_samples | ||||
| return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||||
| self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||||
| return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \ | |||||
| self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas)) | |||||
| @property | |||||
| def num_samples(self): | |||||
| return getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
| def __len__(self)->int: | def __len__(self)->int: | ||||
| """ | """ | ||||
| @@ -515,7 +523,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| if len(sorted_indices)%self.batch_size!=0: | if len(sorted_indices)%self.batch_size!=0: | ||||
| batches.append(sorted_indices[_num_batches*self.batch_size:]) | batches.append(sorted_indices[_num_batches*self.batch_size:]) | ||||
| need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas | |||||
| need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas | |||||
| if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | ||||
| if len(batches) > 0: | if len(batches) > 0: | ||||
| if len(batches[-1])<self.batch_size: | if len(batches[-1])<self.batch_size: | ||||
| @@ -593,7 +601,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | ||||
| " consumed. ") | " consumed. ") | ||||
| states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | ||||
| 'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, | |||||
| 'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle, | |||||
| 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | ||||
| 'num_replicas': self.num_replicas | 'num_replicas': self.num_replicas | ||||
| } | } | ||||
| @@ -609,7 +617,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
| length = states['length'] | length = states['length'] | ||||
| assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ | |||||
| assert length == self.num_samples, "The number of samples is different between the checkpoint record " \ | |||||
| "and current dataset." | "and current dataset." | ||||
| self.seed = states['seed'] | self.seed = states['seed'] | ||||
| self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
| @@ -630,7 +638,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| @property | @property | ||||
| def batch_idx_in_epoch(self): | def batch_idx_in_epoch(self): | ||||
| if self.drop_last: | if self.drop_last: | ||||
| return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
| return self.num_samples // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
| else: | else: | ||||
| return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
| return (self.num_samples // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
| (self.num_left_samples + self.batch_size - 1) // self.batch_size | (self.num_left_samples + self.batch_size - 1) // self.batch_size | ||||
| @@ -48,6 +48,10 @@ class ReproducibleSampler: | |||||
| def num_left_samples(self): | def num_left_samples(self): | ||||
| raise NotImplementedError("Each specific sampler should implement its own `num_left_samples` method.") | raise NotImplementedError("Each specific sampler should implement its own `num_left_samples` method.") | ||||
| @property | |||||
| def num_samples(self): | |||||
| raise NotImplementedError("Each specific sampler should implement its own `num_samples` method.") | |||||
| def set_epoch(self, epoch): | def set_epoch(self, epoch): | ||||
| pass | pass | ||||
| @@ -131,19 +135,19 @@ class RandomSampler(ReproducibleSampler): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| if self.shuffle: | if self.shuffle: | ||||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
| indices = list(range(self.num_samples)) | |||||
| seed = self.seed + self.epoch | seed = self.seed + self.epoch | ||||
| rng = np.random.default_rng(abs(seed)) | rng = np.random.default_rng(abs(seed)) | ||||
| rng.shuffle(indices) | rng.shuffle(indices) | ||||
| if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | ||||
| self.epoch -= 1 | self.epoch -= 1 | ||||
| else: | else: | ||||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
| indices = list(range(self.num_samples)) | |||||
| return indices | return indices | ||||
| def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
| states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | ||||
| 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle} | |||||
| 'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle} | |||||
| return states | return states | ||||
| def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
| @@ -155,8 +159,8 @@ class RandomSampler(ReproducibleSampler): | |||||
| f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
| length = states['length'] | length = states['length'] | ||||
| assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
| f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||||
| assert length == self.num_samples, "The number of samples is different between the checkpoint " \ | |||||
| f"record({length}) and current dataset({self.num_samples})." | |||||
| self.seed = states['seed'] | self.seed = states['seed'] | ||||
| self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
| self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
| @@ -208,9 +212,17 @@ class RandomSampler(ReproducibleSampler): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| num_consumed_samples = self.num_consumed_samples | num_consumed_samples = self.num_consumed_samples | ||||
| return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||||
| self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||||
| return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \ | |||||
| self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas)) | |||||
| @property | |||||
| def num_samples(self): | |||||
| """ | |||||
| 返回样本的总数 | |||||
| :return: | |||||
| """ | |||||
| return getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
| class SequentialSampler(RandomSampler): | class SequentialSampler(RandomSampler): | ||||
| """ | """ | ||||
| @@ -258,12 +270,10 @@ class SequentialSampler(RandomSampler): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
| return list(range(self.num_samples)) | |||||
| def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
| states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | |||||
| 'length': getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
| } | |||||
| states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, 'length': self.num_samples} | |||||
| return states | return states | ||||
| def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
| @@ -275,8 +285,8 @@ class SequentialSampler(RandomSampler): | |||||
| f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
| length = states['length'] | length = states['length'] | ||||
| assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
| f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||||
| assert length == self.num_samples, "The number of samples is different between the checkpoint " \ | |||||
| f"record({length}) and current dataset({self.num_samples})." | |||||
| self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
| if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | ||||
| self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
| @@ -314,9 +324,9 @@ class SortedSampler(SequentialSampler): | |||||
| except BaseException as e: | except BaseException as e: | ||||
| logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") | logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") | ||||
| assert len(length) == getattr(self.dataset, 'total_len', len(self.dataset)), f"The length of `dataset`({len(dataset)}) and " \ | |||||
| f"`length`({getattr(self.dataset, 'total_len', len(self.dataset))}) should be equal." | |||||
| assert len(self.sorted_indices) == getattr(self.dataset, 'total_len', len(self.dataset)), "The indices and dataset should have equal length." | |||||
| assert len(length) == self.num_samples, f"The length of `dataset`({len(dataset)}) and " \ | |||||
| f"`length`({self.num_samples}) should be equal." | |||||
| assert len(self.sorted_indices) == self.num_samples, "The indices and dataset should have equal length." | |||||
| self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | ||||
| self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | ||||
| @@ -42,8 +42,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
| 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| num_common = getattr(self.dataset, 'total_len', len(self.dataset))//self.num_replicas | |||||
| num_samples = num_common + int(self.rank < (getattr(self.dataset, 'total_len', len(self.dataset))-num_common*self.num_replicas)) | |||||
| num_common = self.num_samples//self.num_replicas | |||||
| num_samples = num_common + int(self.rank < (self.num_samples-num_common*self.num_replicas)) | |||||
| return num_samples | return num_samples | ||||
| def __iter__(self): | def __iter__(self): | ||||
| @@ -63,14 +63,14 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| if self.shuffle: | if self.shuffle: | ||||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
| indices = list(range(self.num_samples)) | |||||
| seed = self.seed + self.epoch | seed = self.seed + self.epoch | ||||
| rng = np.random.default_rng(abs(seed)) | rng = np.random.default_rng(abs(seed)) | ||||
| rng.shuffle(indices) | rng.shuffle(indices) | ||||
| if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | ||||
| self.epoch -= 1 | self.epoch -= 1 | ||||
| else: | else: | ||||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
| indices = list(range(self.num_samples)) | |||||
| return indices | return indices | ||||
| def set_epoch(self, epoch: int) -> None: | def set_epoch(self, epoch: int) -> None: | ||||
| @@ -84,8 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
| :param rank: | :param rank: | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| assert num_replicas<=getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of replicas({num_replicas}) should be lesser than the " \ | |||||
| f"number of samples({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||||
| assert num_replicas<=self.num_samples, f"The number of replicas({num_replicas}) should be lesser than the " \ | |||||
| f"number of samples({self.num_samples})." | |||||
| assert num_replicas>0 and isinstance(num_replicas, int) | assert num_replicas>0 and isinstance(num_replicas, int) | ||||
| assert isinstance(rank, int) and 0<=rank<num_replicas | assert isinstance(rank, int) and 0<=rank<num_replicas | ||||
| # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | ||||
| @@ -94,6 +94,15 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
| return self | return self | ||||
| @property | |||||
| def num_samples(self): | |||||
| """ | |||||
| 返回样本的总数 | |||||
| :return: | |||||
| """ | |||||
| return getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
| class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | ||||
| """ | """ | ||||
| @@ -147,5 +156,5 @@ class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | |||||
| yield index | yield index | ||||
| def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
| return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
| return list(range(self.num_samples)) | |||||
| @@ -27,7 +27,7 @@ from paddle.optimizer import Adam | |||||
| from paddle.io import DataLoader | from paddle.io import DataLoader | ||||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||||
| from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset | |||||
| from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | ||||
| @dataclass | @dataclass | ||||
| @@ -52,12 +52,12 @@ def test_trainer_fleet( | |||||
| optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | ||||
| train_dataloader = DataLoader( | train_dataloader = DataLoader( | ||||
| dataset=PaddleRandomMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||||
| dataset=PaddleArgMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||||
| batch_size=MNISTTrainFleetConfig.batch_size, | batch_size=MNISTTrainFleetConfig.batch_size, | ||||
| shuffle=True | shuffle=True | ||||
| ) | ) | ||||
| val_dataloader = DataLoader( | val_dataloader = DataLoader( | ||||
| dataset=PaddleRandomMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||||
| dataset=PaddleArgMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||||
| batch_size=MNISTTrainFleetConfig.batch_size, | batch_size=MNISTTrainFleetConfig.batch_size, | ||||
| shuffle=True | shuffle=True | ||||
| ) | ) | ||||
| @@ -24,7 +24,7 @@ from paddle.io import DataLoader | |||||
| import paddle.distributed.fleet as fleet | import paddle.distributed.fleet as fleet | ||||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_2 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_2 | ||||
| from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||||
| from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset | |||||
| from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | ||||
| @dataclass | @dataclass | ||||
| @@ -54,12 +54,12 @@ def test_trainer_fleet( | |||||
| optimizers = fleet.distributed_optimizer(optimizers) | optimizers = fleet.distributed_optimizer(optimizers) | ||||
| train_dataloader = DataLoader( | train_dataloader = DataLoader( | ||||
| dataset=PaddleRandomMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||||
| dataset=PaddleArgMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||||
| batch_size=MNISTTrainFleetConfig.batch_size, | batch_size=MNISTTrainFleetConfig.batch_size, | ||||
| shuffle=True | shuffle=True | ||||
| ) | ) | ||||
| val_dataloader = DataLoader( | val_dataloader = DataLoader( | ||||
| dataset=PaddleRandomMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||||
| dataset=PaddleArgMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||||
| batch_size=MNISTTrainFleetConfig.batch_size, | batch_size=MNISTTrainFleetConfig.batch_size, | ||||
| shuffle=True | shuffle=True | ||||
| ) | ) | ||||
| @@ -46,8 +46,8 @@ class LSTM(Module): | |||||
| def init_hidden(self, x): | def init_hidden(self, x): | ||||
| # batch_first | # batch_first | ||||
| batch_size = x.shape[0] | batch_size = x.shape[0] | ||||
| h0 = jt.randn(1, batch_size, hidden_size) | |||||
| c0 = jt.randn(1, batch_size, hidden_size) | |||||
| h0 = jt.randn(1, batch_size, self.hidden_size) | |||||
| c0 = jt.randn(1, batch_size, self.hidden_size) | |||||
| return h0, c0 | return h0, c0 | ||||
| @@ -1,4 +1,5 @@ | |||||
| import pytest | import pytest | ||||
| from fastNLP.core.callbacks import callback | |||||
| from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
| from fastNLP.core.controllers.trainer import Evaluator | from fastNLP.core.controllers.trainer import Evaluator | ||||
| @@ -14,6 +15,7 @@ if _NEED_IMPORT_JITTOR: | |||||
| else: | else: | ||||
| from fastNLP.core.utils.dummy_class import DummyClass as Module | from fastNLP.core.utils.dummy_class import DummyClass as Module | ||||
| from fastNLP.core.utils.dummy_class import DummyClass as Dataset | from fastNLP.core.utils.dummy_class import DummyClass as Dataset | ||||
| jt.flags.use_cuda=1 | |||||
| class JittorNormalModel_Classification(Module): | class JittorNormalModel_Classification(Module): | ||||
| @@ -68,11 +70,9 @@ class TrainJittorConfig: | |||||
| batch_size: int = 4 | batch_size: int = 4 | ||||
| shuffle: bool = True | shuffle: bool = True | ||||
| @pytest.mark.parametrize("driver", ["jittor"]) | @pytest.mark.parametrize("driver", ["jittor"]) | ||||
| @pytest.mark.parametrize("device", ["cpu", "gpu", "cuda:0"]) | |||||
| @pytest.mark.parametrize("device", ["cpu", "gpu", "cuda", None]) | |||||
| @pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) | @pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) | ||||
| @pytest.mark.jittor | |||||
| def test_trainer_jittor( | def test_trainer_jittor( | ||||
| driver, | driver, | ||||
| device, | device, | ||||
| @@ -15,7 +15,7 @@ if _NEED_IMPORT_PADDLE: | |||||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||||
| from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset | |||||
| from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
| @dataclass | @dataclass | ||||
| @@ -44,12 +44,12 @@ def test_trainer_paddle( | |||||
| ) | ) | ||||
| optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | ||||
| train_dataloader = DataLoader( | train_dataloader = DataLoader( | ||||
| dataset=PaddleRandomMaxDataset(20, TrainPaddleConfig.feature_dimension), | |||||
| dataset=PaddleArgMaxDataset(20, TrainPaddleConfig.feature_dimension), | |||||
| batch_size=TrainPaddleConfig.batch_size, | batch_size=TrainPaddleConfig.batch_size, | ||||
| shuffle=True | shuffle=True | ||||
| ) | ) | ||||
| val_dataloader = DataLoader( | val_dataloader = DataLoader( | ||||
| dataset=PaddleRandomMaxDataset(12, TrainPaddleConfig.feature_dimension), | |||||
| dataset=PaddleArgMaxDataset(12, TrainPaddleConfig.feature_dimension), | |||||
| batch_size=TrainPaddleConfig.batch_size, | batch_size=TrainPaddleConfig.batch_size, | ||||
| shuffle=True | shuffle=True | ||||
| ) | ) | ||||
| @@ -76,7 +76,7 @@ class TestPaddle: | |||||
| from paddle.io import Dataset | from paddle.io import Dataset | ||||
| import paddle | import paddle | ||||
| class PaddleRandomMaxDataset(Dataset): | |||||
| class PaddleArgMaxDataset(Dataset): | |||||
| def __init__(self, num_samples, num_features): | def __init__(self, num_samples, num_features): | ||||
| self.x = paddle.randn((num_samples, num_features)) | self.x = paddle.randn((num_samples, num_features)) | ||||
| self.y = self.x.argmax(axis=-1) | self.y = self.x.argmax(axis=-1) | ||||
| @@ -87,7 +87,7 @@ class TestPaddle: | |||||
| def __getitem__(self, item): | def __getitem__(self, item): | ||||
| return {"x": self.x[item], "y": self.y[item]} | return {"x": self.x[item], "y": self.y[item]} | ||||
| ds = PaddleRandomMaxDataset(100, 2) | |||||
| ds = PaddleArgMaxDataset(100, 2) | |||||
| dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4) | dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4) | ||||
| for batch in dl: | for batch in dl: | ||||
| print(batch) | print(batch) | ||||
| @@ -0,0 +1,45 @@ | |||||
| import pytest | |||||
| from fastNLP.core.drivers import JittorSingleDriver, JittorMPIDriver | |||||
| from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver | |||||
| from tests.helpers.models.jittor_model import JittorNormalModel_Classification_1 | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
| if _NEED_IMPORT_JITTOR: | |||||
| import jittor as jt | |||||
| @pytest.mark.jittor | |||||
| def test_incorrect_driver(): | |||||
| model = JittorNormalModel_Classification_1(20, 10) | |||||
| with pytest.raises(ValueError): | |||||
| driver = initialize_jittor_driver("torch", 0, model) | |||||
| @pytest.mark.jittor | |||||
| @pytest.mark.parametrize( | |||||
| "device", | |||||
| ["cpu", "gpu", None, "cuda"] | |||||
| ) | |||||
| def test_get_single_device(device): | |||||
| """ | |||||
| 测试正常情况下初始化 JittorSingleDriver 的情况 | |||||
| """ | |||||
| model = JittorNormalModel_Classification_1(20, 10) | |||||
| driver = initialize_jittor_driver("jittor", device, model) | |||||
| assert isinstance(driver, JittorSingleDriver) | |||||
| @pytest.mark.jittor | |||||
| @pytest.mark.parametrize( | |||||
| "device", | |||||
| [[0, 2, 3], 1, 2] | |||||
| ) | |||||
| def test_get_mpi(device): | |||||
| """ | |||||
| 测试 jittor 多卡的初始化情况 | |||||
| """ | |||||
| model = JittorNormalModel_Classification_1(20, 10) | |||||
| with pytest.raises(NotImplementedError): | |||||
| driver = initialize_jittor_driver("jittor", device, model) | |||||
| # assert isinstance(driver, JittorMPIDriver) | |||||
| @@ -1,99 +1,614 @@ | |||||
| import pytest | import pytest | ||||
| import os | |||||
| from copy import deepcopy | |||||
| from pathlib import Path | |||||
| import numpy as np | |||||
| from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
| from fastNLP.core.drivers.jittor_driver import JittorSingleDriver | |||||
| from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||||
| from fastNLP.core.dataloaders import JittorDataLoader | |||||
| from tests.helpers.models.jittor_model import JittorNormalModel_Classification_1 | |||||
| from tests.helpers.datasets.jittor_data import JittorNormalDataset, JittorNormalXYDataset | |||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
| from fastNLP.envs.distributed import rank_zero_rm | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH | |||||
| if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
| import jittor as jt # 将 jittor 引入 | |||||
| from jittor import nn, Module # 引入相关的模块 | |||||
| from jittor import init | |||||
| from jittor.dataset import MNIST | |||||
| else: | |||||
| from fastNLP.core.utils.dummy_class import DummyClass as Module | |||||
| import jittor as jt | |||||
| from jittor.dataset import ( | |||||
| BatchSampler as JittorBatchSampler, | |||||
| RandomSampler as JittorRandomSampler, | |||||
| SequentialSampler as JittorSequentialSampler, | |||||
| SubsetRandomSampler as JittorSubsetRandomSampler | |||||
| ) | |||||
| if _NEED_IMPORT_TORCH: | |||||
| import torch | |||||
| def get_dataloader(dataset, use_dataloader, sampler, batch_size, shuffle, drop_last=False): | |||||
| """ | |||||
| :param dataset: | |||||
| :param use_dataloader: 是否使用 JittorDataLoader 包裹 | |||||
| :param sampler: 使用 BatchSampler Samlper 还是不使用 Sampler | |||||
| """ | |||||
| if use_dataloader: | |||||
| dataloader = JittorDataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) | |||||
| dataloader.dataset.set_attrs(sampler=sampler) | |||||
| else: | |||||
| dataloader = dataset | |||||
| dataloader.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, sampler=sampler) | |||||
| class Model(Module): | |||||
| def __init__ (self): | |||||
| super (Model, self).__init__() | |||||
| self.conv1 = nn.Conv (3, 32, 3, 1) # no padding | |||||
| self.conv2 = nn.Conv (32, 64, 3, 1) | |||||
| self.bn = nn.BatchNorm(64) | |||||
| self.max_pool = nn.Pool (2, 2) | |||||
| self.relu = nn.Relu() | |||||
| self.fc1 = nn.Linear (64 * 12 * 12, 256) | |||||
| self.fc2 = nn.Linear (256, 10) | |||||
| def execute(self, x) : | |||||
| # it's simliar to forward function in Pytorch | |||||
| x = self.conv1 (x) | |||||
| x = self.relu (x) | |||||
| x = self.conv2 (x) | |||||
| x = self.bn (x) | |||||
| x = self.relu (x) | |||||
| x = self.max_pool (x) | |||||
| x = jt.reshape (x, [x.shape[0], -1]) | |||||
| x = self.fc1 (x) | |||||
| x = self.relu(x) | |||||
| x = self.fc2 (x) | |||||
| return x | |||||
| return dataloader | |||||
| ############################################################################ | |||||
| # | |||||
| # 测试基类 JittorDrvier 中的一些简单函数 | |||||
| # | |||||
| ############################################################################ | |||||
| class TestJittorDriverFunctions: | |||||
| """ | |||||
| 使用 JittorSingleDriver 测试基类的函数 | |||||
| """ | |||||
| @classmethod | |||||
| def setup_class(self): | |||||
| model = JittorNormalModel_Classification_1(10, 32) | |||||
| self.driver = JittorSingleDriver(model, device="cpu") | |||||
| @pytest.mark.jittor | |||||
| def test_check_optimizers_legality(self): | |||||
| """ | |||||
| 测试对合法的 optimizers 的检查 | |||||
| """ | |||||
| # 单个 optimizer | |||||
| optimizer = jt.optim.Adam( | |||||
| params=self.driver.model.parameters(), | |||||
| lr=0.01 | |||||
| ) | |||||
| self.driver.set_optimizers(optimizer) | |||||
| # optimizer 列表 | |||||
| optimizers = [ | |||||
| jt.optim.Adam( | |||||
| params=self.driver.model.parameters(), | |||||
| lr=0.01 | |||||
| ) for i in range(10) | |||||
| ] | |||||
| self.driver.set_optimizers(optimizers) | |||||
| @pytest.mark.torchjittor | |||||
| def test_invalid_optimizers(self): | |||||
| """ | |||||
| 测试传入非法的 optimizers | |||||
| """ | |||||
| # 单个 optimizer | |||||
| optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||||
| with pytest.raises(TypeError): | |||||
| self.driver.set_optimizers(optimizer) | |||||
| optimizers = [ | |||||
| torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||||
| ] | |||||
| with pytest.raises(TypeError): | |||||
| self.driver.set_optimizers(optimizers) | |||||
| @pytest.mark.jittor | |||||
| def test_check_dataloader_legality(self): | |||||
| """ | |||||
| 测试 check_dataloader_legality 函数的表现 | |||||
| """ | |||||
| # 使用 JittorDataLoader | |||||
| dataloader = JittorDataLoader(JittorNormalDataset()) | |||||
| self.driver.check_dataloader_legality(dataloader) | |||||
| # 使用 jittor.dataset.Dataset | |||||
| self.driver.check_dataloader_legality(JittorNormalDataset()) | |||||
| @pytest.mark.torchjittor | |||||
| def test_check_dataloader_legality_invalid(self): | |||||
| """ | |||||
| 测试 check_dataloader_legality 函数传入其他类型的表现 | |||||
| """ | |||||
| # 创建 torch 的 dataloader | |||||
| dataloader = torch.utils.data.DataLoader( | |||||
| TorchNormalDataset(), | |||||
| batch_size=32, shuffle=True | |||||
| ) | |||||
| with pytest.raises(TypeError): | |||||
| self.driver.check_dataloader_legality(dataloader) | |||||
| @pytest.mark.jittor | |||||
| def test_tensor_to_numeric(self): | |||||
| """ | |||||
| 测试 tensor_to_numeric 函数 | |||||
| """ | |||||
| # 单个张量 | |||||
| tensor = jt.Var(3) | |||||
| res = JittorSingleDriver.tensor_to_numeric(tensor) | |||||
| assert res == 3 | |||||
| tensor = jt.rand(3, 4) | |||||
| res = JittorSingleDriver.tensor_to_numeric(tensor) | |||||
| assert res == tensor.tolist() | |||||
| # 张量list | |||||
| tensor_list = [jt.rand(6, 4, 2) for i in range(10)] | |||||
| res = JittorSingleDriver.tensor_to_numeric(tensor_list) | |||||
| assert isinstance(res, list) | |||||
| tensor_list = [t.tolist() for t in tensor_list] | |||||
| assert res == tensor_list | |||||
| # 张量tuple | |||||
| tensor_tuple = tuple([jt.rand(6, 4, 2) for i in range(10)]) | |||||
| res = JittorSingleDriver.tensor_to_numeric(tensor_tuple) | |||||
| assert isinstance(res, tuple) | |||||
| tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||||
| assert res == tensor_tuple | |||||
| # 张量dict | |||||
| tensor_dict = { | |||||
| "tensor": jt.rand(3, 4), | |||||
| "list": [jt.rand(6, 4, 2) for i in range(10)], | |||||
| "dict":{ | |||||
| "list": [jt.rand(6, 4, 2) for i in range(10)], | |||||
| "tensor": jt.rand(3, 4) | |||||
| }, | |||||
| "int": 2, | |||||
| "string": "test string" | |||||
| } | |||||
| res = JittorSingleDriver.tensor_to_numeric(tensor_dict) | |||||
| assert isinstance(res, dict) | |||||
| assert res["tensor"] == tensor_dict["tensor"].tolist() | |||||
| assert isinstance(res["list"], list) | |||||
| for r, d in zip(res["list"], tensor_dict["list"]): | |||||
| assert r == d.tolist() | |||||
| assert isinstance(res["int"], int) | |||||
| assert isinstance(res["string"], str) | |||||
| assert isinstance(res["dict"], dict) | |||||
| assert isinstance(res["dict"]["list"], list) | |||||
| for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||||
| assert r == d.tolist() | |||||
| assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | |||||
| @pytest.mark.jittor | |||||
| def test_tensor_to_numeric_reduce(self): | |||||
| tensor = jt.Var([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) | |||||
| res_max = JittorSingleDriver.tensor_to_numeric(tensor, reduce="max") | |||||
| res_min = JittorSingleDriver.tensor_to_numeric(tensor, reduce="min") | |||||
| res_sum = JittorSingleDriver.tensor_to_numeric(tensor, reduce="sum") | |||||
| res_mean = JittorSingleDriver.tensor_to_numeric(tensor, reduce="mean") | |||||
| assert res_max == 6 | |||||
| assert res_min == 1 | |||||
| assert res_sum == 21 | |||||
| assert res_mean == 3.5 | |||||
| @pytest.mark.jittor | |||||
| def test_set_model_mode(self): | |||||
| """ | |||||
| 测试 set_model_mode 函数 | |||||
| """ | |||||
| self.driver.set_model_mode("train") | |||||
| assert self.driver.model.is_training() | |||||
| self.driver.set_model_mode("eval") | |||||
| assert not self.driver.model.is_training() | |||||
| # 应该报错 | |||||
| with pytest.raises(AssertionError): | |||||
| self.driver.set_model_mode("test") | |||||
| @pytest.mark.jittor | |||||
| def test_move_model_to_device_cpu(self): | |||||
| """ | |||||
| 测试 move_model_to_device 函数,仅测试能否运行 | |||||
| """ | |||||
| JittorSingleDriver.move_model_to_device(self.driver.model, "cpu") | |||||
| @pytest.mark.jittor | |||||
| def test_move_model_to_device_gpu(self): | |||||
| """ | |||||
| 测试 move_model_to_device 函数,仅测试能否运行 | |||||
| """ | |||||
| JittorSingleDriver.move_model_to_device(self.driver.model, "gpu") | |||||
| @pytest.mark.jittor | |||||
| def test_set_deterministic_dataloader(self): | |||||
| """ | |||||
| 测试 set_deterministic_dataloader,仅测试能否运行 | |||||
| """ | |||||
| # 先确保不影响运行 | |||||
| # TODO:正确性 | |||||
| dataloader = JittorDataLoader(JittorNormalDataset()) | |||||
| self.driver.set_deterministic_dataloader(dataloader) | |||||
| self.driver.set_deterministic_dataloader(JittorNormalDataset()) | |||||
| @pytest.mark.jittor | |||||
| def test_set_sampler_epoch(self): | |||||
| """ | |||||
| 测试 set_sampler_epoch | |||||
| """ | |||||
| # 先确保不影响运行 | |||||
| # TODO:正确性 | |||||
| dataloader = JittorDataLoader(JittorNormalDataset()) | |||||
| self.driver.set_sampler_epoch(dataloader, 0) | |||||
| self.driver.set_sampler_epoch(JittorNormalDataset(), 0) | |||||
| @pytest.mark.jittor | |||||
| @pytest.mark.parametrize("batch_size", [16]) | |||||
| @pytest.mark.parametrize("shuffle", [True, False]) | |||||
| @pytest.mark.parametrize("drop_last", [True, False]) | |||||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
| def test_get_dataloader_args(self, batch_size, shuffle, drop_last, use_dataloader): | |||||
| """ | |||||
| 测试正常情况下 get_dataloader_args 的表现 | |||||
| """ | |||||
| dataloader = get_dataloader( | |||||
| JittorNormalDataset(), | |||||
| use_dataloader=use_dataloader, | |||||
| sampler=None, | |||||
| batch_size=batch_size, | |||||
| shuffle=shuffle, | |||||
| drop_last=drop_last | |||||
| ) | |||||
| res = JittorSingleDriver.get_dataloader_args(dataloader) | |||||
| assert isinstance(res.dataset, JittorNormalDataset) | |||||
| assert res.sampler is None | |||||
| assert res.shuffle == shuffle | |||||
| assert res.batch_size == batch_size | |||||
| assert res.drop_last == drop_last | |||||
| @pytest.mark.jittor | |||||
| @pytest.mark.parametrize("batch_size", [16]) | |||||
| @pytest.mark.parametrize("shuffle", [True, False]) | |||||
| @pytest.mark.parametrize("drop_last", [True, False]) | |||||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
| def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last, use_dataloader): | |||||
| """ | |||||
| 测试替换了 sampler 后 get_dataloader_args 的表现 | |||||
| """ | |||||
| dataset = JittorNormalDataset() | |||||
| dataloader = get_dataloader( | |||||
| dataset, | |||||
| use_dataloader=use_dataloader, | |||||
| batch_size=batch_size, | |||||
| sampler=RandomSampler(dataset, shuffle=shuffle), | |||||
| shuffle=shuffle, | |||||
| drop_last=drop_last | |||||
| ) | |||||
| res = JittorSingleDriver.get_dataloader_args(dataloader) | |||||
| assert isinstance(res.dataset, JittorNormalDataset) | |||||
| assert isinstance(res.sampler, RandomSampler) | |||||
| assert res.shuffle == shuffle | |||||
| assert res.batch_size == batch_size | |||||
| assert res.drop_last == drop_last | |||||
| ############################################################################ | |||||
| # | |||||
| # 测试 JittorSingleDrvier 中的一些简单函数 | |||||
| # | |||||
| ############################################################################ | |||||
| @pytest.mark.jittor | @pytest.mark.jittor | ||||
| @pytest.mark.skip("Skip jittor tests now.") | |||||
| class TestSingleDevice: | |||||
| def test_on_gpu_without_fp16(self): | |||||
| # TODO get_dataloader | |||||
| batch_size = 64 | |||||
| learning_rate = 0.1 | |||||
| epochs = 5 | |||||
| losses = [] | |||||
| losses_idx = [] | |||||
| train_loader = MNIST(train=True, batch_size=batch_size, shuffle=True) | |||||
| val_loader = MNIST(train=False, batch_size=1, shuffle=False) | |||||
| model = Model() | |||||
| driver = JittorSingleDriver(model, device=[1]) | |||||
| optimizer = nn.SGD(model.parameters(), learning_rate) | |||||
| driver.set_optimizers(optimizer) | |||||
| for epoch in range(epochs): | |||||
| driver.set_model_mode("train") | |||||
| lens = len(train_loader) | |||||
| for batch_idx, (inputs, targets) in enumerate(train_loader): | |||||
| outputs =driver.train_step(inputs) | |||||
| loss = nn.cross_entropy_loss(outputs, targets) | |||||
| driver.backward(loss) | |||||
| driver.step() | |||||
| driver.zero_grad() | |||||
| losses.append(loss.data[0]) | |||||
| losses_idx.append(epoch * lens + batch_idx) | |||||
| test_loss = 0 | |||||
| correct = 0 | |||||
| total_acc = 0 | |||||
| total_num = 0 | |||||
| driver.set_model_mode("eval") | |||||
| for batch_idx, (inputs, targets) in enumerate(val_loader): | |||||
| batch_size = inputs.shape[0] | |||||
| outputs = driver.test_step(inputs) | |||||
| pred = np.argmax(outputs.data, axis=1) | |||||
| acc = np.sum(targets.data==pred) | |||||
| total_acc += acc | |||||
| total_num += batch_size | |||||
| acc = acc / batch_size | |||||
| assert total_acc / total_num > 0.95 | |||||
| def test_on_cpu_without_fp16(self): | |||||
| pass | |||||
| def test_on_gpu_with_fp16(self): | |||||
| pass | |||||
| class TestSingleDeviceFunction: | |||||
| """ | |||||
| 测试其它函数的测试例 | |||||
| """ | |||||
| @classmethod | |||||
| def setup_class(cls): | |||||
| model = JittorNormalModel_Classification_1(10, 784) | |||||
| cls.driver = JittorSingleDriver(model, device="cpu") | |||||
| def test_unwrap_model(self): | |||||
| """ | |||||
| 测试能否运行 | |||||
| """ | |||||
| res = self.driver.unwrap_model() | |||||
| assert res is self.driver.model | |||||
| def test_is_distributed(self): | |||||
| assert self.driver.is_distributed() == False | |||||
| def test_move_data_to_device(self): | |||||
| self.driver.move_data_to_device(jt.rand(32, 64)) | |||||
| ############################################################################ | |||||
| # | |||||
| # 测试 set_dist_repro_dataloader 函数 | |||||
| # | |||||
| ############################################################################ | |||||
| @pytest.mark.jittor | |||||
| class TestSetDistReproDataloader: | |||||
| """ | |||||
| 专门测试 set_dist_repro_dataloader 函数的类 | |||||
| """ | |||||
| def setup_method(self): | |||||
| self.dataset = JittorNormalDataset(20) | |||||
| model = JittorNormalModel_Classification_1(10, 32) | |||||
| self.driver = JittorSingleDriver(model, device="cpu") | |||||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
| def test_with_reproducible_false(self, use_dataloader): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 | |||||
| 当dist为字符串时,此时应该返回原来的 dataloader | |||||
| """ | |||||
| dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=True) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
| assert replaced_loader is dataloader | |||||
| @pytest.mark.parametrize("shuffle", [True, False]) | |||||
| @pytest.mark.parametrize("sampler", [None, "random", "sequential"]) | |||||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
| def test_with_reproducible_true(self, shuffle, sampler, use_dataloader): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||||
| 当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler | |||||
| """ | |||||
| if sampler == "random": | |||||
| sampler = JittorRandomSampler(self.dataset) | |||||
| _shuffle = True | |||||
| elif sampler == "sequential": | |||||
| sampler = JittorSequentialSampler(self.dataset) | |||||
| _shuffle = False | |||||
| else: | |||||
| _shuffle = shuffle | |||||
| dataloader = get_dataloader(self.dataset, use_dataloader, sampler=sampler, batch_size=2, shuffle=shuffle) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert isinstance(replaced_loader.sampler, RandomSampler) | |||||
| assert replaced_loader.sampler.shuffle == _shuffle | |||||
| assert replaced_loader.batch_size == dataloader.batch_size | |||||
| assert replaced_loader.drop_last == dataloader.drop_last | |||||
| self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) | |||||
| @pytest.mark.parametrize("shuffle", ([True, False])) | |||||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
| def test_with_dist_batch_sampler(self, shuffle, use_dataloader): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler | |||||
| 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | |||||
| jittor 暂时不支持这种情况,会报错 | |||||
| """ | |||||
| dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=not shuffle) | |||||
| dist = ReproduceBatchSampler(JittorBatchSampler(JittorRandomSampler(self.dataset), 4, False), 4, False) | |||||
| with pytest.raises(RuntimeError): | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||||
| @pytest.mark.parametrize("shuffle", ([True, False])) | |||||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
| def test_with_dist_sampler(self, shuffle, use_dataloader): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 | |||||
| 应该返回新的 dataloader,并将 sampler 替换为 dist 对应的 Sampler | |||||
| """ | |||||
| dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=not shuffle) | |||||
| dist = RandomSampler(self.dataset, shuffle=shuffle) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert isinstance(replaced_loader.sampler, RandomSampler) | |||||
| assert replaced_loader.sampler is dist | |||||
| assert replaced_loader.batch_size == dataloader.batch_size | |||||
| self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) | |||||
| @pytest.mark.parametrize("shuffle", ([True, False])) | |||||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
| def test_with_dataloader_reproducible_batch_sampler(self, shuffle, use_dataloader): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||||
| 应该返回新的 dataloader,且其余各项设置和原来相同 | |||||
| """ | |||||
| dataloader = get_dataloader( | |||||
| self.dataset, | |||||
| use_dataloader=use_dataloader, | |||||
| sampler=ReproduceBatchSampler( | |||||
| JittorBatchSampler(JittorRandomSampler(self.dataset), 4, False), | |||||
| batch_size=4, | |||||
| drop_last=False, | |||||
| ), | |||||
| batch_size=4, | |||||
| shuffle=shuffle, | |||||
| ) | |||||
| with pytest.raises(RuntimeError): | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
| @pytest.mark.parametrize("shuffle", ([True, False])) | |||||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
| def test_with_dataloader_reproducible_sampler(self, shuffle, use_dataloader): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||||
| 应该返回新的 dataloader,且其余各项设置和原来相同 | |||||
| """ | |||||
| dataloader = get_dataloader( | |||||
| self.dataset, | |||||
| use_dataloader=use_dataloader, | |||||
| sampler=RandomSampler(self.dataset, shuffle), | |||||
| batch_size=2, | |||||
| shuffle=shuffle, | |||||
| ) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert not (replaced_loader.sampler is dataloader.sampler) | |||||
| assert isinstance(replaced_loader.sampler, RandomSampler) | |||||
| assert replaced_loader.batch_size == 2 | |||||
| assert replaced_loader.shuffle == shuffle | |||||
| self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) | |||||
| def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle, use_dataloader): | |||||
| """ | |||||
| 测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | |||||
| """ | |||||
| # 迭代两个 batch | |||||
| num_consumed_batches = 2 | |||||
| already_seen_idx = set() | |||||
| replaced_loader.sampler.set_epoch(6) | |||||
| for idx, batch in enumerate(replaced_loader): | |||||
| if idx >= num_consumed_batches: | |||||
| break | |||||
| already_seen_idx.update(batch.tolist()) | |||||
| sampler_states = replaced_loader.sampler.state_dict() | |||||
| # 重新加载,应该可以输出剩下的内容,且对于 JittorNormalDataset 来说,排序后应该是一个 range | |||||
| left_idxes = set() | |||||
| batch_size = replaced_loader.batch_size | |||||
| sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||||
| # 重新构造 dataloader | |||||
| if use_dataloader: | |||||
| dataset = deepcopy(replaced_loader.dataset.dataset) | |||||
| else: | |||||
| dataset = deepcopy(replaced_loader) | |||||
| new_loader = get_dataloader( | |||||
| dataset=dataset, | |||||
| use_dataloader=use_dataloader, | |||||
| sampler = RandomSampler(dataset, shuffle=shuffle), | |||||
| batch_size=batch_size, | |||||
| shuffle=shuffle, | |||||
| drop_last=False | |||||
| ) | |||||
| new_loader.sampler.load_state_dict(sampler_states) | |||||
| new_loader.sampler.set_epoch(6) | |||||
| for idx, batch in enumerate(new_loader): | |||||
| left_idxes.update(batch.tolist()) | |||||
| print(already_seen_idx) | |||||
| print(left_idxes) | |||||
| assert len(left_idxes) + len(already_seen_idx) == self.dataset.total_len | |||||
| assert len(left_idxes | already_seen_idx) == self.dataset.total_len | |||||
| ############################################################################ | |||||
| # | |||||
| # 测试 save 和 load 相关的功能 | |||||
| # | |||||
| ############################################################################ | |||||
| def generate_random_driver(labels, features, fp16=False, device="cpu", lr=0.01): | |||||
| """ | |||||
| 生成driver | |||||
| """ | |||||
| model = JittorNormalModel_Classification_1(labels, features) | |||||
| opt = jt.optim.Adam(params=model.parameters(), lr=lr) | |||||
| driver = JittorSingleDriver(model, device=device, fp16=fp16) | |||||
| driver.set_optimizers(opt) | |||||
| driver.setup() | |||||
| return driver | |||||
| @pytest.mark.jittor | |||||
| @pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
| def test_save_and_load_model(only_state_dict, use_dataloader): | |||||
| """ | |||||
| 测试 save_model 和 load_model 函数 | |||||
| """ | |||||
| try: | |||||
| path = "model" | |||||
| dataset = JittorNormalXYDataset(20) | |||||
| dataloader = get_dataloader(dataset, sampler=None, use_dataloader=use_dataloader, batch_size=4, shuffle=True) | |||||
| driver1, driver2 = generate_random_driver(20, 1, device="gpu"), generate_random_driver(20, 1, device="gpu") | |||||
| driver1.save_model(path, only_state_dict) | |||||
| driver2.load_model(path, only_state_dict) | |||||
| for batch in dataloader: | |||||
| batch = driver1.move_data_to_device(batch) | |||||
| res1 = driver1.model.evaluate_step(**batch) | |||||
| res2 = driver2.model.evaluate_step(**batch) | |||||
| assert jt.all_(jt.equal(res1["pred"], res2["pred"])) | |||||
| finally: | |||||
| rank_zero_rm(path) | |||||
| @pytest.mark.jittor | |||||
| @pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||||
| def test_save_and_load_with_randomsampler(only_state_dict, use_dataloader): | |||||
| """ | |||||
| 测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 | |||||
| """ | |||||
| try: | |||||
| path = "model.ckp" | |||||
| driver1, driver2 = generate_random_driver(20, 1, device="gpu", lr=0.01), \ | |||||
| generate_random_driver(20, 1, device="gpu", lr=0.001) | |||||
| dataset = JittorNormalXYDataset(20) | |||||
| dataloader = get_dataloader( | |||||
| dataset, use_dataloader, | |||||
| sampler = RandomSampler(dataset, True), | |||||
| batch_size=4, | |||||
| shuffle=True | |||||
| ) | |||||
| num_consumed_batches = 2 | |||||
| already_seen_x_set = set() | |||||
| already_seen_y_set = set() | |||||
| driver1.set_sampler_epoch(dataloader, 7) | |||||
| for idx, batch in enumerate(dataloader): | |||||
| if idx >= num_consumed_batches: | |||||
| break | |||||
| already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||||
| already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||||
| sampler_states = dataloader.sampler.state_dict() | |||||
| save_states = {"num_consumed_batches": num_consumed_batches} | |||||
| driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| # 加载 | |||||
| # 更改 batch_size | |||||
| dataloader = get_dataloader( | |||||
| dataset, use_dataloader, | |||||
| sampler=RandomSampler(dataset, True), | |||||
| batch_size=2, | |||||
| shuffle=True | |||||
| ) | |||||
| load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| replaced_loader = load_states.pop("dataloader") | |||||
| # 1. 检查 optimizer 的状态 | |||||
| assert driver2.optimizers[0].lr == driver1.optimizers[0].lr | |||||
| # 2. 检查 sampler 是否被正确地加载和替换 | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert isinstance(replaced_loader.sampler, RandomSampler) | |||||
| assert replaced_loader.sampler.seed == sampler_states["seed"] | |||||
| assert replaced_loader.sampler.epoch == sampler_states["epoch"] | |||||
| assert replaced_loader.sampler.num_consumed_samples == 4 * num_consumed_batches | |||||
| assert replaced_loader.sampler.dataset.total_len == sampler_states["length"] | |||||
| assert replaced_loader.sampler.shuffle == sampler_states["shuffle"] | |||||
| # 4. 检查 model 的参数是否正确 | |||||
| # 5. 检查 batch_idx | |||||
| start_batch = load_states.pop('batch_idx_in_epoch') | |||||
| assert start_batch == 2 * num_consumed_batches | |||||
| left_x_batches = set() | |||||
| left_y_batches = set() | |||||
| driver2.set_sampler_epoch(replaced_loader, 7) | |||||
| for idx, batch in enumerate(replaced_loader): | |||||
| left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||||
| left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||||
| res1 = driver1.model.evaluate_step(**batch) | |||||
| res2 = driver2.model.evaluate_step(**batch) | |||||
| assert jt.all_(jt.equal(res1["pred"], res2["pred"])) | |||||
| assert len(left_x_batches) + len(already_seen_x_set) == dataset.total_len | |||||
| assert len(left_x_batches | already_seen_x_set) == dataset.total_len | |||||
| assert len(left_y_batches) + len(already_seen_y_set) == dataset.total_len | |||||
| assert len(left_y_batches | already_seen_y_set) == dataset.total_len | |||||
| finally: | |||||
| rank_zero_rm(path) | |||||
| @@ -0,0 +1,43 @@ | |||||
| import pytest | |||||
| from fastNLP.core.drivers.jittor_driver.utils import replace_sampler | |||||
| from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||||
| from fastNLP.core.dataloaders import JittorDataLoader | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
| if _NEED_IMPORT_JITTOR: | |||||
| import jittor as jt | |||||
| from tests.helpers.datasets.jittor_data import JittorNormalDataset | |||||
| @pytest.mark.jittor | |||||
| @pytest.mark.parametrize("dataset", [ | |||||
| JittorNormalDataset(20, batch_size=10, shuffle=True), | |||||
| JittorNormalDataset(20, batch_size=5, drop_last=True), | |||||
| JittorNormalDataset(20) | |||||
| ]) | |||||
| def test_replace_sampler_dataset(dataset): | |||||
| dataset = JittorNormalDataset(20) | |||||
| sampler = RandomSampler(dataset) | |||||
| replaced_loader = replace_sampler(dataset, sampler) | |||||
| assert not (replaced_loader is dataset) | |||||
| assert isinstance(replaced_loader.sampler, RandomSampler) | |||||
| assert replaced_loader.batch_size == dataset.batch_size | |||||
| assert replaced_loader.drop_last == dataset.drop_last | |||||
| assert replaced_loader.shuffle == dataset.shuffle | |||||
| assert replaced_loader.total_len == dataset.total_len | |||||
| @pytest.mark.jittor | |||||
| def test_replace_sampler_jittordataloader(): | |||||
| dataset = JittorNormalDataset(20, batch_size=10, shuffle=True) | |||||
| dataloader = JittorDataLoader(dataset, batch_size=8, shuffle=True) | |||||
| sampler = RandomSampler(dataset) | |||||
| replaced_loader = replace_sampler(dataloader, sampler) | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert not (replaced_loader.dataset.dataset is dataloader.dataset.dataset) | |||||
| assert isinstance(replaced_loader.sampler, RandomSampler) | |||||
| assert replaced_loader.batch_size == 8 | |||||
| assert replaced_loader.shuffle == True | |||||
| @@ -10,7 +10,7 @@ from fastNLP.core.samplers import ( | |||||
| UnrepeatedSequentialSampler, | UnrepeatedSequentialSampler, | ||||
| ) | ) | ||||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | |||||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | |||||
| from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
| from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
| @@ -19,8 +19,8 @@ if _NEED_IMPORT_PADDLE: | |||||
| import paddle.distributed as dist | import paddle.distributed as dist | ||||
| from paddle.io import DataLoader, BatchSampler | from paddle.io import DataLoader, BatchSampler | ||||
| def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): | |||||
| paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) | |||||
| def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="only_error"): | |||||
| paddle_model = PaddleNormalModel_Classification_1(labels, features) | |||||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | ||||
| driver = PaddleFleetDriver( | driver = PaddleFleetDriver( | ||||
| model=paddle_model, | model=paddle_model, | ||||
| @@ -465,10 +465,14 @@ class TestSetDistReproDataloader: | |||||
| num_replicas = len(self.device) | num_replicas = len(self.device) | ||||
| num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
| already_seen_idx = set() | already_seen_idx = set() | ||||
| if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||||
| sampler_states = replaced_loader.batch_sampler.set_epoch(10) | |||||
| else: | |||||
| sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(10) | |||||
| for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
| if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
| break | break | ||||
| already_seen_idx.update(batch) | |||||
| already_seen_idx.update(batch.tolist()) | |||||
| dist.barrier() | dist.barrier() | ||||
| if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | ||||
| sampler_states = replaced_loader.batch_sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.state_dict() | ||||
| @@ -496,6 +500,7 @@ class TestSetDistReproDataloader: | |||||
| pad=True | pad=True | ||||
| ) | ) | ||||
| new_loader.batch_sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.load_state_dict(sampler_states) | ||||
| new_loader.batch_sampler.set_epoch(10) | |||||
| else: | else: | ||||
| batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
| sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | ||||
| @@ -508,8 +513,9 @@ class TestSetDistReproDataloader: | |||||
| ) | ) | ||||
| new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) | new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) | ||||
| new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | ||||
| new_loader.batch_sampler.sampler.set_epoch(10) | |||||
| for idx, batch in enumerate(new_loader): | for idx, batch in enumerate(new_loader): | ||||
| left_idxes.update(batch) | |||||
| left_idxes.update(batch.tolist()) | |||||
| assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | ||||
| assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | ||||
| @@ -533,7 +539,7 @@ class TestSaveLoad: | |||||
| cls.driver = generate_driver(10, 10, device=[0,1]) | cls.driver = generate_driver(10, 10, device=[0,1]) | ||||
| def setup_method(self): | def setup_method(self): | ||||
| self.dataset = PaddleRandomMaxDataset(20, 10) | |||||
| self.dataset = PaddleNormalXYDataset(40) | |||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| @pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
| @@ -545,12 +551,12 @@ class TestSaveLoad: | |||||
| path = "model" | path = "model" | ||||
| dataloader = DataLoader(self.dataset, batch_size=2) | dataloader = DataLoader(self.dataset, batch_size=2) | ||||
| self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) | |||||
| self.driver1, self.driver2 = generate_driver(40, 1), generate_driver(40, 1) | |||||
| if only_state_dict: | if only_state_dict: | ||||
| self.driver1.save_model(path, only_state_dict) | self.driver1.save_model(path, only_state_dict) | ||||
| else: | else: | ||||
| self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 10))]) | |||||
| self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 1))]) | |||||
| # 同步 | # 同步 | ||||
| dist.barrier() | dist.barrier() | ||||
| @@ -594,8 +600,8 @@ class TestSaveLoad: | |||||
| path = "model.ckp" | path = "model.ckp" | ||||
| num_replicas = len(device) | num_replicas = len(device) | ||||
| self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ | |||||
| generate_driver(10, 10, device=device, fp16=False) | |||||
| self.driver1, self.driver2 = generate_driver(40, 1, device=device, fp16=fp16), \ | |||||
| generate_driver(40, 1, device=device, fp16=False) | |||||
| dataloader = DataLoader( | dataloader = DataLoader( | ||||
| dataset=self.dataset, | dataset=self.dataset, | ||||
| batch_sampler=BucketedBatchSampler( | batch_sampler=BucketedBatchSampler( | ||||
| @@ -613,11 +619,12 @@ class TestSaveLoad: | |||||
| already_seen_x_set = set() | already_seen_x_set = set() | ||||
| already_seen_y_set = set() | already_seen_y_set = set() | ||||
| self.driver1.set_sampler_epoch(dataloader, 2) | |||||
| for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
| if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
| break | break | ||||
| already_seen_x_set.update(batch["x"]) | |||||
| already_seen_y_set.update(batch["y"]) | |||||
| already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||||
| already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||||
| # 同步 | # 同步 | ||||
| dist.barrier() | dist.barrier() | ||||
| @@ -669,10 +676,11 @@ class TestSaveLoad: | |||||
| assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
| left_x_batches = set() | left_x_batches = set() | ||||
| left_y_batches = set() | left_y_batches = set() | ||||
| self.driver2.set_sampler_epoch(replaced_loader, 2) | |||||
| for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
| left_x_batches.update(batch["x"]) | |||||
| left_y_batches.update(batch["y"]) | |||||
| left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||||
| left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||||
| res1 = self.driver1.model( | res1 = self.driver1.model( | ||||
| batch, | batch, | ||||
| fastnlp_fn=self.driver1.model._layers.model.evaluate_step, | fastnlp_fn=self.driver1.model._layers.model.evaluate_step, | ||||
| @@ -709,8 +717,8 @@ class TestSaveLoad: | |||||
| num_replicas = len(device) | num_replicas = len(device) | ||||
| self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) | |||||
| self.driver2 = generate_driver(10, 10, device=device, fp16=False) | |||||
| self.driver1 = generate_driver(40, 1, device=device, fp16=fp16) | |||||
| self.driver2 = generate_driver(40, 1, device=device, fp16=False) | |||||
| batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) | batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) | ||||
| batch_sampler.sampler = RandomSampler(self.dataset, True) | batch_sampler.sampler = RandomSampler(self.dataset, True) | ||||
| batch_sampler.sampler.set_distributed( | batch_sampler.sampler.set_distributed( | ||||
| @@ -726,11 +734,12 @@ class TestSaveLoad: | |||||
| already_seen_x_set = set() | already_seen_x_set = set() | ||||
| already_seen_y_set = set() | already_seen_y_set = set() | ||||
| self.driver1.set_sampler_epoch(dataloader, 2) | |||||
| for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
| if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
| break | break | ||||
| already_seen_x_set.update(batch["x"]) | |||||
| already_seen_y_set.update(batch["y"]) | |||||
| already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||||
| already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||||
| # 同步 | # 同步 | ||||
| dist.barrier() | dist.barrier() | ||||
| @@ -779,10 +788,11 @@ class TestSaveLoad: | |||||
| assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
| left_x_batches = set() | left_x_batches = set() | ||||
| left_y_batches = set() | left_y_batches = set() | ||||
| self.driver2.set_sampler_epoch(replaced_loader, 2) | |||||
| for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
| left_x_batches.update(batch["x"]) | |||||
| left_y_batches.update(batch["y"]) | |||||
| left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||||
| left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||||
| res1 = self.driver1.model( | res1 = self.driver1.model( | ||||
| batch, | batch, | ||||
| fastnlp_fn=self.driver1.model._layers.model.evaluate_step, | fastnlp_fn=self.driver1.model._layers.model.evaluate_step, | ||||
| @@ -12,7 +12,7 @@ if _NEED_IMPORT_PADDLE: | |||||
| @pytest.mark.paddle | @pytest.mark.paddle | ||||
| def test_incorrect_driver(): | def test_incorrect_driver(): | ||||
| model = PaddleNormalModel_Classification_1(2, 100) | |||||
| model = PaddleNormalModel_Classification_1(20, 10) | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| driver = initialize_paddle_driver("torch", 0, model) | driver = initialize_paddle_driver("torch", 0, model) | ||||
| @@ -26,7 +26,7 @@ def test_get_single_device(device): | |||||
| 测试正常情况下初始化 PaddleSingleDriver 的情况 | 测试正常情况下初始化 PaddleSingleDriver 的情况 | ||||
| """ | """ | ||||
| model = PaddleNormalModel_Classification_1(2, 100) | |||||
| model = PaddleNormalModel_Classification_1(20, 10) | |||||
| driver = initialize_paddle_driver("paddle", device, model) | driver = initialize_paddle_driver("paddle", device, model) | ||||
| assert isinstance(driver, PaddleSingleDriver) | assert isinstance(driver, PaddleSingleDriver) | ||||
| @@ -41,7 +41,7 @@ def test_get_fleet(device): | |||||
| 测试 fleet 多卡的初始化情况 | 测试 fleet 多卡的初始化情况 | ||||
| """ | """ | ||||
| model = PaddleNormalModel_Classification_1(64, 10) | |||||
| model = PaddleNormalModel_Classification_1(20, 10) | |||||
| driver = initialize_paddle_driver("paddle", device, model) | driver = initialize_paddle_driver("paddle", device, model) | ||||
| assert isinstance(driver, PaddleFleetDriver) | assert isinstance(driver, PaddleFleetDriver) | ||||
| @@ -56,6 +56,6 @@ def test_device_out_of_range(device): | |||||
| """ | """ | ||||
| 测试传入的device超过范围的情况 | 测试传入的device超过范围的情况 | ||||
| """ | """ | ||||
| model = PaddleNormalModel_Classification_1(2, 100) | |||||
| model = PaddleNormalModel_Classification_1(20, 10) | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| driver = initialize_paddle_driver("paddle", device, model) | driver = initialize_paddle_driver("paddle", device, model) | ||||
| @@ -4,14 +4,16 @@ from pathlib import Path | |||||
| from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | ||||
| from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | ||||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | |||||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | |||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
| from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | ||||
| if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
| import paddle | import paddle | ||||
| from paddle.io import DataLoader, BatchSampler | from paddle.io import DataLoader, BatchSampler | ||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| import torch | import torch | ||||
| @@ -31,102 +33,70 @@ class TestPaddleDriverFunctions: | |||||
| model = PaddleNormalModel_Classification_1(10, 32) | model = PaddleNormalModel_Classification_1(10, 32) | ||||
| self.driver = PaddleSingleDriver(model, device="cpu") | self.driver = PaddleSingleDriver(model, device="cpu") | ||||
| @pytest.mark.torchpaddle | |||||
| def test_check_single_optimizer_legality(self): | |||||
| @pytest.mark.paddle | |||||
| def test_check_optimizers_legality(self): | |||||
| """ | """ | ||||
| 测试传入单个 optimizer 时的表现 | |||||
| 测试对合法的 optimizers 的检查 | |||||
| """ | """ | ||||
| # 单个 optimizer | |||||
| optimizer = paddle.optimizer.Adam( | optimizer = paddle.optimizer.Adam( | ||||
| parameters=self.driver.model.parameters(), | parameters=self.driver.model.parameters(), | ||||
| learning_rate=0.01 | learning_rate=0.01 | ||||
| ) | ) | ||||
| self.driver.set_optimizers(optimizer) | self.driver.set_optimizers(optimizer) | ||||
| optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||||
| # 传入torch的optimizer时,应该报错ValueError | |||||
| with pytest.raises(ValueError): | |||||
| self.driver.set_optimizers(optimizer) | |||||
| @pytest.mark.torchpaddle | |||||
| def test_check_optimizers_legality(self): | |||||
| """ | |||||
| 测试传入 optimizer list 的表现 | |||||
| """ | |||||
| # optimizer 列表 | |||||
| optimizers = [ | optimizers = [ | ||||
| paddle.optimizer.Adam( | paddle.optimizer.Adam( | ||||
| parameters=self.driver.model.parameters(), | parameters=self.driver.model.parameters(), | ||||
| learning_rate=0.01 | learning_rate=0.01 | ||||
| ) for i in range(10) | ) for i in range(10) | ||||
| ] | ] | ||||
| self.driver.set_optimizers(optimizers) | self.driver.set_optimizers(optimizers) | ||||
| optimizers += [ | |||||
| @pytest.mark.torchpaddle | |||||
| def test_invalid_optimizers(self): | |||||
| """ | |||||
| 测试传入非法的 optimizers | |||||
| """ | |||||
| # 单个 optimizer | |||||
| optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||||
| with pytest.raises(TypeError): | |||||
| self.driver.set_optimizers(optimizer) | |||||
| optimizers = [ | |||||
| torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | ||||
| ] | ] | ||||
| with pytest.raises(ValueError): | |||||
| with pytest.raises(TypeError): | |||||
| self.driver.set_optimizers(optimizers) | self.driver.set_optimizers(optimizers) | ||||
| @pytest.mark.torchpaddle | |||||
| def test_check_dataloader_legality_in_train(self): | |||||
| @pytest.mark.paddle | |||||
| def test_check_dataloader_legality(self): | |||||
| """ | """ | ||||
| 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||||
| 测试 check_dataloader_legality 函数的表现 | |||||
| """ | """ | ||||
| dataloader = DataLoader(PaddleNormalDataset()) | dataloader = DataLoader(PaddleNormalDataset()) | ||||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
| self.driver.check_dataloader_legality(dataloader) | |||||
| # batch_size 和 batch_sampler 均为 None 的情形 | # batch_size 和 batch_sampler 均为 None 的情形 | ||||
| dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) | dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
| # 创建torch的dataloader | |||||
| dataloader = torch.utils.data.DataLoader( | |||||
| TorchNormalDataset(), | |||||
| batch_size=32, shuffle=True | |||||
| ) | |||||
| with pytest.raises(ValueError): | |||||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
| self.driver.check_dataloader_legality(dataloader) | |||||
| @pytest.mark.torchpaddle | @pytest.mark.torchpaddle | ||||
| def test_check_dataloader_legality_in_test(self): | |||||
| def test_check_dataloader_legality_invalid(self): | |||||
| """ | """ | ||||
| 测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | |||||
| 测试 check_dataloader_legality 函数传入其他类型的表现 | |||||
| """ | """ | ||||
| # 此时传入的应该是dict | |||||
| dataloader = { | |||||
| "train": DataLoader(PaddleNormalDataset()), | |||||
| "test":DataLoader(PaddleNormalDataset()) | |||||
| } | |||||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
| # batch_size 和 batch_sampler 均为 None 的情形 | |||||
| dataloader = { | |||||
| "train": DataLoader(PaddleNormalDataset()), | |||||
| "test":DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
| } | |||||
| with pytest.raises(ValueError): | |||||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
| # 传入的不是 dict ,应该报错 | |||||
| dataloader = DataLoader(PaddleNormalDataset()) | |||||
| with pytest.raises(ValueError): | |||||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
| # 创建 torch 的 dataloader | # 创建 torch 的 dataloader | ||||
| train_loader = torch.utils.data.DataLoader( | |||||
| TorchNormalDataset(), | |||||
| batch_size=32, shuffle=True | |||||
| ) | |||||
| test_loader = torch.utils.data.DataLoader( | |||||
| dataloader = torch.utils.data.DataLoader( | |||||
| TorchNormalDataset(), | TorchNormalDataset(), | ||||
| batch_size=32, shuffle=True | batch_size=32, shuffle=True | ||||
| ) | ) | ||||
| dataloader = {"train": train_loader, "test": test_loader} | |||||
| with pytest.raises(ValueError): | |||||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
| with pytest.raises(TypeError): | |||||
| self.driver.check_dataloader_legality(dataloader) | |||||
| @pytest.mark.paddle | @pytest.mark.paddle | ||||
| def test_tensor_to_numeric(self): | def test_tensor_to_numeric(self): | ||||
| @@ -505,10 +475,14 @@ class TestSetDistReproDataloader: | |||||
| # 迭代两个 batch | # 迭代两个 batch | ||||
| num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
| already_seen_idx = set() | already_seen_idx = set() | ||||
| if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
| sampler_states = replaced_loader.batch_sampler.set_epoch(5) | |||||
| else: | |||||
| sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(5) | |||||
| for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
| if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
| break | break | ||||
| already_seen_idx.update(batch) | |||||
| already_seen_idx.update(batch.tolist()) | |||||
| if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | ||||
| sampler_states = replaced_loader.batch_sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.state_dict() | ||||
| else: | else: | ||||
| @@ -529,6 +503,7 @@ class TestSetDistReproDataloader: | |||||
| ) | ) | ||||
| ) | ) | ||||
| new_loader.batch_sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.load_state_dict(sampler_states) | ||||
| new_loader.batch_sampler.set_epoch(5) | |||||
| else: | else: | ||||
| batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
| sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | ||||
| @@ -537,8 +512,9 @@ class TestSetDistReproDataloader: | |||||
| batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) | batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) | ||||
| new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) | new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) | ||||
| new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | ||||
| new_loader.batch_sampler.sampler.set_epoch(5) | |||||
| for idx, batch in enumerate(new_loader): | for idx, batch in enumerate(new_loader): | ||||
| left_idxes.update(batch) | |||||
| left_idxes.update(batch.tolist()) | |||||
| assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | ||||
| assert len(left_idxes | already_seen_idx) == len(self.dataset) | assert len(left_idxes | already_seen_idx) == len(self.dataset) | ||||
| @@ -549,7 +525,7 @@ class TestSetDistReproDataloader: | |||||
| # | # | ||||
| ############################################################################ | ############################################################################ | ||||
| def generate_random_driver(features, labels, fp16=False, device="cpu"): | |||||
| def generate_random_driver(labels, features, fp16=False, device="cpu"): | |||||
| """ | """ | ||||
| 生成driver | 生成driver | ||||
| """ | """ | ||||
| @@ -569,9 +545,9 @@ def test_save_and_load_model(only_state_dict): | |||||
| """ | """ | ||||
| try: | try: | ||||
| path = "model" | path = "model" | ||||
| dataset = PaddleRandomMaxDataset(40, 10) | |||||
| dataset = PaddleNormalXYDataset(20) | |||||
| dataloader = DataLoader(dataset, batch_size=4) | dataloader = DataLoader(dataset, batch_size=4) | ||||
| driver1, driver2 = generate_random_driver(10, 10, device="gpu"), generate_random_driver(10, 10, device="gpu") | |||||
| driver1, driver2 = generate_random_driver(20, 1, device="gpu"), generate_random_driver(20, 1, device="gpu") | |||||
| if only_state_dict: | if only_state_dict: | ||||
| driver1.save_model(path, only_state_dict) | driver1.save_model(path, only_state_dict) | ||||
| @@ -580,6 +556,7 @@ def test_save_and_load_model(only_state_dict): | |||||
| driver2.load_model(path, only_state_dict) | driver2.load_model(path, only_state_dict) | ||||
| for batch in dataloader: | for batch in dataloader: | ||||
| print("?") | |||||
| batch = driver1.move_data_to_device(batch) | batch = driver1.move_data_to_device(batch) | ||||
| res1 = driver1.model.evaluate_step(**batch) | res1 = driver1.model.evaluate_step(**batch) | ||||
| res2 = driver2.model.evaluate_step(**batch) | res2 = driver2.model.evaluate_step(**batch) | ||||
| @@ -604,22 +581,23 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
| try: | try: | ||||
| path = "model.ckp" | path = "model.ckp" | ||||
| dataset = PaddleRandomMaxDataset(40, 10) | |||||
| dataset = PaddleNormalXYDataset(40) | |||||
| dataloader = DataLoader( | dataloader = DataLoader( | ||||
| dataset=dataset, | dataset=dataset, | ||||
| batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | ||||
| ) | ) | ||||
| driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") | |||||
| driver1, driver2 = generate_random_driver(40, 1, fp16, "gpu"), generate_random_driver(40, 1, False, "gpu") | |||||
| num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
| already_seen_x_set = set() | already_seen_x_set = set() | ||||
| already_seen_y_set = set() | already_seen_y_set = set() | ||||
| driver1.set_sampler_epoch(dataloader, 3) | |||||
| for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
| if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
| break | break | ||||
| already_seen_x_set.update(batch["x"]) | |||||
| already_seen_y_set.update(batch["y"]) | |||||
| already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||||
| already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||||
| sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
| save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
| @@ -656,10 +634,11 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
| assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
| left_x_batches = set() | left_x_batches = set() | ||||
| left_y_batches = set() | left_y_batches = set() | ||||
| driver2.set_sampler_epoch(replaced_loader, 3) | |||||
| for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
| left_x_batches.update(batch["x"]) | |||||
| left_y_batches.update(batch["y"]) | |||||
| left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||||
| left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||||
| res1 = driver1.model.evaluate_step(**batch) | res1 = driver1.model.evaluate_step(**batch) | ||||
| res2 = driver2.model.evaluate_step(**batch) | res2 = driver2.model.evaluate_step(**batch) | ||||
| assert paddle.equal_all(res1["pred"], res2["pred"]) | assert paddle.equal_all(res1["pred"], res2["pred"]) | ||||
| @@ -679,14 +658,14 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
| @pytest.mark.parametrize("fp16", ([True, False])) | @pytest.mark.parametrize("fp16", ([True, False])) | ||||
| def test_save_and_load_with_randomsampler(only_state_dict, fp16): | def test_save_and_load_with_randomsampler(only_state_dict, fp16): | ||||
| """ | """ | ||||
| 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||||
| 测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 | |||||
| """ | """ | ||||
| try: | try: | ||||
| path = "model.ckp" | path = "model.ckp" | ||||
| driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") | |||||
| dataset = PaddleRandomMaxDataset(40, 10) | |||||
| driver1, driver2 = generate_random_driver(40, 1, fp16, "gpu"), generate_random_driver(40, 1, False, "gpu") | |||||
| dataset = PaddleNormalXYDataset(40) | |||||
| batch_sampler = BatchSampler(dataset=dataset, batch_size=4) | batch_sampler = BatchSampler(dataset=dataset, batch_size=4) | ||||
| batch_sampler.sampler = RandomSampler(dataset, True) | batch_sampler.sampler = RandomSampler(dataset, True) | ||||
| dataloader = DataLoader( | dataloader = DataLoader( | ||||
| @@ -697,11 +676,12 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
| already_seen_x_set = set() | already_seen_x_set = set() | ||||
| already_seen_y_set = set() | already_seen_y_set = set() | ||||
| driver1.set_sampler_epoch(dataloader, 3) | |||||
| for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
| if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
| break | break | ||||
| already_seen_x_set.update(batch["x"]) | |||||
| already_seen_y_set.update(batch["y"]) | |||||
| already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||||
| already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||||
| sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
| save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
| @@ -743,10 +723,11 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
| assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
| left_x_batches = set() | left_x_batches = set() | ||||
| left_y_batches = set() | left_y_batches = set() | ||||
| driver1.set_sampler_epoch(replaced_loader, 3) | |||||
| for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
| left_x_batches.update(batch["x"]) | |||||
| left_y_batches.update(batch["y"]) | |||||
| left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||||
| left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||||
| res1 = driver1.model.evaluate_step(**batch) | res1 = driver1.model.evaluate_step(**batch) | ||||
| res2 = driver2.model.evaluate_step(**batch) | res2 = driver2.model.evaluate_step(**batch) | ||||
| assert paddle.equal_all(res1["pred"], res2["pred"]) | assert paddle.equal_all(res1["pred"], res2["pred"]) | ||||
| @@ -10,7 +10,7 @@ from fastNLP.core.samplers import ( | |||||
| UnrepeatedSequentialSampler, | UnrepeatedSequentialSampler, | ||||
| ) | ) | ||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | |||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset | |||||
| from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
| from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
| @@ -19,8 +19,8 @@ if _NEED_IMPORT_TORCH: | |||||
| import torch.distributed as dist | import torch.distributed as dist | ||||
| from torch.utils.data import DataLoader, BatchSampler | from torch.utils.data import DataLoader, BatchSampler | ||||
| def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"): | |||||
| torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) | |||||
| def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all"): | |||||
| torch_model = TorchNormalModel_Classification_1(labels, features) | |||||
| torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | ||||
| device = [torch.device(i) for i in device] | device = [torch.device(i) for i in device] | ||||
| driver = TorchDDPDriver( | driver = TorchDDPDriver( | ||||
| @@ -504,10 +504,14 @@ class TestSetDistReproDataloader: | |||||
| num_replicas = len(self.device) | num_replicas = len(self.device) | ||||
| num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
| already_seen_idx = set() | already_seen_idx = set() | ||||
| if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||||
| sampler_states = replaced_loader.batch_sampler.set_epoch(4) | |||||
| else: | |||||
| sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(4) | |||||
| for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
| if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
| break | break | ||||
| already_seen_idx.update(batch) | |||||
| already_seen_idx.update(batch.tolist()) | |||||
| dist.barrier() | dist.barrier() | ||||
| if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | ||||
| sampler_states = replaced_loader.batch_sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.state_dict() | ||||
| @@ -533,6 +537,7 @@ class TestSetDistReproDataloader: | |||||
| pad=True | pad=True | ||||
| ) | ) | ||||
| new_loader.batch_sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.load_state_dict(sampler_states) | ||||
| new_loader.batch_sampler.set_epoch(4) | |||||
| else: | else: | ||||
| batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
| sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | ||||
| @@ -543,8 +548,9 @@ class TestSetDistReproDataloader: | |||||
| rank=driver.global_rank | rank=driver.global_rank | ||||
| ) | ) | ||||
| new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | ||||
| new_loader.batch_sampler.sampler.set_epoch(4) | |||||
| for idx, batch in enumerate(new_loader): | for idx, batch in enumerate(new_loader): | ||||
| left_idxes.update(batch) | |||||
| left_idxes.update(batch.tolist()) | |||||
| assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | ||||
| assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | ||||
| @@ -562,7 +568,7 @@ class TestSaveLoad: | |||||
| """ | """ | ||||
| def setup_method(self): | def setup_method(self): | ||||
| self.dataset = TorchArgMaxDataset(10, 20) | |||||
| self.dataset = TorchNormalXYDataset(20) | |||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| @pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
| @@ -574,7 +580,7 @@ class TestSaveLoad: | |||||
| path = "model" | path = "model" | ||||
| dataloader = DataLoader(self.dataset, batch_size=2) | dataloader = DataLoader(self.dataset, batch_size=2) | ||||
| driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10) | |||||
| driver1, driver2 = generate_driver(20, 1), generate_driver(20, 1) | |||||
| driver1.save_model(path, only_state_dict) | driver1.save_model(path, only_state_dict) | ||||
| @@ -618,8 +624,8 @@ class TestSaveLoad: | |||||
| path = "model.ckp" | path = "model.ckp" | ||||
| num_replicas = len(device) | num_replicas = len(device) | ||||
| driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ | |||||
| generate_driver(10, 10, device=device, fp16=False) | |||||
| driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16), \ | |||||
| generate_driver(20, 1, device=device, fp16=False) | |||||
| dataloader = dataloader_with_bucketedbatchsampler( | dataloader = dataloader_with_bucketedbatchsampler( | ||||
| self.dataset, | self.dataset, | ||||
| length=[10 for i in range(len(self.dataset))], | length=[10 for i in range(len(self.dataset))], | ||||
| @@ -636,11 +642,12 @@ class TestSaveLoad: | |||||
| already_seen_x_set = set() | already_seen_x_set = set() | ||||
| already_seen_y_set = set() | already_seen_y_set = set() | ||||
| driver1.set_sampler_epoch(dataloader, 4) | |||||
| for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
| if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
| break | break | ||||
| already_seen_x_set.update(batch["x"]) | |||||
| already_seen_y_set.update(batch["y"]) | |||||
| already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||||
| already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||||
| # 同步 | # 同步 | ||||
| dist.barrier() | dist.barrier() | ||||
| @@ -665,7 +672,6 @@ class TestSaveLoad: | |||||
| pad=True | pad=True | ||||
| ) | ) | ||||
| dist.barrier() | dist.barrier() | ||||
| print("========load=======", driver1.global_rank, driver2.global_rank) | |||||
| load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | ||||
| dist.barrier() | dist.barrier() | ||||
| replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
| @@ -690,10 +696,11 @@ class TestSaveLoad: | |||||
| assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
| left_x_batches = set() | left_x_batches = set() | ||||
| left_y_batches = set() | left_y_batches = set() | ||||
| driver2.set_sampler_epoch(replaced_loader, 4) | |||||
| for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
| left_x_batches.update(batch["x"]) | |||||
| left_y_batches.update(batch["y"]) | |||||
| left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||||
| left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||||
| res1 = driver1.model( | res1 = driver1.model( | ||||
| batch, | batch, | ||||
| fastnlp_fn=driver1.model.module.model.evaluate_step, | fastnlp_fn=driver1.model.module.model.evaluate_step, | ||||
| @@ -716,7 +723,6 @@ class TestSaveLoad: | |||||
| dist.barrier() | dist.barrier() | ||||
| finally: | finally: | ||||
| rank_zero_rm(path) | rank_zero_rm(path) | ||||
| print("=======delete======") | |||||
| if dist.is_initialized(): | if dist.is_initialized(): | ||||
| dist.destroy_process_group() | dist.destroy_process_group() | ||||
| @@ -735,8 +741,8 @@ class TestSaveLoad: | |||||
| num_replicas = len(device) | num_replicas = len(device) | ||||
| driver1 = generate_driver(10, 10, device=device, fp16=fp16) | |||||
| driver2 = generate_driver(10, 10, device=device, fp16=False) | |||||
| driver1 = generate_driver(20, 1, device=device, fp16=fp16) | |||||
| driver2 = generate_driver(20, 1, device=device, fp16=False) | |||||
| dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) | ||||
| dataloader.batch_sampler.sampler.set_distributed( | dataloader.batch_sampler.sampler.set_distributed( | ||||
| @@ -748,11 +754,12 @@ class TestSaveLoad: | |||||
| already_seen_x_set = set() | already_seen_x_set = set() | ||||
| already_seen_y_set = set() | already_seen_y_set = set() | ||||
| driver1.set_sampler_epoch(dataloader, 4) | |||||
| for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
| if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
| break | break | ||||
| already_seen_x_set.update(batch["x"]) | |||||
| already_seen_y_set.update(batch["y"]) | |||||
| already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||||
| already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||||
| # 同步 | # 同步 | ||||
| dist.barrier() | dist.barrier() | ||||
| @@ -797,10 +804,11 @@ class TestSaveLoad: | |||||
| assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
| left_x_batches = set() | left_x_batches = set() | ||||
| left_y_batches = set() | left_y_batches = set() | ||||
| driver2.set_sampler_epoch(replaced_loader, 4) | |||||
| for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
| left_x_batches.update(batch["x"]) | |||||
| left_y_batches.update(batch["y"]) | |||||
| left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||||
| left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||||
| res1 = driver1.model( | res1 = driver1.model( | ||||
| batch, | batch, | ||||
| fastnlp_fn=driver1.model.module.model.evaluate_step, | fastnlp_fn=driver1.model.module.model.evaluate_step, | ||||
| @@ -14,7 +14,7 @@ else: | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| def test_incorrect_driver(): | def test_incorrect_driver(): | ||||
| model = TorchNormalModel_Classification_1(2, 100) | |||||
| model = TorchNormalModel_Classification_1(20, 10) | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| driver = initialize_torch_driver("paddle", 0, model) | driver = initialize_torch_driver("paddle", 0, model) | ||||
| @@ -33,7 +33,7 @@ def test_get_single_device(driver, device): | |||||
| 测试正常情况下初始化TorchSingleDriver的情况 | 测试正常情况下初始化TorchSingleDriver的情况 | ||||
| """ | """ | ||||
| model = TorchNormalModel_Classification_1(2, 100) | |||||
| model = TorchNormalModel_Classification_1(20, 10) | |||||
| driver = initialize_torch_driver(driver, device, model) | driver = initialize_torch_driver(driver, device, model) | ||||
| assert isinstance(driver, TorchSingleDriver) | assert isinstance(driver, TorchSingleDriver) | ||||
| @@ -52,7 +52,7 @@ def test_get_ddp(driver, device): | |||||
| 测试 ddp 多卡的初始化情况 | 测试 ddp 多卡的初始化情况 | ||||
| """ | """ | ||||
| model = TorchNormalModel_Classification_1(64, 10) | |||||
| model = TorchNormalModel_Classification_1(20, 10) | |||||
| driver = initialize_torch_driver(driver, device, model) | driver = initialize_torch_driver(driver, device, model) | ||||
| assert isinstance(driver, TorchDDPDriver) | assert isinstance(driver, TorchDDPDriver) | ||||
| @@ -70,6 +70,6 @@ def test_device_out_of_range(driver, device): | |||||
| """ | """ | ||||
| 测试传入的device超过范围的情况 | 测试传入的device超过范围的情况 | ||||
| """ | """ | ||||
| model = TorchNormalModel_Classification_1(2, 100) | |||||
| model = TorchNormalModel_Classification_1(20, 10) | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| driver = initialize_torch_driver(driver, device, model) | driver = initialize_torch_driver(driver, device, model) | ||||
| @@ -6,7 +6,7 @@ from pkg_resources import parse_version | |||||
| from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver | from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver | ||||
| from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | ||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | |||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset | |||||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset | ||||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
| from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
| @@ -15,6 +15,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| import torch | import torch | ||||
| from torch.utils.data import DataLoader, BatchSampler | from torch.utils.data import DataLoader, BatchSampler | ||||
| if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
| import paddle | import paddle | ||||
| @@ -67,95 +68,67 @@ class TestTorchDriverFunctions: | |||||
| model = TorchNormalModel_Classification_1(10, 32) | model = TorchNormalModel_Classification_1(10, 32) | ||||
| self.driver = TorchSingleDriver(model, device="cpu") | self.driver = TorchSingleDriver(model, device="cpu") | ||||
| @pytest.mark.torchpaddle | |||||
| def test_check_single_optimizer_legality(self): | |||||
| @pytest.mark.torch | |||||
| def test_check_optimizers_legality(self): | |||||
| """ | """ | ||||
| 测试传入单个 optimizer 时的表现 | |||||
| 测试对合法 optimizers 的检查 | |||||
| """ | """ | ||||
| # 单个 optimizer | |||||
| optimizer = torch.optim.Adam( | optimizer = torch.optim.Adam( | ||||
| params=self.driver.model.parameters(), | params=self.driver.model.parameters(), | ||||
| lr=0.01 | lr=0.01 | ||||
| ) | ) | ||||
| self.driver.set_optimizers(optimizer) | self.driver.set_optimizers(optimizer) | ||||
| optimizer = paddle.optimizer.Adam( | |||||
| parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | |||||
| learning_rate=0.01, | |||||
| ) | |||||
| # 传入 torch 的 optimize r时,应该报错 ValueError | |||||
| with pytest.raises(ValueError): | |||||
| self.driver.set_optimizers(optimizer) | |||||
| @pytest.mark.torchpaddle | |||||
| def test_check_optimizers_legality(self): | |||||
| """ | |||||
| 测试传入 optimizer list 的表现 | |||||
| """ | |||||
| # 列表 | |||||
| optimizers = [ | optimizers = [ | ||||
| torch.optim.Adam( | torch.optim.Adam( | ||||
| params=self.driver.model.parameters(), | params=self.driver.model.parameters(), | ||||
| lr=0.01 | lr=0.01 | ||||
| ) for i in range(10) | ) for i in range(10) | ||||
| ] | ] | ||||
| self.driver.set_optimizers(optimizers) | self.driver.set_optimizers(optimizers) | ||||
| optimizers += [ | |||||
| @pytest.mark.torchpaddle | |||||
| def test_invalid_optimizers(self): | |||||
| """ | |||||
| 测试传入非法的 optimizers | |||||
| """ | |||||
| optimizer = paddle.optimizer.Adam( | |||||
| parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | |||||
| learning_rate=0.01, | |||||
| ) | |||||
| with pytest.raises(TypeError): | |||||
| self.driver.set_optimizers(optimizer) | |||||
| optimizers = [ | |||||
| paddle.optimizer.Adam( | paddle.optimizer.Adam( | ||||
| parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | ||||
| learning_rate=0.01, | learning_rate=0.01, | ||||
| ) | ) | ||||
| ] | ] | ||||
| with pytest.raises(ValueError): | |||||
| with pytest.raises(TypeError): | |||||
| self.driver.set_optimizers(optimizers) | self.driver.set_optimizers(optimizers) | ||||
| @pytest.mark.torchpaddle | |||||
| def test_check_dataloader_legality_in_train(self): | |||||
| @pytest.mark.torch | |||||
| def test_check_dataloader_legality(self): | |||||
| """ | """ | ||||
| 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||||
| 测试 check_dataloader_legality 函数的表现 | |||||
| """ | """ | ||||
| dataloader = DataLoader(TorchNormalDataset()) | dataloader = DataLoader(TorchNormalDataset()) | ||||
| TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
| # 创建 paddle 的 dataloader | |||||
| dataloader = paddle.io.DataLoader( | |||||
| PaddleNormalDataset(), | |||||
| batch_size=32, shuffle=True | |||||
| ) | |||||
| with pytest.raises(ValueError): | |||||
| TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
| self.driver.check_dataloader_legality(dataloader) | |||||
| @pytest.mark.torchpaddle | @pytest.mark.torchpaddle | ||||
| def test_check_dataloader_legality_in_test(self): | |||||
| def test_check_dataloader_legality_invalid(self): | |||||
| """ | """ | ||||
| 测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | |||||
| 测试 check_dataloader_legality 函数传入其他类型的表现 | |||||
| """ | """ | ||||
| # 此时传入的应该是dict | |||||
| dataloader = { | |||||
| "train": DataLoader(TorchNormalDataset()), | |||||
| "test": DataLoader(TorchNormalDataset()) | |||||
| } | |||||
| TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
| # 传入的不是 dict,应该报错 | |||||
| dataloader = DataLoader(TorchNormalDataset()) | |||||
| with pytest.raises(ValueError): | |||||
| TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
| # 创建 paddle 的 dataloader | # 创建 paddle 的 dataloader | ||||
| train_loader = paddle.io.DataLoader( | |||||
| PaddleNormalDataset(), | |||||
| batch_size=32, shuffle=True | |||||
| ) | |||||
| test_loader = paddle.io.DataLoader( | |||||
| dataloader = paddle.io.DataLoader( | |||||
| PaddleNormalDataset(), | PaddleNormalDataset(), | ||||
| batch_size=32, shuffle=True | batch_size=32, shuffle=True | ||||
| ) | ) | ||||
| dataloader = {"train": train_loader, "test": test_loader} | |||||
| with pytest.raises(ValueError): | |||||
| TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||||
| with pytest.raises(TypeError): | |||||
| self.driver.check_dataloader_legality(dataloader) | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| def test_tensor_to_numeric(self): | def test_tensor_to_numeric(self): | ||||
| @@ -515,10 +488,14 @@ class TestSetDistReproDataloader: | |||||
| # 迭代两个 batch | # 迭代两个 batch | ||||
| num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
| already_seen_idx = set() | already_seen_idx = set() | ||||
| if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
| replaced_loader.batch_sampler.set_epoch(3) | |||||
| else: | |||||
| replaced_loader.batch_sampler.sampler.set_epoch(3) | |||||
| for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
| if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
| break | break | ||||
| already_seen_idx.update(batch) | |||||
| already_seen_idx.update(batch.tolist()) | |||||
| if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | ||||
| sampler_states = replaced_loader.batch_sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.state_dict() | ||||
| else: | else: | ||||
| @@ -532,14 +509,16 @@ class TestSetDistReproDataloader: | |||||
| # 重新改造 dataloader | # 重新改造 dataloader | ||||
| new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False) | new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False) | ||||
| new_loader.batch_sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.load_state_dict(sampler_states) | ||||
| new_loader.batch_sampler.set_epoch(3) | |||||
| else: | else: | ||||
| batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
| sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | ||||
| # 重新构造 dataloader | # 重新构造 dataloader | ||||
| new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False) | new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False) | ||||
| new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | ||||
| new_loader.batch_sampler.sampler.set_epoch(3) | |||||
| for idx, batch in enumerate(new_loader): | for idx, batch in enumerate(new_loader): | ||||
| left_idxes.update(batch) | |||||
| left_idxes.update(batch.tolist()) | |||||
| assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | ||||
| assert len(left_idxes | already_seen_idx) == len(self.dataset) | assert len(left_idxes | already_seen_idx) == len(self.dataset) | ||||
| @@ -550,7 +529,7 @@ class TestSetDistReproDataloader: | |||||
| # | # | ||||
| ############################################################################ | ############################################################################ | ||||
| def generate_random_driver(features, labels, fp16=False, device="cpu"): | |||||
| def generate_random_driver(labels, features, fp16=False, device="cpu"): | |||||
| """ | """ | ||||
| 生成driver | 生成driver | ||||
| """ | """ | ||||
| @@ -570,9 +549,9 @@ def test_save_and_load_model(only_state_dict): | |||||
| """ | """ | ||||
| try: | try: | ||||
| path = "model" | path = "model" | ||||
| dataset = TorchArgMaxDataset(10, 40) | |||||
| dataset = TorchNormalXYDataset(20) | |||||
| dataloader = DataLoader(dataset, batch_size=4) | dataloader = DataLoader(dataset, batch_size=4) | ||||
| driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||||
| driver1, driver2 = generate_random_driver(20, 1), generate_random_driver(20, 1) | |||||
| driver1.save_model(path, only_state_dict) | driver1.save_model(path, only_state_dict) | ||||
| driver2.load_model(path, only_state_dict) | driver2.load_model(path, only_state_dict) | ||||
| @@ -596,19 +575,20 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
| try: | try: | ||||
| path = "model.ckp" | path = "model.ckp" | ||||
| dataset = TorchArgMaxDataset(10, 40) | |||||
| dataset = TorchNormalXYDataset(20) | |||||
| dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) | dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) | ||||
| driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") | |||||
| driver1, driver2 = generate_random_driver(20, 1, fp16, "cuda"), generate_random_driver(20, 1, False, "cuda") | |||||
| num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
| already_seen_x_set = set() | already_seen_x_set = set() | ||||
| already_seen_y_set = set() | already_seen_y_set = set() | ||||
| driver1.set_sampler_epoch(dataloader, 3) | |||||
| for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
| if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
| break | break | ||||
| already_seen_x_set.update(batch["x"]) | |||||
| already_seen_y_set.update(batch["y"]) | |||||
| already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||||
| already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||||
| sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
| save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
| @@ -639,11 +619,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
| assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
| left_x_batches = set() | left_x_batches = set() | ||||
| left_y_batches = set() | left_y_batches = set() | ||||
| driver1.set_sampler_epoch(replaced_loader, 3) | |||||
| for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
| batch = driver2.move_data_to_device(batch) | batch = driver2.move_data_to_device(batch) | ||||
| left_x_batches.update(batch["x"]) | |||||
| left_y_batches.update(batch["y"]) | |||||
| left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||||
| left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||||
| res1 = driver1.model.evaluate_step(**batch) | res1 = driver1.model.evaluate_step(**batch) | ||||
| res2 = driver2.model.evaluate_step(**batch) | res2 = driver2.model.evaluate_step(**batch) | ||||
| assert torch.equal(res1["preds"], res2["preds"]) | assert torch.equal(res1["preds"], res2["preds"]) | ||||
| @@ -660,24 +641,25 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
| @pytest.mark.parametrize("fp16", ([True, False])) | @pytest.mark.parametrize("fp16", ([True, False])) | ||||
| def test_save_and_load_with_randomsampler(only_state_dict, fp16): | def test_save_and_load_with_randomsampler(only_state_dict, fp16): | ||||
| """ | """ | ||||
| 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||||
| 测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 | |||||
| """ | """ | ||||
| try: | try: | ||||
| path = "model.ckp" | path = "model.ckp" | ||||
| driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") | |||||
| dataset = TorchArgMaxDataset(10, 40) | |||||
| driver1, driver2 = generate_random_driver(40, 1, fp16, "cuda"), generate_random_driver(40, 1, False, "cuda") | |||||
| dataset = TorchNormalXYDataset(40) | |||||
| dataloader = dataloader_with_randomsampler(dataset, 4, True, False) | dataloader = dataloader_with_randomsampler(dataset, 4, True, False) | ||||
| num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
| already_seen_x_set = set() | already_seen_x_set = set() | ||||
| already_seen_y_set = set() | already_seen_y_set = set() | ||||
| driver1.set_sampler_epoch(dataloader, 3) | |||||
| for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
| if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
| break | break | ||||
| already_seen_x_set.update(batch["x"]) | |||||
| already_seen_y_set.update(batch["y"]) | |||||
| already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||||
| already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||||
| sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
| save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
| @@ -711,11 +693,13 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
| assert start_batch == 2 * num_consumed_batches | assert start_batch == 2 * num_consumed_batches | ||||
| left_x_batches = set() | left_x_batches = set() | ||||
| left_y_batches = set() | left_y_batches = set() | ||||
| # set epoch | |||||
| driver2.set_sampler_epoch(replaced_loader, 3) | |||||
| for idx, batch in enumerate(replaced_loader): | for idx, batch in enumerate(replaced_loader): | ||||
| batch = driver2.move_data_to_device(batch) | batch = driver2.move_data_to_device(batch) | ||||
| left_x_batches.update(batch["x"]) | |||||
| left_y_batches.update(batch["y"]) | |||||
| left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||||
| left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||||
| res1 = driver1.model.evaluate_step(**batch) | res1 = driver1.model.evaluate_step(**batch) | ||||
| res2 = driver2.model.evaluate_step(**batch) | res2 = driver2.model.evaluate_step(**batch) | ||||
| assert torch.equal(res1["preds"], res2["preds"]) | assert torch.equal(res1["preds"], res2["preds"]) | ||||
| @@ -0,0 +1,46 @@ | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
| if _NEED_IMPORT_JITTOR: | |||||
| import jittor as jt | |||||
| from jittor.dataset import Dataset | |||||
| else: | |||||
| from fastNLP.core.utils.dummy_class import DummyClass as Dataset | |||||
| class JittorNormalDataset(Dataset): | |||||
| def __init__(self, num_of_data=100, **kwargs): | |||||
| super(JittorNormalDataset, self).__init__(**kwargs) | |||||
| self._data = list(range(num_of_data)) | |||||
| self.set_attrs(total_len=num_of_data) | |||||
| def __getitem__(self, item): | |||||
| return self._data[item] | |||||
| class JittorNormalXYDataset(Dataset): | |||||
| """ | |||||
| 可以被输入到分类模型中的普通数据集 | |||||
| """ | |||||
| def __init__(self, num_of_data=1000, **kwargs): | |||||
| super(JittorNormalXYDataset, self).__init__(**kwargs) | |||||
| self.num_of_data = num_of_data | |||||
| self._data = list(range(num_of_data)) | |||||
| self.set_attrs(total_len=num_of_data) | |||||
| def __getitem__(self, item): | |||||
| return { | |||||
| "x": jt.Var([self._data[item]]), | |||||
| "y": jt.Var([self._data[item]]) | |||||
| } | |||||
| class JittorArgMaxDataset(Dataset): | |||||
| def __init__(self, num_samples, num_features, **kwargs): | |||||
| super(JittorArgMaxDataset, self).__init__(**kwargs) | |||||
| self.x = jt.randn(num_samples, num_features) | |||||
| self.y = self.x.argmax(dim=-1) | |||||
| self.set_attrs(total_len=num_samples) | |||||
| def __getitem__(self, item): | |||||
| return {"x": self.x[item], "y": self.y[item]} | |||||
| if __name__ == "__main__": | |||||
| dataset = JittorNormalDataset() | |||||
| print(len(dataset)) | |||||
| @@ -19,8 +19,24 @@ class PaddleNormalDataset(Dataset): | |||||
| def __getitem__(self, item): | def __getitem__(self, item): | ||||
| return self._data[item] | return self._data[item] | ||||
| class PaddleNormalXYDataset(Dataset): | |||||
| """ | |||||
| 可以被输入到分类模型中的普通数据集 | |||||
| """ | |||||
| def __init__(self, num_of_data=1000): | |||||
| self.num_of_data = num_of_data | |||||
| self._data = list(range(num_of_data)) | |||||
| def __len__(self): | |||||
| return self.num_of_data | |||||
| def __getitem__(self, item): | |||||
| return { | |||||
| "x": paddle.to_tensor([self._data[item]], dtype="float32"), | |||||
| "y": paddle.to_tensor([self._data[item]], dtype="float32") | |||||
| } | |||||
| class PaddleRandomMaxDataset(Dataset): | |||||
| class PaddleArgMaxDataset(Dataset): | |||||
| def __init__(self, num_samples, num_features): | def __init__(self, num_samples, num_features): | ||||
| self.x = paddle.randn((num_samples, num_features)) | self.x = paddle.randn((num_samples, num_features)) | ||||
| self.y = self.x.argmax(axis=-1) | self.y = self.x.argmax(axis=-1) | ||||
| @@ -1,4 +1,6 @@ | |||||
| from functools import reduce | from functools import reduce | ||||
| from numpy import dtype | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| @@ -19,6 +21,23 @@ class TorchNormalDataset(Dataset): | |||||
| def __getitem__(self, item): | def __getitem__(self, item): | ||||
| return self._data[item] | return self._data[item] | ||||
| class TorchNormalXYDataset(Dataset): | |||||
| """ | |||||
| 可以被输入到分类模型中的普通数据集 | |||||
| """ | |||||
| def __init__(self, num_of_data=1000): | |||||
| self.num_of_data = num_of_data | |||||
| self._data = list(range(num_of_data)) | |||||
| def __len__(self): | |||||
| return self.num_of_data | |||||
| def __getitem__(self, item): | |||||
| return { | |||||
| "x": torch.tensor([self._data[item]], dtype=torch.float), | |||||
| "y": torch.tensor([self._data[item]], dtype=torch.float) | |||||
| } | |||||
| # 该类专门用于为 tests.helpers.models.torch_model.py/ TorchNormalModel_Classification_1 创建数据; | # 该类专门用于为 tests.helpers.models.torch_model.py/ TorchNormalModel_Classification_1 创建数据; | ||||
| class TorchNormalDataset_Classification(Dataset): | class TorchNormalDataset_Classification(Dataset): | ||||
| @@ -0,0 +1,57 @@ | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
| if _NEED_IMPORT_JITTOR: | |||||
| from jittor import Module, nn | |||||
| else: | |||||
| from fastNLP.core.utils.dummy_class import DummyClass as Module | |||||
| class JittorNormalModel_Classification_1(Module): | |||||
| """ | |||||
| 基础的 jittor 分类模型 | |||||
| """ | |||||
| def __init__(self, num_labels, feature_dimension): | |||||
| super(JittorNormalModel_Classification_1, self).__init__() | |||||
| self.num_labels = num_labels | |||||
| self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | |||||
| self.ac1 = nn.ReLU() | |||||
| self.linear2 = nn.Linear(in_features=64, out_features=32) | |||||
| self.ac2 = nn.ReLU() | |||||
| self.output = nn.Linear(in_features=32, out_features=num_labels) | |||||
| self.loss_fn = nn.CrossEntropyLoss() | |||||
| def execute(self, x): | |||||
| x = self.ac1(self.linear1(x)) | |||||
| x = self.ac2(self.linear2(x)) | |||||
| x = self.output(x) | |||||
| return x | |||||
| def train_step(self, x, y): | |||||
| x = self(x) | |||||
| return {"loss": self.loss_fn(x, y)} | |||||
| def evaluate_step(self, x, y): | |||||
| x = self(x) | |||||
| return {"pred": x, "target": y.reshape((-1,))} | |||||
| class JittorNormalModel_Classification_2(Module): | |||||
| """ | |||||
| 基础的 jittor 分类模型,只实现 execute 函数测试用户自己初始化了分布式的场景 | |||||
| """ | |||||
| def __init__(self, num_labels, feature_dimension): | |||||
| super(JittorNormalModel_Classification_2, self).__init__() | |||||
| self.num_labels = num_labels | |||||
| self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | |||||
| self.ac1 = nn.ReLU() | |||||
| self.linear2 = nn.Linear(in_features=64, out_features=32) | |||||
| self.ac2 = nn.ReLU() | |||||
| self.output = nn.Linear(in_features=32, out_features=num_labels) | |||||
| self.loss_fn = nn.CrossEntropyLoss() | |||||
| def execute(self, x, y): | |||||
| x = self.ac1(self.linear1(x)) | |||||
| x = self.ac2(self.linear2(x)) | |||||
| x = self.output(x) | |||||
| return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))} | |||||
| @@ -8,7 +8,7 @@ else: | |||||
| class PaddleNormalModel_Classification_1(Layer): | class PaddleNormalModel_Classification_1(Layer): | ||||
| """ | """ | ||||
| 基础的paddle分类模型 | |||||
| 基础的 paddle 分类模型 | |||||
| """ | """ | ||||
| def __init__(self, num_labels, feature_dimension): | def __init__(self, num_labels, feature_dimension): | ||||
| super(PaddleNormalModel_Classification_1, self).__init__() | super(PaddleNormalModel_Classification_1, self).__init__() | ||||
| @@ -39,7 +39,7 @@ class PaddleNormalModel_Classification_1(Layer): | |||||
| class PaddleNormalModel_Classification_2(Layer): | class PaddleNormalModel_Classification_2(Layer): | ||||
| """ | """ | ||||
| 基础的paddle分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 | |||||
| 基础的 paddle 分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 | |||||
| """ | """ | ||||
| def __init__(self, num_labels, feature_dimension): | def __init__(self, num_labels, feature_dimension): | ||||
| super(PaddleNormalModel_Classification_2, self).__init__() | super(PaddleNormalModel_Classification_2, self).__init__() | ||||
| @@ -56,5 +56,4 @@ class PaddleNormalModel_Classification_2(Layer): | |||||
| x = self.ac1(self.linear1(x)) | x = self.ac1(self.linear1(x)) | ||||
| x = self.ac2(self.linear2(x)) | x = self.ac2(self.linear2(x)) | ||||
| x = self.output(x) | x = self.output(x) | ||||
| loss = self.loss_fn(x, y) | |||||
| return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))} | return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))} | ||||
| @@ -33,7 +33,11 @@ class TestPaddle2Torch: | |||||
| """ | """ | ||||
| assert isinstance(tensor, torch.Tensor) | assert isinstance(tensor, torch.Tensor) | ||||
| assert tensor.device == torch.device(device) | |||||
| if device == "cpu": | |||||
| assert not tensor.is_cuda | |||||
| else: | |||||
| assert tensor.is_cuda | |||||
| assert tensor.device.index == torch.device(device).index | |||||
| assert tensor.requires_grad == requires_grad | assert tensor.requires_grad == requires_grad | ||||
| def test_gradient(self): | def test_gradient(self): | ||||
| @@ -261,7 +265,8 @@ class TestJittor2Torch: | |||||
| if device == "cpu": | if device == "cpu": | ||||
| assert not tensor.is_cuda | assert not tensor.is_cuda | ||||
| else: | else: | ||||
| assert tensor.device == torch.device(device) | |||||
| assert tensor.is_cuda | |||||
| assert tensor.device.index == torch.device(device).index | |||||
| assert tensor.requires_grad == requires_grad | assert tensor.requires_grad == requires_grad | ||||
| def test_var_transfer(self): | def test_var_transfer(self): | ||||
| @@ -271,7 +276,10 @@ class TestJittor2Torch: | |||||
| jittor_var = jittor.rand((3, 4, 5)) | jittor_var = jittor.rand((3, 4, 5)) | ||||
| res = jittor2torch(jittor_var) | res = jittor2torch(jittor_var) | ||||
| self.check_torch_tensor(res, "cpu", True) | |||||
| if jittor.flags.use_cuda: | |||||
| self.check_torch_tensor(res, "cuda:0", True) | |||||
| else: | |||||
| self.check_torch_tensor(res, "cpu", True) | |||||
| res = jittor2torch(jittor_var, device="cuda:2", no_gradient=None) | res = jittor2torch(jittor_var, device="cuda:2", no_gradient=None) | ||||
| self.check_torch_tensor(res, "cuda:2", True) | self.check_torch_tensor(res, "cuda:2", True) | ||||
| @@ -291,7 +299,10 @@ class TestJittor2Torch: | |||||
| res = jittor2torch(jittor_list) | res = jittor2torch(jittor_list) | ||||
| assert isinstance(res, list) | assert isinstance(res, list) | ||||
| for t in res: | for t in res: | ||||
| self.check_torch_tensor(t, "cpu", True) | |||||
| if jittor.flags.use_cuda: | |||||
| self.check_torch_tensor(t, "cuda:0", True) | |||||
| else: | |||||
| self.check_torch_tensor(t, "cpu", True) | |||||
| res = jittor2torch(jittor_list, device="cuda:1", no_gradient=False) | res = jittor2torch(jittor_list, device="cuda:1", no_gradient=False) | ||||
| assert isinstance(res, list) | assert isinstance(res, list) | ||||
| @@ -327,17 +338,29 @@ class TestJittor2Torch: | |||||
| } | } | ||||
| res = jittor2torch(jittor_dict) | res = jittor2torch(jittor_dict) | ||||
| assert isinstance(res, dict) | assert isinstance(res, dict) | ||||
| self.check_torch_tensor(res["tensor"], "cpu", True) | |||||
| if jittor.flags.use_cuda: | |||||
| self.check_torch_tensor(res["tensor"], "cuda:0", True) | |||||
| else: | |||||
| self.check_torch_tensor(res["tensor"], "cpu", True) | |||||
| assert isinstance(res["list"], list) | assert isinstance(res["list"], list) | ||||
| for t in res["list"]: | for t in res["list"]: | ||||
| self.check_torch_tensor(t, "cpu", True) | |||||
| if jittor.flags.use_cuda: | |||||
| self.check_torch_tensor(t, "cuda:0", True) | |||||
| else: | |||||
| self.check_torch_tensor(t, "cpu", True) | |||||
| assert isinstance(res["int"], int) | assert isinstance(res["int"], int) | ||||
| assert isinstance(res["string"], str) | assert isinstance(res["string"], str) | ||||
| assert isinstance(res["dict"], dict) | assert isinstance(res["dict"], dict) | ||||
| assert isinstance(res["dict"]["list"], list) | assert isinstance(res["dict"]["list"], list) | ||||
| for t in res["dict"]["list"]: | for t in res["dict"]["list"]: | ||||
| self.check_torch_tensor(t, "cpu", True) | |||||
| self.check_torch_tensor(res["dict"]["tensor"], "cpu", True) | |||||
| if jittor.flags.use_cuda: | |||||
| self.check_torch_tensor(t, "cuda:0", True) | |||||
| else: | |||||
| self.check_torch_tensor(t, "cpu", True) | |||||
| if jittor.flags.use_cuda: | |||||
| self.check_torch_tensor(res["dict"]["tensor"], "cuda:0", True) | |||||
| else: | |||||
| self.check_torch_tensor(res["dict"]["tensor"], "cpu", True) | |||||
| ############################################################################ | ############################################################################ | ||||