You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_loader.py 2.6 kB

7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import configparser
  2. import json
  3. import os
  4. import unittest
  5. from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
  6. from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, POSDatasetLoader, LMDatasetLoader
  7. class TestConfigLoader(unittest.TestCase):
  8. def test_case_ConfigLoader(self):
  9. def read_section_from_config(config_path, section_name):
  10. dict = {}
  11. if not os.path.exists(config_path):
  12. raise FileNotFoundError("config file {} NOT found.".format(config_path))
  13. cfg = configparser.ConfigParser()
  14. cfg.read(config_path)
  15. if section_name not in cfg:
  16. raise AttributeError("config file {} do NOT have section {}".format(
  17. config_path, section_name
  18. ))
  19. gen_sec = cfg[section_name]
  20. for s in gen_sec.keys():
  21. try:
  22. val = json.loads(gen_sec[s])
  23. dict[s] = val
  24. except Exception as e:
  25. raise AttributeError("json can NOT load {} in section {}, config file {}".format(
  26. s, section_name, config_path
  27. ))
  28. return dict
  29. test_arg = ConfigSection()
  30. ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg})
  31. section = read_section_from_config(os.path.join("./test/loader", "config"), "test")
  32. for sec in section:
  33. if (sec not in test_arg) or (section[sec] != test_arg[sec]):
  34. raise AttributeError("ERROR")
  35. for sec in test_arg.__dict__.keys():
  36. if (sec not in section) or (section[sec] != test_arg[sec]):
  37. raise AttributeError("ERROR")
  38. try:
  39. not_exist = test_arg["NOT EXIST"]
  40. except Exception as e:
  41. pass
  42. print("pass config test!")
  43. class TestDatasetLoader(unittest.TestCase):
  44. def test_case_TokenizeDatasetLoader(self):
  45. loader = TokenizeDatasetLoader("./test/data_for_tests/cws_pku_utf_8")
  46. data = loader.load_pku(max_seq_len=32)
  47. print("pass TokenizeDatasetLoader test!")
  48. def test_case_POSDatasetLoader(self):
  49. loader = POSDatasetLoader("./test/data_for_tests/people.txt")
  50. data = loader.load()
  51. datas = loader.load_lines()
  52. print("pass POSDatasetLoader test!")
  53. def test_case_LMDatasetLoader(self):
  54. loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8")
  55. data = loader.load()
  56. datas = loader.load_lines()
  57. print("pass TokenizeDatasetLoader test!")

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等