| @@ -1,4 +1,5 @@ | |||||
| numpy>=1.14.2 | numpy>=1.14.2 | ||||
| http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl | http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl | ||||
| torchvision>=0.1.8 | torchvision>=0.1.8 | ||||
| sphinx-rtd-theme==0.4.1 | |||||
| sphinx-rtd-theme==0.4.1 | |||||
| tensorboardX>=1.4 | |||||
| @@ -92,8 +92,40 @@ class ConfigSection(object): | |||||
| setattr(self, key, value) | setattr(self, key, value) | ||||
| def __contains__(self, item): | def __contains__(self, item): | ||||
| """ | |||||
| :param item: The key of item. | |||||
| :return: True if the key in self.__dict__.keys() else False. | |||||
| """ | |||||
| return item in self.__dict__.keys() | return item in self.__dict__.keys() | ||||
| def __eq__(self, other): | |||||
| """Overwrite the == operator | |||||
| :param other: Another ConfigSection() object which to be compared. | |||||
| :return: True if value of each key in each ConfigSection() object are equal to the other, else False. | |||||
| """ | |||||
| for k in self.__dict__.keys(): | |||||
| if k not in other.__dict__.keys(): | |||||
| return False | |||||
| if getattr(self, k) != getattr(self, k): | |||||
| return False | |||||
| for k in other.__dict__.keys(): | |||||
| if k not in self.__dict__.keys(): | |||||
| return False | |||||
| if getattr(self, k) != getattr(self, k): | |||||
| return False | |||||
| return True | |||||
| def __ne__(self, other): | |||||
| """Overwrite the != operator | |||||
| :param other: | |||||
| :return: | |||||
| """ | |||||
| return not self.__eq__(other) | |||||
| @property | @property | ||||
| def data(self): | def data(self): | ||||
| return self.__dict__ | return self.__dict__ | ||||
| @@ -0,0 +1,147 @@ | |||||
| import os | |||||
| from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||||
| from fastNLP.saver.logger import create_logger | |||||
| class ConfigSaver(object): | |||||
| def __init__(self, file_path): | |||||
| self.file_path = file_path | |||||
| if not os.path.exists(self.file_path): | |||||
| raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) | |||||
| def _get_section(self, sect_name): | |||||
| """This is the function to get the section with the section name. | |||||
| :param sect_name: The name of section what wants to load. | |||||
| :return: The section. | |||||
| """ | |||||
| sect = ConfigSection() | |||||
| ConfigLoader(self.file_path).load_config(self.file_path, {sect_name: sect}) | |||||
| return sect | |||||
| def _read_section(self): | |||||
| """This is the function to read sections from the config file. | |||||
| :return: sect_list, sect_key_list | |||||
| sect_list: A list of ConfigSection(). | |||||
| sect_key_list: A list of names in sect_list. | |||||
| """ | |||||
| sect_name = None | |||||
| sect_list = {} | |||||
| sect_key_list = [] | |||||
| single_section = {} | |||||
| single_section_key = [] | |||||
| with open(self.file_path, 'r') as f: | |||||
| lines = f.readlines() | |||||
| for line in lines: | |||||
| if line.startswith('[') and line.endswith(']\n'): | |||||
| if sect_name is None: | |||||
| pass | |||||
| else: | |||||
| sect_list[sect_name] = single_section, single_section_key | |||||
| single_section = {} | |||||
| single_section_key = [] | |||||
| sect_key_list.append(sect_name) | |||||
| sect_name = line[1: -2] | |||||
| continue | |||||
| if line.startswith('#'): | |||||
| single_section[line] = '#' | |||||
| single_section_key.append(line) | |||||
| continue | |||||
| if line.startswith('\n'): | |||||
| single_section_key.append('\n') | |||||
| continue | |||||
| if '=' not in line: | |||||
| log = create_logger(__name__, './config_saver.log') | |||||
| log.error("can NOT load config file [%s]" % self.file_path) | |||||
| raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | |||||
| key = line.split('=', maxsplit=1)[0].strip() | |||||
| value = line.split('=', maxsplit=1)[1].strip() + '\n' | |||||
| single_section[key] = value | |||||
| single_section_key.append(key) | |||||
| if sect_name is not None: | |||||
| sect_list[sect_name] = single_section, single_section_key | |||||
| sect_key_list.append(sect_name) | |||||
| return sect_list, sect_key_list | |||||
| def _write_section(self, sect_list, sect_key_list): | |||||
| """This is the function to write config file with section list and name list. | |||||
| :param sect_list: A list of ConfigSection() need to be writen into file. | |||||
| :param sect_key_list: A list of name of sect_list. | |||||
| :return: | |||||
| """ | |||||
| with open(self.file_path, 'w') as f: | |||||
| for sect_key in sect_key_list: | |||||
| single_section, single_section_key = sect_list[sect_key] | |||||
| f.write('[' + sect_key + ']\n') | |||||
| for key in single_section_key: | |||||
| if key == '\n': | |||||
| f.write('\n') | |||||
| continue | |||||
| if single_section[key] == '#': | |||||
| f.write(key) | |||||
| continue | |||||
| f.write(key + ' = ' + single_section[key]) | |||||
| f.write('\n') | |||||
| def save_config_file(self, section_name, section): | |||||
| """This is the function to be called to change the config file with a single section and its name. | |||||
| :param section_name: The name of section what needs to be changed and saved. | |||||
| :param section: The section with key and value what needs to be changed and saved. | |||||
| :return: | |||||
| """ | |||||
| section_file = self._get_section(section_name) | |||||
| if len(section_file.__dict__.keys()) == 0:#the section not in file before | |||||
| with open(self.file_path, 'a') as f: | |||||
| f.write('[' + section_name + ']\n') | |||||
| for k in section.__dict__.keys(): | |||||
| f.write(k + ' = ') | |||||
| if isinstance(section[k], str): | |||||
| f.write('\"' + str(section[k]) + '\"\n\n') | |||||
| else: | |||||
| f.write(str(section[k]) + '\n\n') | |||||
| else: | |||||
| change_file = False | |||||
| for k in section.__dict__.keys(): | |||||
| if k not in section_file: | |||||
| change_file = True | |||||
| break | |||||
| if section_file[k] != section[k]: | |||||
| logger = create_logger(__name__, "./config_loader.log") | |||||
| logger.warning("section [%s] in config file [%s] has been changed" % ( | |||||
| section_name, self.file_path | |||||
| )) | |||||
| change_file = True | |||||
| break | |||||
| if not change_file: | |||||
| return | |||||
| sect_list, sect_key_list = self._read_section() | |||||
| if section_name not in sect_key_list: | |||||
| raise AttributeError() | |||||
| sect, sect_key = sect_list[section_name] | |||||
| for k in section.__dict__.keys(): | |||||
| if k not in sect_key: | |||||
| if sect_key[-1] != '\n': | |||||
| sect_key.append('\n') | |||||
| sect_key.append(k) | |||||
| sect[k] = str(section[k]) | |||||
| if isinstance(section[k], str): | |||||
| sect[k] = "\"" + sect[k] + "\"" | |||||
| sect[k] = sect[k] + "\n" | |||||
| sect_list[section_name] = sect, sect_key | |||||
| self._write_section(sect_list, sect_key_list) | |||||
| @@ -1,7 +1,18 @@ | |||||
| [test] | [test] | ||||
| x = 1 | x = 1 | ||||
| y = 2 | y = 2 | ||||
| z = 3 | z = 3 | ||||
| #this is an example | |||||
| input = [1,2,3] | input = [1,2,3] | ||||
| text = "this is text" | text = "this is text" | ||||
| doubles = 0.5 | doubles = 0.5 | ||||
| [t] | |||||
| x = "this is an test section" | |||||
| [test-case-2] | |||||
| double = 0.5 | |||||
| @@ -33,18 +33,16 @@ class TestConfigLoader(unittest.TestCase): | |||||
| test_arg = ConfigSection() | test_arg = ConfigSection() | ||||
| ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) | ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) | ||||
| # ConfigLoader("config").load_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", | |||||
| # {"test": test_arg}) | |||||
| #dict = read_section_from_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", "test") | |||||
| dict = read_section_from_config(os.path.join("./test/loader", "config"), "test") | |||||
| section = read_section_from_config(os.path.join("./test/loader", "config"), "test") | |||||
| for sec in dict: | |||||
| if (sec not in test_arg) or (dict[sec] != test_arg[sec]): | |||||
| for sec in section: | |||||
| if (sec not in test_arg) or (section[sec] != test_arg[sec]): | |||||
| raise AttributeError("ERROR") | raise AttributeError("ERROR") | ||||
| for sec in test_arg.__dict__.keys(): | for sec in test_arg.__dict__.keys(): | ||||
| if (sec not in dict) or (dict[sec] != test_arg[sec]): | |||||
| if (sec not in section) or (section[sec] != test_arg[sec]): | |||||
| raise AttributeError("ERROR") | raise AttributeError("ERROR") | ||||
| try: | try: | ||||
| @@ -71,4 +69,4 @@ class TestDatasetLoader(unittest.TestCase): | |||||
| loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8") | loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8") | ||||
| data = loader.load() | data = loader.load() | ||||
| datas = loader.load_lines() | datas = loader.load_lines() | ||||
| print("pass TokenizeDatasetLoader test!") | |||||
| print("pass TokenizeDatasetLoader test!") | |||||
| @@ -0,0 +1,82 @@ | |||||
| import os | |||||
| import unittest | |||||
| import configparser | |||||
| import json | |||||
| from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||||
| from fastNLP.saver.config_saver import ConfigSaver | |||||
| class TestConfigSaver(unittest.TestCase): | |||||
| def test_case_1(self): | |||||
| config_file_dir = "./test/loader/" | |||||
| config_file_name = "config" | |||||
| config_file_path = os.path.join(config_file_dir, config_file_name) | |||||
| tmp_config_file_path = os.path.join(config_file_dir, "tmp_config") | |||||
| with open(config_file_path, "r") as f: | |||||
| lines = f.readlines() | |||||
| standard_section = ConfigSection() | |||||
| t_section = ConfigSection() | |||||
| ConfigLoader(config_file_path).load_config(config_file_path, {"test": standard_section, "t": t_section}) | |||||
| config_saver = ConfigSaver(config_file_path) | |||||
| section = ConfigSection() | |||||
| section["doubles"] = 0.8 | |||||
| section["tt"] = 0.5 | |||||
| section["test"] = 105 | |||||
| section["str"] = "this is a str" | |||||
| test_case_2_section = section | |||||
| test_case_2_section["double"] = 0.5 | |||||
| for k in section.__dict__.keys(): | |||||
| standard_section[k] = section[k] | |||||
| config_saver.save_config_file("test", section) | |||||
| config_saver.save_config_file("another-test", section) | |||||
| config_saver.save_config_file("one-another-test", section) | |||||
| config_saver.save_config_file("test-case-2", section) | |||||
| test_section = ConfigSection() | |||||
| at_section = ConfigSection() | |||||
| another_test_section = ConfigSection() | |||||
| one_another_test_section = ConfigSection() | |||||
| a_test_case_2_section = ConfigSection() | |||||
| ConfigLoader(config_file_path).load_config(config_file_path, {"test": test_section, | |||||
| "another-test": another_test_section, | |||||
| "t": at_section, | |||||
| "one-another-test": one_another_test_section, | |||||
| "test-case-2": a_test_case_2_section}) | |||||
| assert test_section == standard_section | |||||
| assert at_section == t_section | |||||
| assert another_test_section == section | |||||
| assert one_another_test_section == section | |||||
| assert a_test_case_2_section == test_case_2_section | |||||
| config_saver.save_config_file("test", section) | |||||
| with open(config_file_path, "w") as f: | |||||
| f.writelines(lines) | |||||
| with open(tmp_config_file_path, "w") as f: | |||||
| f.write('[test]\n') | |||||
| f.write('this is an fault example\n') | |||||
| tmp_config_saver = ConfigSaver(tmp_config_file_path) | |||||
| try: | |||||
| tmp_config_saver._read_section() | |||||
| except Exception as e: | |||||
| pass | |||||
| os.remove(tmp_config_file_path) | |||||
| try: | |||||
| tmp_config_saver = ConfigSaver("file-NOT-exist") | |||||
| except Exception as e: | |||||
| pass | |||||