From 51a3e901f005333b04ec2b0aad9f5e2c2e9e0a0f Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 2 May 2022 11:05:29 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E9=83=A8=E5=88=86=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/collators/padders/test_raw_padder.py | 3 +-- tests/core/utils/test_cache_results.py | 1 + tests/envs/test_set_backend.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/core/collators/padders/test_raw_padder.py b/tests/core/collators/padders/test_raw_padder.py index 9742bc9a..9cb38766 100644 --- a/tests/core/collators/padders/test_raw_padder.py +++ b/tests/core/collators/padders/test_raw_padder.py @@ -23,7 +23,6 @@ class TestRawSequencePadder: assert (a == b).sum().item() == shape[0]*shape[1] def test_dtype_check(self): - with pytest.raises(DtypeError): - padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) + padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) with pytest.raises(DtypeError): padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) \ No newline at end of file diff --git a/tests/core/utils/test_cache_results.py b/tests/core/utils/test_cache_results.py index 5657ae81..77c618bb 100644 --- a/tests/core/utils/test_cache_results.py +++ b/tests/core/utils/test_cache_results.py @@ -3,6 +3,7 @@ import pytest import subprocess from io import StringIO import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '../../..')) from fastNLP.core.utils.cache_results import cache_results from fastNLP.core import rank_zero_rm diff --git a/tests/envs/test_set_backend.py b/tests/envs/test_set_backend.py index 395c854d..170110ce 100644 --- a/tests/envs/test_set_backend.py +++ b/tests/envs/test_set_backend.py @@ -1,4 +1,5 @@ import os +import pytest from fastNLP.envs.set_backend import dump_fastnlp_backend from tests.helpers.utils import Capturing @@ -9,7 +10,7 @@ def test_dump_fastnlp_envs(): filepath = None try: with Capturing() as output: - dump_fastnlp_backend() + dump_fastnlp_backend(backend="torch") filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') assert filepath in output[0] assert os.path.exists(filepath)