You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

jittor_utils.py 1.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. __all__ = [
  2. 'is_jittor_dataset',
  3. 'jittor_collate_wraps'
  4. ]
  5. from collections.abc import Mapping, Callable
  6. from functools import wraps
  7. from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
  8. if _NEED_IMPORT_JITTOR:
  9. import jittor as jt
  10. from fastNLP.core.dataset import Instance
  11. def is_jittor_dataset(dataset) -> bool:
  12. try:
  13. if isinstance(dataset, jt.dataset.Dataset):
  14. return True
  15. else:
  16. return False
  17. except BaseException:
  18. return False
  19. def jittor_collate_wraps(func, auto_collator: Callable):
  20. """
  21. 对jittor的collate_fn进行wrap封装, 如果数据集为mapping类型,那么采用auto_collator,否则还是采用jittor自带的collate_batch
  22. :param func:
  23. :param auto_collator:
  24. :return:
  25. """
  26. @wraps(func)
  27. def wrapper(batch):
  28. if isinstance(batch[0], Instance):
  29. if auto_collator is not None:
  30. result = auto_collator(batch)
  31. else:
  32. raise ValueError(f"auto_collator is None, but batch exist fastnlp instance!")
  33. elif isinstance(batch[0], Mapping):
  34. if auto_collator is not None:
  35. result = auto_collator(batch)
  36. else:
  37. result = func(batch)
  38. else:
  39. result = func(batch)
  40. return result
  41. return wrapper