import os import configparser import json import unittest from fastNLP.loader.config_loader import ConfigSection, ConfigLoader from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, POSDatasetLoader, LMDatasetLoader class TestConfigLoader(unittest.TestCase): def test_case_ConfigLoader(self): def read_section_from_config(config_path, section_name): dict = {} if not os.path.exists(config_path): raise FileNotFoundError("config file {} NOT found.".format(config_path)) cfg = configparser.ConfigParser() cfg.read(config_path) if section_name not in cfg: raise AttributeError("config file {} do NOT have section {}".format( config_path, section_name )) gen_sec = cfg[section_name] for s in gen_sec.keys(): try: val = json.loads(gen_sec[s]) dict[s] = val except Exception as e: raise AttributeError("json can NOT load {} in section {}, config file {}".format( s, section_name, config_path )) return dict test_arg = ConfigSection() ConfigLoader("config", "").load_config(os.path.join("./loader", "config"), {"test": test_arg}) #ConfigLoader("config", "").load_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", # {"test": test_arg}) #dict = read_section_from_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", "test") dict = read_section_from_config(os.path.join("./loader", "config"), "test") for sec in dict: if (sec not in test_arg) or (dict[sec] != test_arg[sec]): raise AttributeError("ERROR") for sec in test_arg.__dict__.keys(): if (sec not in dict) or (dict[sec] != test_arg[sec]): raise AttributeError("ERROR") try: not_exist = test_arg["NOT EXIST"] except Exception as e: pass print("pass config test!") class TestDatasetLoader(unittest.TestCase): def test_case_TokenizeDatasetLoader(self): loader = TokenizeDatasetLoader("cws_pku_utf_8", "./data_for_tests/cws_pku_utf_8") data = loader.load_pku(max_seq_len=32) print("pass TokenizeDatasetLoader test!") def test_case_POSDatasetLoader(self): loader = POSDatasetLoader("people", "./data_for_tests/people.txt") data = loader.load() datas = loader.load_lines() print("pass POSDatasetLoader test!") def test_case_LMDatasetLoader(self): loader = LMDatasetLoader("cws_pku_utf_8", "./data_for_tests/cws_pku_utf_8") data = loader.load() datas = loader.load_lines() print("pass TokenizeDatasetLoader test!")