| @@ -1,5 +1,5 @@ | |||||
| import os | import os | ||||
| from typing import Optional, Union | |||||
| from typing import Optional, Union, Callable, Dict, Tuple | |||||
| from .jittor_driver import JittorDriver | from .jittor_driver import JittorDriver | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
| @@ -61,14 +61,11 @@ class JittorMPIDriver(JittorDriver): | |||||
| return self._data_device | return self._data_device | ||||
| return self.model_device | return self.model_device | ||||
| def train_step(self, batch): | |||||
| return self._train_step(batch) | |||||
| def validate_step(self, batch): | |||||
| return self._validate_step(batch) | |||||
| def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||||
| pass | |||||
| def test_step(self, batch): | |||||
| return self._test_step(batch) | |||||
| def get_model_call_fn(self, fn: str) -> Tuple: | |||||
| pass | |||||
| def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], | def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], | ||||
| reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
| @@ -1,9 +1,11 @@ | |||||
| from typing import Dict, Union | |||||
| from typing import Dict, Union, Tuple, Callable, Optional | |||||
| from .jittor_driver import JittorDriver | from .jittor_driver import JittorDriver | ||||
| from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
| from fastNLP.core.utils.utils import _get_fun_msg | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | ||||
| from fastNLP.core.log import logger | |||||
| if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
| import jittor | import jittor | ||||
| @@ -27,42 +29,6 @@ class JittorSingleDriver(JittorDriver): | |||||
| self.global_rank = 0 | self.global_rank = 0 | ||||
| self.world_size = 1 | self.world_size = 1 | ||||
| if hasattr(self.model, "train_step"): | |||||
| self._train_step = self.model.train_step | |||||
| self._train_signature_fn = None | |||||
| else: | |||||
| self._train_step = self.model | |||||
| model = self.unwrap_model() | |||||
| self._train_signature_fn = model.execute | |||||
| if hasattr(self.model, "evaluate_step"): | |||||
| self._validate_step = self.model.evaluate_step | |||||
| self._validate_signature_fn = None | |||||
| elif hasattr(self.model, "test_step"): | |||||
| self._validate_step = self.model.test_step | |||||
| self._validate_signature_fn = self.model.test_step | |||||
| else: | |||||
| self._validate_step = self.model | |||||
| model = self.unwrap_model() | |||||
| self._validate_signature_fn = model.execute | |||||
| if hasattr(self.model, "test_step"): | |||||
| self._test_step = self.model.test_step | |||||
| self._test_signature_fn = None | |||||
| elif hasattr(self.model, "evaluate_step"): | |||||
| self._test_step = self.model.evaluate_step | |||||
| self._test_signature_fn = self.model.evaluate_step | |||||
| else: | |||||
| self._test_step = self.model | |||||
| model = self.unwrap_model() | |||||
| self._test_signature_fn = model.execute | |||||
| def train_step(self, batch) -> Dict: | |||||
| if isinstance(batch, Dict): | |||||
| return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||||
| else: | |||||
| return self._train_step(batch) | |||||
| def step(self): | def step(self): | ||||
| """ | """ | ||||
| jittor optimizers 的step函数可以传入参数loss | jittor optimizers 的step函数可以传入参数loss | ||||
| @@ -80,18 +46,24 @@ class JittorSingleDriver(JittorDriver): | |||||
| for optimizer in self.optimizers: | for optimizer in self.optimizers: | ||||
| optimizer.zero_grad() | optimizer.zero_grad() | ||||
| def validate_step(self, batch): | |||||
| if isinstance(batch, Dict): | |||||
| return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||||
| def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||||
| if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
| return auto_param_call(fn, batch, signature_fn=signature_fn) | |||||
| else: | else: | ||||
| return self._validate_step(batch) | |||||
| def test_step(self, batch): | |||||
| if isinstance(batch, Dict): | |||||
| return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||||
| return fn(batch) | |||||
| def get_model_call_fn(self, fn: str) -> Tuple: | |||||
| if hasattr(self.model, fn): | |||||
| fn = getattr(self.model, fn) | |||||
| if not callable(fn): | |||||
| raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") | |||||
| logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') | |||||
| return fn, None | |||||
| elif fn in {"train_step", "evaluate_step"}: | |||||
| logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') | |||||
| return self.model, self.model.forward | |||||
| else: | else: | ||||
| return self._test_step(batch) | |||||
| raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | |||||
| def unwrap_model(self): | def unwrap_model(self): | ||||
| return self.model | return self.model | ||||
| @@ -0,0 +1,376 @@ | |||||
| import io | |||||
| import pickle | |||||
| _pickler = pickle.Pickler | |||||
| _unpickler = pickle.Unpickler | |||||
| from typing import Any, List | |||||
| from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||||
| from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
| if _NEED_IMPORT_TORCH: | |||||
| import torch | |||||
| from torch import distributed as dist | |||||
| if _TORCH_GREATER_EQUAL_1_8: | |||||
| try: | |||||
| from torch._C._distributed_c10d import ProcessGroupGloo | |||||
| from torch._C._distributed_c10d import _ProcessGroupWrapper | |||||
| except ImportError: | |||||
| pass | |||||
| from fastNLP.core.utils import apply_to_collection | |||||
| def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||||
| if dst == my_rank: | |||||
| if not gather_list: | |||||
| raise ValueError( | |||||
| "Argument ``gather_list`` must be specified on destination rank." | |||||
| ) | |||||
| elif gather_list: | |||||
| raise ValueError( | |||||
| "Argument ``gather_list`` must NOT be specified " | |||||
| "on non-destination ranks." | |||||
| ) | |||||
| def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP): | |||||
| """ | |||||
| 从其它 rank gather 东西到 dst rank 。 | |||||
| Gathers picklable objects from the whole group in a single process. | |||||
| Similar to :func:`gather`, but Python objects can be passed in. Note that the | |||||
| object must be picklable in order to be gathered. | |||||
| Args: | |||||
| obj (Any): Input object. Must be picklable. | |||||
| object_gather_list (list[Any]): Output list. On the ``dst`` rank, it | |||||
| should be correctly sized as the size of the group for this | |||||
| collective and will contain the output. Must be ``None`` on non-dst | |||||
| ranks. (default is ``None``) | |||||
| dst (int, optional): Destination rank. (default is 0) | |||||
| group: (ProcessGroup, optional): The process group to work on. If None, | |||||
| the default process group will be used. Default is ``None``. | |||||
| Returns: | |||||
| None. On the ``dst`` rank, ``object_gather_list`` will contain the | |||||
| output of the collective. | |||||
| .. note:: Note that this API differs slightly from the gather collective | |||||
| since it does not provide an async_op handle and thus will be a blocking | |||||
| call. | |||||
| .. note:: Note that this API is not supported when using the NCCL backend. | |||||
| .. warning:: | |||||
| :func:`gather_object` uses ``pickle`` module implicitly, which is | |||||
| known to be insecure. It is possible to construct malicious pickle data | |||||
| which will execute arbitrary code during unpickling. Only call this | |||||
| function with data you trust. | |||||
| Example:: | |||||
| >>> # Note: Process group initialization omitted on each rank. | |||||
| >>> import torch.distributed as dist | |||||
| >>> # Assumes world_size of 3. | |||||
| >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||||
| >>> output = [None for _ in gather_objects] | |||||
| >>> dist.gather_object( | |||||
| gather_objects[dist.get_rank()], | |||||
| output if dist.get_rank() == 0 else None, | |||||
| dst=0 | |||||
| ) | |||||
| >>> # On rank 0 | |||||
| >>> output | |||||
| ['foo', 12, {1: 2}] | |||||
| """ | |||||
| if group is None: | |||||
| group = DEFAULT_TORCH_GROUP | |||||
| if dist.distributed_c10d._rank_not_in_group(group): | |||||
| return | |||||
| # Ensure object_gather_list is specified appopriately. | |||||
| my_rank = dist.get_rank() | |||||
| _validate_output_list_for_rank(my_rank, dst, object_gather_list) | |||||
| # 防止 unpickle 的时候出现在了发送的 gpu 上。 | |||||
| obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||||
| input_tensor, local_size = _object_to_tensor(obj) | |||||
| group_backend = dist.get_backend(group) | |||||
| current_device = torch.device("cpu") | |||||
| is_nccl_backend = group_backend == dist.Backend.NCCL | |||||
| if is_nccl_backend: | |||||
| current_device = torch.device('cuda', torch.cuda.current_device()) | |||||
| input_tensor = input_tensor.to(current_device) | |||||
| local_size = local_size.to(current_device) | |||||
| # Gather all local sizes. This is so that we can find the max size, and index | |||||
| # until the correct size when deserializing the tensors. | |||||
| group_size = dist.get_world_size(group=group) | |||||
| object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device) | |||||
| object_size_list = [ | |||||
| object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||||
| ] | |||||
| # Allgather tensor sizes. An all-gather is needed here despite this being a | |||||
| # gather, since each rank needs to broadcast a tensor of the same (maximal) | |||||
| # size. | |||||
| dist.all_gather(object_size_list, local_size, group=group) | |||||
| max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||||
| # Resize tensor to max size across all ranks. | |||||
| input_tensor.resize_(max_object_size) | |||||
| # Avoid populating output tensors if the result won't be gathered on this rank. | |||||
| if my_rank == dst: | |||||
| coalesced_output_tensor = torch.empty( | |||||
| max_object_size * group_size, dtype=torch.uint8, device=current_device | |||||
| ) | |||||
| # Output tensors are nonoverlapping views of coalesced_output_tensor | |||||
| output_tensors = [ | |||||
| coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] | |||||
| for i in range(group_size) | |||||
| ] | |||||
| # All ranks call gather with equal-sized tensors. | |||||
| dist.gather( | |||||
| input_tensor, | |||||
| gather_list=output_tensors if my_rank == dst else None, | |||||
| dst=dst, | |||||
| group=group, | |||||
| ) | |||||
| if my_rank != dst: | |||||
| return | |||||
| for i, tensor in enumerate(output_tensors): | |||||
| tensor = tensor.type(torch.uint8) # type: ignore[call-overload] | |||||
| tensor_size = object_size_list[i] | |||||
| object_gather_list[i] = _tensor_to_object(tensor, tensor_size) | |||||
| def _object_to_tensor(obj, device=None): | |||||
| f = io.BytesIO() | |||||
| _pickler(f).dump(obj) | |||||
| byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined] | |||||
| # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. | |||||
| # Otherwise, it will casue 100X slowdown. | |||||
| # See: https://github.com/pytorch/pytorch/issues/65696 | |||||
| byte_tensor = torch.ByteTensor(byte_storage) | |||||
| local_size = torch.LongTensor([byte_tensor.numel()]) | |||||
| if device is not None: | |||||
| byte_tensor = byte_tensor.to(device) | |||||
| local_size = local_size.to(device) | |||||
| return byte_tensor, local_size | |||||
| def _tensor_to_object(tensor, tensor_size): | |||||
| buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size] | |||||
| return _unpickler(io.BytesIO(buf)).load() | |||||
| def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): | |||||
| # src rank send to all other ranks | |||||
| size = torch.LongTensor([0]).to(device) | |||||
| if cur_rank == src: | |||||
| world_size = dist.get_world_size(group=group) | |||||
| tensor, size = _object_to_tensor(obj) | |||||
| tensor = tensor.to(device) | |||||
| size = size.to(device) | |||||
| # 首先同步 obj 的 size 的信息; | |||||
| dist.broadcast(size, src, group=group) | |||||
| for subrank in range(world_size): | |||||
| if subrank != src: | |||||
| dist.send(tensor=tensor, dst=subrank, group=group, tag=tag) | |||||
| else: | |||||
| dist.broadcast(size, src, group=group) | |||||
| tensor = torch.ByteTensor([0] * size).to(device) | |||||
| dist.recv(tensor=tensor, src=src, group=group, tag=tag) | |||||
| return _tensor_to_object(tensor.cpu(), size) | |||||
| def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List: | |||||
| """ | |||||
| 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | |||||
| example: | |||||
| obj = { | |||||
| 'a': [1, 1], | |||||
| 'b': [[1, 2], [1, 2]], | |||||
| 'c': { | |||||
| 'd': [1, 2] | |||||
| } | |||||
| } | |||||
| -> | |||||
| [ | |||||
| {'a': 1, 'b':[1, 2], 'c':{'d': 1}}, | |||||
| {'a': 1, 'b':[1, 2], 'c':{'d': 2}} | |||||
| ] | |||||
| :param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行 | |||||
| 序列化之后进行传输。 | |||||
| :param device: 当前该参数无意义。 | |||||
| :param group: | |||||
| :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | |||||
| """ | |||||
| if group is None: | |||||
| group = DEFAULT_TORCH_GROUP | |||||
| if isinstance(obj, torch.Tensor): | |||||
| objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] | |||||
| dist.all_gather(objs, obj, group=group) | |||||
| else: | |||||
| objs = [None for _ in range(dist.get_world_size(group))] | |||||
| # 防止 unpickle 的时候弄到发送的 gpu 上了 | |||||
| obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||||
| if _TORCH_GREATER_EQUAL_1_8: | |||||
| dist.all_gather_object(objs, obj, group=group) | |||||
| else: | |||||
| objs = all_gather_object(objs, obj, group=group) | |||||
| return objs | |||||
| def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP): | |||||
| """ | |||||
| 将 src 上的 obj 对象广播到其它 rank 上。 | |||||
| :param obj: | |||||
| :param src: | |||||
| :param device: | |||||
| :param group: | |||||
| :return: | |||||
| """ | |||||
| if group is None: | |||||
| group = DEFAULT_TORCH_GROUP | |||||
| cur_rank = dist.get_rank(group) | |||||
| if cur_rank == src: | |||||
| # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 | |||||
| obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||||
| if _TORCH_GREATER_EQUAL_1_8: | |||||
| if cur_rank!=src: | |||||
| get_obj = [None] | |||||
| dist.broadcast_object_list(get_obj, src=src, group=group) | |||||
| return get_obj[0] | |||||
| else: | |||||
| dist.broadcast_object_list([obj], src=src, group=group) | |||||
| return obj | |||||
| if device is None: | |||||
| device = torch.cuda.current_device() | |||||
| if cur_rank == src: | |||||
| tensor, size = _object_to_tensor(obj, device=device) | |||||
| else: | |||||
| size = torch.LongTensor([0]).to(device) | |||||
| dist.broadcast(size, src=src, group=group) | |||||
| if cur_rank != src: | |||||
| tensor = torch.empty( | |||||
| size.int().item(), # type: ignore[arg-type] | |||||
| dtype=torch.uint8, | |||||
| device=device | |||||
| ) | |||||
| dist.broadcast(tensor, src=src, group=group) | |||||
| return _tensor_to_object(tensor, tensor_size=size.item()) | |||||
| def _check_for_nccl_backend(group): | |||||
| pg = group or dist.distributed_c10d._get_default_group() | |||||
| # It is not expected for PG to be wrapped many times, but support it just | |||||
| # in case | |||||
| while isinstance(pg, _ProcessGroupWrapper): | |||||
| pg = pg.wrapped_pg | |||||
| return ( | |||||
| dist.is_nccl_available() and | |||||
| isinstance(pg, dist.ProcessGroupNCCL) | |||||
| ) | |||||
| def all_gather_object(object_list, obj, group=None): | |||||
| """ | |||||
| 复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。 | |||||
| Gathers picklable objects from the whole group into a list. Similar to | |||||
| :func:`all_gather`, but Python objects can be passed in. Note that the object | |||||
| must be picklable in order to be gathered. | |||||
| Args: | |||||
| object_list (list[Any]): Output list. It should be correctly sized as the | |||||
| size of the group for this collective and will contain the output. | |||||
| object (Any): Pickable Python object to be broadcast from current process. | |||||
| group (ProcessGroup, optional): The process group to work on. If None, | |||||
| the default process group will be used. Default is ``None``. | |||||
| Returns: | |||||
| None. If the calling rank is part of this group, the output of the | |||||
| collective will be populated into the input ``object_list``. If the | |||||
| calling rank is not part of the group, the passed in ``object_list`` will | |||||
| be unmodified. | |||||
| .. note:: Note that this API differs slightly from the :func:`all_gather` | |||||
| collective since it does not provide an ``async_op`` handle and thus | |||||
| will be a blocking call. | |||||
| .. note:: For NCCL-based processed groups, internal tensor representations | |||||
| of objects must be moved to the GPU device before communication takes | |||||
| place. In this case, the device used is given by | |||||
| ``torch.cuda.current_device()`` and it is the user's responsiblity to | |||||
| ensure that this is set so that each rank has an individual GPU, via | |||||
| ``torch.cuda.set_device()``. | |||||
| .. warning:: | |||||
| :func:`all_gather_object` uses ``pickle`` module implicitly, which is | |||||
| known to be insecure. It is possible to construct malicious pickle data | |||||
| which will execute arbitrary code during unpickling. Only call this | |||||
| function with data you trust. | |||||
| Example:: | |||||
| >>> # Note: Process group initialization omitted on each rank. | |||||
| >>> import torch.distributed as dist | |||||
| >>> # Assumes world_size of 3. | |||||
| >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||||
| >>> output = [None for _ in gather_objects] | |||||
| >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) | |||||
| >>> output | |||||
| ['foo', 12, {1: 2}] | |||||
| """ | |||||
| if dist.distributed_c10d._rank_not_in_group(group): | |||||
| return | |||||
| if _TORCH_GREATER_EQUAL_1_8: | |||||
| current_device = torch.device("cpu") | |||||
| is_nccl_backend = _check_for_nccl_backend(group) | |||||
| if is_nccl_backend: | |||||
| # See note about using torch.cuda.current_device() here in docstring. | |||||
| # We cannot simply use my_rank since rank == device is not necessarily | |||||
| # true. | |||||
| current_device = torch.device("cuda", torch.cuda.current_device()) | |||||
| else: | |||||
| current_device = torch.cuda.current_device() | |||||
| input_tensor, local_size = _object_to_tensor(obj, device=current_device) | |||||
| # Gather all local sizes. This is so that we can find the max size, and index | |||||
| # until the correct size when deserializing the tensors. | |||||
| group_size = dist.get_world_size(group=group) | |||||
| object_sizes_tensor = torch.zeros( | |||||
| group_size, dtype=torch.long, device=current_device | |||||
| ) | |||||
| object_size_list = [ | |||||
| object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||||
| ] | |||||
| # Allgather tensor sizes | |||||
| dist.all_gather(object_size_list, local_size, group=group) | |||||
| max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||||
| # Resize tensor to max size across all ranks. | |||||
| input_tensor.resize_(max_object_size) | |||||
| coalesced_output_tensor = torch.empty( | |||||
| max_object_size * group_size, dtype=torch.uint8, device=current_device | |||||
| ) | |||||
| # Output tensors are nonoverlapping views of coalesced_output_tensor | |||||
| output_tensors = [ | |||||
| coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] | |||||
| for i in range(group_size) | |||||
| ] | |||||
| dist.all_gather(output_tensors, input_tensor, group=group) | |||||
| # Deserialize outputs back to object. | |||||
| for i, tensor in enumerate(output_tensors): | |||||
| tensor = tensor.type(torch.uint8) | |||||
| if tensor.device != torch.device("cpu"): | |||||
| tensor = tensor.cpu() | |||||
| tensor_size = object_size_list[i] | |||||
| object_list[i] = _tensor_to_object(tensor, tensor_size) | |||||
| return object_list | |||||
| @@ -1,13 +1,12 @@ | |||||
| import os | import os | ||||
| import shutil | |||||
| from functools import partial | from functools import partial | ||||
| from typing import List, Union, Optional, Dict | |||||
| from typing import List, Union, Optional, Dict, Tuple, Callable | |||||
| from .paddle_driver import PaddleDriver | from .paddle_driver import PaddleDriver | ||||
| from .fleet_launcher import FleetLauncher | from .fleet_launcher import FleetLauncher | ||||
| from .utils import ( | from .utils import ( | ||||
| _FleetWrappingModel, | _FleetWrappingModel, | ||||
| ForwardState, | |||||
| _MODE_PARAMETER, | |||||
| get_device_from_visible, | get_device_from_visible, | ||||
| reset_seed, | reset_seed, | ||||
| replace_sampler, | replace_sampler, | ||||
| @@ -47,8 +46,7 @@ if _NEED_IMPORT_PADDLE: | |||||
| __all__ = [ | __all__ = [ | ||||
| "PaddleFleetDriver", | "PaddleFleetDriver", | ||||
| ] | ] | ||||
| # if os.path.exists(self.gloo_rendezvous_dir): | |||||
| # shutil.rmtree(self.gloo_rendezvous_dir) | |||||
| class PaddleFleetDriver(PaddleDriver): | class PaddleFleetDriver(PaddleDriver): | ||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| @@ -104,34 +102,6 @@ class PaddleFleetDriver(PaddleDriver): | |||||
| # 我们就直接将 model_device 置为 None; | # 我们就直接将 model_device 置为 None; | ||||
| self._model_device = None | self._model_device = None | ||||
| def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): | |||||
| if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
| return auto_param_call(step_fn, batch, signature_fn=signature_fn) | |||||
| else: | |||||
| return self._validate_step(batch) | |||||
| model = model._layers | |||||
| if hasattr(model, "train_step"): | |||||
| logger.warning( | |||||
| "Notice your model is a `paddle.DataParallel` model. And your " | |||||
| "model also implements the `train_step` method, which we can not call actually, we will" | |||||
| " call `forward` function instead of `train_step` and you should note that.") | |||||
| self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
| if hasattr(model, "evaluate_step"): | |||||
| logger.warning( | |||||
| "Notice your model is a `paddle.DataParallel` model. And your " | |||||
| "model also implements the `evaluate_step` method, which we can not call actually, " | |||||
| "we will call `forward` function instead of `evaluate_step` and you should note that.") | |||||
| self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
| if hasattr(model, "test_step"): | |||||
| logger.warning( | |||||
| "Notice your model is a `paddle.DataParallel` model. And your " | |||||
| "model also implements the `test_step` method, which we can not call actually, we will" | |||||
| " call `forward` function instead of `test_step` and you should note that.") | |||||
| self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
| # 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | # 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | ||||
| self._data_device = kwargs.get("data_device", None) | self._data_device = kwargs.get("data_device", None) | ||||
| if self._data_device is not None: | if self._data_device is not None: | ||||
| @@ -150,8 +120,6 @@ class PaddleFleetDriver(PaddleDriver): | |||||
| self.world_size = None | self.world_size = None | ||||
| self.global_rank = 0 | self.global_rank = 0 | ||||
| self._configured = False # 防止重复调用 configure_ddp() 函数使用 | |||||
| self._has_setup = False # 防止重复调用 setup() 函数 | |||||
| self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) | self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) | ||||
| check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) | check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) | ||||
| @@ -173,6 +141,9 @@ class PaddleFleetDriver(PaddleDriver): | |||||
| os.makedirs(name=self.output_from_new_proc, exist_ok=True) | os.makedirs(name=self.output_from_new_proc, exist_ok=True) | ||||
| self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | ||||
| self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||||
| self._has_fleetwrapped = False # 判断传入的模型是否经过 _has_fleetwrapped 包裹; | |||||
| def setup(self): | def setup(self): | ||||
| """ | """ | ||||
| 在主进程拉起其它子进程,将主进程作为rank 0 | 在主进程拉起其它子进程,将主进程作为rank 0 | ||||
| @@ -268,17 +239,17 @@ class PaddleFleetDriver(PaddleDriver): | |||||
| dist.barrier() | dist.barrier() | ||||
| def configure_fleet(self): | def configure_fleet(self): | ||||
| if not self._configured and not isinstance(self.model, DataParallel): | |||||
| if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): | |||||
| self.model = DataParallel( | self.model = DataParallel( | ||||
| _FleetWrappingModel(self.model), | _FleetWrappingModel(self.model), | ||||
| **self._fleet_kwargs | **self._fleet_kwargs | ||||
| ) | ) | ||||
| self._has_fleetwrapped = True | |||||
| self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) | |||||
| self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) | |||||
| self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) | |||||
| self._configured = True | |||||
| def on_exception(self): | |||||
| if os.path.exists(self.gloo_rendezvous_dir): | |||||
| shutil.rmtree(self.gloo_rendezvous_dir) | |||||
| super().on_exception() | |||||
| @property | @property | ||||
| def world_size(self) -> int: | def world_size(self) -> int: | ||||
| @@ -310,14 +281,39 @@ class PaddleFleetDriver(PaddleDriver): | |||||
| return self._data_device | return self._data_device | ||||
| return self.model_device | return self.model_device | ||||
| def train_step(self, batch): | |||||
| return self._train_step(batch) | |||||
| def validate_step(self, batch): | |||||
| return self._validate_step(batch) | |||||
| def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||||
| if self._has_fleetwrapped: | |||||
| return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn, | |||||
| wo_auto_param_call=self.wo_auto_param_call) | |||||
| else: | |||||
| if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
| return auto_param_call(fn, batch, signature_fn=signature_fn) | |||||
| else: | |||||
| return fn(batch) | |||||
| def get_model_call_fn(self, fn: str) -> Tuple: | |||||
| model = self.unwrap_model() | |||||
| if self._has_fleetwrapped: | |||||
| if hasattr(model, fn): | |||||
| fn = getattr(model, fn) | |||||
| if not callable(fn): | |||||
| raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.") | |||||
| return fn, None | |||||
| elif fn in {"train_step", "evaluate_step"}: | |||||
| return model, model.forward | |||||
| else: | |||||
| raise RuntimeError(f"There is no `{fn}` method in your model.") | |||||
| else: | |||||
| if hasattr(model, fn): | |||||
| logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements " | |||||
| f"the `{fn}` method, which we can not call actually, we will" | |||||
| " call `forward` function instead of `train_step` and you should note that.") | |||||
| elif fn not in {"train_step", "evaluate_step"}: | |||||
| raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a " | |||||
| "`DistributedDataParallel` model, which means that we will only call model.forward " | |||||
| "function when we are in forward propagation.") | |||||
| def test_step(self, batch): | |||||
| return self._test_step(batch) | |||||
| return self.model, model.forward | |||||
| def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], | def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], | ||||
| reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
| @@ -406,14 +402,6 @@ class PaddleFleetDriver(PaddleDriver): | |||||
| else: | else: | ||||
| raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | ||||
| def backward(self, loss): | |||||
| self.grad_scaler.scale(loss).backward() | |||||
| def step(self): | |||||
| for optimizer in self.optimizers: | |||||
| self.grad_scaler.step(optimizer) | |||||
| self.grad_scaler.update() | |||||
| def is_global_zero(self): | def is_global_zero(self): | ||||
| return self.global_rank == 0 | return self.global_rank == 0 | ||||
| @@ -450,3 +438,45 @@ class PaddleFleetDriver(PaddleDriver): | |||||
| if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)): | if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)): | ||||
| raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " | raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " | ||||
| f"not {type(each_optimizer)}.") | f"not {type(each_optimizer)}.") | ||||
| def broadcast_object(self, obj, src:int=0, group=None, **kwargs): | |||||
| """ | |||||
| 从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 | |||||
| 传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | |||||
| :param obj: obj,可能是 Tensor 或 嵌套类型的数据 | |||||
| :param int src: source 的 global rank 。 | |||||
| :param int dst: target 的 global rank,可以是多个目标 rank | |||||
| :param group: 所属的 group | |||||
| :param kwargs: | |||||
| :return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 | |||||
| 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 | |||||
| """ | |||||
| return | |||||
| return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group) | |||||
| def all_gather(self, obj, group) -> List: | |||||
| """ | |||||
| 将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | |||||
| pickle 进行序列化,接收到之后再反序列化。 | |||||
| example: | |||||
| obj = { | |||||
| 'a': [1, 1], | |||||
| 'b': [[1, 2], [1, 2]], | |||||
| 'c': { | |||||
| 'd': [1, 2] | |||||
| } | |||||
| } | |||||
| -> | |||||
| [ | |||||
| {'a': 1, 'b':[1, 2], 'c':{'d': 1}}, | |||||
| {'a': 1, 'b':[1, 2], 'c':{'d': 2}} | |||||
| ] | |||||
| :param obj: 需要传输的对象,在每个rank上都应该保持相同的结构。 | |||||
| :param group: | |||||
| :return: | |||||
| """ | |||||
| return | |||||
| return fastnlp_paddle_all_gather(obj, group=group) | |||||
| @@ -71,6 +71,14 @@ class PaddleDriver(Driver): | |||||
| for optimizer in self.optimizers: | for optimizer in self.optimizers: | ||||
| optimizer.clear_grad() | optimizer.clear_grad() | ||||
| def backward(self, loss): | |||||
| self.grad_scaler.scale(loss).backward() | |||||
| def step(self): | |||||
| for optimizer in self.optimizers: | |||||
| self.grad_scaler.step(optimizer) | |||||
| self.grad_scaler.update() | |||||
| @staticmethod | @staticmethod | ||||
| def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | ||||
| r""" | r""" | ||||
| @@ -115,28 +123,6 @@ class PaddleDriver(Driver): | |||||
| raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " | raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " | ||||
| f"not {type(each_optimizer)}.") | f"not {type(each_optimizer)}.") | ||||
| def check_evaluator_mode(self, mode: str): | |||||
| r""" | |||||
| 因为我们在具体的 driver 的 evaluate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数; | |||||
| 因此如果用户的 evaluator evaluate_fn 是 validate,但是传入的 model 却没有实现 evaluate_step 函数,而是实现了 test_step 函数,那么 | |||||
| 我们应当提醒用户这一行为; | |||||
| """ | |||||
| model = self.unwrap_model() | |||||
| if mode == "validate": | |||||
| if not hasattr(model, "evaluate_step"): | |||||
| if hasattr(model, "test_step"): | |||||
| logger.warning( | |||||
| "Your model does not have 'evaluate_step' method but has 'test_step' method, but you" | |||||
| "are using 'Evaluator.validate', we are going to use 'test_step' to substitute for" | |||||
| "'evaluate_step'.") | |||||
| else: | |||||
| if not hasattr(model, "test_step"): | |||||
| if hasattr(model, "evaluate_step"): | |||||
| logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" | |||||
| "are using 'Evaluator.test', we are going to use 'evaluate_step' to substitute for" | |||||
| "'test_step'.") | |||||
| @staticmethod | @staticmethod | ||||
| def tensor_to_numeric(tensor, reduce=None): | def tensor_to_numeric(tensor, reduce=None): | ||||
| r""" | r""" | ||||
| @@ -268,10 +254,10 @@ class PaddleDriver(Driver): | |||||
| except: # 有可能 batch_size 为 None,就只有损失精度了 | except: # 有可能 batch_size 为 None,就只有损失精度了 | ||||
| pass | pass | ||||
| assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report." | assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report." | ||||
| states["sampler_states"] = sampler_states | |||||
| else: | else: | ||||
| raise RuntimeError( | raise RuntimeError( | ||||
| "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") | "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") | ||||
| states["sampler_states"] = sampler_states | |||||
| # 2. 保存模型的状态; | # 2. 保存模型的状态; | ||||
| if should_save_model: | if should_save_model: | ||||
| @@ -1,5 +1,5 @@ | |||||
| import os | import os | ||||
| from typing import Optional, Dict, Union | |||||
| from typing import Optional, Dict, Union, Callable, Tuple | |||||
| from .paddle_driver import PaddleDriver | from .paddle_driver import PaddleDriver | ||||
| from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible | from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible | ||||
| @@ -11,16 +11,19 @@ from fastNLP.core.utils import ( | |||||
| get_paddle_device_id, | get_paddle_device_id, | ||||
| paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
| ) | ) | ||||
| from fastNLP.core.utils.utils import _get_fun_msg | |||||
| from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
| ReproducibleBatchSampler, | ReproducibleBatchSampler, | ||||
| RandomBatchSampler, | RandomBatchSampler, | ||||
| ReproducibleSampler, | ReproducibleSampler, | ||||
| RandomSampler, | |||||
| re_instantiate_sampler, | re_instantiate_sampler, | ||||
| ) | ) | ||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
| import paddle | import paddle | ||||
| from paddle import DataParallel | |||||
| from paddle.fluid.reader import _DatasetKind | from paddle.fluid.reader import _DatasetKind | ||||
| __all__ = [ | __all__ = [ | ||||
| @@ -28,109 +31,57 @@ __all__ = [ | |||||
| ] | ] | ||||
| class PaddleSingleDriver(PaddleDriver): | class PaddleSingleDriver(PaddleDriver): | ||||
| def __init__(self, model, device: str, fp16: Optional[bool] = False, **kwargs): | |||||
| def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs): | |||||
| if isinstance(model, DataParallel): | |||||
| raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") | |||||
| cuda_visible_devices = os.environ.get(USER_CUDA_VISIBLE_DEVICES, None) | |||||
| if cuda_visible_devices == "": | |||||
| device = "cpu" | |||||
| logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to" | |||||
| "use `cpu` instead of `gpu` device.") | |||||
| super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs) | super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs) | ||||
| if device is None: | if device is None: | ||||
| raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") | raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") | ||||
| if device != "cpu": | |||||
| if isinstance(device, int): | |||||
| device_id = device | |||||
| else: | |||||
| device_id = get_paddle_device_id(device) | |||||
| os.environ["CUDA_VISIBLE_DEVICES"] = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] | |||||
| self.model_device = get_paddle_gpu_str(device) | self.model_device = get_paddle_gpu_str(device) | ||||
| self.local_rank = 0 | self.local_rank = 0 | ||||
| self.global_rank = 0 | self.global_rank = 0 | ||||
| self.world_size = 1 | self.world_size = 1 | ||||
| if isinstance(model, paddle.DataParallel): | |||||
| # 注意这里的 unwrap_model 调用的是具体子类的方法; | |||||
| model = self.unwrap_model() | |||||
| if hasattr(model, "train_step"): | |||||
| logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " | |||||
| "implements the `train_step` method, which we can not call actually, we will " | |||||
| " call `forward` function instead of `train_step` and you should note that.") | |||||
| self._train_step = self.model | |||||
| self._train_signature_fn = model.forward | |||||
| if hasattr(model, "evaluate_step"): | |||||
| logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " | |||||
| "implements the `evaluate_step` method, which we can not call actually, we " | |||||
| "will call `forward` function instead of `evaluate_step` and you should note that.") | |||||
| self._validate_step = self.model | |||||
| self._validate_signature_fn = model.forward | |||||
| if hasattr(model, "test_step"): | |||||
| logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " | |||||
| "implements the `test_step` method, which we can not call actually, we will " | |||||
| "call `forward` function instead of `test_step` and you should note that.") | |||||
| self._test_step = self.model | |||||
| self._test_signature_fn = model.forward | |||||
| else: | |||||
| if hasattr(self.model, "train_step"): | |||||
| self._train_step = self.model.train_step | |||||
| self._train_signature_fn = None | |||||
| else: | |||||
| self._train_step = self.model | |||||
| # 输入的模型是 `DataParallel`,我们需要保证其 signature_fn 是正确的; | |||||
| model = self.unwrap_model() | |||||
| self._train_signature_fn = model.forward | |||||
| if hasattr(self.model, "evaluate_step"): | |||||
| self._validate_step = self.model.evaluate_step | |||||
| self._validate_signature_fn = None | |||||
| elif hasattr(self.model, "test_step"): | |||||
| self._validate_step = self.model.test_step | |||||
| self._validate_signature_fn = self.model.test_step | |||||
| else: | |||||
| self._validate_step = self.model | |||||
| model = self.unwrap_model() | |||||
| self._validate_signature_fn = model.forward | |||||
| if hasattr(self.model, "test_step"): | |||||
| self._test_step = self.model.test_step | |||||
| self._test_signature_fn = None | |||||
| elif hasattr(self.model, "evaluate_step"): | |||||
| self._test_step = self.model.evaluate_step | |||||
| self._test_signature_fn = self.model.evaluate_step | |||||
| else: | |||||
| self._test_step = self.model | |||||
| model = self.unwrap_model() | |||||
| self._test_signature_fn = model.forward | |||||
| def setup(self): | def setup(self): | ||||
| device = self.model_device | device = self.model_device | ||||
| if device != "cpu": | |||||
| device_id = get_paddle_device_id(device) | |||||
| device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] | |||||
| os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) | |||||
| device = get_device_from_visible(device, output_type=str) | |||||
| device = get_device_from_visible(device, output_type=str) | |||||
| paddle.device.set_device(device) | paddle.device.set_device(device) | ||||
| self.model.to(device) | self.model.to(device) | ||||
| def train_step(self, batch) -> Dict: | |||||
| # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | |||||
| def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||||
| if isinstance(batch, Dict) and not self.wo_auto_param_call: | if isinstance(batch, Dict) and not self.wo_auto_param_call: | ||||
| return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||||
| return auto_param_call(fn, batch, signature_fn=signature_fn) | |||||
| else: | else: | ||||
| return self._train_step(batch) | |||||
| def backward(self, loss): | |||||
| self.grad_scaler.scale(loss).backward() | |||||
| def step(self): | |||||
| for optimizer in self.optimizers: | |||||
| self.grad_scaler.step(optimizer) | |||||
| self.grad_scaler.update() | |||||
| def validate_step(self, batch) -> Dict: | |||||
| if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
| return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||||
| return fn(batch) | |||||
| def get_model_call_fn(self, fn: str) -> Tuple: | |||||
| if hasattr(self.model, fn): | |||||
| fn = getattr(self.model, fn) | |||||
| if not callable(fn): | |||||
| raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") | |||||
| logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') | |||||
| return fn, None | |||||
| elif fn in {"train_step", "evaluate_step"}: | |||||
| logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') | |||||
| return self.model, self.model.forward | |||||
| else: | else: | ||||
| return self._validate_step(batch) | |||||
| def test_step(self, batch) -> Dict: | |||||
| if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
| return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||||
| else: | |||||
| return self._test_step(batch) | |||||
| raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | |||||
| def move_data_to_device(self, batch: 'paddle.Tensor'): | def move_data_to_device(self, batch: 'paddle.Tensor'): | ||||
| r""" | r""" | ||||
| @@ -164,12 +115,18 @@ class PaddleSingleDriver(PaddleDriver): | |||||
| return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
| if reproducible: | if reproducible: | ||||
| batch_sampler = RandomBatchSampler( | |||||
| batch_sampler=args.batch_sampler, | |||||
| batch_size=args.batch_size, | |||||
| drop_last=args.drop_last | |||||
| ) | |||||
| return replace_batch_sampler(dataloader, batch_sampler) | |||||
| if isinstance(args.sampler, paddle.io.RandomSampler): | |||||
| # 如果本来就是随机的,直接替换 | |||||
| sampler = RandomSampler(args.sampler.data_source) | |||||
| logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||||
| return replace_sampler(dataloader, sampler) | |||||
| else: | |||||
| batch_sampler = RandomBatchSampler( | |||||
| batch_sampler=args.batch_sampler, | |||||
| batch_size=args.batch_size, | |||||
| drop_last=args.drop_last | |||||
| ) | |||||
| return replace_batch_sampler(dataloader, batch_sampler) | |||||
| else: | else: | ||||
| return dataloader | return dataloader | ||||
| @@ -11,7 +11,6 @@ from typing import Dict, Optional, Union | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
| from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to | from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to | ||||
| from fastNLP.core.samplers import RandomSampler | |||||
| from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES | from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES | ||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| @@ -87,8 +86,6 @@ class ForwardState(IntEnum): | |||||
| TEST = 2 | TEST = 2 | ||||
| PREDICT = 3 | PREDICT = 3 | ||||
| _MODE_PARAMETER = "forward_state" | |||||
| class _FleetWrappingModel(Layer): | class _FleetWrappingModel(Layer): | ||||
| """ | """ | ||||
| 参考_DDPWrappingModel,paddle的分布式训练也需要用paddle.nn.DataParallel进行包装,采用和 | 参考_DDPWrappingModel,paddle的分布式训练也需要用paddle.nn.DataParallel进行包装,采用和 | ||||
| @@ -98,83 +95,16 @@ class _FleetWrappingModel(Layer): | |||||
| super(_FleetWrappingModel, self).__init__() | super(_FleetWrappingModel, self).__init__() | ||||
| self.model = model | self.model = model | ||||
| if isinstance(model, paddle.DataParallel): | |||||
| model = model._layers | |||||
| if hasattr(model, "train_step"): | |||||
| logger.warning( | |||||
| "Notice your model is a `paddle.DataParallel` model. And your " | |||||
| "model also implements the `train_step` method, which we can not call actually, we will" | |||||
| " call `forward` function instead of `train_step` and you should note that.") | |||||
| self._train_step = self.model | |||||
| self._train_signature_fn = model.forward | |||||
| if hasattr(model, "evaluate_step"): | |||||
| logger.warning( | |||||
| "Notice your model is a `paddle.DataParallel` model. And your " | |||||
| "model also implements the `evaluate_step` method, which we can not call actually, " | |||||
| "we will call `forward` function instead of `evaluate_step` and you should note that.") | |||||
| self._validate_step = self.model | |||||
| self._validate_signature_fn = model.forward | |||||
| if hasattr(model, "test_step"): | |||||
| logger.warning( | |||||
| "Notice your model is a `paddle.DataParallel` model. And your " | |||||
| "model also implements the `test_step` method, which we can not call actually, we will" | |||||
| " call `forward` function instead of `test_step` and you should note that.") | |||||
| self._test_step = self.model | |||||
| self._test_signature_fn = model.forward | |||||
| else: | |||||
| if hasattr(model, "train_step"): | |||||
| self._train_step = model.train_step | |||||
| self._train_signature_fn = None | |||||
| else: | |||||
| self._train_step = model | |||||
| self._train_signature_fn = model.forward | |||||
| if hasattr(model, "evaluate_step"): | |||||
| self._validate_step = model.validate_step | |||||
| self._validate_signature_fn = None | |||||
| elif hasattr(model, "test_step"): | |||||
| self._validate_step = model.test_step | |||||
| self._validate_signature_fn = None | |||||
| else: | |||||
| self._validate_step = model | |||||
| self._validate_signature_fn = model.forward | |||||
| if hasattr(model, "test_step"): | |||||
| self._test_step = model.test_step | |||||
| self._test_signature_fn = None | |||||
| elif hasattr(model, "evaluate_step"): | |||||
| self._test_step = model.validate_step | |||||
| self._test_signature_fn = None | |||||
| else: | |||||
| self._test_step = model | |||||
| self._test_signature_fn = model.forward | |||||
| def forward(self, batch, **kwargs) -> Dict: | def forward(self, batch, **kwargs) -> Dict: | ||||
| forward_state = kwargs.pop(_MODE_PARAMETER) | |||||
| fn = kwargs.pop("fastnlp_fn") | |||||
| signature_fn = kwargs.pop("fastnlp_signature_fn") | |||||
| wo_auto_param_call = kwargs.pop("wo_auto_param_call") | wo_auto_param_call = kwargs.pop("wo_auto_param_call") | ||||
| if forward_state == ForwardState.TRAIN: | |||||
| if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
| return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||||
| else: | |||||
| return self._train_step(batch) | |||||
| elif forward_state == ForwardState.VALIDATE: | |||||
| if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
| return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||||
| else: | |||||
| return self._validate_step(batch) | |||||
| elif forward_state == ForwardState.TEST: | |||||
| if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
| return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||||
| else: | |||||
| return self._test_step(batch) | |||||
| elif forward_state == ForwardState.PREDICT: | |||||
| raise NotImplementedError("'PREDICT' evaluate_fn has not been implemented.") | |||||
| if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
| return auto_param_call(fn, batch, signature_fn=signature_fn) | |||||
| else: | else: | ||||
| raise NotImplementedError("You should direct a concrete evaluate_fn.") | |||||
| return fn(batch) | |||||
| class DummyGradScaler: | class DummyGradScaler: | ||||
| """ | """ | ||||
| @@ -1,6 +1,7 @@ | |||||
| from typing import Optional, Dict, Union, Callable | |||||
| from typing import Optional, Dict, Union, Callable, Tuple | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | ||||
| from fastNLP.core.utils.utils import _get_fun_msg | |||||
| if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
| @@ -48,33 +49,6 @@ class TorchPaddleDriver(Driver): | |||||
| elif self._data_device is not None: | elif self._data_device is not None: | ||||
| raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | ||||
| if hasattr(self.model, "train_step"): | |||||
| self._train_step = self.model.train_step | |||||
| self._train_signature_fn = None | |||||
| else: | |||||
| self._train_step = self.model | |||||
| self._train_signature_fn = self.model.forward | |||||
| if hasattr(self.model, "evaluate_step"): | |||||
| self._validate_step = self.model.evaluate_step | |||||
| self._validate_signature_fn = None | |||||
| elif hasattr(self.model, "test_step"): | |||||
| self._validate_step = self.model.test_step | |||||
| self._validate_signature_fn = self.model.forward | |||||
| else: | |||||
| self._validate_step = self.model | |||||
| self._validate_signature_fn = self.model.forward | |||||
| if hasattr(self.model, "test_step"): | |||||
| self._test_step = self.model.test_step | |||||
| self._test_signature_fn = None | |||||
| elif hasattr(self.model, "evaluate_step"): | |||||
| self._test_step = self.model.evaluate_step | |||||
| self._test_signature_fn = self.model.forward | |||||
| else: | |||||
| self._test_step = self.model | |||||
| self._test_signature_fn = self.model.forward | |||||
| def setup(self): | def setup(self): | ||||
| if self.model_device is not None: | if self.model_device is not None: | ||||
| paddle.device.set_device(self.model_device.replace("cuda", "gpu")) | paddle.device.set_device(self.model_device.replace("cuda", "gpu")) | ||||
| @@ -103,12 +77,6 @@ class TorchPaddleDriver(Driver): | |||||
| f"'torch.optim.Optimizer' or 'paddle.optimizers.Optimizer' type, " | f"'torch.optim.Optimizer' or 'paddle.optimizers.Optimizer' type, " | ||||
| f"not {type(each_optimizer)}.") | f"not {type(each_optimizer)}.") | ||||
| def train_step(self, batch) -> Dict: | |||||
| if isinstance(batch, Dict): | |||||
| return auto_param_call(self._train_step, batch) | |||||
| else: | |||||
| return self._train_step(batch) | |||||
| def step(self): | def step(self): | ||||
| for optimizer in self.optimizers: | for optimizer in self.optimizers: | ||||
| optimizer.step() | optimizer.step() | ||||
| @@ -125,17 +93,24 @@ class TorchPaddleDriver(Driver): | |||||
| else: | else: | ||||
| raise ValueError("Unknown optimizers type.") | raise ValueError("Unknown optimizers type.") | ||||
| def validate_step(self, batch): | |||||
| if isinstance(batch, Dict): | |||||
| return auto_param_call(self._validate_step, batch) | |||||
| def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||||
| if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
| return auto_param_call(fn, batch, signature_fn=signature_fn) | |||||
| else: | else: | ||||
| return self._validate_step(batch) | |||||
| def test_step(self, batch): | |||||
| if isinstance(batch, Dict): | |||||
| return auto_param_call(self._test_step, batch) | |||||
| return fn(batch) | |||||
| def get_model_call_fn(self, fn: str) -> Tuple: | |||||
| if hasattr(self.model, fn): | |||||
| fn = getattr(self.model, fn) | |||||
| if not callable(fn): | |||||
| raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") | |||||
| logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') | |||||
| return fn, None | |||||
| elif fn in {"train_step", "evaluate_step"}: | |||||
| logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') | |||||
| return self.model, self.model.forward | |||||
| else: | else: | ||||
| return self._test_step(batch) | |||||
| raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | |||||
| def predict_step(self, batch): | def predict_step(self, batch): | ||||
| if isinstance(batch, Dict): | if isinstance(batch, Dict): | ||||
| @@ -85,7 +85,7 @@ class MixModule: | |||||
| def test_step(self, batch): | def test_step(self, batch): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def validate_step(self, batch): | |||||
| def evaluate_step(self, batch): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def train(self): | def train(self): | ||||
| @@ -1,13 +1,11 @@ | |||||
| import pytest | import pytest | ||||
| import os | import os | ||||
| os.environ["FASTNLP_BACKEND"] = "paddle" | os.environ["FASTNLP_BACKEND"] = "paddle" | ||||
| from typing import Any | |||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
| from fastNLP.core.metrics.accuracy import Accuracy | from fastNLP.core.metrics.accuracy import Accuracy | ||||
| from fastNLP.core.callbacks.progress_callback import RichCallback | from fastNLP.core.callbacks.progress_callback import RichCallback | ||||
| from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||||
| from paddle.optimizer import Adam | from paddle.optimizer import Adam | ||||
| from paddle.io import DataLoader | from paddle.io import DataLoader | ||||
| @@ -19,40 +17,18 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordM | |||||
| from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
| @dataclass | @dataclass | ||||
| class MNISTTrainPaddleConfig: | |||||
| class TrainPaddleConfig: | |||||
| num_labels: int = 10 | num_labels: int = 10 | ||||
| feature_dimension: int = 784 | |||||
| feature_dimension: int = 10 | |||||
| batch_size: int = 32 | |||||
| batch_size: int = 2 | |||||
| shuffle: bool = True | shuffle: bool = True | ||||
| validate_every = -5 | |||||
| evaluate_every = 2 | |||||
| driver: str = "paddle" | |||||
| device = "gpu" | |||||
| @dataclass | |||||
| class MNISTTrainFleetConfig: | |||||
| num_labels: int = 10 | |||||
| feature_dimension: int = 784 | |||||
| batch_size: int = 32 | |||||
| shuffle: bool = True | |||||
| validate_every = -5 | |||||
| @dataclass | |||||
| class TrainerParameters: | |||||
| model: Any = None | |||||
| optimizers: Any = None | |||||
| train_dataloader: Any = None | |||||
| validate_dataloaders: Any = None | |||||
| input_mapping: Any = None | |||||
| output_mapping: Any = None | |||||
| metrics: Any = None | |||||
| @pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)]) | |||||
| @pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)]) | |||||
| # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | ||||
| @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True), | |||||
| RichCallback(5), RecordLossCallback(loss_threshold=0.3)]]) | |||||
| @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | |||||
| RichCallback(5)]]) | |||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_trainer_paddle( | def test_trainer_paddle( | ||||
| driver, | driver, | ||||
| @@ -60,38 +36,36 @@ def test_trainer_paddle( | |||||
| callbacks, | callbacks, | ||||
| n_epochs=2, | n_epochs=2, | ||||
| ): | ): | ||||
| trainer_params = TrainerParameters() | |||||
| trainer_params.model = PaddleNormalModel_Classification_1( | |||||
| num_labels=MNISTTrainPaddleConfig.num_labels, | |||||
| feature_dimension=MNISTTrainPaddleConfig.feature_dimension | |||||
| model = PaddleNormalModel_Classification_1( | |||||
| num_labels=TrainPaddleConfig.num_labels, | |||||
| feature_dimension=TrainPaddleConfig.feature_dimension | |||||
| ) | ) | ||||
| trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) | |||||
| optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) | |||||
| train_dataloader = DataLoader( | train_dataloader = DataLoader( | ||||
| dataset=PaddleRandomMaxDataset(6400, 10), | |||||
| batch_size=MNISTTrainPaddleConfig.batch_size, | |||||
| dataset=PaddleRandomMaxDataset(20, 10), | |||||
| batch_size=TrainPaddleConfig.batch_size, | |||||
| shuffle=True | shuffle=True | ||||
| ) | ) | ||||
| val_dataloader = DataLoader( | val_dataloader = DataLoader( | ||||
| dataset=PaddleRandomMaxDataset(1000, 10), | |||||
| batch_size=MNISTTrainPaddleConfig.batch_size, | |||||
| dataset=PaddleRandomMaxDataset(20, 10), | |||||
| batch_size=TrainPaddleConfig.batch_size, | |||||
| shuffle=True | shuffle=True | ||||
| ) | ) | ||||
| trainer_params.train_dataloader = train_dataloader | |||||
| trainer_params.validate_dataloaders = val_dataloader | |||||
| trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every | |||||
| trainer_params.metrics = {"acc": Accuracy(backend="paddle")} | |||||
| train_dataloader = train_dataloader | |||||
| evaluate_dataloaders = val_dataloader | |||||
| evaluate_every = TrainPaddleConfig.evaluate_every | |||||
| metrics = {"acc": Accuracy(backend="paddle")} | |||||
| trainer = Trainer( | trainer = Trainer( | ||||
| model=trainer_params.model, | |||||
| model=model, | |||||
| driver=driver, | driver=driver, | ||||
| device=device, | device=device, | ||||
| optimizers=trainer_params.optimizers, | |||||
| train_dataloader=trainer_params.train_dataloader, | |||||
| validate_dataloaders=trainer_params.validate_dataloaders, | |||||
| validate_every=trainer_params.validate_every, | |||||
| input_mapping=trainer_params.input_mapping, | |||||
| output_mapping=trainer_params.output_mapping, | |||||
| metrics=trainer_params.metrics, | |||||
| optimizers=optimizers, | |||||
| train_dataloader=train_dataloader, | |||||
| evaluate_dataloaders=evaluate_dataloaders, | |||||
| evaluate_every=evaluate_every, | |||||
| input_mapping=None, | |||||
| output_mapping=None, | |||||
| metrics=metrics, | |||||
| n_epochs=n_epochs, | n_epochs=n_epochs, | ||||
| callbacks=callbacks, | callbacks=callbacks, | ||||
| @@ -56,34 +56,57 @@ def test_save_and_load_with_randombatchsampler(only_state_dict): | |||||
| dataset=dataset, | dataset=dataset, | ||||
| batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | ||||
| ) | ) | ||||
| num_consumed_batches = 2 | |||||
| # TODO 断点重训完善后在这里迭代几次 | # TODO 断点重训完善后在这里迭代几次 | ||||
| already_seen_set = set() | |||||
| for idx, batch in enumerate(dataloader): | |||||
| if idx >= num_consumed_batches: | |||||
| break | |||||
| already_seen_set.update(batch) | |||||
| sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
| save_states = {"num_consumed_batches": num_consumed_batches} | |||||
| if only_state_dict: | if only_state_dict: | ||||
| driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) | |||||
| driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| else: | else: | ||||
| driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
| states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
| # 加载 | |||||
| # 更改 batch_size | |||||
| dataloader = DataLoader( | |||||
| dataset=dataset, | |||||
| batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False) | |||||
| ) | |||||
| load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| replaced_loader = load_states.pop("dataloader") | |||||
| # 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
| # TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
| # 2. 检查 batch_sampler 是否被正确地加载和替换 | # 2. 检查 batch_sampler 是否被正确地加载和替换 | ||||
| replaced_loader = states["dataloader"] | |||||
| assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | ||||
| assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | ||||
| assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"] | assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"] | ||||
| # 3. 检查 model 的参数是否被正确加载 | # 3. 检查 model 的参数是否被正确加载 | ||||
| for batch in dataloader: | for batch in dataloader: | ||||
| res1 = driver1.validate_step(batch) | |||||
| res2 = driver2.validate_step(batch) | |||||
| res1 = driver1.model.evaluate_step(**batch) | |||||
| res2 = driver2.model.evaluate_step(**batch) | |||||
| assert paddle.equal_all(res1["pred"], res2["pred"]) | assert paddle.equal_all(res1["pred"], res2["pred"]) | ||||
| # 4. 检查 batch_idx | # 4. 检查 batch_idx | ||||
| # TODO | |||||
| start_batch = load_states.pop('batch_idx_in_epoch') | |||||
| assert start_batch == 2 * num_consumed_batches | |||||
| left_batches = set() | |||||
| for idx, batch in enumerate(replaced_loader): | |||||
| left_batches.update(batch) | |||||
| assert len(left_batches) + len(already_seen_set) == len(dataset) | |||||
| assert len(left_batches | already_seen_set) == len(dataset) | |||||
| finally: | finally: | ||||
| synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
| @@ -104,21 +127,36 @@ def test_save_and_load_with_randomsampler(only_state_dict): | |||||
| dataset, | dataset, | ||||
| batch_sampler=batch_sampler | batch_sampler=batch_sampler | ||||
| ) | ) | ||||
| num_consumed_batches = 2 | |||||
| # TODO 断点重训完善后在这里迭代几次 | # TODO 断点重训完善后在这里迭代几次 | ||||
| already_seen_set = set() | |||||
| for idx, batch in enumerate(dataloader): | |||||
| if idx >= num_consumed_batches: | |||||
| break | |||||
| already_seen_set.update(batch) | |||||
| sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
| save_states = {"num_consumed_batches": num_consumed_batches} | |||||
| if only_state_dict: | if only_state_dict: | ||||
| driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) | |||||
| driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
| else: | else: | ||||
| driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
| states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
| # 加载 | |||||
| # 更改 batch_size | |||||
| dataloader = DataLoader( | |||||
| dataset=dataset, | |||||
| batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False) | |||||
| ) | |||||
| load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
| replaced_loader = load_states.pop("dataloader") | |||||
| # 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
| # TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
| # 2. 检查 sampler 是否被正确地加载和替换 | # 2. 检查 sampler 是否被正确地加载和替换 | ||||
| replaced_loader = states["dataloader"] | |||||
| replaced_loader = load_states["dataloader"] | |||||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | ||||
| assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | ||||
| @@ -129,60 +167,51 @@ def test_save_and_load_with_randomsampler(only_state_dict): | |||||
| # 3. 检查 model 的参数是否被正确加载 | # 3. 检查 model 的参数是否被正确加载 | ||||
| for batch in dataloader: | for batch in dataloader: | ||||
| res1 = driver1.validate_step(batch) | |||||
| res2 = driver2.validate_step(batch) | |||||
| res1 = driver1.model.evaluate_step(**batch) | |||||
| res2 = driver2.model.evaluate_step(**batch) | |||||
| assert paddle.equal_all(res1["pred"], res2["pred"]) | assert paddle.equal_all(res1["pred"], res2["pred"]) | ||||
| # 4. 检查 batch_idx | # 4. 检查 batch_idx | ||||
| # TODO | |||||
| finally: | |||||
| synchronize_safe_rm(path) | |||||
| def test_save_and_load_state_dict(prepare_test_save_load): | |||||
| """ | |||||
| 测试save和load函数 | |||||
| TODO optimizer的state_dict为空,暂时不测试 | |||||
| """ | |||||
| try: | |||||
| path = "dict" | |||||
| driver1, driver2, dataloader = prepare_test_save_load | |||||
| driver1.save_model(path) | |||||
| driver2.load_model(path) | |||||
| for batch in dataloader: | |||||
| batch = driver1.move_data_to_device(batch) | |||||
| res1 = driver1.validate_step(batch) | |||||
| res2 = driver2.validate_step(batch) | |||||
| start_batch = load_states.pop('batch_idx_in_epoch') | |||||
| assert start_batch == 2 * num_consumed_batches | |||||
| left_batches = set() | |||||
| for idx, batch in enumerate(replaced_loader): | |||||
| left_batches.update(batch) | |||||
| assert paddle.equal_all(res1["pred"], res2["pred"]) | |||||
| assert len(left_batches) + len(already_seen_set) == len(dataset) | |||||
| assert len(left_batches | already_seen_set) == len(dataset) | |||||
| finally: | finally: | ||||
| synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
| def test_save_and_load_whole_model(prepare_test_save_load): | |||||
| @pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
| def test_save_and_load_model(prepare_test_save_load, only_state_dict): | |||||
| """ | """ | ||||
| 测试save和load函数 | |||||
| TODO optimizer的state_dict为空,暂时不测试 | |||||
| 测试 save_model 和 load_model 函数 | |||||
| """ | """ | ||||
| try: | try: | ||||
| path = "model" | path = "model" | ||||
| driver1, driver2, dataloader = prepare_test_save_load | driver1, driver2, dataloader = prepare_test_save_load | ||||
| driver1.save_model(path, only_state_dict=False, input_spec=[paddle.ones((32, 10))]) | |||||
| driver2.load_model(path, only_state_dict=False) | |||||
| if only_state_dict: | |||||
| driver1.save_model(path, only_state_dict) | |||||
| else: | |||||
| driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((32, 10))]) | |||||
| driver2.load_model(path, only_state_dict) | |||||
| for batch in dataloader: | for batch in dataloader: | ||||
| batch = driver1.move_data_to_device(batch) | batch = driver1.move_data_to_device(batch) | ||||
| res1 = driver1.validate_step(batch) | |||||
| res2 = driver2.validate_step(batch) | |||||
| res1 = driver1.model.evaluate_step(**batch) | |||||
| res2 = driver2.model.evaluate_step(**batch) | |||||
| assert paddle.equal_all(res1["pred"], res2["pred"]) | assert paddle.equal_all(res1["pred"], res2["pred"]) | ||||
| finally: | finally: | ||||
| synchronize_safe_rm(path + ".pdiparams") | |||||
| synchronize_safe_rm(path + ".pdiparams.info") | |||||
| synchronize_safe_rm(path + ".pdmodel") | |||||
| if only_state_dict: | |||||
| synchronize_safe_rm(path) | |||||
| else: | |||||
| synchronize_safe_rm(path + ".pdiparams") | |||||
| synchronize_safe_rm(path + ".pdiparams.info") | |||||
| synchronize_safe_rm(path + ".pdmodel") | |||||
| class TestSingleDeviceFunction: | class TestSingleDeviceFunction: | ||||
| """ | """ | ||||
| @@ -199,13 +228,7 @@ class TestSingleDeviceFunction: | |||||
| 测试能否运行 | 测试能否运行 | ||||
| """ | """ | ||||
| res = self.driver.unwrap_model() | res = self.driver.unwrap_model() | ||||
| def test_check_evaluator_mode(self): | |||||
| """ | |||||
| 这两个函数没有返回值和抛出异常,仅检查是否有import错误等影响运行的因素 | |||||
| """ | |||||
| self.driver.check_evaluator_mode("validate") | |||||
| self.driver.check_evaluator_mode("test") | |||||
| assert res is self.driver.model | |||||
| def test_is_distributed(self): | def test_is_distributed(self): | ||||
| assert self.driver.is_distributed() == False | assert self.driver.is_distributed() == False | ||||
| @@ -237,21 +260,30 @@ class TestSetDistReproDataloder: | |||||
| assert replaced_loader is dataloader | assert replaced_loader is dataloader | ||||
| def test_set_dist_repro_dataloader_with_reproducible_true(self): | |||||
| @pytest.mark.parametrize("shuffle", [True, False]) | |||||
| def test_set_dist_repro_dataloader_with_reproducible_true(self, shuffle): | |||||
| """ | """ | ||||
| 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | ||||
| 当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler | |||||
| 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), | |||||
| 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler | |||||
| """ | """ | ||||
| dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||||
| dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | ||||
| assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
| assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | |||||
| if shuffle: | |||||
| # 此时会替换 sampler | |||||
| assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler) | |||||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
| else: | |||||
| # 此时会替换 batch_sampler | |||||
| assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | |||||
| assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | ||||
| assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
| # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||||
| self.check_set_dist_repro_dataloader(dataloader, replaced_loader) | |||||
| def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): | def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): | ||||
| """ | """ | ||||
| @@ -72,7 +72,7 @@ class RecordTrainerEventTriggerCallback(Callback): | |||||
| print("on_train_end") | print("on_train_end") | ||||
| def on_train_epoch_begin(self, trainer): | def on_train_epoch_begin(self, trainer): | ||||
| if trainer.current_epoch_idx >= 1: | |||||
| if trainer.cur_epoch_idx >= 1: | |||||
| # 触发 on_exception; | # 触发 on_exception; | ||||
| raise Exception | raise Exception | ||||
| print("on_train_epoch_begin") | print("on_train_epoch_begin") | ||||
| @@ -26,7 +26,7 @@ class PaddleNormalModel_Classification_1(paddle.nn.Layer): | |||||
| x = self(x) | x = self(x) | ||||
| return {"loss": self.loss_fn(x, y)} | return {"loss": self.loss_fn(x, y)} | ||||
| def validate_step(self, x, y): | |||||
| def evaluate_step(self, x, y): | |||||
| x = self(x) | x = self(x) | ||||
| return {"pred": x, "target": y.reshape((-1,))} | return {"pred": x, "target": y.reshape((-1,))} | ||||