You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_loader.py 2.9 kB

7 years ago
7 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import configparser
  2. import json
  3. import os
  4. import unittest
  5. from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
  6. from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, POSDatasetLoader, LMDatasetLoader
  7. class TestConfigLoader(unittest.TestCase):
  8. def test_case_ConfigLoader(self):
  9. def read_section_from_config(config_path, section_name):
  10. dict = {}
  11. if not os.path.exists(config_path):
  12. raise FileNotFoundError("config file {} NOT found.".format(config_path))
  13. cfg = configparser.ConfigParser()
  14. cfg.read(config_path)
  15. if section_name not in cfg:
  16. raise AttributeError("config file {} do NOT have section {}".format(
  17. config_path, section_name
  18. ))
  19. gen_sec = cfg[section_name]
  20. for s in gen_sec.keys():
  21. try:
  22. val = json.loads(gen_sec[s])
  23. dict[s] = val
  24. except Exception as e:
  25. raise AttributeError("json can NOT load {} in section {}, config file {}".format(
  26. s, section_name, config_path
  27. ))
  28. return dict
  29. test_arg = ConfigSection()
  30. ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg})
  31. # ConfigLoader("config").load_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config",
  32. # {"test": test_arg})
  33. #dict = read_section_from_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", "test")
  34. dict = read_section_from_config(os.path.join("./test/loader", "config"), "test")
  35. for sec in dict:
  36. if (sec not in test_arg) or (dict[sec] != test_arg[sec]):
  37. raise AttributeError("ERROR")
  38. for sec in test_arg.__dict__.keys():
  39. if (sec not in dict) or (dict[sec] != test_arg[sec]):
  40. raise AttributeError("ERROR")
  41. try:
  42. not_exist = test_arg["NOT EXIST"]
  43. except Exception as e:
  44. pass
  45. print("pass config test!")
  46. class TestDatasetLoader(unittest.TestCase):
  47. def test_case_TokenizeDatasetLoader(self):
  48. loader = TokenizeDatasetLoader("./test/data_for_tests/cws_pku_utf_8")
  49. data = loader.load_pku(max_seq_len=32)
  50. print("pass TokenizeDatasetLoader test!")
  51. def test_case_POSDatasetLoader(self):
  52. loader = POSDatasetLoader("./test/data_for_tests/people.txt")
  53. data = loader.load()
  54. datas = loader.load_lines()
  55. print("pass POSDatasetLoader test!")
  56. def test_case_LMDatasetLoader(self):
  57. loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8")
  58. data = loader.load()
  59. datas = loader.load_lines()
  60. print("pass TokenizeDatasetLoader test!")

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等