You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import os
  2. from typing import Union, Dict
  3. def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
  4. """
  5. 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果
  6. {
  7. 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。
  8. 'test': 'xxx' # 可能有,也可能没有
  9. ...
  10. }
  11. 如果paths为不合法的,将直接进行raise相应的错误
  12. :param paths: 路径
  13. :return:
  14. """
  15. if isinstance(paths, str):
  16. if os.path.isfile(paths):
  17. return {'train': paths}
  18. elif os.path.isdir(paths):
  19. train_fp = os.path.join(paths, 'train.txt')
  20. if not os.path.isfile(train_fp):
  21. raise FileNotFoundError(f"train.txt is not found in folder {paths}.")
  22. files = {'train': train_fp}
  23. for filename in ['test.txt', 'dev.txt']:
  24. fp = os.path.join(paths, filename)
  25. if os.path.isfile(fp):
  26. files[filename.split('.')[0]] = fp
  27. return files
  28. else:
  29. raise FileNotFoundError(f"{paths} is not a valid file path.")
  30. elif isinstance(paths, dict):
  31. if paths:
  32. if 'train' not in paths:
  33. raise KeyError("You have to include `train` in your dict.")
  34. for key, value in paths.items():
  35. if isinstance(key, str) and isinstance(value, str):
  36. if not os.path.isfile(value):
  37. raise TypeError(f"{value} is not a valid file.")
  38. else:
  39. raise TypeError("All keys and values in paths should be str.")
  40. return paths
  41. else:
  42. raise ValueError("Empty paths is not allowed.")
  43. else:
  44. raise TypeError(f"paths only supports str and dict. not {type(paths)}.")