You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_prepare_dataloader.py 355 B

12345678910111213
  1. import pytest
  2. from fastNLP import prepare_dataloader
  3. from fastNLP import DataSet
  4. @pytest.mark.torch
  5. def test_torch():
  6. import torch
  7. ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
  8. dl = prepare_dataloader(ds, batch_size=2, shuffle=True)
  9. for batch in dl:
  10. assert isinstance(batch['x'], torch.Tensor)