Browse Source

deepspeed driver init

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
55d8738def
3 changed files with 252 additions and 1 deletions
  1. +165
    -0
      fastNLP/core/drivers/torch_driver/deepspeed.py
  2. +86
    -1
      fastNLP/core/drivers/torch_driver/utils.py
  3. +1
    -0
      fastNLP/envs/imports.py

+ 165
- 0
fastNLP/core/drivers/torch_driver/deepspeed.py View File

@@ -0,0 +1,165 @@
from typing import Optional, Union, Callable, Dict, Tuple, Sequence, List
from .torch_driver import TorchDriver
from .utils import _create_default_config
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \
ReproduceBatchSampler
from fastNLP.core.log import logger
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED

if _NEED_IMPORT_TORCH:
import pytorch_lightning
import torch
from torch.nn import DataParallel
if _NEED_IMPORT_DEEPSPEED:
import deepspeed

__all__ = [
"DeepSpeedDriver",
]

class DeepSpeedDriver(TorchDriver):
def __init__(self, model, fp16, strategy, **kwargs):
super(DeepSpeedDriver, self).__init__(model, fp16)

self.strategy = strategy

def setup(self):

if self.strategy == "deepspeed":
self.config = _create_default_config(stage=2)
elif self.strategy == "deepspeed_stage_1":
self.config = _create_default_config(stage=1)
elif self.strategy == "deepspeed_stage_2":
self.config = _create_default_config(stage=2)
elif self.strategy == "deepspeed_stage_2_offload":
self.config = _create_default_config(stage=2, offload_optimizer=True)
elif self.strategy == "deepspeed_stage_3":
self.config = _create_default_config(stage=3)
elif self.strategy == "deepspeed_stage_3_offload":
self.config = _create_default_config(
stage=3,
offload_optimizer=True,
offload_parameters=True,
)
elif self.strategy == "deepspeed_stage_3_offload_nvme":
self.config = _create_default_config(
stage=3,
offload_optimizer=True,
offload_parameters=True,
remote_device="nvme",
offload_params_device="nvme",
offload_optimizer_device="nvme",
)
for i, optimizer in enumerate(self.optimizers):
# TODO 多个 optimizer
engine, optimizer_ds, _, _ = deepspeed.initialize(
model=self.model,
optimizer=optimizer,
config=self.config
)
self._optimizers[i] = optimizer_ds
self.model = engine

self._set_deepspeed_activation_checkpointing()

def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
return self.model, self.model.forward
else:
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

def set_dist_repro_dataloader(self, dataloader,
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None,
reproducible: bool = False):
return dataloader
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用;
if isinstance(dist, ReproducibleBatchSampler):
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleSampler):
return replace_sampler(dataloader, dist)

# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleSampler):
sampler = re_instantiate_sampler(args.sampler)
return replace_sampler(dataloader, sampler)

if reproducible:
if type(args.batch_sampler) is TorchBatchSampler:
if type(args.sampler) is TorchRandomSampler:
if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'replacements', False) is False \
and getattr(args.sampler, 'generator', None) is None:
# 如果本来就是随机的,并且没有定制,直接替换掉吧。
sampler = RandomSampler(args.sampler.data_source, shuffle=True)
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
elif type(args.sampler) is TorchSequentialSampler:
# 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.data_source, shuffle=False)
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
else:
return dataloader

def unwrap_model(self):
r"""
:return: 返回原本的模型,例如没有被 ``DataParallel`` 包裹;
"""
if isinstance(self.model, deepspeed.DeepSpeedEngine):
print(type(self.model.module), self.model.module)
return self.model.module
if isinstance(self.model, torch.nn.DataParallel) or \
isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
return self.model.module
else:
return self.model

@property
def data_device(self):
r"""
注意单卡模式下使用 ``driver.data_device`` 等价于使用 ``driver.model_device``;
"""
return self.model_device

def is_distributed(self):
r"""
:return: 返回当前使用的 driver 是否是分布式的 driver,对于 ``TorchSingleDriver`` 来说直接返回 ``False``;
"""
return False

def _set_deepspeed_activation_checkpointing(self):
if self.config.get("activation_checkpointing"):
checkpoint_config = self.config["activation_checkpointing"]
deepspeed.checkpointing.configure(
mpu_=None,
partition_activations=checkpoint_config.get("partition_activations"),
contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"),
checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"),
profile=checkpoint_config.get("profile"),
)

