|
|
|
@@ -1,21 +1,31 @@ |
|
|
|
import os |
|
|
|
import random |
|
|
|
from typing import Union, Optional, Callable, Dict |
|
|
|
from typing import Union, Optional, Dict |
|
|
|
from pathlib import Path |
|
|
|
from functools import partial |
|
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from .utils import _build_fp16_env |
|
|
|
from .utils import _build_fp16_env, optimizer_state_to_device |
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE |
|
|
|
from fastNLP.core.drivers.driver import Driver |
|
|
|
from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device |
|
|
|
from fastNLP.envs import rank_zero_call |
|
|
|
from fastNLP.envs import FASTNLP_SEED_WORKERS |
|
|
|
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME |
|
|
|
from fastNLP.core.log import logger |
|
|
|
from fastNLP.core.samplers import ReproducibleBatchSampler |
|
|
|
|
|
|
|
if _NEED_IMPORT_PADDLE: |
|
|
|
import paddle |
|
|
|
from paddle.io import DataLoader, IterableDataset |
|
|
|
from paddle.io import ( |
|
|
|
DataLoader, |
|
|
|
IterableDataset, |
|
|
|
Dataset, |
|
|
|
Sampler, |
|
|
|
BatchSampler, |
|
|
|
RandomSampler, |
|
|
|
) |
|
|
|
from paddle.optimizer import Optimizer |
|
|
|
|
|
|
|
_reduces = { |
|
|
|
@@ -69,6 +79,8 @@ class PaddleDriver(Driver): |
|
|
|
# TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; |
|
|
|
if isinstance(dataloader.dataset, IterableDataset): |
|
|
|
raise TypeError("`IterableDataset` is not allowed.") |
|
|
|
if dataloader.batch_sampler is None and dataloader.batch_size is None: |
|
|
|
raise ValueError(f"At least one of `{dataloader_name}`'s `batch_sampler` and `batch_size` should be set.") |
|
|
|
else: |
|
|
|
if not isinstance(dataloader, Dict): |
|
|
|
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") |
|
|
|
@@ -79,6 +91,9 @@ class PaddleDriver(Driver): |
|
|
|
f"type, not {type(each_dataloader)}.") |
|
|
|
if isinstance(each_dataloader.dataset, IterableDataset): |
|
|
|
raise TypeError("`IterableDataset` is not allowed.") |
|
|
|
if dataloader.batch_sampler is None and dataloader.batch_size is None: |
|
|
|
raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of " |
|
|
|
f"`batch_sampler` and `batch_size` should be set.") |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _check_optimizer_legality(optimizers): |
|
|
|
@@ -153,45 +168,53 @@ class PaddleDriver(Driver): |
|
|
|
getattr(self.model, mode)() |
|
|
|
|
|
|
|
@rank_zero_call |
|
|
|
def save_model(self, filepath: str, only_state_dict: bool = True, model_save_fn: Optional[Callable]=None, **kwargs): |
|
|
|
def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs): |
|
|
|
r""" |
|
|
|
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; |
|
|
|
如果 `model_save_fn` 是一个可调用的函数,那么我们会直接运行该函数; |
|
|
|
|
|
|
|
:param filepath: 保存文件的文件位置(需要包括文件名); |
|
|
|
:param only_state_dict: 是否只保存模型的 `state_dict`;注意该参数仅当 `model_save_fn` 为 None 时有效; |
|
|
|
:param model_save_fn: 用户传入的用来代替该函数本身保存逻辑的函数;如果该参数不为 None,那么我们会调用 model_save_fn(path); |
|
|
|
:param only_state_dict: 是否只保存模型的 `state_dict`; |
|
|
|
:param kwargs: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if model_save_fn is not None: |
|
|
|
model_save_fn(filepath) |
|
|
|
model = self.unwrap_model() |
|
|
|
|
|
|
|
if only_state_dict: |
|
|
|
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} |
|
|
|
paddle.save(states, filepath) |
|
|
|
else: |
|
|
|
model = self.unwrap_model() |
|
|
|
if only_state_dict: |
|
|
|
paddle.save(model.state_dict(), filepath) |
|
|
|
# paddle 在保存整个模型时需要传入额外参数 |
|
|
|
input_spec = kwargs.get("input_spec", None) |
|
|
|
if input_spec is None: |
|
|
|
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") |
|
|
|
if self.model_device is not None: |
|
|
|
if not self.is_distributed(): |
|
|
|
self.move_model_to_device(model, "cpu") |
|
|
|
paddle.jit.save(model, filepath, input_spec) |
|
|
|
if not self.is_distributed(): |
|
|
|
self.move_model_to_device(model, self.model_device) |
|
|
|
else: |
|
|
|
input_spec = kwargs.get("input_spec", None) |
|
|
|
if input_spec is None: |
|
|
|
raise Exception("To save the whole Paddle Layer, parameter 'input_spec' is needed.") |
|
|
|
paddle.jit.save(model, filepath, input_spec) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
@rank_zero_call |
|
|
|
def load_model(filepath: str, load_dict: bool = True): |
|
|
|
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): |
|
|
|
r""" |
|
|
|
加载模型的函数;注意函数 `load` 是用来进行断点重训的函数; |
|
|
|
|
|
|
|
:param filepath: 需要被加载的对象的文件位置(需要包括文件名); |
|
|
|
:param load_dict: 是否加载state_dict,默认为True。当用户在save_model时将only_state_dict设置为False时, |
|
|
|
即保存了整个模型时,这个参数必须也为False |
|
|
|
:return: 返回加载指定文件后的结果; |
|
|
|
:param kwargs: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if load_dict: |
|
|
|
return paddle.load(filepath) |
|
|
|
model = self.unwrap_model() |
|
|
|
if only_state_dict: |
|
|
|
model.load_dict(paddle.load(filepath)) |
|
|
|
else: |
|
|
|
return paddle.jit.load(filepath) |
|
|
|
model.load_dict(paddle.jit.load(filepath).state_dict()) |
|
|
|
|
|
|
|
@rank_zero_call |
|
|
|
def save(self, folder, states: Dict): |
|
|
|
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): |
|
|
|
r""" |
|
|
|
断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict; |
|
|
|
需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver |
|
|
|
@@ -203,48 +226,110 @@ class PaddleDriver(Driver): |
|
|
|
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 |
|
|
|
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的 |
|
|
|
传入的值保持一致。 |
|
|
|
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 |
|
|
|
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 |
|
|
|
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
# 1. 保存模型的状态; |
|
|
|
model = self.unwrap_model() |
|
|
|
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} |
|
|
|
# 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; |
|
|
|
states["model_state_dict"] = model_state_dict |
|
|
|
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 |
|
|
|
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; |
|
|
|
|
|
|
|
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; |
|
|
|
# paddle 的 DataLoader 在初始化之后 batch_sampler 可能为 None,也可能为用户设置的 batch_sampler |
|
|
|
dataloader_args = self.get_dataloader_args(dataloader) |
|
|
|
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): |
|
|
|
sampler = dataloader_args.batch_sampler |
|
|
|
elif dataloader_args.sampler: |
|
|
|
sampler = dataloader_args.sampler |
|
|
|
else: |
|
|
|
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") |
|
|
|
|
|
|
|
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): |
|
|
|
states['sampler_states'] = sampler.state_dict() |
|
|
|
else: |
|
|
|
raise RuntimeError( |
|
|
|
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') |
|
|
|
|
|
|
|
# 2. 保存 optimizers 的状态; |
|
|
|
# 2. 保存模型的状态; |
|
|
|
if should_save_model: |
|
|
|
model = self.unwrap_model() |
|
|
|
if only_state_dict: |
|
|
|
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} |
|
|
|
paddle.save(model_state_dict, folder.joinpath(FASTNLP_MODEL_FILENAME)) |
|
|
|
logger.debug("Save model state dict") |
|
|
|
else: |
|
|
|
input_spec = kwargs.get("input_spec", None) |
|
|
|
if input_spec is None: |
|
|
|
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") |
|
|
|
paddle.jit.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME), input_spec) |
|
|
|
logger.debug("Save model") |
|
|
|
|
|
|
|
# 3. 保存 optimizers 的状态; |
|
|
|
optimizers_state_dict = {} |
|
|
|
for i in range(len(self.optimizers)): |
|
|
|
optimizer: Optimizer = self.optimizers[i] |
|
|
|
optimizer_state = optimizer.state_dict() |
|
|
|
optimizer_state = {name: param.cpu().detach().clone() for name, param in optimizer_state.items()} |
|
|
|
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") |
|
|
|
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; |
|
|
|
states["optimizers_state_dict"] = optimizers_state_dict |
|
|
|
|
|
|
|
paddle.save(states, folder) |
|
|
|
|
|
|
|
def load(self, filepath) -> Dict: |
|
|
|
r""" |
|
|
|
断点重训的加载函数,注意该函数会负责读取数据,并且恢复模型和 optimizers 的 state_dict 等; |
|
|
|
driver 实例需要在该函数中先加载模型和 optimizers 的 state_dict,然后将一个 state 字典返回给 trainer 。 |
|
|
|
因此 save 函数和 load 函数的接受和返回值应该是对应的; |
|
|
|
|
|
|
|
该函数需要在所有 rank 上执行。 |
|
|
|
logger.debug("Save optimizer state dict") |
|
|
|
states["optimizers_state_dict"] = optimizers_state_dict |
|
|
|
paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) |
|
|
|
|
|
|
|
:param filepath: 保存断点重训的状态的文件名; |
|
|
|
:return: 需要返回 save 函数输入的 states 内容; |
|
|
|
""" |
|
|
|
states = paddle.load(filepath) |
|
|
|
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: |
|
|
|
|
|
|
|
states = paddle.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) |
|
|
|
|
|
|
|
# 1. 加载 optimizers 的状态; |
|
|
|
optimizers_state_dict = states["optimizers_state_dict"] |
|
|
|
for i in range(len(self.optimizers)): |
|
|
|
optimizer: paddle.optimizer.Optimizer = self.optimizers[i] |
|
|
|
optimizer: Optimizer = self.optimizers[i] |
|
|
|
optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) |
|
|
|
logger.debug("Load optimizer state dict.") |
|
|
|
|
|
|
|
# 2. 加载模型状态; |
|
|
|
model = self.unwrap_model() |
|
|
|
model.load_dict(states["model_state_dict"]) |
|
|
|
if should_load_model: |
|
|
|
model = self.unwrap_model() |
|
|
|
if only_state_dict: |
|
|
|
res = paddle.load(folder.joinpath(FASTNLP_MODEL_FILENAME)) |
|
|
|
model.load_dict(res) |
|
|
|
logger.debug("Load model state dict.") |
|
|
|
else: |
|
|
|
model.load_dict(paddle.jit.load(folder.joinpath(FASTNLP_MODEL_FILENAME)).state_dict()) |
|
|
|
logger.debug("Load model.") |
|
|
|
|
|
|
|
# 3. 恢复 sampler 的状态; |
|
|
|
dataloader_args = self.get_dataloader_args(dataloader) |
|
|
|
sampler = dataloader_args.sampler |
|
|
|
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): |
|
|
|
# 说明这里需要使用 ReproduceSampler 来弄一下了 |
|
|
|
if self.is_distributed(): |
|
|
|
raise RuntimeError( |
|
|
|
"It is not allowed to use single device checkpoint retraining before but ddp now.") |
|
|
|
sampler = ReproducibleBatchSampler( |
|
|
|
batch_sampler=sampler, |
|
|
|
batch_size=dataloader_args.batch_sampler.batch_size, |
|
|
|
drop_last=dataloader_args.drop_last |
|
|
|
) |
|
|
|
sampler.load_state_dict(states['sampler_states']) |
|
|
|
|
|
|
|
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) |
|
|
|
|
|
|
|
# 4. 修改 trainer_state.batch_idx_in_epoch |
|
|
|
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; |
|
|
|
if not isinstance(sampler, ReproducibleBatchSampler): |
|
|
|
if dataloader_args.drop_last: |
|
|
|
batch_idx_in_epoch = len( |
|
|
|
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size |
|
|
|
else: |
|
|
|
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ |
|
|
|
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size |
|
|
|
# sampler 是 batch_sampler; |
|
|
|
else: |
|
|
|
batch_idx_in_epoch = sampler.batch_idx_in_epoch |
|
|
|
|
|
|
|
states["batch_idx_in_epoch"] = batch_idx_in_epoch |
|
|
|
|
|
|
|
self.barrier() |
|
|
|
return states |
|
|
|
|
|
|
|
def get_evaluate_context(self): |
|
|
|
@@ -313,3 +398,53 @@ class PaddleDriver(Driver): |
|
|
|
""" |
|
|
|
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): |
|
|
|
dataloader.batch_sampler.set_epoch(cur_epoch_idx) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_dataloader_args(dataloader: "DataLoader"): |
|
|
|
""" |
|
|
|
获取 dataloader 的 shuffle 和 drop_last 属性; |
|
|
|
""" |
|
|
|
|
|
|
|
@dataclass |
|
|
|
class Res: |
|
|
|
dataset: Optional[Dataset] = None |
|
|
|
batch_sampler: Optional[BatchSampler] = None |
|
|
|
sampler: Optional[Sampler] = None |
|
|
|
batch_size: Optional[int] = None |
|
|
|
shuffle: Optional[bool] = None |
|
|
|
drop_last: Optional[bool] = None |
|
|
|
|
|
|
|
res = Res() |
|
|
|
|
|
|
|
# paddle 的 DataLoader 一定会有 dataset 属性; |
|
|
|
res.dataset = dataloader.dataset |
|
|
|
|
|
|
|
if dataloader.batch_sampler is not None: |
|
|
|
res.batch_sampler = dataloader.batch_sampler |
|
|
|
if hasattr(dataloader.batch_sampler, "batch_size"): |
|
|
|
res.batch_size = getattr(dataloader.batch_sampler, "batch_size") |
|
|
|
# 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性; |
|
|
|
else: |
|
|
|
dataloader_iter = iter(dataloader) |
|
|
|
pre_sample = next(dataloader_iter) |
|
|
|
res.batch_size = pre_sample.shape[0] |
|
|
|
|
|
|
|
if hasattr(dataloader.batch_sampler, "sampler"): |
|
|
|
res.sampler = dataloader.batch_sampler.sampler |
|
|
|
if hasattr(dataloader.batch_sampler.sampler, "shuffle"): |
|
|
|
res.shuffle = dataloader.batch_sampler.sampler.shuffle |
|
|
|
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): |
|
|
|
res.shuffle = True |
|
|
|
else: |
|
|
|
res.shuffle = False |
|
|
|
else: |
|
|
|
res.sampler = None |
|
|
|
res.shuffle = False |
|
|
|
|
|
|
|
if hasattr(dataloader.batch_sampler, "drop_last"): |
|
|
|
res.drop_last = getattr(dataloader.batch_sampler, "drop_last") |
|
|
|
# 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性; |
|
|
|
else: |
|
|
|
res.drop_last = False |
|
|
|
|
|
|
|
return res |