You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 1.6 kB

12345678910111213141516171819202122232425262728293031323334353637
  1. from fastNLP.core.collators.utils import *
  2. def test_unpack_batch_mapping():
  3. batch = [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}]
  4. assert unpack_batch_mapping(batch, {})=={'a': [[1, 2], [3]], 'b': [1, 2]}
  5. def test_unpack_batch_nested_mapping():
  6. batch = [{'a': [1, 2], 'b': 1, 'c': {'c': 1}}, {'a': [3], 'b': 2, 'c': {'c': 2}}]
  7. assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c','c'): [1, 2]}
  8. batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1}}}, {'a': [3], 'b': 2, 'c': {'c': {'c': 2}}}]
  9. assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2]}
  10. batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1, 'd':[1, 1]}, 'd': [1]}},
  11. {'a': [3], 'b': 2, 'c': {'c': {'c': 2, 'd': [2, 2]}, 'd': [2, 2]}}]
  12. assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2],
  13. ('c','c', 'd'):[[1, 1], [2, 2]], ('c', 'd'): [[1], [2, 2]]}
  14. def test_pack_batch_nested_mapping():
  15. batch = {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2],
  16. ('c', 'c', 'd'):[[1, 1], [2, 2]], ('c', 'd'): [[1], [2, 2]]}
  17. new_batch = pack_batch_nested_mapping(batch)
  18. assert new_batch == {'a': [[1, 2], [3]], 'b': [1, 2],
  19. 'c': {'c':{'c': [1, 2], 'd': [[1, 1], [2, 2]]}, 'd':[[1], [2, 2]]}}
  20. def test_unpack_batch_sequence():
  21. batch = [[1, 2, 3], [2, 4, 6]]
  22. new_batch = unpack_batch_sequence(batch, {})
  23. assert new_batch == {'_0': [1, 2], '_1': [2, 4], '_2': [3, 6]}