+ 86
- 1
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -1,6 +1,6 @@
import os import os


from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from enum import IntEnum from enum import IntEnum
import contextlib import contextlib
import random import random
@@ -292,3 +292,88 @@ def _check_dataloader_args_for_distributed(args, controller='Trainer'):
f"``{substitution}``. The customized sampler should set for distributed running " f"``{substitution}``. The customized sampler should set for distributed running "
f"before initializing ``{controller}`` , and then set the " f"before initializing ``{controller}`` , and then set the "
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.")

def _create_default_config(
zero_optimization: bool = True,
zero_allow_untested_optimizer: bool = True,
logging_batch_size_per_gpu: Union[str, int] = "auto",
partition_activations: bool = False,
cpu_checkpointing: bool = False,
contiguous_memory_optimization: bool = False,
synchronize_checkpoint_boundary: bool = False,
offload_optimizer: bool = False,
offload_parameters: bool = False,
offload_params_device: str = "cpu",
nvme_path: str = "/local_nvme",
params_buffer_count: int = 5,
params_buffer_size: int = 100_000_000,
max_in_cpu: int = 1_000_000_000,
offload_optimizer_device: str = "cpu",
optimizer_buffer_count: int = 4,
pin_memory: bool = False,
block_size: int = 1048576,
queue_depth: int = 8,
single_submit: bool = False,
overlap_events: bool = True,
thread_count: int = 1,
stage: int = 2,
contiguous_gradients: bool = True,
overlap_comm: bool = True,
allgather_partitions: bool = True,
reduce_scatter: bool = True,
allgather_bucket_size: int = 200_000_000,
reduce_bucket_size: int = 200_000_000,
sub_group_size: int = 1_000_000_000_000,
) -> Dict:
cfg = {
"activation_checkpointing": {
"partition_activations": partition_activations,
"cpu_checkpointing": cpu_checkpointing,
"contiguous_memory_optimization": contiguous_memory_optimization,
"synchronize_checkpoint_boundary": synchronize_checkpoint_boundary,
},
"aio": {
"block_size": block_size,
"queue_depth": queue_depth,
"single_submit": single_submit,
"overlap_events": overlap_events,
"thread_count": thread_count,
},
}
zero_kwargs = {
"stage": stage,
"contiguous_gradients": contiguous_gradients,
"overlap_comm": overlap_comm,
"allgather_partitions": allgather_partitions,
"reduce_scatter": reduce_scatter,
"allgather_bucket_size": allgather_bucket_size,
"reduce_bucket_size": reduce_bucket_size,
"sub_group_size": sub_group_size,
}
if zero_optimization:
zero_config = zero_kwargs

if offload_optimizer:
zero_config["offload_optimizer"] = {
"device": offload_optimizer_device,
"nvme_path": nvme_path,
"buffer_count": optimizer_buffer_count,
"pin_memory": pin_memory,
}
if offload_parameters:
zero_config["offload_param"] = {
"device": offload_params_device,
"nvme_path": nvme_path,
"buffer_count": params_buffer_count,
"buffer_size": params_buffer_size,
"max_in_cpu": max_in_cpu,
"pin_memory": pin_memory,
}
cfg = {
"zero_allow_untested_optimizer": zero_allow_untested_optimizer,
"zero_optimization": zero_config,
**cfg,
}
if logging_batch_size_per_gpu != "auto":
cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
return cfg

+ 1
- 0
fastNLP/envs/imports.py View File

@@ -22,5 +22,6 @@ _NEED_IMPORT_FAIRSCALE = not _IS_WINDOWS and _module_available("fairscale") and
_NEED_IMPORT_TORCH = _module_available("torch") and 'torch' in need_import _NEED_IMPORT_TORCH = _module_available("torch") and 'torch' in need_import
_NEED_IMPORT_JITTOR = _module_available("jittor") and 'jittor' in need_import _NEED_IMPORT_JITTOR = _module_available("jittor") and 'jittor' in need_import
_NEED_IMPORT_PADDLE = _module_available("paddle") and 'paddle' in need_import _NEED_IMPORT_PADDLE = _module_available("paddle") and 'paddle' in need_import
_NEED_IMPORT_DEEPSPEED = _module_available("deepspeed") and 'deepspeed' in need_import


_TORCH_GREATER_EQUAL_1_8 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.8.0") _TORCH_GREATER_EQUAL_1_8 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.8.0")

Loading…
Cancel
Save