| @@ -1,3 +1,3 @@ | |||||
| from fastNLP.envs import * | from fastNLP.envs import * | ||||
| from fastNLP.core import Trainer, Evaluator | |||||
| from fastNLP.core import * | |||||
| @@ -57,9 +57,37 @@ __all__ = [ | |||||
| "TorchPaddleDriver", | "TorchPaddleDriver", | ||||
| # log | # log | ||||
| "logger" | |||||
| "logger", | |||||
| "print", | |||||
| # | |||||
| # metrics | |||||
| "Metric", | |||||
| "Accuracy", | |||||
| 'SpanFPreRecMetric', | |||||
| 'ClassifyFPreRecMetric', | |||||
| # samplers | |||||
| 'ReproducibleSampler', | |||||
| 'RandomSampler', | |||||
| "SequentialSampler", | |||||
| "SortedSampler", | |||||
| 'UnrepeatedSampler', | |||||
| 'UnrepeatedRandomSampler', | |||||
| "UnrepeatedSortedSampler", | |||||
| "UnrepeatedSequentialSampler", | |||||
| "ReproduceBatchSampler", | |||||
| "BucketedBatchSampler", | |||||
| "ReproducibleBatchSampler", | |||||
| "RandomBatchSampler", | |||||
| # utils | |||||
| "cache_results", | |||||
| "f_rich_progress", | |||||
| "auto_param_call", | |||||
| "seq_len_to_mask", | |||||
| # vocabulary.py | |||||
| 'Vocabulary' | |||||
| ] | ] | ||||
| from .callbacks import * | from .callbacks import * | ||||
| from .collators import * | from .collators import * | ||||
| @@ -68,4 +96,7 @@ from .dataloaders import * | |||||
| from .dataset import * | from .dataset import * | ||||
| from .drivers import * | from .drivers import * | ||||
| from .log import * | from .log import * | ||||
| from .utils import * | |||||
| from .metrics import * | |||||
| from .samplers import * | |||||
| from .utils import * | |||||
| from .vocabulary import Vocabulary | |||||
| @@ -7,7 +7,7 @@ from copy import deepcopy | |||||
| from pathlib import Path | from pathlib import Path | ||||
| from typing import Optional, Dict, Tuple, Callable, Union | from typing import Optional, Dict, Tuple, Callable, Union | ||||
| from fastNLP.core.utils import rank_zero_rm | |||||
| from ...envs.distributed import rank_zero_rm | |||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from fastNLP.envs import FASTNLP_LAUNCH_TIME | from fastNLP.envs import FASTNLP_LAUNCH_TIME | ||||
| from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
| @@ -8,6 +8,7 @@ __all__ = [ | |||||
| "NullPadder", | "NullPadder", | ||||
| "RawNumberPadder", | "RawNumberPadder", | ||||
| "RawSequencePadder", | "RawSequencePadder", | ||||
| "RawTensorPadder", | |||||
| 'TorchNumberPadder', | 'TorchNumberPadder', | ||||
| 'TorchSequencePadder', | 'TorchSequencePadder', | ||||
| 'TorchTensorPadder', | 'TorchTensorPadder', | ||||
| @@ -67,7 +67,7 @@ def _get_backend() -> str: | |||||
| # 方式 (2) | # 方式 (2) | ||||
| for backend in CHECK_BACKEND: | for backend in CHECK_BACKEND: | ||||
| if backend in sys.modules: | if backend in sys.modules: | ||||
| logger.debug(f"sys.modules contains backend:{catch_backend[0]}.") | |||||
| logger.debug(f"sys.modules contains backend:{backend}.") | |||||
| return backend | return backend | ||||
| for key, module in sys.modules.items(): | for key, module in sys.modules.items(): | ||||
| catch_backend = _check_module(module) | catch_backend = _check_module(module) | ||||
| @@ -9,6 +9,7 @@ __all__ = [ | |||||
| "RawNumberPadder", | "RawNumberPadder", | ||||
| "RawSequencePadder", | "RawSequencePadder", | ||||
| "RawTensorPadder", | |||||
| 'TorchNumberPadder', | 'TorchNumberPadder', | ||||
| 'TorchSequencePadder', | 'TorchSequencePadder', | ||||
| @@ -79,7 +79,7 @@ class NumpyTensorPadder(Padder): | |||||
| def pad(batch_field, pad_val, dtype): | def pad(batch_field, pad_val, dtype): | ||||
| try: | try: | ||||
| if not isinstance(batch_field[0], np.ndarray): | if not isinstance(batch_field[0], np.ndarray): | ||||
| batch_field = [np.array(field.tolist()) for field in batch_field] | |||||
| batch_field = [np.array(field.tolist(), dtype=dtype) for field in batch_field] | |||||
| except AttributeError: | except AttributeError: | ||||
| raise RuntimeError(f"If the field is not a np.ndarray (it is {type(batch_field[0])}), " | raise RuntimeError(f"If the field is not a np.ndarray (it is {type(batch_field[0])}), " | ||||
| f"it must have tolist() method.") | f"it must have tolist() method.") | ||||
| @@ -131,7 +131,7 @@ class PaddleTensorPadder(Padder): | |||||
| def pad(batch_field, pad_val, dtype): | def pad(batch_field, pad_val, dtype): | ||||
| try: | try: | ||||
| if not isinstance(batch_field[0], paddle.Tensor): | if not isinstance(batch_field[0], paddle.Tensor): | ||||
| batch_field = [paddle.to_tensor(field.tolist()) for field in batch_field] | |||||
| batch_field = [paddle.to_tensor(field.tolist(), dtype=dtype) for field in batch_field] | |||||
| except AttributeError: | except AttributeError: | ||||
| raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " | raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " | ||||
| f"it must have tolist() method.") | f"it must have tolist() method.") | ||||
| @@ -143,8 +143,6 @@ class PaddleTensorPadder(Padder): | |||||
| tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) | tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) | ||||
| for i, field in enumerate(batch_field): | for i, field in enumerate(batch_field): | ||||
| slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | ||||
| if isinstance(field, np.ndarray): | |||||
| field = paddle.to_tensor(field) | |||||
| tensor[slices] = field | tensor[slices] = field | ||||
| return tensor | return tensor | ||||
| @@ -114,7 +114,7 @@ class TorchTensorPadder(Padder): | |||||
| def pad(batch_field, pad_val, dtype): | def pad(batch_field, pad_val, dtype): | ||||
| try: | try: | ||||
| if not isinstance(batch_field[0], torch.Tensor): | if not isinstance(batch_field[0], torch.Tensor): | ||||
| batch_field = [torch.tensor(field.tolist()) for field in batch_field] | |||||
| batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field] | |||||
| except AttributeError: | except AttributeError: | ||||
| raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " | raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " | ||||
| f"it must have tolist() method.") | f"it must have tolist() method.") | ||||
| @@ -124,8 +124,6 @@ class TorchTensorPadder(Padder): | |||||
| tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | ||||
| for i, field in enumerate(batch_field): | for i, field in enumerate(batch_field): | ||||
| slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | ||||
| if isinstance(field, np.ndarray): | |||||
| field = torch.from_numpy(field) | |||||
| tensor[slices] = field | tensor[slices] = field | ||||
| return tensor | return tensor | ||||
| @@ -16,8 +16,10 @@ from fastNLP.core.utils import ( | |||||
| auto_param_call, | auto_param_call, | ||||
| check_user_specific_params, | check_user_specific_params, | ||||
| is_in_paddle_dist, | is_in_paddle_dist, | ||||
| rank_zero_rm | |||||
| rank_zero_rm, | |||||
| is_in_paddle_dist, | |||||
| ) | ) | ||||
| from fastNLP.envs.distributed import rank_zero_rm | |||||
| from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
| ReproduceBatchSampler, | ReproduceBatchSampler, | ||||
| ReproducibleSampler, | ReproducibleSampler, | ||||
| @@ -1,6 +1,8 @@ | |||||
| __all__ = [ | __all__ = [ | ||||
| 'logger' | |||||
| 'logger', | |||||
| "print" | |||||
| ] | ] | ||||
| from .logger import logger | from .logger import logger | ||||
| from .print import print | |||||
| @@ -1,16 +1,11 @@ | |||||
| __all__ = [ | __all__ = [ | ||||
| "Metric", | "Metric", | ||||
| "Accuracy", | "Accuracy", | ||||
| 'Backend', | |||||
| 'AutoBackend', | |||||
| 'PaddleBackend', | |||||
| 'TorchBackend', | |||||
| 'SpanFPreRecMetric', | 'SpanFPreRecMetric', | ||||
| 'ClassifyFPreRecMetric', | 'ClassifyFPreRecMetric', | ||||
| ] | ] | ||||
| from .metric import Metric | from .metric import Metric | ||||
| from .accuracy import Accuracy | from .accuracy import Accuracy | ||||
| from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend | |||||
| from .span_f1_pre_rec_metric import SpanFPreRecMetric | from .span_f1_pre_rec_metric import SpanFPreRecMetric | ||||
| from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | ||||
| @@ -24,8 +24,6 @@ __all__ = [ | |||||
| 'Option', | 'Option', | ||||
| 'deprecated', | 'deprecated', | ||||
| 'seq_len_to_mask', | 'seq_len_to_mask', | ||||
| 'rank_zero_rm', | |||||
| 'rank_zero_mkdir' | |||||
| ] | ] | ||||
| from .cache_results import cache_results | from .cache_results import cache_results | ||||
| @@ -37,7 +35,6 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device | |||||
| from .torch_utils import torch_move_data_to_device | from .torch_utils import torch_move_data_to_device | ||||
| from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | ||||
| dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | ||||
| deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir | |||||
| from ..dataloaders.utils import indice_collate_wrapper | |||||
| deprecated, seq_len_to_mask | |||||
| @@ -22,8 +22,6 @@ import numpy as np | |||||
| from pathlib import Path | from pathlib import Path | ||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from fastNLP.envs import FASTNLP_GLOBAL_RANK | |||||
| __all__ = [ | __all__ = [ | ||||
| 'get_fn_arg_names', | 'get_fn_arg_names', | ||||
| @@ -37,8 +35,6 @@ __all__ = [ | |||||
| 'Option', | 'Option', | ||||
| 'deprecated', | 'deprecated', | ||||
| 'seq_len_to_mask', | 'seq_len_to_mask', | ||||
| 'rank_zero_rm', | |||||
| 'rank_zero_mkdir' | |||||
| ] | ] | ||||
| @@ -609,54 +605,6 @@ def wait_filepath(path, exist=True): | |||||
| logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...") | logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...") | ||||
| def rank_zero_rm(path: Optional[Union[str, Path]]): | |||||
| """ | |||||
| 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | |||||
| 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | |||||
| 该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 | |||||
| :param path: | |||||
| :return: | |||||
| """ | |||||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
| if path is None: | |||||
| return | |||||
| if isinstance(path, str): | |||||
| path = Path(path) | |||||
| if not path.exists(): | |||||
| return | |||||
| _recursive_rm(path) | |||||
| def _recursive_rm(path: Path): | |||||
| if path.is_file() or path.is_symlink(): | |||||
| if path.exists(): | |||||
| try: | |||||
| path.unlink() | |||||
| except Exception: | |||||
| pass | |||||
| return | |||||
| for sub_path in list(path.iterdir()): | |||||
| _recursive_rm(sub_path) | |||||
| path.rmdir() | |||||
| def rank_zero_mkdir(path: Optional[Union[str, Path]]): | |||||
| """ | |||||
| 注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; | |||||
| 该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 | |||||
| """ | |||||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
| if path is None: | |||||
| return | |||||
| if isinstance(path, str): | |||||
| path = Path(path) | |||||
| path.mkdir(parents=True, exist_ok=True) | |||||
| def get_class_that_defined_method(method): | def get_class_that_defined_method(method): | ||||
| """ | """ | ||||
| 给定一个method,返回这个 method 的 class 的对象 | 给定一个method,返回这个 method 的 class 的对象 | ||||
| @@ -3,12 +3,17 @@ r""" | |||||
| """ | """ | ||||
| __all__ = [ | __all__ = [ | ||||
| 'dump_fastnlp_backend', | 'dump_fastnlp_backend', | ||||
| 'is_cur_env_distributed', | |||||
| 'get_global_rank', | |||||
| # utils | |||||
| 'get_gpu_count', | |||||
| # distributed | |||||
| "rank_zero_rm", | |||||
| 'rank_zero_call', | 'rank_zero_call', | ||||
| 'get_global_rank', | |||||
| 'fastnlp_no_sync_context', | |||||
| 'all_rank_call_context', | 'all_rank_call_context', | ||||
| 'get_gpu_count', | |||||
| 'fastnlp_no_sync_context' | |||||
| 'is_cur_env_distributed', | |||||
| ] | ] | ||||
| @@ -1,6 +1,7 @@ | |||||
| import os | import os | ||||
| from functools import wraps | from functools import wraps | ||||
| from typing import Callable, Any, Optional | |||||
| from pathlib import Path | |||||
| from typing import Callable, Any, Optional, Union | |||||
| from contextlib import contextmanager | from contextlib import contextmanager | ||||
| __all__ = [ | __all__ = [ | ||||
| @@ -8,7 +9,8 @@ __all__ = [ | |||||
| 'get_global_rank', | 'get_global_rank', | ||||
| 'rank_zero_call', | 'rank_zero_call', | ||||
| 'all_rank_call_context', | 'all_rank_call_context', | ||||
| 'fastnlp_no_sync_context' | |||||
| 'fastnlp_no_sync_context', | |||||
| "rank_zero_rm" | |||||
| ] | ] | ||||
| from fastNLP.envs.env import FASTNLP_GLOBAL_RANK, FASTNLP_NO_SYNC | from fastNLP.envs.env import FASTNLP_GLOBAL_RANK, FASTNLP_NO_SYNC | ||||
| @@ -96,3 +98,35 @@ def all_rank_call_context(): | |||||
| os.environ[FASTNLP_GLOBAL_RANK] = old_fastnlp_global_rank | os.environ[FASTNLP_GLOBAL_RANK] = old_fastnlp_global_rank | ||||
| else: | else: | ||||
| os.environ.pop(FASTNLP_GLOBAL_RANK) | os.environ.pop(FASTNLP_GLOBAL_RANK) | ||||
| def rank_zero_rm(path: Optional[Union[str, Path]]): | |||||
| """ | |||||
| 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | |||||
| 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | |||||
| 该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 | |||||
| :param path: | |||||
| :return: | |||||
| """ | |||||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
| if path is None: | |||||
| return | |||||
| if isinstance(path, str): | |||||
| path = Path(path) | |||||
| if not path.exists(): | |||||
| return | |||||
| _recursive_rm(path) | |||||
| def _recursive_rm(path: Path): | |||||
| if path.is_file() or path.is_symlink(): | |||||
| if path.exists(): | |||||
| try: | |||||
| path.unlink() | |||||
| except Exception: | |||||
| pass | |||||
| return | |||||
| for sub_path in list(path.iterdir()): | |||||
| _recursive_rm(sub_path) | |||||
| path.rmdir() | |||||
| @@ -22,7 +22,7 @@ FASTNLP_GLOBAL_RANK = "FASTNLP_GLOBAL_RANK" | |||||
| FASTNLP_LOG_LEVEL = "FASTNLP_LOG_LEVEL" | FASTNLP_LOG_LEVEL = "FASTNLP_LOG_LEVEL" | ||||
| # todo 每一个分布式的 driver 都应当正确地设立该值;具体可见 ddp; | |||||
| # 每一个分布式的 driver 都应当正确地设立该值;具体可见 ddp; | |||||
| # FASTNLP_LAUNCH_TIME 记录了当前 fastNLP 脚本启动的时间。 | # FASTNLP_LAUNCH_TIME 记录了当前 fastNLP 脚本启动的时间。 | ||||
| FASTNLP_LAUNCH_TIME = "FASTNLP_LAUNCH_TIME" | FASTNLP_LAUNCH_TIME = "FASTNLP_LAUNCH_TIME" | ||||
| @@ -42,7 +42,7 @@ USER_CUDA_VISIBLE_DEVICES = 'USER_CUDA_VISIBLE_DEVICES' | |||||
| # 用于在 torch.distributed.launch 时移除传入的 rank ,在 pytorch 中有使用。值的可选为 [0, 1] | # 用于在 torch.distributed.launch 时移除传入的 rank ,在 pytorch 中有使用。值的可选为 [0, 1] | ||||
| FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' | FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' | ||||
| # todo 注释 | |||||
| # 检测到当前脚本是通过类似 python -m torch.launch 启动的话设置这个变量为1 | |||||
| FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | ||||
| # fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行; | # fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行; | ||||
| @@ -11,7 +11,7 @@ from fastNLP.core.controllers.trainer import Trainer | |||||
| from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | ||||
| from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
| from fastNLP.core import rank_zero_rm | |||||
| from fastNLP.envs.distributed import rank_zero_rm | |||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.torch_data import TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchArgMaxDataset | ||||
| from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
| @@ -20,7 +20,7 @@ from fastNLP.core.controllers.trainer import Trainer | |||||
| from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | ||||
| from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
| from fastNLP.core import rank_zero_rm | |||||
| from fastNLP.envs.distributed import rank_zero_rm | |||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.torch_data import TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchArgMaxDataset | ||||
| from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
| @@ -83,7 +83,7 @@ class TestCollator: | |||||
| assert raw_pad_batch == collator(dict_batch) | assert raw_pad_batch == collator(dict_batch) | ||||
| collator = Collator(backend='raw') | collator = Collator(backend='raw') | ||||
| raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | ||||
| [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
| [1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], | |||||
| [{'1'}, {'2'}]] | [{'1'}, {'2'}]] | ||||
| findListDiff(raw_pad_lst, collator(list_batch)) | findListDiff(raw_pad_lst, collator(list_batch)) | ||||
| @@ -194,7 +194,7 @@ class TestCollator: | |||||
| collator.set_ignore('_0', '_3', '_1') | collator.set_ignore('_0', '_3', '_1') | ||||
| collator.set_pad('_4', pad_val=None) | collator.set_pad('_4', pad_val=None) | ||||
| raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], | raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], | ||||
| [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
| [1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], | |||||
| [{'1'}, {'2'}]] | [{'1'}, {'2'}]] | ||||
| findListDiff(raw_pad_lst, collator(list_batch)) | findListDiff(raw_pad_lst, collator(list_batch)) | ||||
| @@ -210,7 +210,7 @@ class TestCollator: | |||||
| collator.set_pad('_2', backend='numpy') | collator.set_pad('_2', backend='numpy') | ||||
| collator.set_pad('_4', backend='numpy', pad_val=100) | collator.set_pad('_4', backend='numpy', pad_val=100) | ||||
| raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), | raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), | ||||
| [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
| [1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], | |||||
| [{'1'}, {'2'}]] | [{'1'}, {'2'}]] | ||||
| findListDiff(raw_pad_lst, collator(list_batch)) | findListDiff(raw_pad_lst, collator(list_batch)) | ||||
| @@ -13,7 +13,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
| from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | ||||
| from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | ||||
| from tests.helpers.utils import magic_argv_env_context, Capturing | from tests.helpers.utils import magic_argv_env_context, Capturing | ||||
| from fastNLP.core import rank_zero_rm | |||||
| from fastNLP.envs.distributed import rank_zero_rm | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| import torch.distributed as dist | import torch.distributed as dist | ||||
| @@ -12,7 +12,7 @@ from fastNLP.core.samplers import ( | |||||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | ||||
| from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
| from fastNLP.core import rank_zero_rm | |||||
| from fastNLP.envs.distributed import rank_zero_rm | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
| if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
| import paddle | import paddle | ||||
| @@ -7,7 +7,7 @@ from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | ||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
| from fastNLP.core import rank_zero_rm | |||||
| from fastNLP.envs.distributed import rank_zero_rm | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | ||||
| if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
| import paddle | import paddle | ||||
| @@ -12,7 +12,7 @@ from fastNLP.core.samplers import ( | |||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | ||||
| from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
| from fastNLP.core import rank_zero_rm | |||||
| from fastNLP.envs.distributed import rank_zero_rm | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| import torch | import torch | ||||
| @@ -7,7 +7,7 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | ||||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset | ||||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
| from fastNLP.core import rank_zero_rm | |||||
| from fastNLP.envs.distributed import rank_zero_rm | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | ||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| import torch | import torch | ||||
| @@ -7,7 +7,7 @@ import re | |||||
| import pytest | import pytest | ||||
| from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | ||||
| from fastNLP.core import rank_zero_rm | |||||
| from fastNLP.envs.distributed import rank_zero_rm | |||||
| from fastNLP.core.log.logger import logger | from fastNLP.core.log.logger import logger | ||||
| from tests.helpers.utils import magic_argv_env_context, recover_logger | from tests.helpers.utils import magic_argv_env_context, recover_logger | ||||
| @@ -6,7 +6,7 @@ import sys | |||||
| sys.path.append(os.path.join(os.path.dirname(__file__), '../../..')) | sys.path.append(os.path.join(os.path.dirname(__file__), '../../..')) | ||||
| from fastNLP.core.utils.cache_results import cache_results | from fastNLP.core.utils.cache_results import cache_results | ||||
| from fastNLP.core import rank_zero_rm | |||||
| from fastNLP.envs.distributed import rank_zero_rm | |||||
| def get_subprocess_results(cmd): | def get_subprocess_results(cmd): | ||||
| @@ -3,7 +3,7 @@ import pytest | |||||
| from fastNLP.envs.set_backend import dump_fastnlp_backend | from fastNLP.envs.set_backend import dump_fastnlp_backend | ||||
| from tests.helpers.utils import Capturing | from tests.helpers.utils import Capturing | ||||
| from fastNLP.core import rank_zero_rm | |||||
| from fastNLP.envs.distributed import rank_zero_rm | |||||
| def test_dump_fastnlp_envs(): | def test_dump_fastnlp_envs(): | ||||
| @@ -9,7 +9,7 @@ import numpy as np | |||||
| from fastNLP.modules.mix_modules.mix_module import MixModule | from fastNLP.modules.mix_modules.mix_module import MixModule | ||||
| from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle | from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle | ||||
| from fastNLP.core import rank_zero_rm | |||||
| from fastNLP.envs.distributed import rank_zero_rm | |||||
| ############################################################################ | ############################################################################ | ||||