| @@ -24,7 +24,6 @@ from fastNLP.core.dataset import DataSet as FDataSet | |||
| class _JittorDataset(Dataset): | |||
| """ | |||
| 对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset | |||
| """ | |||
| def __init__(self, dataset) -> None: | |||
| @@ -83,7 +82,7 @@ class JittorDataLoader: | |||
| # TODO 验证支持replacesampler (以后完成) 增加Sampler | |||
| # 将内部dataset批次设置为1 | |||
| if isinstance(dataset, Dataset): | |||
| dataset.set_attrs(batch_size=1) | |||
| dataset.set_attrs(batch_size=1, shuffle=False, endless=False) | |||
| # FastNLP Datset, collate_fn not None | |||
| if isinstance(dataset, FDataSet) and collate_fn is None: | |||
| @@ -115,6 +114,12 @@ class JittorDataLoader: | |||
| self.cur_batch_indices = None | |||
| def __getattr__(self, attr): | |||
| if attr in ["batch_size", "shuffle", "drop_last", "num_workers", "buffer_size", "stop_grad", | |||
| "keep_numpy_array", "endless", "sampler"]: | |||
| return getattr(self.dataset, attr) | |||
| raise AttributeError(f"{self} has not attribute '{attr}'") | |||
| def __iter__(self): | |||
| # TODO 第一次迭代后不能设置collate_fn,设置是无效的 | |||
| if self.cur_batch_indices is None: | |||
| @@ -10,7 +10,7 @@ if _NEED_IMPORT_JITTOR: | |||
| __all__ = [] | |||
| def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: jittor.Module, **kwargs) -> JittorDriver: | |||
| def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: "jittor.Module", **kwargs) -> JittorDriver: | |||
| r""" | |||
| 用来根据参数 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去。 | |||
| @@ -30,7 +30,7 @@ def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], mo | |||
| raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") | |||
| # TODO 实现更详细的判断 | |||
| if device in ["cpu", "gpu", "cuda", "cuda:0", 0, None]: | |||
| if device in ["cpu", "gpu", "cuda", None]: | |||
| return JittorSingleDriver(model, device, **kwargs) | |||
| elif type(device) is int: | |||
| return JittorMPIDriver(model, device, **kwargs) | |||
| @@ -1,23 +1,31 @@ | |||
| import os | |||
| import random | |||
| from pathlib import Path | |||
| from typing import Union, Optional | |||
| from functools import partial | |||
| import numpy as np | |||
| from typing import Union, Optional, Dict | |||
| from contextlib import nullcontext | |||
| from dataclasses import dataclass | |||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
| from fastNLP.core.drivers.driver import Driver | |||
| from fastNLP.core.dataloaders import JittorDataLoader | |||
| from fastNLP.core.samplers import ReproducibleSampler, RandomSampler | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.core.utils import apply_to_collection | |||
| from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_SEED_WORKERS | |||
| from fastNLP.envs import ( | |||
| FASTNLP_MODEL_FILENAME, | |||
| FASTNLP_CHECKPOINT_FILENAME, | |||
| ) | |||
| if _NEED_IMPORT_JITTOR: | |||
| import jittor as jt | |||
| from jittor import Module | |||
| from jittor.optim import Optimizer | |||
| from jittor.dataset import Dataset | |||
| from jittor.dataset import ( | |||
| BatchSampler as JittorBatchSampler, | |||
| Sampler as JittorSampler, | |||
| RandomSampler as JittorRandomSampler, | |||
| SequentialSampler as JittorSequentialSampler | |||
| ) | |||
| _reduces = { | |||
| 'max': jt.max, | |||
| @@ -56,6 +64,7 @@ class JittorDriver(Driver): | |||
| else: | |||
| jt.flags.auto_mixed_precision_level = 0 | |||
| self.fp16 = fp16 | |||
| self._auto_cast = nullcontext | |||
| # 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||
| self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||
| @@ -68,7 +77,7 @@ class JittorDriver(Driver): | |||
| def _check_optimizer_legality(optimizers): | |||
| for each_optimizer in optimizers: | |||
| if not isinstance(each_optimizer, Optimizer): | |||
| raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " | |||
| raise TypeError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " | |||
| f"not {type(each_optimizer)}.") | |||
| def step(self): | |||
| @@ -117,30 +126,118 @@ class JittorDriver(Driver): | |||
| model = self.unwrap_model() | |||
| model.load(filepath) | |||
| def save_checkpoint(self): | |||
| ... | |||
| def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
| dataloader_args = self.get_dataloader_args(dataloader) | |||
| if dataloader_args.sampler: | |||
| sampler = dataloader_args.sampler | |||
| else: | |||
| raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||
| num_consumed_batches = states.pop('num_consumed_batches') | |||
| if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||
| sampler_states = sampler.state_dict() | |||
| # 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
| # 会造成多余实际消耗的问题。因为 | |||
| num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
| if num_consumed_samples_array is not None: | |||
| if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||
| if dataloader_args.batch_size is not None: | |||
| num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
| else: # 有可能 batch_size 为 None,就只有损失精度了 | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
| assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
| else: | |||
| if dataloader_args.batch_size is not None: | |||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
| * num_consumed_batches | |||
| else: | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| states['sampler_states'] = sampler_states | |||
| else: | |||
| raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training ' | |||
| 'state.') | |||
| # 2. 保存模型的状态; | |||
| if should_save_model: | |||
| if not os.path.exists(folder): | |||
| os.mkdir(folder) | |||
| model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) | |||
| self.save_model(model_path, only_state_dict=only_state_dict) | |||
| # 3. 保存 optimizers 的状态; | |||
| states["optimizers_state_dict"] = self.get_optimizer_state() | |||
| # 4. 保存fp16的状态 | |||
| logger.debug("Save optimizer state dict") | |||
| jt.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||
| def get_optimizer_state(self): | |||
| # optimizers_state_dict = {} | |||
| # for i in range(len(self.optimizers)): | |||
| # optimizer: torch.optim.Optimizer = self.optimizers[i] | |||
| # optimizer_state = optimizer.state_dict() | |||
| # optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) | |||
| # optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||
| # return optimizers_state_dict | |||
| ... | |||
| optimizers_state_dict = {} | |||
| for i in range(len(self.optimizers)): | |||
| optimizer: Optimizer = self.optimizers[i] | |||
| optimizers_state_dict[f"optimizer{i}"] = optimizer.state_dict() # 注意这里没有使用 deepcopy,测试是不需要的; | |||
| return optimizers_state_dict | |||
| def load_optimizer_state(self, states): | |||
| # assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||
| # f"checkpoint it is:{len(states)}" | |||
| # for i in range(len(self.optimizers)): | |||
| # optimizer: torch.optim.Optimizer = self.optimizers[i] | |||
| # optimizer.load_state_dict(states[f"optimizer{i}"]) | |||
| # logger.debug("Load optimizer state dict.") | |||
| ... | |||
| assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||
| f"checkpoint it is:{len(states)}" | |||
| for i in range(len(self.optimizers)): | |||
| optimizer: Optimizer = self.optimizers[i] | |||
| optimizer.load_state_dict(states[f"optimizer{i}"]) | |||
| logger.debug("Load optimizer state dict.") | |||
| def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||
| states = jt.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) | |||
| # 1. 加载 optimizers 的状态; | |||
| optimizers_state_dict = states.pop("optimizers_state_dict") | |||
| self.load_optimizer_state(optimizers_state_dict) | |||
| # 2. 加载模型状态; | |||
| if should_load_model: | |||
| self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) | |||
| # 3. 加载fp16的状态 | |||
| # 4. 恢复 sampler 的状态; | |||
| dataloader_args = self.get_dataloader_args(dataloader) | |||
| if dataloader_args.sampler is None: | |||
| sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=dataloader_args.shuffle) | |||
| elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||
| sampler = dataloader_args.sampler | |||
| elif isinstance(dataloader_args.sampler, JittorRandomSampler): | |||
| sampler = RandomSampler(dataloader_args.sampler.dataset) | |||
| logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | |||
| elif isinstance(dataloader_args.sampler, JittorSequentialSampler): | |||
| sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=False) | |||
| logger.debug("Replace jittor Sampler into fastNLP RandomSampler without shuffle.") | |||
| elif self.is_distributed(): | |||
| raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" | |||
| "`ReproducibleSampler`.") | |||
| else: | |||
| raise RuntimeError(f"Jittor sampler {type(dataloader_args.sampler)} is not supported now.") | |||
| sampler.load_state_dict(states.pop('sampler_states')) | |||
| states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||
| # 4. 修改 trainer_state.batch_idx_in_epoch | |||
| # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||
| if dataloader_args.drop_last: | |||
| batch_idx_in_epoch = len( | |||
| sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||
| else: | |||
| batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||
| (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||
| def load_checkpoint(self): | |||
| ... | |||
| states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||
| return states | |||
| def get_evaluate_context(self): | |||
| return jt.no_grad | |||
| @@ -198,26 +295,8 @@ class JittorDriver(Driver): | |||
| """ | |||
| return batch | |||
| @staticmethod | |||
| def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | |||
| global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | |||
| process_seed = jt.get_seed() | |||
| # back out the base seed so we can use all the bits | |||
| base_seed = process_seed - worker_id | |||
| ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) | |||
| # use 128 bits (4 x 32-bit words) | |||
| np.random.seed(ss.generate_state(4)) | |||
| # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module | |||
| jittor_ss, stdlib_ss = ss.spawn(2) | |||
| jt.set_global_seed(jittor_ss.generate_state(1, dtype=np.uint64)[0]) | |||
| # use 128 bits expressed as an integer | |||
| stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() | |||
| random.seed(stdlib_seed) | |||
| def set_deterministic_dataloader(self, dataloader: Union["JittorDataLoader", "Dataset"]): | |||
| if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: | |||
| dataloader.worker_init_fn = partial(self.worker_init_function, | |||
| rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))) | |||
| ... | |||
| def set_sampler_epoch(self, dataloader: Union["JittorDataLoader", "Dataset"], cur_epoch_idx: int): | |||
| # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | |||
| @@ -226,4 +305,45 @@ class JittorDriver(Driver): | |||
| @staticmethod | |||
| def get_dataloader_args(dataloader: Union["JittorDataLoader", "Dataset"]): | |||
| pass | |||
| @dataclass | |||
| class Res: | |||
| dataset: Optional[Dataset] = None | |||
| batch_sampler: Optional[JittorBatchSampler] = None | |||
| sampler: Optional[JittorSampler] = None | |||
| batch_size: Optional[int] = None | |||
| shuffle: Optional[bool] = None | |||
| drop_last: Optional[bool] = None | |||
| res = Res() | |||
| from fastNLP.core.dataloaders.jittor_dataloader.fdl import _JittorDataset | |||
| if isinstance(dataloader, JittorDataLoader): | |||
| # JittorDataLoader 实际上是迭代 dataset 成员的 | |||
| dataloader = dataloader.dataset | |||
| if isinstance(dataloader, _JittorDataset): | |||
| # 获取最原始的 dataset | |||
| res.dataset = dataloader.dataset | |||
| else: | |||
| res.dataset = dataloader | |||
| # jittor 现在不支持 batch_sampler,所以除了 shuffle 都可以直接获取 | |||
| res.batch_size = dataloader.batch_size | |||
| res.drop_last = dataloader.drop_last | |||
| if dataloader.sampler is None: | |||
| # sampler 是 None,那么就从 Dataset 的属性中获取 | |||
| res.shuffle = dataloader.shuffle | |||
| elif isinstance(list(dataloader.sampler.__iter__())[0], (list,tuple)): | |||
| # jittor 目前不支持 batch_sampler | |||
| raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, " | |||
| "please check if you have set `Dataset.sampler` as `BatchSampler`") | |||
| else: | |||
| # sampler 不为 None | |||
| res.sampler = dataloader.sampler | |||
| if hasattr(dataloader.sampler, "shuffle"): | |||
| # 这种情况一般出现在 fastNLP 的 ReproduceSampler 中 | |||
| res.shuffle = dataloader.sampler.shuffle | |||
| elif isinstance(dataloader.sampler, JittorRandomSampler): | |||
| res.shuffle = True | |||
| else: | |||
| res.shuffle = False | |||
| return res | |||
| @@ -38,6 +38,7 @@ class JittorMPIDriver(JittorDriver): | |||
| ): | |||
| super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs) | |||
| raise NotImplementedError("MPI for Jittor is not supported right now.") | |||
| self.is_pull_by_jittor_run = is_pull_by_jittor_run | |||
| self.parallel_device = parallel_device | |||
| @@ -100,22 +101,6 @@ class JittorMPIDriver(JittorDriver): | |||
| return self._data_device | |||
| return self.parallel_device | |||
| def step(self): | |||
| # for optimizer in self.optimizers: | |||
| # self.grad_scaler.step(optimizer) | |||
| # self.grad_scaler.update() | |||
| for optimizer in self.optimizers: | |||
| optimizer.step() | |||
| def backward(self, loss): | |||
| # self.grad_scaler.scale(loss).backward() | |||
| for optimizer in self.optimizers: | |||
| optimizer.backward(loss) | |||
| def zero_grad(self): | |||
| for optimizer in self.optimizers: | |||
| optimizer.zero_grad() | |||
| def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||
| if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
| return auto_param_call(fn, batch, signature_fn=signature_fn) | |||
| @@ -1,14 +1,21 @@ | |||
| from typing import Dict, Union, Tuple, Callable, Optional | |||
| from .jittor_driver import JittorDriver | |||
| from .utils import replace_batch_sampler, replace_sampler | |||
| from fastNLP.core.utils import auto_param_call | |||
| from fastNLP.core.utils.utils import _get_fun_msg | |||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \ | |||
| ReproduceBatchSampler | |||
| from fastNLP.core.samplers import RandomSampler | |||
| from fastNLP.core.log import logger | |||
| if _NEED_IMPORT_JITTOR: | |||
| import jittor as jt | |||
| from jittor.dataset import ( | |||
| RandomSampler as JittorRandomSampler, | |||
| SequentialSampler as JittorSequentialSampler, | |||
| ) | |||
| __all__ = [ | |||
| "JittorSingleDriver", | |||
| @@ -89,31 +96,46 @@ class JittorSingleDriver(JittorDriver): | |||
| """ | |||
| return False | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||
| reproducible: bool = False, sampler_or_batch_sampler=None): | |||
| # reproducible 的相关功能暂时没有实现 | |||
| def set_dist_repro_dataloader(self, dataloader, | |||
| dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, | |||
| reproducible: bool = False): | |||
| # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||
| if isinstance(dist, ReproducibleBatchSampler): | |||
| raise NotImplementedError | |||
| dataloader.batch_sampler = dist_sample | |||
| if isinstance(dist, ReproducibleSampler): | |||
| raise NotImplementedError | |||
| dataloader.batch_sampler.sampler = dist | |||
| return replace_batch_sampler(dataloader, dist) | |||
| elif isinstance(dist, ReproducibleSampler): | |||
| return replace_sampler(dataloader, dist) | |||
| # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||
| args = self.get_dataloader_args(dataloader) | |||
| if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
| batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
| return replace_batch_sampler(dataloader, batch_sampler) | |||
| elif isinstance(args.sampler, ReproducibleSampler): | |||
| sampler = re_instantiate_sampler(args.sampler) | |||
| return replace_sampler(dataloader, sampler) | |||
| if reproducible: | |||
| raise NotImplementedError | |||
| if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||
| return dataloader | |||
| elif isinstance(dataloader.batch_sampler, RandomBatchSampler): | |||
| return dataloader | |||
| else: | |||
| # TODO | |||
| batch_sampler = RandomBatchSampler( | |||
| batch_sampler=dataloader.batch_sampler, | |||
| batch_size=dataloader.batch_sampler.batch_size, | |||
| drop_last=dataloader.drop_last | |||
| ) | |||
| dataloader.batch_sampler = batch_sampler | |||
| return dataloader | |||
| if args.sampler is None: | |||
| sampler = RandomSampler(args.dataset, args.shuffle) | |||
| return replace_sampler(dataloader, sampler) | |||
| elif isinstance(args.sampler, JittorRandomSampler): | |||
| if getattr(args.sampler, '_num_samples', None) is None \ | |||
| and getattr(args.sampler, 'rep', False) is False: | |||
| # 如果本来就是随机的,并且没有定制,直接替换掉吧。 | |||
| sampler = RandomSampler(args.sampler.dataset, shuffle=True) | |||
| logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | |||
| return replace_sampler(dataloader, sampler) | |||
| elif isinstance(args.sampler, JittorSequentialSampler): | |||
| # 需要替换为不要 shuffle 的。 | |||
| sampler = RandomSampler(args.sampler.dataset, shuffle=False) | |||
| logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") | |||
| return replace_sampler(dataloader, sampler) | |||
| batch_sampler = ReproduceBatchSampler( | |||
| batch_sampler=args.batch_sampler, | |||
| batch_size=args.batch_size, | |||
| drop_last=args.drop_last | |||
| ) | |||
| return replace_batch_sampler(dataloader, batch_sampler) | |||
| else: | |||
| return dataloader | |||
| @@ -1,6 +1,29 @@ | |||
| import inspect | |||
| from copy import deepcopy | |||
| from typing import Union | |||
| from fastNLP.core.dataloaders import JittorDataLoader | |||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
| if _NEED_IMPORT_JITTOR: | |||
| import jittor | |||
| from jittor.dataset import Dataset | |||
| __all__ = [] | |||
| def replace_batch_sampler(dataloader, batch_sampler): | |||
| raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, " | |||
| "please check if you have set `Dataset.sampler` as `BatchSampler`" | |||
| "or report this bug to us.") | |||
| def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler): | |||
| if isinstance(dataloader, JittorDataLoader): | |||
| init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||
| reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()} | |||
| reconstruct_args["dataset"] = replace_sampler(reconstruct_args["dataset"].dataset, reconstruct_args["dataset"].sampler) | |||
| new_dataloader = type(dataloader)(**reconstruct_args) | |||
| new_dataloader.dataset.set_attrs(sampler=sampler) | |||
| else: | |||
| new_dataloader = deepcopy(dataloader) | |||
| new_dataloader.set_attrs(sampler=sampler) | |||
| return new_dataloader | |||
| @@ -31,7 +31,6 @@ if _NEED_IMPORT_PADDLE: | |||
| import paddle | |||
| from paddle.io import ( | |||
| DataLoader, | |||
| IterableDataset, | |||
| Dataset, | |||
| Sampler, | |||
| BatchSampler, | |||
| @@ -97,6 +96,9 @@ class PaddleDriver(Driver): | |||
| def check_dataloader_legality(self, dataloader): | |||
| if not isinstance(dataloader, DataLoader): | |||
| raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | |||
| if dataloader.batch_size is None and dataloader.batch_sampler is None: | |||
| raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" | |||
| "is not None") | |||
| @staticmethod | |||
| def _check_optimizer_legality(optimizers): | |||
| @@ -107,7 +109,7 @@ class PaddleDriver(Driver): | |||
| """ | |||
| for each_optimizer in optimizers: | |||
| if not isinstance(each_optimizer, Optimizer): | |||
| raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " | |||
| raise TypeError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " | |||
| f"not {type(each_optimizer)}.") | |||
| @staticmethod | |||
| @@ -263,9 +265,7 @@ class PaddleDriver(Driver): | |||
| optimizers_state_dict = {} | |||
| for i in range(len(self.optimizers)): | |||
| optimizer: Optimizer = self.optimizers[i] | |||
| optimizer_state = optimizer.state_dict() | |||
| optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") | |||
| optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||
| optimizers_state_dict[f"optimizer{i}"] = optimizer_state_to_device(optimizer.state_dict(), "cpu") | |||
| return optimizers_state_dict | |||
| @@ -399,6 +399,8 @@ class PaddleDriver(Driver): | |||
| def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx): | |||
| if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | |||
| dataloader.batch_sampler.set_epoch(cur_epoch_idx) | |||
| elif callable(getattr(dataloader.batch_sampler.sampler, "set_epoch", None)): | |||
| dataloader.batch_sampler.sampler.set_epoch(cur_epoch_idx) | |||
| @staticmethod | |||
| def get_dataloader_args(dataloader: "DataLoader"): | |||
| @@ -99,7 +99,7 @@ class TorchDriver(Driver): | |||
| def _check_optimizer_legality(optimizers): | |||
| for each_optimizer in optimizers: | |||
| if not isinstance(each_optimizer, Optimizer): | |||
| raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " | |||
| raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " | |||
| f"not {type(each_optimizer)}.") | |||
| @staticmethod | |||
| @@ -210,7 +210,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| self.num_consumed_samples = 0 | |||
| self.during_iter = True | |||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
| indices = list(range(self.num_samples)) | |||
| if self.shuffle: | |||
| if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | |||
| @@ -237,7 +237,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| if len(indices)%self.batch_size!=0: | |||
| batches.append(indices[_num_batches*self.batch_size:]) | |||
| need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas | |||
| need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas | |||
| if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | |||
| if len(batches) > 0: | |||
| if len(batches[-1])<self.batch_size: | |||
| @@ -290,9 +290,9 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| @property | |||
| def batch_idx_in_epoch(self): | |||
| if self.drop_last: | |||
| return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||
| return self.num_samples // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||
| else: | |||
| return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||
| return (self.num_samples // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||
| (self.num_left_samples + self.batch_size - 1) // self.batch_size | |||
| @property | |||
| @@ -313,8 +313,12 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| :return: | |||
| """ | |||
| num_consumed_samples = self.num_consumed_samples | |||
| return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||
| self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||
| return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \ | |||
| self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas)) | |||
| @property | |||
| def num_samples(self): | |||
| return getattr(self.dataset, 'total_len', len(self.dataset)) | |||
| def __len__(self)->int: | |||
| """ | |||
| @@ -332,7 +336,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | |||
| " consumed. ") | |||
| states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
| 'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, | |||
| 'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle, | |||
| 'batch_size': self.batch_size, | |||
| 'num_replicas': self.num_replicas} | |||
| @@ -347,7 +351,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| f"we cannot use {self.__class__.__name__} to load it." | |||
| length = states['length'] | |||
| assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ | |||
| assert length == self.num_samples, "The number of samples is different between the checkpoint record " \ | |||
| "and current dataset." | |||
| self.seed = states['seed'] | |||
| self.epoch = states['epoch'] | |||
| @@ -464,8 +468,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| :return: | |||
| """ | |||
| num_consumed_samples = self.num_consumed_samples | |||
| return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||
| self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||
| return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \ | |||
| self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas)) | |||
| @property | |||
| def num_samples(self): | |||
| return getattr(self.dataset, 'total_len', len(self.dataset)) | |||
| def __len__(self)->int: | |||
| """ | |||
| @@ -515,7 +523,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| if len(sorted_indices)%self.batch_size!=0: | |||
| batches.append(sorted_indices[_num_batches*self.batch_size:]) | |||
| need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas | |||
| need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas | |||
| if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | |||
| if len(batches) > 0: | |||
| if len(batches[-1])<self.batch_size: | |||
| @@ -593,7 +601,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | |||
| " consumed. ") | |||
| states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
| 'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, | |||
| 'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle, | |||
| 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | |||
| 'num_replicas': self.num_replicas | |||
| } | |||
| @@ -609,7 +617,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| f"we cannot use {self.__class__.__name__} to load it." | |||
| length = states['length'] | |||
| assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ | |||
| assert length == self.num_samples, "The number of samples is different between the checkpoint record " \ | |||
| "and current dataset." | |||
| self.seed = states['seed'] | |||
| self.epoch = states['epoch'] | |||
| @@ -630,7 +638,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| @property | |||
| def batch_idx_in_epoch(self): | |||
| if self.drop_last: | |||
| return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||
| return self.num_samples // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||
| else: | |||
| return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||
| return (self.num_samples // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||
| (self.num_left_samples + self.batch_size - 1) // self.batch_size | |||
| @@ -48,6 +48,10 @@ class ReproducibleSampler: | |||
| def num_left_samples(self): | |||
| raise NotImplementedError("Each specific sampler should implement its own `num_left_samples` method.") | |||
| @property | |||
| def num_samples(self): | |||
| raise NotImplementedError("Each specific sampler should implement its own `num_samples` method.") | |||
| def set_epoch(self, epoch): | |||
| pass | |||
| @@ -131,19 +135,19 @@ class RandomSampler(ReproducibleSampler): | |||
| :return: | |||
| """ | |||
| if self.shuffle: | |||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
| indices = list(range(self.num_samples)) | |||
| seed = self.seed + self.epoch | |||
| rng = np.random.default_rng(abs(seed)) | |||
| rng.shuffle(indices) | |||
| if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | |||
| self.epoch -= 1 | |||
| else: | |||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
| indices = list(range(self.num_samples)) | |||
| return indices | |||
| def state_dict(self) -> Dict: | |||
| states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
| 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle} | |||
| 'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle} | |||
| return states | |||
| def load_state_dict(self, states: Dict): | |||
| @@ -155,8 +159,8 @@ class RandomSampler(ReproducibleSampler): | |||
| f"we cannot use {self.__class__.__name__} to load it." | |||
| length = states['length'] | |||
| assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ | |||
| f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||
| assert length == self.num_samples, "The number of samples is different between the checkpoint " \ | |||
| f"record({length}) and current dataset({self.num_samples})." | |||
| self.seed = states['seed'] | |||
| self.epoch = states['epoch'] | |||
| self.num_consumed_samples = states['num_consumed_samples'] | |||
| @@ -208,9 +212,17 @@ class RandomSampler(ReproducibleSampler): | |||
| :return: | |||
| """ | |||
| num_consumed_samples = self.num_consumed_samples | |||
| return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||
| self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||
| return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \ | |||
| self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas)) | |||
| @property | |||
| def num_samples(self): | |||
| """ | |||
| 返回样本的总数 | |||
| :return: | |||
| """ | |||
| return getattr(self.dataset, 'total_len', len(self.dataset)) | |||
| class SequentialSampler(RandomSampler): | |||
| """ | |||
| @@ -258,12 +270,10 @@ class SequentialSampler(RandomSampler): | |||
| :return: | |||
| """ | |||
| return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
| return list(range(self.num_samples)) | |||
| def state_dict(self) -> Dict: | |||
| states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | |||
| 'length': getattr(self.dataset, 'total_len', len(self.dataset)) | |||
| } | |||
| states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, 'length': self.num_samples} | |||
| return states | |||
| def load_state_dict(self, states: Dict): | |||
| @@ -275,8 +285,8 @@ class SequentialSampler(RandomSampler): | |||
| f"we cannot use {self.__class__.__name__} to load it." | |||
| length = states['length'] | |||
| assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ | |||
| f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||
| assert length == self.num_samples, "The number of samples is different between the checkpoint " \ | |||
| f"record({length}) and current dataset({self.num_samples})." | |||
| self.num_consumed_samples = states['num_consumed_samples'] | |||
| if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||
| self.num_consumed_samples = 0 | |||
| @@ -314,9 +324,9 @@ class SortedSampler(SequentialSampler): | |||
| except BaseException as e: | |||
| logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") | |||
| assert len(length) == getattr(self.dataset, 'total_len', len(self.dataset)), f"The length of `dataset`({len(dataset)}) and " \ | |||
| f"`length`({getattr(self.dataset, 'total_len', len(self.dataset))}) should be equal." | |||
| assert len(self.sorted_indices) == getattr(self.dataset, 'total_len', len(self.dataset)), "The indices and dataset should have equal length." | |||
| assert len(length) == self.num_samples, f"The length of `dataset`({len(dataset)}) and " \ | |||
| f"`length`({self.num_samples}) should be equal." | |||
| assert len(self.sorted_indices) == self.num_samples, "The indices and dataset should have equal length." | |||
| self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
| self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | |||
| @@ -42,8 +42,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
| 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | |||
| :return: | |||
| """ | |||
| num_common = getattr(self.dataset, 'total_len', len(self.dataset))//self.num_replicas | |||
| num_samples = num_common + int(self.rank < (getattr(self.dataset, 'total_len', len(self.dataset))-num_common*self.num_replicas)) | |||
| num_common = self.num_samples//self.num_replicas | |||
| num_samples = num_common + int(self.rank < (self.num_samples-num_common*self.num_replicas)) | |||
| return num_samples | |||
| def __iter__(self): | |||
| @@ -63,14 +63,14 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
| :return: | |||
| """ | |||
| if self.shuffle: | |||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
| indices = list(range(self.num_samples)) | |||
| seed = self.seed + self.epoch | |||
| rng = np.random.default_rng(abs(seed)) | |||
| rng.shuffle(indices) | |||
| if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | |||
| self.epoch -= 1 | |||
| else: | |||
| indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
| indices = list(range(self.num_samples)) | |||
| return indices | |||
| def set_epoch(self, epoch: int) -> None: | |||
| @@ -84,8 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
| :param rank: | |||
| :return: | |||
| """ | |||
| assert num_replicas<=getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of replicas({num_replicas}) should be lesser than the " \ | |||
| f"number of samples({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||
| assert num_replicas<=self.num_samples, f"The number of replicas({num_replicas}) should be lesser than the " \ | |||
| f"number of samples({self.num_samples})." | |||
| assert num_replicas>0 and isinstance(num_replicas, int) | |||
| assert isinstance(rank, int) and 0<=rank<num_replicas | |||
| # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | |||
| @@ -94,6 +94,15 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
| return self | |||
| @property | |||
| def num_samples(self): | |||
| """ | |||
| 返回样本的总数 | |||
| :return: | |||
| """ | |||
| return getattr(self.dataset, 'total_len', len(self.dataset)) | |||
| class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||
| """ | |||
| @@ -147,5 +156,5 @@ class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | |||
| yield index | |||
| def generate_indices(self) -> List[int]: | |||
| return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
| return list(range(self.num_samples)) | |||
| @@ -27,7 +27,7 @@ from paddle.optimizer import Adam | |||
| from paddle.io import DataLoader | |||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
| from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||
| from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset | |||
| from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | |||
| @dataclass | |||
| @@ -52,12 +52,12 @@ def test_trainer_fleet( | |||
| optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | |||
| train_dataloader = DataLoader( | |||
| dataset=PaddleRandomMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||
| dataset=PaddleArgMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||
| batch_size=MNISTTrainFleetConfig.batch_size, | |||
| shuffle=True | |||
| ) | |||
| val_dataloader = DataLoader( | |||
| dataset=PaddleRandomMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||
| dataset=PaddleArgMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||
| batch_size=MNISTTrainFleetConfig.batch_size, | |||
| shuffle=True | |||
| ) | |||
| @@ -24,7 +24,7 @@ from paddle.io import DataLoader | |||
| import paddle.distributed.fleet as fleet | |||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_2 | |||
| from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||
| from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset | |||
| from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback | |||
| @dataclass | |||
| @@ -54,12 +54,12 @@ def test_trainer_fleet( | |||
| optimizers = fleet.distributed_optimizer(optimizers) | |||
| train_dataloader = DataLoader( | |||
| dataset=PaddleRandomMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||
| dataset=PaddleArgMaxDataset(20, MNISTTrainFleetConfig.feature_dimension), | |||
| batch_size=MNISTTrainFleetConfig.batch_size, | |||
| shuffle=True | |||
| ) | |||
| val_dataloader = DataLoader( | |||
| dataset=PaddleRandomMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||
| dataset=PaddleArgMaxDataset(12, MNISTTrainFleetConfig.feature_dimension), | |||
| batch_size=MNISTTrainFleetConfig.batch_size, | |||
| shuffle=True | |||
| ) | |||
| @@ -46,8 +46,8 @@ class LSTM(Module): | |||
| def init_hidden(self, x): | |||
| # batch_first | |||
| batch_size = x.shape[0] | |||
| h0 = jt.randn(1, batch_size, hidden_size) | |||
| c0 = jt.randn(1, batch_size, hidden_size) | |||
| h0 = jt.randn(1, batch_size, self.hidden_size) | |||
| c0 = jt.randn(1, batch_size, self.hidden_size) | |||
| return h0, c0 | |||
| @@ -1,4 +1,5 @@ | |||
| import pytest | |||
| from fastNLP.core.callbacks import callback | |||
| from fastNLP.core.controllers.trainer import Trainer | |||
| from fastNLP.core.controllers.trainer import Evaluator | |||
| @@ -14,6 +15,7 @@ if _NEED_IMPORT_JITTOR: | |||
| else: | |||
| from fastNLP.core.utils.dummy_class import DummyClass as Module | |||
| from fastNLP.core.utils.dummy_class import DummyClass as Dataset | |||
| jt.flags.use_cuda=1 | |||
| class JittorNormalModel_Classification(Module): | |||
| @@ -68,11 +70,9 @@ class TrainJittorConfig: | |||
| batch_size: int = 4 | |||
| shuffle: bool = True | |||
| @pytest.mark.parametrize("driver", ["jittor"]) | |||
| @pytest.mark.parametrize("device", ["cpu", "gpu", "cuda:0"]) | |||
| @pytest.mark.parametrize("device", ["cpu", "gpu", "cuda", None]) | |||
| @pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) | |||
| @pytest.mark.jittor | |||
| def test_trainer_jittor( | |||
| driver, | |||
| device, | |||
| @@ -15,7 +15,7 @@ if _NEED_IMPORT_PADDLE: | |||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
| from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | |||
| from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset | |||
| from tests.helpers.utils import magic_argv_env_context | |||
| @dataclass | |||
| @@ -44,12 +44,12 @@ def test_trainer_paddle( | |||
| ) | |||
| optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | |||
| train_dataloader = DataLoader( | |||
| dataset=PaddleRandomMaxDataset(20, TrainPaddleConfig.feature_dimension), | |||
| dataset=PaddleArgMaxDataset(20, TrainPaddleConfig.feature_dimension), | |||
| batch_size=TrainPaddleConfig.batch_size, | |||
| shuffle=True | |||
| ) | |||
| val_dataloader = DataLoader( | |||
| dataset=PaddleRandomMaxDataset(12, TrainPaddleConfig.feature_dimension), | |||
| dataset=PaddleArgMaxDataset(12, TrainPaddleConfig.feature_dimension), | |||
| batch_size=TrainPaddleConfig.batch_size, | |||
| shuffle=True | |||
| ) | |||
| @@ -76,7 +76,7 @@ class TestPaddle: | |||
| from paddle.io import Dataset | |||
| import paddle | |||
| class PaddleRandomMaxDataset(Dataset): | |||
| class PaddleArgMaxDataset(Dataset): | |||
| def __init__(self, num_samples, num_features): | |||
| self.x = paddle.randn((num_samples, num_features)) | |||
| self.y = self.x.argmax(axis=-1) | |||
| @@ -87,7 +87,7 @@ class TestPaddle: | |||
| def __getitem__(self, item): | |||
| return {"x": self.x[item], "y": self.y[item]} | |||
| ds = PaddleRandomMaxDataset(100, 2) | |||
| ds = PaddleArgMaxDataset(100, 2) | |||
| dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4) | |||
| for batch in dl: | |||
| print(batch) | |||
| @@ -0,0 +1,45 @@ | |||
| import pytest | |||
| from fastNLP.core.drivers import JittorSingleDriver, JittorMPIDriver | |||
| from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver | |||
| from tests.helpers.models.jittor_model import JittorNormalModel_Classification_1 | |||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
| if _NEED_IMPORT_JITTOR: | |||
| import jittor as jt | |||
| @pytest.mark.jittor | |||
| def test_incorrect_driver(): | |||
| model = JittorNormalModel_Classification_1(20, 10) | |||
| with pytest.raises(ValueError): | |||
| driver = initialize_jittor_driver("torch", 0, model) | |||
| @pytest.mark.jittor | |||
| @pytest.mark.parametrize( | |||
| "device", | |||
| ["cpu", "gpu", None, "cuda"] | |||
| ) | |||
| def test_get_single_device(device): | |||
| """ | |||
| 测试正常情况下初始化 JittorSingleDriver 的情况 | |||
| """ | |||
| model = JittorNormalModel_Classification_1(20, 10) | |||
| driver = initialize_jittor_driver("jittor", device, model) | |||
| assert isinstance(driver, JittorSingleDriver) | |||
| @pytest.mark.jittor | |||
| @pytest.mark.parametrize( | |||
| "device", | |||
| [[0, 2, 3], 1, 2] | |||
| ) | |||
| def test_get_mpi(device): | |||
| """ | |||
| 测试 jittor 多卡的初始化情况 | |||
| """ | |||
| model = JittorNormalModel_Classification_1(20, 10) | |||
| with pytest.raises(NotImplementedError): | |||
| driver = initialize_jittor_driver("jittor", device, model) | |||
| # assert isinstance(driver, JittorMPIDriver) | |||
| @@ -1,99 +1,614 @@ | |||
| import pytest | |||
| import os | |||
| from copy import deepcopy | |||
| from pathlib import Path | |||
| import numpy as np | |||
| from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver | |||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
| from fastNLP.core.drivers.jittor_driver import JittorSingleDriver | |||
| from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||
| from fastNLP.core.dataloaders import JittorDataLoader | |||
| from tests.helpers.models.jittor_model import JittorNormalModel_Classification_1 | |||
| from tests.helpers.datasets.jittor_data import JittorNormalDataset, JittorNormalXYDataset | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
| from fastNLP.envs.distributed import rank_zero_rm | |||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_JITTOR: | |||
| import jittor as jt # 将 jittor 引入 | |||
| from jittor import nn, Module # 引入相关的模块 | |||
| from jittor import init | |||
| from jittor.dataset import MNIST | |||
| else: | |||
| from fastNLP.core.utils.dummy_class import DummyClass as Module | |||
| import jittor as jt | |||
| from jittor.dataset import ( | |||
| BatchSampler as JittorBatchSampler, | |||
| RandomSampler as JittorRandomSampler, | |||
| SequentialSampler as JittorSequentialSampler, | |||
| SubsetRandomSampler as JittorSubsetRandomSampler | |||
| ) | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| def get_dataloader(dataset, use_dataloader, sampler, batch_size, shuffle, drop_last=False): | |||
| """ | |||
| :param dataset: | |||
| :param use_dataloader: 是否使用 JittorDataLoader 包裹 | |||
| :param sampler: 使用 BatchSampler Samlper 还是不使用 Sampler | |||
| """ | |||
| if use_dataloader: | |||
| dataloader = JittorDataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) | |||
| dataloader.dataset.set_attrs(sampler=sampler) | |||
| else: | |||
| dataloader = dataset | |||
| dataloader.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, sampler=sampler) | |||
| class Model(Module): | |||
| def __init__ (self): | |||
| super (Model, self).__init__() | |||
| self.conv1 = nn.Conv (3, 32, 3, 1) # no padding | |||
| self.conv2 = nn.Conv (32, 64, 3, 1) | |||
| self.bn = nn.BatchNorm(64) | |||
| self.max_pool = nn.Pool (2, 2) | |||
| self.relu = nn.Relu() | |||
| self.fc1 = nn.Linear (64 * 12 * 12, 256) | |||
| self.fc2 = nn.Linear (256, 10) | |||
| def execute(self, x) : | |||
| # it's simliar to forward function in Pytorch | |||
| x = self.conv1 (x) | |||
| x = self.relu (x) | |||
| x = self.conv2 (x) | |||
| x = self.bn (x) | |||
| x = self.relu (x) | |||
| x = self.max_pool (x) | |||
| x = jt.reshape (x, [x.shape[0], -1]) | |||
| x = self.fc1 (x) | |||
| x = self.relu(x) | |||
| x = self.fc2 (x) | |||
| return x | |||
| return dataloader | |||
| ############################################################################ | |||
| # | |||
| # 测试基类 JittorDrvier 中的一些简单函数 | |||
| # | |||
| ############################################################################ | |||
| class TestJittorDriverFunctions: | |||
| """ | |||
| 使用 JittorSingleDriver 测试基类的函数 | |||
| """ | |||
| @classmethod | |||
| def setup_class(self): | |||
| model = JittorNormalModel_Classification_1(10, 32) | |||
| self.driver = JittorSingleDriver(model, device="cpu") | |||
| @pytest.mark.jittor | |||
| def test_check_optimizers_legality(self): | |||
| """ | |||
| 测试对合法的 optimizers 的检查 | |||
| """ | |||
| # 单个 optimizer | |||
| optimizer = jt.optim.Adam( | |||
| params=self.driver.model.parameters(), | |||
| lr=0.01 | |||
| ) | |||
| self.driver.set_optimizers(optimizer) | |||
| # optimizer 列表 | |||
| optimizers = [ | |||
| jt.optim.Adam( | |||
| params=self.driver.model.parameters(), | |||
| lr=0.01 | |||
| ) for i in range(10) | |||
| ] | |||
| self.driver.set_optimizers(optimizers) | |||
| @pytest.mark.torchjittor | |||
| def test_invalid_optimizers(self): | |||
| """ | |||
| 测试传入非法的 optimizers | |||
| """ | |||
| # 单个 optimizer | |||
| optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||
| with pytest.raises(TypeError): | |||
| self.driver.set_optimizers(optimizer) | |||
| optimizers = [ | |||
| torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||
| ] | |||
| with pytest.raises(TypeError): | |||
| self.driver.set_optimizers(optimizers) | |||
| @pytest.mark.jittor | |||
| def test_check_dataloader_legality(self): | |||
| """ | |||
| 测试 check_dataloader_legality 函数的表现 | |||
| """ | |||
| # 使用 JittorDataLoader | |||
| dataloader = JittorDataLoader(JittorNormalDataset()) | |||
| self.driver.check_dataloader_legality(dataloader) | |||
| # 使用 jittor.dataset.Dataset | |||
| self.driver.check_dataloader_legality(JittorNormalDataset()) | |||
| @pytest.mark.torchjittor | |||
| def test_check_dataloader_legality_invalid(self): | |||
| """ | |||
| 测试 check_dataloader_legality 函数传入其他类型的表现 | |||
| """ | |||
| # 创建 torch 的 dataloader | |||
| dataloader = torch.utils.data.DataLoader( | |||
| TorchNormalDataset(), | |||
| batch_size=32, shuffle=True | |||
| ) | |||
| with pytest.raises(TypeError): | |||
| self.driver.check_dataloader_legality(dataloader) | |||
| @pytest.mark.jittor | |||
| def test_tensor_to_numeric(self): | |||
| """ | |||
| 测试 tensor_to_numeric 函数 | |||
| """ | |||
| # 单个张量 | |||
| tensor = jt.Var(3) | |||
| res = JittorSingleDriver.tensor_to_numeric(tensor) | |||
| assert res == 3 | |||
| tensor = jt.rand(3, 4) | |||
| res = JittorSingleDriver.tensor_to_numeric(tensor) | |||
| assert res == tensor.tolist() | |||
| # 张量list | |||
| tensor_list = [jt.rand(6, 4, 2) for i in range(10)] | |||
| res = JittorSingleDriver.tensor_to_numeric(tensor_list) | |||
| assert isinstance(res, list) | |||
| tensor_list = [t.tolist() for t in tensor_list] | |||
| assert res == tensor_list | |||
| # 张量tuple | |||
| tensor_tuple = tuple([jt.rand(6, 4, 2) for i in range(10)]) | |||
| res = JittorSingleDriver.tensor_to_numeric(tensor_tuple) | |||
| assert isinstance(res, tuple) | |||
| tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||
| assert res == tensor_tuple | |||
| # 张量dict | |||
| tensor_dict = { | |||
| "tensor": jt.rand(3, 4), | |||
| "list": [jt.rand(6, 4, 2) for i in range(10)], | |||
| "dict":{ | |||
| "list": [jt.rand(6, 4, 2) for i in range(10)], | |||
| "tensor": jt.rand(3, 4) | |||
| }, | |||
| "int": 2, | |||
| "string": "test string" | |||
| } | |||
| res = JittorSingleDriver.tensor_to_numeric(tensor_dict) | |||
| assert isinstance(res, dict) | |||
| assert res["tensor"] == tensor_dict["tensor"].tolist() | |||
| assert isinstance(res["list"], list) | |||
| for r, d in zip(res["list"], tensor_dict["list"]): | |||
| assert r == d.tolist() | |||
| assert isinstance(res["int"], int) | |||
| assert isinstance(res["string"], str) | |||
| assert isinstance(res["dict"], dict) | |||
| assert isinstance(res["dict"]["list"], list) | |||
| for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||
| assert r == d.tolist() | |||
| assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | |||
| @pytest.mark.jittor | |||
| def test_tensor_to_numeric_reduce(self): | |||
| tensor = jt.Var([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) | |||
| res_max = JittorSingleDriver.tensor_to_numeric(tensor, reduce="max") | |||
| res_min = JittorSingleDriver.tensor_to_numeric(tensor, reduce="min") | |||
| res_sum = JittorSingleDriver.tensor_to_numeric(tensor, reduce="sum") | |||
| res_mean = JittorSingleDriver.tensor_to_numeric(tensor, reduce="mean") | |||
| assert res_max == 6 | |||
| assert res_min == 1 | |||
| assert res_sum == 21 | |||
| assert res_mean == 3.5 | |||
| @pytest.mark.jittor | |||
| def test_set_model_mode(self): | |||
| """ | |||
| 测试 set_model_mode 函数 | |||
| """ | |||
| self.driver.set_model_mode("train") | |||
| assert self.driver.model.is_training() | |||
| self.driver.set_model_mode("eval") | |||
| assert not self.driver.model.is_training() | |||
| # 应该报错 | |||
| with pytest.raises(AssertionError): | |||
| self.driver.set_model_mode("test") | |||
| @pytest.mark.jittor | |||
| def test_move_model_to_device_cpu(self): | |||
| """ | |||
| 测试 move_model_to_device 函数,仅测试能否运行 | |||
| """ | |||
| JittorSingleDriver.move_model_to_device(self.driver.model, "cpu") | |||
| @pytest.mark.jittor | |||
| def test_move_model_to_device_gpu(self): | |||
| """ | |||
| 测试 move_model_to_device 函数,仅测试能否运行 | |||
| """ | |||
| JittorSingleDriver.move_model_to_device(self.driver.model, "gpu") | |||
| @pytest.mark.jittor | |||
| def test_set_deterministic_dataloader(self): | |||
| """ | |||
| 测试 set_deterministic_dataloader,仅测试能否运行 | |||
| """ | |||
| # 先确保不影响运行 | |||
| # TODO:正确性 | |||
| dataloader = JittorDataLoader(JittorNormalDataset()) | |||
| self.driver.set_deterministic_dataloader(dataloader) | |||
| self.driver.set_deterministic_dataloader(JittorNormalDataset()) | |||
| @pytest.mark.jittor | |||
| def test_set_sampler_epoch(self): | |||
| """ | |||
| 测试 set_sampler_epoch | |||
| """ | |||
| # 先确保不影响运行 | |||
| # TODO:正确性 | |||
| dataloader = JittorDataLoader(JittorNormalDataset()) | |||
| self.driver.set_sampler_epoch(dataloader, 0) | |||
| self.driver.set_sampler_epoch(JittorNormalDataset(), 0) | |||
| @pytest.mark.jittor | |||
| @pytest.mark.parametrize("batch_size", [16]) | |||
| @pytest.mark.parametrize("shuffle", [True, False]) | |||
| @pytest.mark.parametrize("drop_last", [True, False]) | |||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||
| def test_get_dataloader_args(self, batch_size, shuffle, drop_last, use_dataloader): | |||
| """ | |||
| 测试正常情况下 get_dataloader_args 的表现 | |||
| """ | |||
| dataloader = get_dataloader( | |||
| JittorNormalDataset(), | |||
| use_dataloader=use_dataloader, | |||
| sampler=None, | |||
| batch_size=batch_size, | |||
| shuffle=shuffle, | |||
| drop_last=drop_last | |||
| ) | |||
| res = JittorSingleDriver.get_dataloader_args(dataloader) | |||
| assert isinstance(res.dataset, JittorNormalDataset) | |||
| assert res.sampler is None | |||
| assert res.shuffle == shuffle | |||
| assert res.batch_size == batch_size | |||
| assert res.drop_last == drop_last | |||
| @pytest.mark.jittor | |||
| @pytest.mark.parametrize("batch_size", [16]) | |||
| @pytest.mark.parametrize("shuffle", [True, False]) | |||
| @pytest.mark.parametrize("drop_last", [True, False]) | |||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||
| def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last, use_dataloader): | |||
| """ | |||
| 测试替换了 sampler 后 get_dataloader_args 的表现 | |||
| """ | |||
| dataset = JittorNormalDataset() | |||
| dataloader = get_dataloader( | |||
| dataset, | |||
| use_dataloader=use_dataloader, | |||
| batch_size=batch_size, | |||
| sampler=RandomSampler(dataset, shuffle=shuffle), | |||
| shuffle=shuffle, | |||
| drop_last=drop_last | |||
| ) | |||
| res = JittorSingleDriver.get_dataloader_args(dataloader) | |||
| assert isinstance(res.dataset, JittorNormalDataset) | |||
| assert isinstance(res.sampler, RandomSampler) | |||
| assert res.shuffle == shuffle | |||
| assert res.batch_size == batch_size | |||
| assert res.drop_last == drop_last | |||
| ############################################################################ | |||
| # | |||
| # 测试 JittorSingleDrvier 中的一些简单函数 | |||
| # | |||
| ############################################################################ | |||
| @pytest.mark.jittor | |||
| @pytest.mark.skip("Skip jittor tests now.") | |||
| class TestSingleDevice: | |||
| def test_on_gpu_without_fp16(self): | |||
| # TODO get_dataloader | |||
| batch_size = 64 | |||
| learning_rate = 0.1 | |||
| epochs = 5 | |||
| losses = [] | |||
| losses_idx = [] | |||
| train_loader = MNIST(train=True, batch_size=batch_size, shuffle=True) | |||
| val_loader = MNIST(train=False, batch_size=1, shuffle=False) | |||
| model = Model() | |||
| driver = JittorSingleDriver(model, device=[1]) | |||
| optimizer = nn.SGD(model.parameters(), learning_rate) | |||
| driver.set_optimizers(optimizer) | |||
| for epoch in range(epochs): | |||
| driver.set_model_mode("train") | |||
| lens = len(train_loader) | |||
| for batch_idx, (inputs, targets) in enumerate(train_loader): | |||
| outputs =driver.train_step(inputs) | |||
| loss = nn.cross_entropy_loss(outputs, targets) | |||
| driver.backward(loss) | |||
| driver.step() | |||
| driver.zero_grad() | |||
| losses.append(loss.data[0]) | |||
| losses_idx.append(epoch * lens + batch_idx) | |||
| test_loss = 0 | |||
| correct = 0 | |||
| total_acc = 0 | |||
| total_num = 0 | |||
| driver.set_model_mode("eval") | |||
| for batch_idx, (inputs, targets) in enumerate(val_loader): | |||
| batch_size = inputs.shape[0] | |||
| outputs = driver.test_step(inputs) | |||
| pred = np.argmax(outputs.data, axis=1) | |||
| acc = np.sum(targets.data==pred) | |||
| total_acc += acc | |||
| total_num += batch_size | |||
| acc = acc / batch_size | |||
| assert total_acc / total_num > 0.95 | |||
| def test_on_cpu_without_fp16(self): | |||
| pass | |||
| def test_on_gpu_with_fp16(self): | |||
| pass | |||
| class TestSingleDeviceFunction: | |||
| """ | |||
| 测试其它函数的测试例 | |||
| """ | |||
| @classmethod | |||
| def setup_class(cls): | |||
| model = JittorNormalModel_Classification_1(10, 784) | |||
| cls.driver = JittorSingleDriver(model, device="cpu") | |||
| def test_unwrap_model(self): | |||
| """ | |||
| 测试能否运行 | |||
| """ | |||
| res = self.driver.unwrap_model() | |||
| assert res is self.driver.model | |||
| def test_is_distributed(self): | |||
| assert self.driver.is_distributed() == False | |||
| def test_move_data_to_device(self): | |||
| self.driver.move_data_to_device(jt.rand(32, 64)) | |||
| ############################################################################ | |||
| # | |||
| # 测试 set_dist_repro_dataloader 函数 | |||
| # | |||
| ############################################################################ | |||
| @pytest.mark.jittor | |||
| class TestSetDistReproDataloader: | |||
| """ | |||
| 专门测试 set_dist_repro_dataloader 函数的类 | |||
| """ | |||
| def setup_method(self): | |||
| self.dataset = JittorNormalDataset(20) | |||
| model = JittorNormalModel_Classification_1(10, 32) | |||
| self.driver = JittorSingleDriver(model, device="cpu") | |||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||
| def test_with_reproducible_false(self, use_dataloader): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 | |||
| 当dist为字符串时,此时应该返回原来的 dataloader | |||
| """ | |||
| dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=True) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||
| assert replaced_loader is dataloader | |||
| @pytest.mark.parametrize("shuffle", [True, False]) | |||
| @pytest.mark.parametrize("sampler", [None, "random", "sequential"]) | |||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||
| def test_with_reproducible_true(self, shuffle, sampler, use_dataloader): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||
| 当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler | |||
| """ | |||
| if sampler == "random": | |||
| sampler = JittorRandomSampler(self.dataset) | |||
| _shuffle = True | |||
| elif sampler == "sequential": | |||
| sampler = JittorSequentialSampler(self.dataset) | |||
| _shuffle = False | |||
| else: | |||
| _shuffle = shuffle | |||
| dataloader = get_dataloader(self.dataset, use_dataloader, sampler=sampler, batch_size=2, shuffle=shuffle) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.sampler, RandomSampler) | |||
| assert replaced_loader.sampler.shuffle == _shuffle | |||
| assert replaced_loader.batch_size == dataloader.batch_size | |||
| assert replaced_loader.drop_last == dataloader.drop_last | |||
| self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) | |||
| @pytest.mark.parametrize("shuffle", ([True, False])) | |||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||
| def test_with_dist_batch_sampler(self, shuffle, use_dataloader): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler | |||
| 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | |||
| jittor 暂时不支持这种情况,会报错 | |||
| """ | |||
| dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=not shuffle) | |||
| dist = ReproduceBatchSampler(JittorBatchSampler(JittorRandomSampler(self.dataset), 4, False), 4, False) | |||
| with pytest.raises(RuntimeError): | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||
| @pytest.mark.parametrize("shuffle", ([True, False])) | |||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||
| def test_with_dist_sampler(self, shuffle, use_dataloader): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 | |||
| 应该返回新的 dataloader,并将 sampler 替换为 dist 对应的 Sampler | |||
| """ | |||
| dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=not shuffle) | |||
| dist = RandomSampler(self.dataset, shuffle=shuffle) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.sampler, RandomSampler) | |||
| assert replaced_loader.sampler is dist | |||
| assert replaced_loader.batch_size == dataloader.batch_size | |||
| self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) | |||
| @pytest.mark.parametrize("shuffle", ([True, False])) | |||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||
| def test_with_dataloader_reproducible_batch_sampler(self, shuffle, use_dataloader): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||
| 应该返回新的 dataloader,且其余各项设置和原来相同 | |||
| """ | |||
| dataloader = get_dataloader( | |||
| self.dataset, | |||
| use_dataloader=use_dataloader, | |||
| sampler=ReproduceBatchSampler( | |||
| JittorBatchSampler(JittorRandomSampler(self.dataset), 4, False), | |||
| batch_size=4, | |||
| drop_last=False, | |||
| ), | |||
| batch_size=4, | |||
| shuffle=shuffle, | |||
| ) | |||
| with pytest.raises(RuntimeError): | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||
| @pytest.mark.parametrize("shuffle", ([True, False])) | |||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||
| def test_with_dataloader_reproducible_sampler(self, shuffle, use_dataloader): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||
| 应该返回新的 dataloader,且其余各项设置和原来相同 | |||
| """ | |||
| dataloader = get_dataloader( | |||
| self.dataset, | |||
| use_dataloader=use_dataloader, | |||
| sampler=RandomSampler(self.dataset, shuffle), | |||
| batch_size=2, | |||
| shuffle=shuffle, | |||
| ) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||
| assert not (replaced_loader is dataloader) | |||
| assert not (replaced_loader.sampler is dataloader.sampler) | |||
| assert isinstance(replaced_loader.sampler, RandomSampler) | |||
| assert replaced_loader.batch_size == 2 | |||
| assert replaced_loader.shuffle == shuffle | |||
| self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader) | |||
| def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle, use_dataloader): | |||
| """ | |||
| 测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | |||
| """ | |||
| # 迭代两个 batch | |||
| num_consumed_batches = 2 | |||
| already_seen_idx = set() | |||
| replaced_loader.sampler.set_epoch(6) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_idx.update(batch.tolist()) | |||
| sampler_states = replaced_loader.sampler.state_dict() | |||
| # 重新加载,应该可以输出剩下的内容,且对于 JittorNormalDataset 来说,排序后应该是一个 range | |||
| left_idxes = set() | |||
| batch_size = replaced_loader.batch_size | |||
| sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||
| # 重新构造 dataloader | |||
| if use_dataloader: | |||
| dataset = deepcopy(replaced_loader.dataset.dataset) | |||
| else: | |||
| dataset = deepcopy(replaced_loader) | |||
| new_loader = get_dataloader( | |||
| dataset=dataset, | |||
| use_dataloader=use_dataloader, | |||
| sampler = RandomSampler(dataset, shuffle=shuffle), | |||
| batch_size=batch_size, | |||
| shuffle=shuffle, | |||
| drop_last=False | |||
| ) | |||
| new_loader.sampler.load_state_dict(sampler_states) | |||
| new_loader.sampler.set_epoch(6) | |||
| for idx, batch in enumerate(new_loader): | |||
| left_idxes.update(batch.tolist()) | |||
| print(already_seen_idx) | |||
| print(left_idxes) | |||
| assert len(left_idxes) + len(already_seen_idx) == self.dataset.total_len | |||
| assert len(left_idxes | already_seen_idx) == self.dataset.total_len | |||
| ############################################################################ | |||
| # | |||
| # 测试 save 和 load 相关的功能 | |||
| # | |||
| ############################################################################ | |||
| def generate_random_driver(labels, features, fp16=False, device="cpu", lr=0.01): | |||
| """ | |||
| 生成driver | |||
| """ | |||
| model = JittorNormalModel_Classification_1(labels, features) | |||
| opt = jt.optim.Adam(params=model.parameters(), lr=lr) | |||
| driver = JittorSingleDriver(model, device=device, fp16=fp16) | |||
| driver.set_optimizers(opt) | |||
| driver.setup() | |||
| return driver | |||
| @pytest.mark.jittor | |||
| @pytest.mark.parametrize("only_state_dict", ([True, False])) | |||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||
| def test_save_and_load_model(only_state_dict, use_dataloader): | |||
| """ | |||
| 测试 save_model 和 load_model 函数 | |||
| """ | |||
| try: | |||
| path = "model" | |||
| dataset = JittorNormalXYDataset(20) | |||
| dataloader = get_dataloader(dataset, sampler=None, use_dataloader=use_dataloader, batch_size=4, shuffle=True) | |||
| driver1, driver2 = generate_random_driver(20, 1, device="gpu"), generate_random_driver(20, 1, device="gpu") | |||
| driver1.save_model(path, only_state_dict) | |||
| driver2.load_model(path, only_state_dict) | |||
| for batch in dataloader: | |||
| batch = driver1.move_data_to_device(batch) | |||
| res1 = driver1.model.evaluate_step(**batch) | |||
| res2 = driver2.model.evaluate_step(**batch) | |||
| assert jt.all_(jt.equal(res1["pred"], res2["pred"])) | |||
| finally: | |||
| rank_zero_rm(path) | |||
| @pytest.mark.jittor | |||
| @pytest.mark.parametrize("only_state_dict", ([True, False])) | |||
| @pytest.mark.parametrize("use_dataloader", [True, False]) | |||
| def test_save_and_load_with_randomsampler(only_state_dict, use_dataloader): | |||
| """ | |||
| 测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 | |||
| """ | |||
| try: | |||
| path = "model.ckp" | |||
| driver1, driver2 = generate_random_driver(20, 1, device="gpu", lr=0.01), \ | |||
| generate_random_driver(20, 1, device="gpu", lr=0.001) | |||
| dataset = JittorNormalXYDataset(20) | |||
| dataloader = get_dataloader( | |||
| dataset, use_dataloader, | |||
| sampler = RandomSampler(dataset, True), | |||
| batch_size=4, | |||
| shuffle=True | |||
| ) | |||
| num_consumed_batches = 2 | |||
| already_seen_x_set = set() | |||
| already_seen_y_set = set() | |||
| driver1.set_sampler_epoch(dataloader, 7) | |||
| for idx, batch in enumerate(dataloader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||
| already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||
| sampler_states = dataloader.sampler.state_dict() | |||
| save_states = {"num_consumed_batches": num_consumed_batches} | |||
| driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||
| # 加载 | |||
| # 更改 batch_size | |||
| dataloader = get_dataloader( | |||
| dataset, use_dataloader, | |||
| sampler=RandomSampler(dataset, True), | |||
| batch_size=2, | |||
| shuffle=True | |||
| ) | |||
| load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||
| replaced_loader = load_states.pop("dataloader") | |||
| # 1. 检查 optimizer 的状态 | |||
| assert driver2.optimizers[0].lr == driver1.optimizers[0].lr | |||
| # 2. 检查 sampler 是否被正确地加载和替换 | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.sampler, RandomSampler) | |||
| assert replaced_loader.sampler.seed == sampler_states["seed"] | |||
| assert replaced_loader.sampler.epoch == sampler_states["epoch"] | |||
| assert replaced_loader.sampler.num_consumed_samples == 4 * num_consumed_batches | |||
| assert replaced_loader.sampler.dataset.total_len == sampler_states["length"] | |||
| assert replaced_loader.sampler.shuffle == sampler_states["shuffle"] | |||
| # 4. 检查 model 的参数是否正确 | |||
| # 5. 检查 batch_idx | |||
| start_batch = load_states.pop('batch_idx_in_epoch') | |||
| assert start_batch == 2 * num_consumed_batches | |||
| left_x_batches = set() | |||
| left_y_batches = set() | |||
| driver2.set_sampler_epoch(replaced_loader, 7) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||
| left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||
| res1 = driver1.model.evaluate_step(**batch) | |||
| res2 = driver2.model.evaluate_step(**batch) | |||
| assert jt.all_(jt.equal(res1["pred"], res2["pred"])) | |||
| assert len(left_x_batches) + len(already_seen_x_set) == dataset.total_len | |||
| assert len(left_x_batches | already_seen_x_set) == dataset.total_len | |||
| assert len(left_y_batches) + len(already_seen_y_set) == dataset.total_len | |||
| assert len(left_y_batches | already_seen_y_set) == dataset.total_len | |||
| finally: | |||
| rank_zero_rm(path) | |||
| @@ -0,0 +1,43 @@ | |||
| import pytest | |||
| from fastNLP.core.drivers.jittor_driver.utils import replace_sampler | |||
| from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||
| from fastNLP.core.dataloaders import JittorDataLoader | |||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
| if _NEED_IMPORT_JITTOR: | |||
| import jittor as jt | |||
| from tests.helpers.datasets.jittor_data import JittorNormalDataset | |||
| @pytest.mark.jittor | |||
| @pytest.mark.parametrize("dataset", [ | |||
| JittorNormalDataset(20, batch_size=10, shuffle=True), | |||
| JittorNormalDataset(20, batch_size=5, drop_last=True), | |||
| JittorNormalDataset(20) | |||
| ]) | |||
| def test_replace_sampler_dataset(dataset): | |||
| dataset = JittorNormalDataset(20) | |||
| sampler = RandomSampler(dataset) | |||
| replaced_loader = replace_sampler(dataset, sampler) | |||
| assert not (replaced_loader is dataset) | |||
| assert isinstance(replaced_loader.sampler, RandomSampler) | |||
| assert replaced_loader.batch_size == dataset.batch_size | |||
| assert replaced_loader.drop_last == dataset.drop_last | |||
| assert replaced_loader.shuffle == dataset.shuffle | |||
| assert replaced_loader.total_len == dataset.total_len | |||
| @pytest.mark.jittor | |||
| def test_replace_sampler_jittordataloader(): | |||
| dataset = JittorNormalDataset(20, batch_size=10, shuffle=True) | |||
| dataloader = JittorDataLoader(dataset, batch_size=8, shuffle=True) | |||
| sampler = RandomSampler(dataset) | |||
| replaced_loader = replace_sampler(dataloader, sampler) | |||
| assert not (replaced_loader is dataloader) | |||
| assert not (replaced_loader.dataset.dataset is dataloader.dataset.dataset) | |||
| assert isinstance(replaced_loader.sampler, RandomSampler) | |||
| assert replaced_loader.batch_size == 8 | |||
| assert replaced_loader.shuffle == True | |||
| @@ -10,7 +10,7 @@ from fastNLP.core.samplers import ( | |||
| UnrepeatedSequentialSampler, | |||
| ) | |||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | |||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | |||
| from tests.helpers.utils import magic_argv_env_context | |||
| from fastNLP.envs.distributed import rank_zero_rm | |||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
| @@ -19,8 +19,8 @@ if _NEED_IMPORT_PADDLE: | |||
| import paddle.distributed as dist | |||
| from paddle.io import DataLoader, BatchSampler | |||
| def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): | |||
| paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) | |||
| def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="only_error"): | |||
| paddle_model = PaddleNormalModel_Classification_1(labels, features) | |||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
| driver = PaddleFleetDriver( | |||
| model=paddle_model, | |||
| @@ -465,10 +465,14 @@ class TestSetDistReproDataloader: | |||
| num_replicas = len(self.device) | |||
| num_consumed_batches = 2 | |||
| already_seen_idx = set() | |||
| if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||
| sampler_states = replaced_loader.batch_sampler.set_epoch(10) | |||
| else: | |||
| sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(10) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_idx.update(batch) | |||
| already_seen_idx.update(batch.tolist()) | |||
| dist.barrier() | |||
| if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||
| sampler_states = replaced_loader.batch_sampler.state_dict() | |||
| @@ -496,6 +500,7 @@ class TestSetDistReproDataloader: | |||
| pad=True | |||
| ) | |||
| new_loader.batch_sampler.load_state_dict(sampler_states) | |||
| new_loader.batch_sampler.set_epoch(10) | |||
| else: | |||
| batch_size = replaced_loader.batch_sampler.batch_size | |||
| sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | |||
| @@ -508,8 +513,9 @@ class TestSetDistReproDataloader: | |||
| ) | |||
| new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) | |||
| new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||
| new_loader.batch_sampler.sampler.set_epoch(10) | |||
| for idx, batch in enumerate(new_loader): | |||
| left_idxes.update(batch) | |||
| left_idxes.update(batch.tolist()) | |||
| assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | |||
| assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | |||
| @@ -533,7 +539,7 @@ class TestSaveLoad: | |||
| cls.driver = generate_driver(10, 10, device=[0,1]) | |||
| def setup_method(self): | |||
| self.dataset = PaddleRandomMaxDataset(20, 10) | |||
| self.dataset = PaddleNormalXYDataset(40) | |||
| @magic_argv_env_context | |||
| @pytest.mark.parametrize("only_state_dict", ([True, False])) | |||
| @@ -545,12 +551,12 @@ class TestSaveLoad: | |||
| path = "model" | |||
| dataloader = DataLoader(self.dataset, batch_size=2) | |||
| self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) | |||
| self.driver1, self.driver2 = generate_driver(40, 1), generate_driver(40, 1) | |||
| if only_state_dict: | |||
| self.driver1.save_model(path, only_state_dict) | |||
| else: | |||
| self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 10))]) | |||
| self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 1))]) | |||
| # 同步 | |||
| dist.barrier() | |||
| @@ -594,8 +600,8 @@ class TestSaveLoad: | |||
| path = "model.ckp" | |||
| num_replicas = len(device) | |||
| self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ | |||
| generate_driver(10, 10, device=device, fp16=False) | |||
| self.driver1, self.driver2 = generate_driver(40, 1, device=device, fp16=fp16), \ | |||
| generate_driver(40, 1, device=device, fp16=False) | |||
| dataloader = DataLoader( | |||
| dataset=self.dataset, | |||
| batch_sampler=BucketedBatchSampler( | |||
| @@ -613,11 +619,12 @@ class TestSaveLoad: | |||
| already_seen_x_set = set() | |||
| already_seen_y_set = set() | |||
| self.driver1.set_sampler_epoch(dataloader, 2) | |||
| for idx, batch in enumerate(dataloader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_x_set.update(batch["x"]) | |||
| already_seen_y_set.update(batch["y"]) | |||
| already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||
| already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||
| # 同步 | |||
| dist.barrier() | |||
| @@ -669,10 +676,11 @@ class TestSaveLoad: | |||
| assert start_batch == 2 * num_consumed_batches | |||
| left_x_batches = set() | |||
| left_y_batches = set() | |||
| self.driver2.set_sampler_epoch(replaced_loader, 2) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| left_x_batches.update(batch["x"]) | |||
| left_y_batches.update(batch["y"]) | |||
| left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||
| left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||
| res1 = self.driver1.model( | |||
| batch, | |||
| fastnlp_fn=self.driver1.model._layers.model.evaluate_step, | |||
| @@ -709,8 +717,8 @@ class TestSaveLoad: | |||
| num_replicas = len(device) | |||
| self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) | |||
| self.driver2 = generate_driver(10, 10, device=device, fp16=False) | |||
| self.driver1 = generate_driver(40, 1, device=device, fp16=fp16) | |||
| self.driver2 = generate_driver(40, 1, device=device, fp16=False) | |||
| batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) | |||
| batch_sampler.sampler = RandomSampler(self.dataset, True) | |||
| batch_sampler.sampler.set_distributed( | |||
| @@ -726,11 +734,12 @@ class TestSaveLoad: | |||
| already_seen_x_set = set() | |||
| already_seen_y_set = set() | |||
| self.driver1.set_sampler_epoch(dataloader, 2) | |||
| for idx, batch in enumerate(dataloader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_x_set.update(batch["x"]) | |||
| already_seen_y_set.update(batch["y"]) | |||
| already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||
| already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||
| # 同步 | |||
| dist.barrier() | |||
| @@ -779,10 +788,11 @@ class TestSaveLoad: | |||
| assert start_batch == 2 * num_consumed_batches | |||
| left_x_batches = set() | |||
| left_y_batches = set() | |||
| self.driver2.set_sampler_epoch(replaced_loader, 2) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| left_x_batches.update(batch["x"]) | |||
| left_y_batches.update(batch["y"]) | |||
| left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||
| left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||
| res1 = self.driver1.model( | |||
| batch, | |||
| fastnlp_fn=self.driver1.model._layers.model.evaluate_step, | |||
| @@ -12,7 +12,7 @@ if _NEED_IMPORT_PADDLE: | |||
| @pytest.mark.paddle | |||
| def test_incorrect_driver(): | |||
| model = PaddleNormalModel_Classification_1(2, 100) | |||
| model = PaddleNormalModel_Classification_1(20, 10) | |||
| with pytest.raises(ValueError): | |||
| driver = initialize_paddle_driver("torch", 0, model) | |||
| @@ -26,7 +26,7 @@ def test_get_single_device(device): | |||
| 测试正常情况下初始化 PaddleSingleDriver 的情况 | |||
| """ | |||
| model = PaddleNormalModel_Classification_1(2, 100) | |||
| model = PaddleNormalModel_Classification_1(20, 10) | |||
| driver = initialize_paddle_driver("paddle", device, model) | |||
| assert isinstance(driver, PaddleSingleDriver) | |||
| @@ -41,7 +41,7 @@ def test_get_fleet(device): | |||
| 测试 fleet 多卡的初始化情况 | |||
| """ | |||
| model = PaddleNormalModel_Classification_1(64, 10) | |||
| model = PaddleNormalModel_Classification_1(20, 10) | |||
| driver = initialize_paddle_driver("paddle", device, model) | |||
| assert isinstance(driver, PaddleFleetDriver) | |||
| @@ -56,6 +56,6 @@ def test_device_out_of_range(device): | |||
| """ | |||
| 测试传入的device超过范围的情况 | |||
| """ | |||
| model = PaddleNormalModel_Classification_1(2, 100) | |||
| model = PaddleNormalModel_Classification_1(20, 10) | |||
| with pytest.raises(ValueError): | |||
| driver = initialize_paddle_driver("paddle", device, model) | |||
| @@ -4,14 +4,16 @@ from pathlib import Path | |||
| from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||
| from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | |||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
| from fastNLP.envs.distributed import rank_zero_rm | |||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_PADDLE: | |||
| import paddle | |||
| from paddle.io import DataLoader, BatchSampler | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| @@ -31,102 +33,70 @@ class TestPaddleDriverFunctions: | |||
| model = PaddleNormalModel_Classification_1(10, 32) | |||
| self.driver = PaddleSingleDriver(model, device="cpu") | |||
| @pytest.mark.torchpaddle | |||
| def test_check_single_optimizer_legality(self): | |||
| @pytest.mark.paddle | |||
| def test_check_optimizers_legality(self): | |||
| """ | |||
| 测试传入单个 optimizer 时的表现 | |||
| 测试对合法的 optimizers 的检查 | |||
| """ | |||
| # 单个 optimizer | |||
| optimizer = paddle.optimizer.Adam( | |||
| parameters=self.driver.model.parameters(), | |||
| learning_rate=0.01 | |||
| ) | |||
| self.driver.set_optimizers(optimizer) | |||
| optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||
| # 传入torch的optimizer时,应该报错ValueError | |||
| with pytest.raises(ValueError): | |||
| self.driver.set_optimizers(optimizer) | |||
| @pytest.mark.torchpaddle | |||
| def test_check_optimizers_legality(self): | |||
| """ | |||
| 测试传入 optimizer list 的表现 | |||
| """ | |||
| # optimizer 列表 | |||
| optimizers = [ | |||
| paddle.optimizer.Adam( | |||
| parameters=self.driver.model.parameters(), | |||
| learning_rate=0.01 | |||
| ) for i in range(10) | |||
| ] | |||
| self.driver.set_optimizers(optimizers) | |||
| optimizers += [ | |||
| @pytest.mark.torchpaddle | |||
| def test_invalid_optimizers(self): | |||
| """ | |||
| 测试传入非法的 optimizers | |||
| """ | |||
| # 单个 optimizer | |||
| optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||
| with pytest.raises(TypeError): | |||
| self.driver.set_optimizers(optimizer) | |||
| optimizers = [ | |||
| torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||
| ] | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(TypeError): | |||
| self.driver.set_optimizers(optimizers) | |||
| @pytest.mark.torchpaddle | |||
| def test_check_dataloader_legality_in_train(self): | |||
| @pytest.mark.paddle | |||
| def test_check_dataloader_legality(self): | |||
| """ | |||
| 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||
| 测试 check_dataloader_legality 函数的表现 | |||
| """ | |||
| dataloader = DataLoader(PaddleNormalDataset()) | |||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
| self.driver.check_dataloader_legality(dataloader) | |||
| # batch_size 和 batch_sampler 均为 None 的情形 | |||
| dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) | |||
| with pytest.raises(ValueError): | |||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
| # 创建torch的dataloader | |||
| dataloader = torch.utils.data.DataLoader( | |||
| TorchNormalDataset(), | |||
| batch_size=32, shuffle=True | |||
| ) | |||
| with pytest.raises(ValueError): | |||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
| self.driver.check_dataloader_legality(dataloader) | |||
| @pytest.mark.torchpaddle | |||
| def test_check_dataloader_legality_in_test(self): | |||
| def test_check_dataloader_legality_invalid(self): | |||
| """ | |||
| 测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | |||
| 测试 check_dataloader_legality 函数传入其他类型的表现 | |||
| """ | |||
| # 此时传入的应该是dict | |||
| dataloader = { | |||
| "train": DataLoader(PaddleNormalDataset()), | |||
| "test":DataLoader(PaddleNormalDataset()) | |||
| } | |||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
| # batch_size 和 batch_sampler 均为 None 的情形 | |||
| dataloader = { | |||
| "train": DataLoader(PaddleNormalDataset()), | |||
| "test":DataLoader(PaddleNormalDataset(), batch_size=None) | |||
| } | |||
| with pytest.raises(ValueError): | |||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
| # 传入的不是 dict ,应该报错 | |||
| dataloader = DataLoader(PaddleNormalDataset()) | |||
| with pytest.raises(ValueError): | |||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
| # 创建 torch 的 dataloader | |||
| train_loader = torch.utils.data.DataLoader( | |||
| TorchNormalDataset(), | |||
| batch_size=32, shuffle=True | |||
| ) | |||
| test_loader = torch.utils.data.DataLoader( | |||
| dataloader = torch.utils.data.DataLoader( | |||
| TorchNormalDataset(), | |||
| batch_size=32, shuffle=True | |||
| ) | |||
| dataloader = {"train": train_loader, "test": test_loader} | |||
| with pytest.raises(ValueError): | |||
| PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
| with pytest.raises(TypeError): | |||
| self.driver.check_dataloader_legality(dataloader) | |||
| @pytest.mark.paddle | |||
| def test_tensor_to_numeric(self): | |||
| @@ -505,10 +475,14 @@ class TestSetDistReproDataloader: | |||
| # 迭代两个 batch | |||
| num_consumed_batches = 2 | |||
| already_seen_idx = set() | |||
| if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||
| sampler_states = replaced_loader.batch_sampler.set_epoch(5) | |||
| else: | |||
| sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(5) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_idx.update(batch) | |||
| already_seen_idx.update(batch.tolist()) | |||
| if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||
| sampler_states = replaced_loader.batch_sampler.state_dict() | |||
| else: | |||
| @@ -529,6 +503,7 @@ class TestSetDistReproDataloader: | |||
| ) | |||
| ) | |||
| new_loader.batch_sampler.load_state_dict(sampler_states) | |||
| new_loader.batch_sampler.set_epoch(5) | |||
| else: | |||
| batch_size = replaced_loader.batch_sampler.batch_size | |||
| sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||
| @@ -537,8 +512,9 @@ class TestSetDistReproDataloader: | |||
| batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) | |||
| new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) | |||
| new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||
| new_loader.batch_sampler.sampler.set_epoch(5) | |||
| for idx, batch in enumerate(new_loader): | |||
| left_idxes.update(batch) | |||
| left_idxes.update(batch.tolist()) | |||
| assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | |||
| assert len(left_idxes | already_seen_idx) == len(self.dataset) | |||
| @@ -549,7 +525,7 @@ class TestSetDistReproDataloader: | |||
| # | |||
| ############################################################################ | |||
| def generate_random_driver(features, labels, fp16=False, device="cpu"): | |||
| def generate_random_driver(labels, features, fp16=False, device="cpu"): | |||
| """ | |||
| 生成driver | |||
| """ | |||
| @@ -569,9 +545,9 @@ def test_save_and_load_model(only_state_dict): | |||
| """ | |||
| try: | |||
| path = "model" | |||
| dataset = PaddleRandomMaxDataset(40, 10) | |||
| dataset = PaddleNormalXYDataset(20) | |||
| dataloader = DataLoader(dataset, batch_size=4) | |||
| driver1, driver2 = generate_random_driver(10, 10, device="gpu"), generate_random_driver(10, 10, device="gpu") | |||
| driver1, driver2 = generate_random_driver(20, 1, device="gpu"), generate_random_driver(20, 1, device="gpu") | |||
| if only_state_dict: | |||
| driver1.save_model(path, only_state_dict) | |||
| @@ -580,6 +556,7 @@ def test_save_and_load_model(only_state_dict): | |||
| driver2.load_model(path, only_state_dict) | |||
| for batch in dataloader: | |||
| print("?") | |||
| batch = driver1.move_data_to_device(batch) | |||
| res1 = driver1.model.evaluate_step(**batch) | |||
| res2 = driver2.model.evaluate_step(**batch) | |||
| @@ -604,22 +581,23 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||
| try: | |||
| path = "model.ckp" | |||
| dataset = PaddleRandomMaxDataset(40, 10) | |||
| dataset = PaddleNormalXYDataset(40) | |||
| dataloader = DataLoader( | |||
| dataset=dataset, | |||
| batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | |||
| ) | |||
| driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") | |||
| driver1, driver2 = generate_random_driver(40, 1, fp16, "gpu"), generate_random_driver(40, 1, False, "gpu") | |||
| num_consumed_batches = 2 | |||
| already_seen_x_set = set() | |||
| already_seen_y_set = set() | |||
| driver1.set_sampler_epoch(dataloader, 3) | |||
| for idx, batch in enumerate(dataloader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_x_set.update(batch["x"]) | |||
| already_seen_y_set.update(batch["y"]) | |||
| already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||
| already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||
| sampler_states = dataloader.batch_sampler.state_dict() | |||
| save_states = {"num_consumed_batches": num_consumed_batches} | |||
| @@ -656,10 +634,11 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||
| assert start_batch == 2 * num_consumed_batches | |||
| left_x_batches = set() | |||
| left_y_batches = set() | |||
| driver2.set_sampler_epoch(replaced_loader, 3) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| left_x_batches.update(batch["x"]) | |||
| left_y_batches.update(batch["y"]) | |||
| left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||
| left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||
| res1 = driver1.model.evaluate_step(**batch) | |||
| res2 = driver2.model.evaluate_step(**batch) | |||
| assert paddle.equal_all(res1["pred"], res2["pred"]) | |||
| @@ -679,14 +658,14 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||
| @pytest.mark.parametrize("fp16", ([True, False])) | |||
| def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||
| """ | |||
| 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||
| 测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 | |||
| """ | |||
| try: | |||
| path = "model.ckp" | |||
| driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") | |||
| dataset = PaddleRandomMaxDataset(40, 10) | |||
| driver1, driver2 = generate_random_driver(40, 1, fp16, "gpu"), generate_random_driver(40, 1, False, "gpu") | |||
| dataset = PaddleNormalXYDataset(40) | |||
| batch_sampler = BatchSampler(dataset=dataset, batch_size=4) | |||
| batch_sampler.sampler = RandomSampler(dataset, True) | |||
| dataloader = DataLoader( | |||
| @@ -697,11 +676,12 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||
| already_seen_x_set = set() | |||
| already_seen_y_set = set() | |||
| driver1.set_sampler_epoch(dataloader, 3) | |||
| for idx, batch in enumerate(dataloader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_x_set.update(batch["x"]) | |||
| already_seen_y_set.update(batch["y"]) | |||
| already_seen_x_set.update(batch["x"].reshape((-1, )).tolist()) | |||
| already_seen_y_set.update(batch["y"].reshape((-1, )).tolist()) | |||
| sampler_states = dataloader.batch_sampler.sampler.state_dict() | |||
| save_states = {"num_consumed_batches": num_consumed_batches} | |||
| @@ -743,10 +723,11 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||
| assert start_batch == 2 * num_consumed_batches | |||
| left_x_batches = set() | |||
| left_y_batches = set() | |||
| driver1.set_sampler_epoch(replaced_loader, 3) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| left_x_batches.update(batch["x"]) | |||
| left_y_batches.update(batch["y"]) | |||
| left_x_batches.update(batch["x"].reshape((-1, )).tolist()) | |||
| left_y_batches.update(batch["y"].reshape((-1, )).tolist()) | |||
| res1 = driver1.model.evaluate_step(**batch) | |||
| res2 = driver2.model.evaluate_step(**batch) | |||
| assert paddle.equal_all(res1["pred"], res2["pred"]) | |||
| @@ -10,7 +10,7 @@ from fastNLP.core.samplers import ( | |||
| UnrepeatedSequentialSampler, | |||
| ) | |||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset | |||
| from tests.helpers.utils import magic_argv_env_context | |||
| from fastNLP.envs.distributed import rank_zero_rm | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| @@ -19,8 +19,8 @@ if _NEED_IMPORT_TORCH: | |||
| import torch.distributed as dist | |||
| from torch.utils.data import DataLoader, BatchSampler | |||
| def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"): | |||
| torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) | |||
| def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all"): | |||
| torch_model = TorchNormalModel_Classification_1(labels, features) | |||
| torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | |||
| device = [torch.device(i) for i in device] | |||
| driver = TorchDDPDriver( | |||
| @@ -504,10 +504,14 @@ class TestSetDistReproDataloader: | |||
| num_replicas = len(self.device) | |||
| num_consumed_batches = 2 | |||
| already_seen_idx = set() | |||
| if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||
| sampler_states = replaced_loader.batch_sampler.set_epoch(4) | |||
| else: | |||
| sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(4) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_idx.update(batch) | |||
| already_seen_idx.update(batch.tolist()) | |||
| dist.barrier() | |||
| if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||
| sampler_states = replaced_loader.batch_sampler.state_dict() | |||
| @@ -533,6 +537,7 @@ class TestSetDistReproDataloader: | |||
| pad=True | |||
| ) | |||
| new_loader.batch_sampler.load_state_dict(sampler_states) | |||
| new_loader.batch_sampler.set_epoch(4) | |||
| else: | |||
| batch_size = replaced_loader.batch_sampler.batch_size | |||
| sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | |||
| @@ -543,8 +548,9 @@ class TestSetDistReproDataloader: | |||
| rank=driver.global_rank | |||
| ) | |||
| new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||
| new_loader.batch_sampler.sampler.set_epoch(4) | |||
| for idx, batch in enumerate(new_loader): | |||
| left_idxes.update(batch) | |||
| left_idxes.update(batch.tolist()) | |||
| assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | |||
| assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | |||
| @@ -562,7 +568,7 @@ class TestSaveLoad: | |||
| """ | |||
| def setup_method(self): | |||
| self.dataset = TorchArgMaxDataset(10, 20) | |||
| self.dataset = TorchNormalXYDataset(20) | |||
| @magic_argv_env_context | |||
| @pytest.mark.parametrize("only_state_dict", ([True, False])) | |||
| @@ -574,7 +580,7 @@ class TestSaveLoad: | |||
| path = "model" | |||
| dataloader = DataLoader(self.dataset, batch_size=2) | |||
| driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10) | |||
| driver1, driver2 = generate_driver(20, 1), generate_driver(20, 1) | |||
| driver1.save_model(path, only_state_dict) | |||
| @@ -618,8 +624,8 @@ class TestSaveLoad: | |||
| path = "model.ckp" | |||
| num_replicas = len(device) | |||
| driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ | |||
| generate_driver(10, 10, device=device, fp16=False) | |||
| driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16), \ | |||
| generate_driver(20, 1, device=device, fp16=False) | |||
| dataloader = dataloader_with_bucketedbatchsampler( | |||
| self.dataset, | |||
| length=[10 for i in range(len(self.dataset))], | |||
| @@ -636,11 +642,12 @@ class TestSaveLoad: | |||
| already_seen_x_set = set() | |||
| already_seen_y_set = set() | |||
| driver1.set_sampler_epoch(dataloader, 4) | |||
| for idx, batch in enumerate(dataloader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_x_set.update(batch["x"]) | |||
| already_seen_y_set.update(batch["y"]) | |||
| already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||
| already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||
| # 同步 | |||
| dist.barrier() | |||
| @@ -665,7 +672,6 @@ class TestSaveLoad: | |||
| pad=True | |||
| ) | |||
| dist.barrier() | |||
| print("========load=======", driver1.global_rank, driver2.global_rank) | |||
| load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||
| dist.barrier() | |||
| replaced_loader = load_states.pop("dataloader") | |||
| @@ -690,10 +696,11 @@ class TestSaveLoad: | |||
| assert start_batch == 2 * num_consumed_batches | |||
| left_x_batches = set() | |||
| left_y_batches = set() | |||
| driver2.set_sampler_epoch(replaced_loader, 4) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| left_x_batches.update(batch["x"]) | |||
| left_y_batches.update(batch["y"]) | |||
| left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||
| left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||
| res1 = driver1.model( | |||
| batch, | |||
| fastnlp_fn=driver1.model.module.model.evaluate_step, | |||
| @@ -716,7 +723,6 @@ class TestSaveLoad: | |||
| dist.barrier() | |||
| finally: | |||
| rank_zero_rm(path) | |||
| print("=======delete======") | |||
| if dist.is_initialized(): | |||
| dist.destroy_process_group() | |||
| @@ -735,8 +741,8 @@ class TestSaveLoad: | |||
| num_replicas = len(device) | |||
| driver1 = generate_driver(10, 10, device=device, fp16=fp16) | |||
| driver2 = generate_driver(10, 10, device=device, fp16=False) | |||
| driver1 = generate_driver(20, 1, device=device, fp16=fp16) | |||
| driver2 = generate_driver(20, 1, device=device, fp16=False) | |||
| dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) | |||
| dataloader.batch_sampler.sampler.set_distributed( | |||
| @@ -748,11 +754,12 @@ class TestSaveLoad: | |||
| already_seen_x_set = set() | |||
| already_seen_y_set = set() | |||
| driver1.set_sampler_epoch(dataloader, 4) | |||
| for idx, batch in enumerate(dataloader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_x_set.update(batch["x"]) | |||
| already_seen_y_set.update(batch["y"]) | |||
| already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||
| already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||
| # 同步 | |||
| dist.barrier() | |||
| @@ -797,10 +804,11 @@ class TestSaveLoad: | |||
| assert start_batch == 2 * num_consumed_batches | |||
| left_x_batches = set() | |||
| left_y_batches = set() | |||
| driver2.set_sampler_epoch(replaced_loader, 4) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| left_x_batches.update(batch["x"]) | |||
| left_y_batches.update(batch["y"]) | |||
| left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||
| left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||
| res1 = driver1.model( | |||
| batch, | |||
| fastnlp_fn=driver1.model.module.model.evaluate_step, | |||
| @@ -14,7 +14,7 @@ else: | |||
| @pytest.mark.torch | |||
| def test_incorrect_driver(): | |||
| model = TorchNormalModel_Classification_1(2, 100) | |||
| model = TorchNormalModel_Classification_1(20, 10) | |||
| with pytest.raises(ValueError): | |||
| driver = initialize_torch_driver("paddle", 0, model) | |||
| @@ -33,7 +33,7 @@ def test_get_single_device(driver, device): | |||
| 测试正常情况下初始化TorchSingleDriver的情况 | |||
| """ | |||
| model = TorchNormalModel_Classification_1(2, 100) | |||
| model = TorchNormalModel_Classification_1(20, 10) | |||
| driver = initialize_torch_driver(driver, device, model) | |||
| assert isinstance(driver, TorchSingleDriver) | |||
| @@ -52,7 +52,7 @@ def test_get_ddp(driver, device): | |||
| 测试 ddp 多卡的初始化情况 | |||
| """ | |||
| model = TorchNormalModel_Classification_1(64, 10) | |||
| model = TorchNormalModel_Classification_1(20, 10) | |||
| driver = initialize_torch_driver(driver, device, model) | |||
| assert isinstance(driver, TorchDDPDriver) | |||
| @@ -70,6 +70,6 @@ def test_device_out_of_range(driver, device): | |||
| """ | |||
| 测试传入的device超过范围的情况 | |||
| """ | |||
| model = TorchNormalModel_Classification_1(2, 100) | |||
| model = TorchNormalModel_Classification_1(20, 10) | |||
| with pytest.raises(ValueError): | |||
| driver = initialize_torch_driver(driver, device, model) | |||
| @@ -6,7 +6,7 @@ from pkg_resources import parse_version | |||
| from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver | |||
| from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset | |||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset | |||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
| from fastNLP.envs.distributed import rank_zero_rm | |||
| @@ -15,6 +15,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| from torch.utils.data import DataLoader, BatchSampler | |||
| if _NEED_IMPORT_PADDLE: | |||
| import paddle | |||
| @@ -67,95 +68,67 @@ class TestTorchDriverFunctions: | |||
| model = TorchNormalModel_Classification_1(10, 32) | |||
| self.driver = TorchSingleDriver(model, device="cpu") | |||
| @pytest.mark.torchpaddle | |||
| def test_check_single_optimizer_legality(self): | |||
| @pytest.mark.torch | |||
| def test_check_optimizers_legality(self): | |||
| """ | |||
| 测试传入单个 optimizer 时的表现 | |||
| 测试对合法 optimizers 的检查 | |||
| """ | |||
| # 单个 optimizer | |||
| optimizer = torch.optim.Adam( | |||
| params=self.driver.model.parameters(), | |||
| lr=0.01 | |||
| ) | |||
| self.driver.set_optimizers(optimizer) | |||
| optimizer = paddle.optimizer.Adam( | |||
| parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | |||
| learning_rate=0.01, | |||
| ) | |||
| # 传入 torch 的 optimize r时,应该报错 ValueError | |||
| with pytest.raises(ValueError): | |||
| self.driver.set_optimizers(optimizer) | |||
| @pytest.mark.torchpaddle | |||
| def test_check_optimizers_legality(self): | |||
| """ | |||
| 测试传入 optimizer list 的表现 | |||
| """ | |||
| # 列表 | |||
| optimizers = [ | |||
| torch.optim.Adam( | |||
| params=self.driver.model.parameters(), | |||
| lr=0.01 | |||
| ) for i in range(10) | |||
| ] | |||
| self.driver.set_optimizers(optimizers) | |||
| optimizers += [ | |||
| @pytest.mark.torchpaddle | |||
| def test_invalid_optimizers(self): | |||
| """ | |||
| 测试传入非法的 optimizers | |||
| """ | |||
| optimizer = paddle.optimizer.Adam( | |||
| parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | |||
| learning_rate=0.01, | |||
| ) | |||
| with pytest.raises(TypeError): | |||
| self.driver.set_optimizers(optimizer) | |||
| optimizers = [ | |||
| paddle.optimizer.Adam( | |||
| parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | |||
| learning_rate=0.01, | |||
| ) | |||
| ] | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(TypeError): | |||
| self.driver.set_optimizers(optimizers) | |||
| @pytest.mark.torchpaddle | |||
| def test_check_dataloader_legality_in_train(self): | |||
| @pytest.mark.torch | |||
| def test_check_dataloader_legality(self): | |||
| """ | |||
| 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||
| 测试 check_dataloader_legality 函数的表现 | |||
| """ | |||
| dataloader = DataLoader(TorchNormalDataset()) | |||
| TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
| # 创建 paddle 的 dataloader | |||
| dataloader = paddle.io.DataLoader( | |||
| PaddleNormalDataset(), | |||
| batch_size=32, shuffle=True | |||
| ) | |||
| with pytest.raises(ValueError): | |||
| TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
| self.driver.check_dataloader_legality(dataloader) | |||
| @pytest.mark.torchpaddle | |||
| def test_check_dataloader_legality_in_test(self): | |||
| def test_check_dataloader_legality_invalid(self): | |||
| """ | |||
| 测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | |||
| 测试 check_dataloader_legality 函数传入其他类型的表现 | |||
| """ | |||
| # 此时传入的应该是dict | |||
| dataloader = { | |||
| "train": DataLoader(TorchNormalDataset()), | |||
| "test": DataLoader(TorchNormalDataset()) | |||
| } | |||
| TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
| # 传入的不是 dict,应该报错 | |||
| dataloader = DataLoader(TorchNormalDataset()) | |||
| with pytest.raises(ValueError): | |||
| TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
| # 创建 paddle 的 dataloader | |||
| train_loader = paddle.io.DataLoader( | |||
| PaddleNormalDataset(), | |||
| batch_size=32, shuffle=True | |||
| ) | |||
| test_loader = paddle.io.DataLoader( | |||
| dataloader = paddle.io.DataLoader( | |||
| PaddleNormalDataset(), | |||
| batch_size=32, shuffle=True | |||
| ) | |||
| dataloader = {"train": train_loader, "test": test_loader} | |||
| with pytest.raises(ValueError): | |||
| TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") | |||
| with pytest.raises(TypeError): | |||
| self.driver.check_dataloader_legality(dataloader) | |||
| @pytest.mark.torch | |||
| def test_tensor_to_numeric(self): | |||
| @@ -515,10 +488,14 @@ class TestSetDistReproDataloader: | |||
| # 迭代两个 batch | |||
| num_consumed_batches = 2 | |||
| already_seen_idx = set() | |||
| if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||
| replaced_loader.batch_sampler.set_epoch(3) | |||
| else: | |||
| replaced_loader.batch_sampler.sampler.set_epoch(3) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_idx.update(batch) | |||
| already_seen_idx.update(batch.tolist()) | |||
| if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||
| sampler_states = replaced_loader.batch_sampler.state_dict() | |||
| else: | |||
| @@ -532,14 +509,16 @@ class TestSetDistReproDataloader: | |||
| # 重新改造 dataloader | |||
| new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False) | |||
| new_loader.batch_sampler.load_state_dict(sampler_states) | |||
| new_loader.batch_sampler.set_epoch(3) | |||
| else: | |||
| batch_size = replaced_loader.batch_sampler.batch_size | |||
| sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||
| # 重新构造 dataloader | |||
| new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False) | |||
| new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||
| new_loader.batch_sampler.sampler.set_epoch(3) | |||
| for idx, batch in enumerate(new_loader): | |||
| left_idxes.update(batch) | |||
| left_idxes.update(batch.tolist()) | |||
| assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | |||
| assert len(left_idxes | already_seen_idx) == len(self.dataset) | |||
| @@ -550,7 +529,7 @@ class TestSetDistReproDataloader: | |||
| # | |||
| ############################################################################ | |||
| def generate_random_driver(features, labels, fp16=False, device="cpu"): | |||
| def generate_random_driver(labels, features, fp16=False, device="cpu"): | |||
| """ | |||
| 生成driver | |||
| """ | |||
| @@ -570,9 +549,9 @@ def test_save_and_load_model(only_state_dict): | |||
| """ | |||
| try: | |||
| path = "model" | |||
| dataset = TorchArgMaxDataset(10, 40) | |||
| dataset = TorchNormalXYDataset(20) | |||
| dataloader = DataLoader(dataset, batch_size=4) | |||
| driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||
| driver1, driver2 = generate_random_driver(20, 1), generate_random_driver(20, 1) | |||
| driver1.save_model(path, only_state_dict) | |||
| driver2.load_model(path, only_state_dict) | |||
| @@ -596,19 +575,20 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||
| try: | |||
| path = "model.ckp" | |||
| dataset = TorchArgMaxDataset(10, 40) | |||
| dataset = TorchNormalXYDataset(20) | |||
| dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) | |||
| driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") | |||
| driver1, driver2 = generate_random_driver(20, 1, fp16, "cuda"), generate_random_driver(20, 1, False, "cuda") | |||
| num_consumed_batches = 2 | |||
| already_seen_x_set = set() | |||
| already_seen_y_set = set() | |||
| driver1.set_sampler_epoch(dataloader, 3) | |||
| for idx, batch in enumerate(dataloader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_x_set.update(batch["x"]) | |||
| already_seen_y_set.update(batch["y"]) | |||
| already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||
| already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||
| sampler_states = dataloader.batch_sampler.state_dict() | |||
| save_states = {"num_consumed_batches": num_consumed_batches} | |||
| @@ -639,11 +619,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||
| assert start_batch == 2 * num_consumed_batches | |||
| left_x_batches = set() | |||
| left_y_batches = set() | |||
| driver1.set_sampler_epoch(replaced_loader, 3) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| batch = driver2.move_data_to_device(batch) | |||
| left_x_batches.update(batch["x"]) | |||
| left_y_batches.update(batch["y"]) | |||
| left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||
| left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||
| res1 = driver1.model.evaluate_step(**batch) | |||
| res2 = driver2.model.evaluate_step(**batch) | |||
| assert torch.equal(res1["preds"], res2["preds"]) | |||
| @@ -660,24 +641,25 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||
| @pytest.mark.parametrize("fp16", ([True, False])) | |||
| def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||
| """ | |||
| 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||
| 测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 | |||
| """ | |||
| try: | |||
| path = "model.ckp" | |||
| driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") | |||
| dataset = TorchArgMaxDataset(10, 40) | |||
| driver1, driver2 = generate_random_driver(40, 1, fp16, "cuda"), generate_random_driver(40, 1, False, "cuda") | |||
| dataset = TorchNormalXYDataset(40) | |||
| dataloader = dataloader_with_randomsampler(dataset, 4, True, False) | |||
| num_consumed_batches = 2 | |||
| already_seen_x_set = set() | |||
| already_seen_y_set = set() | |||
| driver1.set_sampler_epoch(dataloader, 3) | |||
| for idx, batch in enumerate(dataloader): | |||
| if idx >= num_consumed_batches: | |||
| break | |||
| already_seen_x_set.update(batch["x"]) | |||
| already_seen_y_set.update(batch["y"]) | |||
| already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||
| already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||
| sampler_states = dataloader.batch_sampler.sampler.state_dict() | |||
| save_states = {"num_consumed_batches": num_consumed_batches} | |||
| @@ -711,11 +693,13 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||
| assert start_batch == 2 * num_consumed_batches | |||
| left_x_batches = set() | |||
| left_y_batches = set() | |||
| # set epoch | |||
| driver2.set_sampler_epoch(replaced_loader, 3) | |||
| for idx, batch in enumerate(replaced_loader): | |||
| batch = driver2.move_data_to_device(batch) | |||
| left_x_batches.update(batch["x"]) | |||
| left_y_batches.update(batch["y"]) | |||
| left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||
| left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||
| res1 = driver1.model.evaluate_step(**batch) | |||
| res2 = driver2.model.evaluate_step(**batch) | |||
| assert torch.equal(res1["preds"], res2["preds"]) | |||
| @@ -0,0 +1,46 @@ | |||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
| if _NEED_IMPORT_JITTOR: | |||
| import jittor as jt | |||
| from jittor.dataset import Dataset | |||
| else: | |||
| from fastNLP.core.utils.dummy_class import DummyClass as Dataset | |||
| class JittorNormalDataset(Dataset): | |||
| def __init__(self, num_of_data=100, **kwargs): | |||
| super(JittorNormalDataset, self).__init__(**kwargs) | |||
| self._data = list(range(num_of_data)) | |||
| self.set_attrs(total_len=num_of_data) | |||
| def __getitem__(self, item): | |||
| return self._data[item] | |||
| class JittorNormalXYDataset(Dataset): | |||
| """ | |||
| 可以被输入到分类模型中的普通数据集 | |||
| """ | |||
| def __init__(self, num_of_data=1000, **kwargs): | |||
| super(JittorNormalXYDataset, self).__init__(**kwargs) | |||
| self.num_of_data = num_of_data | |||
| self._data = list(range(num_of_data)) | |||
| self.set_attrs(total_len=num_of_data) | |||
| def __getitem__(self, item): | |||
| return { | |||
| "x": jt.Var([self._data[item]]), | |||
| "y": jt.Var([self._data[item]]) | |||
| } | |||
| class JittorArgMaxDataset(Dataset): | |||
| def __init__(self, num_samples, num_features, **kwargs): | |||
| super(JittorArgMaxDataset, self).__init__(**kwargs) | |||
| self.x = jt.randn(num_samples, num_features) | |||
| self.y = self.x.argmax(dim=-1) | |||
| self.set_attrs(total_len=num_samples) | |||
| def __getitem__(self, item): | |||
| return {"x": self.x[item], "y": self.y[item]} | |||
| if __name__ == "__main__": | |||
| dataset = JittorNormalDataset() | |||
| print(len(dataset)) | |||
| @@ -19,8 +19,24 @@ class PaddleNormalDataset(Dataset): | |||
| def __getitem__(self, item): | |||
| return self._data[item] | |||
| class PaddleNormalXYDataset(Dataset): | |||
| """ | |||
| 可以被输入到分类模型中的普通数据集 | |||
| """ | |||
| def __init__(self, num_of_data=1000): | |||
| self.num_of_data = num_of_data | |||
| self._data = list(range(num_of_data)) | |||
| def __len__(self): | |||
| return self.num_of_data | |||
| def __getitem__(self, item): | |||
| return { | |||
| "x": paddle.to_tensor([self._data[item]], dtype="float32"), | |||
| "y": paddle.to_tensor([self._data[item]], dtype="float32") | |||
| } | |||
| class PaddleRandomMaxDataset(Dataset): | |||
| class PaddleArgMaxDataset(Dataset): | |||
| def __init__(self, num_samples, num_features): | |||
| self.x = paddle.randn((num_samples, num_features)) | |||
| self.y = self.x.argmax(axis=-1) | |||
| @@ -1,4 +1,6 @@ | |||
| from functools import reduce | |||
| from numpy import dtype | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| @@ -19,6 +21,23 @@ class TorchNormalDataset(Dataset): | |||
| def __getitem__(self, item): | |||
| return self._data[item] | |||
| class TorchNormalXYDataset(Dataset): | |||
| """ | |||
| 可以被输入到分类模型中的普通数据集 | |||
| """ | |||
| def __init__(self, num_of_data=1000): | |||
| self.num_of_data = num_of_data | |||
| self._data = list(range(num_of_data)) | |||
| def __len__(self): | |||
| return self.num_of_data | |||
| def __getitem__(self, item): | |||
| return { | |||
| "x": torch.tensor([self._data[item]], dtype=torch.float), | |||
| "y": torch.tensor([self._data[item]], dtype=torch.float) | |||
| } | |||
| # 该类专门用于为 tests.helpers.models.torch_model.py/ TorchNormalModel_Classification_1 创建数据; | |||
| class TorchNormalDataset_Classification(Dataset): | |||
| @@ -0,0 +1,57 @@ | |||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
| if _NEED_IMPORT_JITTOR: | |||
| from jittor import Module, nn | |||
| else: | |||
| from fastNLP.core.utils.dummy_class import DummyClass as Module | |||
| class JittorNormalModel_Classification_1(Module): | |||
| """ | |||
| 基础的 jittor 分类模型 | |||
| """ | |||
| def __init__(self, num_labels, feature_dimension): | |||
| super(JittorNormalModel_Classification_1, self).__init__() | |||
| self.num_labels = num_labels | |||
| self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | |||
| self.ac1 = nn.ReLU() | |||
| self.linear2 = nn.Linear(in_features=64, out_features=32) | |||
| self.ac2 = nn.ReLU() | |||
| self.output = nn.Linear(in_features=32, out_features=num_labels) | |||
| self.loss_fn = nn.CrossEntropyLoss() | |||
| def execute(self, x): | |||
| x = self.ac1(self.linear1(x)) | |||
| x = self.ac2(self.linear2(x)) | |||
| x = self.output(x) | |||
| return x | |||
| def train_step(self, x, y): | |||
| x = self(x) | |||
| return {"loss": self.loss_fn(x, y)} | |||
| def evaluate_step(self, x, y): | |||
| x = self(x) | |||
| return {"pred": x, "target": y.reshape((-1,))} | |||
| class JittorNormalModel_Classification_2(Module): | |||
| """ | |||
| 基础的 jittor 分类模型,只实现 execute 函数测试用户自己初始化了分布式的场景 | |||
| """ | |||
| def __init__(self, num_labels, feature_dimension): | |||
| super(JittorNormalModel_Classification_2, self).__init__() | |||
| self.num_labels = num_labels | |||
| self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | |||
| self.ac1 = nn.ReLU() | |||
| self.linear2 = nn.Linear(in_features=64, out_features=32) | |||
| self.ac2 = nn.ReLU() | |||
| self.output = nn.Linear(in_features=32, out_features=num_labels) | |||
| self.loss_fn = nn.CrossEntropyLoss() | |||
| def execute(self, x, y): | |||
| x = self.ac1(self.linear1(x)) | |||
| x = self.ac2(self.linear2(x)) | |||
| x = self.output(x) | |||
| return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))} | |||
| @@ -8,7 +8,7 @@ else: | |||
| class PaddleNormalModel_Classification_1(Layer): | |||
| """ | |||
| 基础的paddle分类模型 | |||
| 基础的 paddle 分类模型 | |||
| """ | |||
| def __init__(self, num_labels, feature_dimension): | |||
| super(PaddleNormalModel_Classification_1, self).__init__() | |||
| @@ -39,7 +39,7 @@ class PaddleNormalModel_Classification_1(Layer): | |||
| class PaddleNormalModel_Classification_2(Layer): | |||
| """ | |||
| 基础的paddle分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 | |||
| 基础的 paddle 分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 | |||
| """ | |||
| def __init__(self, num_labels, feature_dimension): | |||
| super(PaddleNormalModel_Classification_2, self).__init__() | |||
| @@ -56,5 +56,4 @@ class PaddleNormalModel_Classification_2(Layer): | |||
| x = self.ac1(self.linear1(x)) | |||
| x = self.ac2(self.linear2(x)) | |||
| x = self.output(x) | |||
| loss = self.loss_fn(x, y) | |||
| return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))} | |||
| @@ -33,7 +33,11 @@ class TestPaddle2Torch: | |||
| """ | |||
| assert isinstance(tensor, torch.Tensor) | |||
| assert tensor.device == torch.device(device) | |||
| if device == "cpu": | |||
| assert not tensor.is_cuda | |||
| else: | |||
| assert tensor.is_cuda | |||
| assert tensor.device.index == torch.device(device).index | |||
| assert tensor.requires_grad == requires_grad | |||
| def test_gradient(self): | |||
| @@ -261,7 +265,8 @@ class TestJittor2Torch: | |||
| if device == "cpu": | |||
| assert not tensor.is_cuda | |||
| else: | |||
| assert tensor.device == torch.device(device) | |||
| assert tensor.is_cuda | |||
| assert tensor.device.index == torch.device(device).index | |||
| assert tensor.requires_grad == requires_grad | |||
| def test_var_transfer(self): | |||
| @@ -271,7 +276,10 @@ class TestJittor2Torch: | |||
| jittor_var = jittor.rand((3, 4, 5)) | |||
| res = jittor2torch(jittor_var) | |||
| self.check_torch_tensor(res, "cpu", True) | |||
| if jittor.flags.use_cuda: | |||
| self.check_torch_tensor(res, "cuda:0", True) | |||
| else: | |||
| self.check_torch_tensor(res, "cpu", True) | |||
| res = jittor2torch(jittor_var, device="cuda:2", no_gradient=None) | |||
| self.check_torch_tensor(res, "cuda:2", True) | |||
| @@ -291,7 +299,10 @@ class TestJittor2Torch: | |||
| res = jittor2torch(jittor_list) | |||
| assert isinstance(res, list) | |||
| for t in res: | |||
| self.check_torch_tensor(t, "cpu", True) | |||
| if jittor.flags.use_cuda: | |||
| self.check_torch_tensor(t, "cuda:0", True) | |||
| else: | |||
| self.check_torch_tensor(t, "cpu", True) | |||
| res = jittor2torch(jittor_list, device="cuda:1", no_gradient=False) | |||
| assert isinstance(res, list) | |||
| @@ -327,17 +338,29 @@ class TestJittor2Torch: | |||
| } | |||
| res = jittor2torch(jittor_dict) | |||
| assert isinstance(res, dict) | |||
| self.check_torch_tensor(res["tensor"], "cpu", True) | |||
| if jittor.flags.use_cuda: | |||
| self.check_torch_tensor(res["tensor"], "cuda:0", True) | |||
| else: | |||
| self.check_torch_tensor(res["tensor"], "cpu", True) | |||
| assert isinstance(res["list"], list) | |||
| for t in res["list"]: | |||
| self.check_torch_tensor(t, "cpu", True) | |||
| if jittor.flags.use_cuda: | |||
| self.check_torch_tensor(t, "cuda:0", True) | |||
| else: | |||
| self.check_torch_tensor(t, "cpu", True) | |||
| assert isinstance(res["int"], int) | |||
| assert isinstance(res["string"], str) | |||
| assert isinstance(res["dict"], dict) | |||
| assert isinstance(res["dict"]["list"], list) | |||
| for t in res["dict"]["list"]: | |||
| self.check_torch_tensor(t, "cpu", True) | |||
| self.check_torch_tensor(res["dict"]["tensor"], "cpu", True) | |||
| if jittor.flags.use_cuda: | |||
| self.check_torch_tensor(t, "cuda:0", True) | |||
| else: | |||
| self.check_torch_tensor(t, "cpu", True) | |||
| if jittor.flags.use_cuda: | |||
| self.check_torch_tensor(res["dict"]["tensor"], "cuda:0", True) | |||
| else: | |||
| self.check_torch_tensor(res["dict"]["tensor"], "cpu", True) | |||
| ############################################################################ | |||