| @@ -14,14 +14,17 @@ import json | |||||
| from typing import Union | from typing import Union | ||||
| import numpy as np | import numpy as np | ||||
| import torch | |||||
| import torch.nn as nn | |||||
| from .embedding import TokenEmbedding | from .embedding import TokenEmbedding | ||||
| from ...core import logger | from ...core import logger | ||||
| from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
| from ...io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | from ...io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | ||||
| from ...io.file_utils import _get_file_name_base_on_postfix | from ...io.file_utils import _get_file_name_base_on_postfix | ||||
| from ...envs.imports import _NEED_IMPORT_TORCH | |||||
| if _NEED_IMPORT_TORCH: | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| VOCAB_FILENAME = 'vocab.txt' | VOCAB_FILENAME = 'vocab.txt' | ||||
| @@ -7,12 +7,18 @@ __all__ = [ | |||||
| "LSTM" | "LSTM" | ||||
| ] | ] | ||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.utils.rnn as rnn | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
| if _NEED_IMPORT_TORCH: | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.utils.rnn as rnn | |||||
| from torch.nn import Module | |||||
| else: | |||||
| from fastNLP.core.utils.dummy_class import DummyClass as Module | |||||
| class LSTM(nn.Module): | |||||
| class LSTM(Module): | |||||
| r""" | r""" | ||||
| LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 | LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 | ||||
| 为1; 且可以应对DataParallel中LSTM的使用问题。 | 为1; 且可以应对DataParallel中LSTM的使用问题。 | ||||
| @@ -15,7 +15,6 @@ from fastNLP.envs.distributed import rank_zero_rm | |||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.torch_data import TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchArgMaxDataset | ||||
| from tests.helpers.utils import Capturing | from tests.helpers.utils import Capturing | ||||
| from torchmetrics import Accuracy | |||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
| @@ -23,6 +22,7 @@ if _NEED_IMPORT_TORCH: | |||||
| from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
| from torch.optim import SGD | from torch.optim import SGD | ||||
| import torch.distributed as dist | import torch.distributed as dist | ||||
| from torchmetrics import Accuracy | |||||
| @dataclass | @dataclass | ||||
| class ArgMaxDatasetConfig: | class ArgMaxDatasetConfig: | ||||
| @@ -18,14 +18,15 @@ from tests.helpers.utils import magic_argv_env_context | |||||
| from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.torch_data import TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchArgMaxDataset | ||||
| from torchmetrics import Accuracy | |||||
| from fastNLP.core.metrics import Metric | from fastNLP.core.metrics import Metric | ||||
| from fastNLP.core.callbacks import MoreEvaluateCallback | from fastNLP.core.callbacks import MoreEvaluateCallback | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
| from torch.optim import SGD | from torch.optim import SGD | ||||
| import torch.distributed as dist | import torch.distributed as dist | ||||
| from torchmetrics import Accuracy | |||||
| @dataclass | @dataclass | ||||
| class ArgMaxDatasetConfig: | class ArgMaxDatasetConfig: | ||||
| @@ -19,18 +19,20 @@ for folder in list(folders[::-1]): | |||||
| path = os.sep.join(folders) | path = os.sep.join(folders) | ||||
| sys.path.extend([path, os.path.join(path, 'fastNLP')]) | sys.path.extend([path, os.path.join(path, 'fastNLP')]) | ||||
| import torch | |||||
| from torch.nn.parallel import DistributedDataParallel | |||||
| from torch.utils.data import DataLoader | |||||
| from torch.optim import SGD | |||||
| import torch.distributed as dist | |||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| from torchmetrics import Accuracy | |||||
| from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_2 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_2 | ||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | ||||
| if _NEED_IMPORT_TORCH: | |||||
| import torch | |||||
| from torch.nn.parallel import DistributedDataParallel | |||||
| from torch.utils.data import DataLoader | |||||
| from torch.optim import SGD | |||||
| import torch.distributed as dist | |||||
| from torchmetrics import Accuracy | |||||
| @dataclass | @dataclass | ||||
| class NormalClassificationTrainTorchConfig: | class NormalClassificationTrainTorchConfig: | ||||
| @@ -19,17 +19,18 @@ for folder in list(folders[::-1]): | |||||
| path = os.sep.join(folders) | path = os.sep.join(folders) | ||||
| sys.path.extend([path, os.path.join(path, 'fastNLP')]) | sys.path.extend([path, os.path.join(path, 'fastNLP')]) | ||||
| from torch.utils.data import DataLoader | |||||
| from torch.optim import SGD | |||||
| import torch.distributed as dist | |||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| from torchmetrics import Accuracy | |||||
| from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | ||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
| if _NEED_IMPORT_TORCH: | |||||
| from torch.utils.data import DataLoader | |||||
| from torch.optim import SGD | |||||
| import torch.distributed as dist | |||||
| from torchmetrics import Accuracy | |||||
| @dataclass | @dataclass | ||||
| class NormalClassificationTrainTorchConfig: | class NormalClassificationTrainTorchConfig: | ||||
| @@ -5,7 +5,6 @@ import pytest | |||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| from typing import Any | from typing import Any | ||||
| from torchmetrics import Accuracy | |||||
| from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
| @@ -18,6 +17,7 @@ if _NEED_IMPORT_TORCH: | |||||
| from torch.optim import SGD | from torch.optim import SGD | ||||
| from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
| import torch.distributed as dist | import torch.distributed as dist | ||||
| from torchmetrics import Accuracy | |||||
| @dataclass | @dataclass | ||||
| @@ -1,16 +1,18 @@ | |||||
| import os | import os | ||||
| import pytest | import pytest | ||||
| import torch | |||||
| import torch.distributed as dist | |||||
| import numpy as np | import numpy as np | ||||
| # print(isinstance((1,), tuple)) | # print(isinstance((1,), tuple)) | ||||
| # exit() | # exit() | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
| from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | ||||
| from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context | from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context | ||||
| if _NEED_IMPORT_TORCH: | |||||
| import torch | |||||
| import torch.distributed as dist | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| @@ -5,10 +5,11 @@ from fastNLP.core.drivers.torch_driver.utils import ( | |||||
| replace_sampler, | replace_sampler, | ||||
| ) | ) | ||||
| from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | ||||
| from torch.utils.data import DataLoader, BatchSampler | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
| if _NEED_IMPORT_TORCH: | |||||
| from torch.utils.data import DataLoader, BatchSampler | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| def test_replace_batch_sampler(): | def test_replace_batch_sampler(): | ||||
| @@ -1,13 +1,16 @@ | |||||
| from array import array | from array import array | ||||
| import torch | |||||
| from torch.utils.data import DataLoader | |||||
| import pytest | import pytest | ||||
| from fastNLP.core.samplers import ReproduceBatchSampler | from fastNLP.core.samplers import ReproduceBatchSampler | ||||
| from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
| if _NEED_IMPORT_TORCH: | |||||
| import torch | |||||
| from torch.utils.data import DataLoader | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| class TestReproducibleBatchSamplerTorch: | class TestReproducibleBatchSamplerTorch: | ||||
| @@ -1,9 +1,9 @@ | |||||
| import torch | |||||
| from functools import reduce | from functools import reduce | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| from torch.utils.data import Dataset, DataLoader, DistributedSampler | |||||
| from torch.utils.data.sampler import SequentialSampler, BatchSampler | |||||
| import torch | |||||
| from torch.utils.data import Dataset | |||||
| else: | else: | ||||
| from fastNLP.core.utils.dummy_class import DummyClass as Dataset | from fastNLP.core.utils.dummy_class import DummyClass as Dataset | ||||