| @@ -0,0 +1,10 @@ | |||||
| # 首先保证 FASTNLP_GLOBAL_RANK 正确设置 | |||||
| from fastNLP.envs.set_env_on_import import set_env_on_import | |||||
| set_env_on_import() | |||||
| # 再设置 backend 相关 | |||||
| from fastNLP.envs.set_backend import _set_backend | |||||
| _set_backend() | |||||
| from fastNLP.core import Trainer, Evaluator | |||||
| @@ -0,0 +1,22 @@ | |||||
| __all__ = [ | |||||
| "TorchSingleDriver", | |||||
| "TorchDDPDriver", | |||||
| "PaddleSingleDriver", | |||||
| "PaddleFleetDriver", | |||||
| "JittorSingleDriver", | |||||
| "JittorMPIDriver", | |||||
| "TorchPaddleDriver", | |||||
| "paddle_to", | |||||
| "get_paddle_gpu_str", | |||||
| "get_paddle_device_id", | |||||
| "paddle_move_data_to_device", | |||||
| "torch_paddle_move_data_to_device", | |||||
| ] | |||||
| # TODO:之后要优化一下这里的导入,应该是每一个 sub module 先import自己内部的类和函数,然后外层的 module 再直接从 submodule 中 import; | |||||
| from fastNLP.core.controllers.trainer import Trainer | |||||
| from fastNLP.core.controllers.evaluator import Evaluator | |||||
| from fastNLP.core.dataloaders.torch_dataloader import * | |||||
| from .drivers import * | |||||
| from .utils import * | |||||
| @@ -32,8 +32,8 @@ __all__ = [ | |||||
| ] | ] | ||||
| from fastNLP.core.log.handler import StdoutStreamHandler, TqdmLoggingHandler | from fastNLP.core.log.handler import StdoutStreamHandler, TqdmLoggingHandler | ||||
| from fastNLP.core.envs import FASTNLP_LOG_LEVEL, FASTNLP_GLOBAL_RANK, FASTNLP_LAUNCH_TIME, FASTNLP_BACKEND_LAUNCH | |||||
| from fastNLP.core.envs import is_cur_env_distributed | |||||
| from fastNLP.envs.env import FASTNLP_LOG_LEVEL, FASTNLP_GLOBAL_RANK, FASTNLP_LAUNCH_TIME, FASTNLP_BACKEND_LAUNCH | |||||
| from fastNLP.envs.distributed import is_cur_env_distributed | |||||
| ROOT_NAME = 'fastNLP' | ROOT_NAME = 'fastNLP' | ||||
| @@ -10,7 +10,7 @@ __all__ = [ | |||||
| 'all_rank_call' | 'all_rank_call' | ||||
| ] | ] | ||||
| from fastNLP.core.envs import FASTNLP_GLOBAL_RANK | |||||
| from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||||
| def is_cur_env_distributed() -> bool: | def is_cur_env_distributed() -> bool: | ||||
| @@ -3,8 +3,8 @@ import os | |||||
| import operator | import operator | ||||
| from fastNLP.core.envs.env import FASTNLP_BACKEND | |||||
| from fastNLP.core.envs.utils import _module_available, _compare_version | |||||
| from fastNLP.envs.env import FASTNLP_BACKEND | |||||
| from fastNLP.envs.utils import _module_available, _compare_version | |||||
| SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] | SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] | ||||
| @@ -8,9 +8,9 @@ import sys | |||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from fastNLP.core.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | |||||
| from fastNLP.core.envs import SUPPORT_BACKENDS | |||||
| from fastNLP.core.envs.utils import _module_available | |||||
| from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | |||||
| from fastNLP.envs.imports import SUPPORT_BACKENDS | |||||
| from fastNLP.envs.utils import _module_available | |||||
| def _set_backend(): | def _set_backend(): | ||||
| @@ -0,0 +1,17 @@ | |||||
| import os | |||||
| from fastNLP.envs.set_env import dump_fastnlp_backend | |||||
| from tests.helpers.utils import Capturing | |||||
| from fastNLP.core import synchronize_safe_rm | |||||
| def test_dump_fastnlp_envs(): | |||||
| filepath = None | |||||
| try: | |||||
| with Capturing() as output: | |||||
| dump_fastnlp_backend() | |||||
| filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') | |||||
| assert filepath in output[0] | |||||
| assert os.path.exists(filepath) | |||||
| finally: | |||||
| synchronize_safe_rm(filepath) | |||||