| @@ -23,7 +23,6 @@ class TestRawSequencePadder: | |||||
| assert (a == b).sum().item() == shape[0]*shape[1] | assert (a == b).sum().item() == shape[0]*shape[1] | ||||
| def test_dtype_check(self): | def test_dtype_check(self): | ||||
| with pytest.raises(DtypeError): | |||||
| padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
| padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
| with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
| padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) | padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) | ||||
| @@ -3,6 +3,7 @@ import pytest | |||||
| import subprocess | import subprocess | ||||
| from io import StringIO | from io import StringIO | ||||
| import sys | import sys | ||||
| sys.path.append(os.path.join(os.path.dirname(__file__), '../../..')) | |||||
| from fastNLP.core.utils.cache_results import cache_results | from fastNLP.core.utils.cache_results import cache_results | ||||
| from fastNLP.core import rank_zero_rm | from fastNLP.core import rank_zero_rm | ||||
| @@ -1,4 +1,5 @@ | |||||
| import os | import os | ||||
| import pytest | |||||
| from fastNLP.envs.set_backend import dump_fastnlp_backend | from fastNLP.envs.set_backend import dump_fastnlp_backend | ||||
| from tests.helpers.utils import Capturing | from tests.helpers.utils import Capturing | ||||
| @@ -9,7 +10,7 @@ def test_dump_fastnlp_envs(): | |||||
| filepath = None | filepath = None | ||||
| try: | try: | ||||
| with Capturing() as output: | with Capturing() as output: | ||||
| dump_fastnlp_backend() | |||||
| dump_fastnlp_backend(backend="torch") | |||||
| filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') | filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') | ||||
| assert filepath in output[0] | assert filepath in output[0] | ||||
| assert os.path.exists(filepath) | assert os.path.exists(filepath) | ||||