| @@ -1,7 +0,0 @@ | |||
| fastNLP.io.base\_loader | |||
| ======================= | |||
| .. automodule:: fastNLP.io.base_loader | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -0,0 +1,7 @@ | |||
| fastNLP.io.data\_bundle | |||
| ======================= | |||
| .. automodule:: fastNLP.io.data_bundle | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -20,7 +20,7 @@ Submodules | |||
| .. toctree:: | |||
| fastNLP.io.base_loader | |||
| fastNLP.io.data_bundle | |||
| fastNLP.io.dataset_loader | |||
| fastNLP.io.embed_loader | |||
| fastNLP.io.file_utils | |||
| @@ -12,10 +12,9 @@ | |||
| 这些类的使用方法如下: | |||
| """ | |||
| __all__ = [ | |||
| 'EmbedLoader', | |||
| 'DataBundle', | |||
| 'DataSetLoader', | |||
| 'EmbedLoader', | |||
| 'YelpLoader', | |||
| 'YelpFullLoader', | |||
| @@ -69,7 +68,7 @@ __all__ = [ | |||
| ] | |||
| from .embed_loader import EmbedLoader | |||
| from .base_loader import DataBundle, DataSetLoader | |||
| from .data_bundle import DataBundle | |||
| from .dataset_loader import CSVLoader, JsonLoader | |||
| from .model_io import ModelLoader, ModelSaver | |||
| @@ -1,313 +0,0 @@ | |||
| """ | |||
| 用于读入和处理和保存 config 文件 | |||
| .. todo:: | |||
| 这个模块中的类可能被抛弃? | |||
| """ | |||
| __all__ = [ | |||
| "ConfigLoader", | |||
| "ConfigSection", | |||
| "ConfigSaver" | |||
| ] | |||
| import configparser | |||
| import json | |||
| import os | |||
| from .base_loader import BaseLoader | |||
| class ConfigLoader(BaseLoader): | |||
| """ | |||
| 别名::class:`fastNLP.io.ConfigLoader` :class:`fastNLP.io.config_io.ConfigLoader` | |||
| 读取配置文件的Loader | |||
| :param str data_path: 配置文件的路径 | |||
| """ | |||
| def __init__(self, data_path=None): | |||
| super(ConfigLoader, self).__init__() | |||
| if data_path is not None: | |||
| self.config = self.parse(super(ConfigLoader, self).load(data_path)) | |||
| @staticmethod | |||
| def parse(string): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def load_config(file_path, sections): | |||
| """ | |||
| 把配置文件的section 存入提供的 ``sections`` 中 | |||
| :param str file_path: 配置文件的路径 | |||
| :param dict sections: 符合如下键值对组成的字典 `section_name(string)` : :class:`~fastNLP.io.ConfigSection` | |||
| Example:: | |||
| test_args = ConfigSection() | |||
| ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||
| """ | |||
| assert isinstance(sections, dict) | |||
| cfg = configparser.ConfigParser() | |||
| if not os.path.exists(file_path): | |||
| raise FileNotFoundError("config file {} not found. ".format(file_path)) | |||
| cfg.read(file_path) | |||
| for s in sections: | |||
| attr_list = [i for i in sections[s].__dict__.keys() if | |||
| not callable(getattr(sections[s], i)) and not i.startswith("__")] | |||
| if s not in cfg: | |||
| print('section %s not found in config file' % (s)) | |||
| continue | |||
| gen_sec = cfg[s] | |||
| for attr in gen_sec.keys(): | |||
| try: | |||
| val = json.loads(gen_sec[attr]) | |||
| # print(s, attr, val, type(val)) | |||
| if attr in attr_list: | |||
| assert type(val) == type(getattr(sections[s], attr)), \ | |||
| 'type not match, except %s but got %s' % \ | |||
| (type(getattr(sections[s], attr)), type(val)) | |||
| """ | |||
| if attr in attr_list then check its type and | |||
| update its value. | |||
| else add a new attr in sections[s] | |||
| """ | |||
| setattr(sections[s], attr, val) | |||
| except Exception as e: | |||
| print("cannot load attribute %s in section %s" | |||
| % (attr, s)) | |||
| pass | |||
| class ConfigSection(object): | |||
| """ | |||
| 别名::class:`fastNLP.io.ConfigSection` :class:`fastNLP.io.config_io.ConfigSection` | |||
| ConfigSection是一个存储了一个section中所有键值对的数据结构,推荐使用此类的实例来配合 :meth:`ConfigLoader.load_config` 使用 | |||
| """ | |||
| def __init__(self): | |||
| super(ConfigSection, self).__init__() | |||
| def __getitem__(self, key): | |||
| """ | |||
| :param key: str, the name of the attribute | |||
| :return attr: the value of this attribute | |||
| if key not in self.__dict__.keys(): | |||
| return self[key] | |||
| else: | |||
| raise AttributeError | |||
| """ | |||
| if key in self.__dict__.keys(): | |||
| return getattr(self, key) | |||
| raise AttributeError("do NOT have attribute %s" % key) | |||
| def __setitem__(self, key, value): | |||
| """ | |||
| :param key: str, the name of the attribute | |||
| :param value: the value of this attribute | |||
| if key not in self.__dict__.keys(): | |||
| self[key] will be added | |||
| else: | |||
| self[key] will be updated | |||
| """ | |||
| if key in self.__dict__.keys(): | |||
| if not isinstance(value, type(getattr(self, key))): | |||
| raise AttributeError("attr %s except %s but got %s" % | |||
| (key, str(type(getattr(self, key))), str(type(value)))) | |||
| setattr(self, key, value) | |||
| def __contains__(self, item): | |||
| """ | |||
| :param item: The key of item. | |||
| :return: True if the key in self.__dict__.keys() else False. | |||
| """ | |||
| return item in self.__dict__.keys() | |||
| def __eq__(self, other): | |||
| """Overwrite the == operator | |||
| :param other: Another ConfigSection() object which to be compared. | |||
| :return: True if value of each key in each ConfigSection() object are equal to the other, else False. | |||
| """ | |||
| for k in self.__dict__.keys(): | |||
| if k not in other.__dict__.keys(): | |||
| return False | |||
| if getattr(self, k) != getattr(self, k): | |||
| return False | |||
| for k in other.__dict__.keys(): | |||
| if k not in self.__dict__.keys(): | |||
| return False | |||
| if getattr(self, k) != getattr(self, k): | |||
| return False | |||
| return True | |||
| def __ne__(self, other): | |||
| """Overwrite the != operator | |||
| :param other: | |||
| :return: | |||
| """ | |||
| return not self.__eq__(other) | |||
| @property | |||
| def data(self): | |||
| return self.__dict__ | |||
| class ConfigSaver(object): | |||
| """ | |||
| 别名::class:`fastNLP.io.ConfigSaver` :class:`fastNLP.io.config_io.ConfigSaver` | |||
| ConfigSaver 是用来存储配置文件并解决相关冲突的类 | |||
| :param str file_path: 配置文件的路径 | |||
| """ | |||
| def __init__(self, file_path): | |||
| self.file_path = file_path | |||
| if not os.path.exists(self.file_path): | |||
| raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) | |||
| def _get_section(self, sect_name): | |||
| """ | |||
| This is the function to get the section with the section name. | |||
| :param sect_name: The name of section what wants to load. | |||
| :return: The section. | |||
| """ | |||
| sect = ConfigSection() | |||
| ConfigLoader().load_config(self.file_path, {sect_name: sect}) | |||
| return sect | |||
| def _read_section(self): | |||
| """ | |||
| This is the function to read sections from the config file. | |||
| :return: sect_list, sect_key_list | |||
| sect_list: A list of ConfigSection(). | |||
| sect_key_list: A list of names in sect_list. | |||
| """ | |||
| sect_name = None | |||
| sect_list = {} | |||
| sect_key_list = [] | |||
| single_section = {} | |||
| single_section_key = [] | |||
| with open(self.file_path, 'r') as f: | |||
| lines = f.readlines() | |||
| for line in lines: | |||
| if line.startswith('[') and line.endswith(']\n'): | |||
| if sect_name is None: | |||
| pass | |||
| else: | |||
| sect_list[sect_name] = single_section, single_section_key | |||
| single_section = {} | |||
| single_section_key = [] | |||
| sect_key_list.append(sect_name) | |||
| sect_name = line[1: -2] | |||
| continue | |||
| if line.startswith('#'): | |||
| single_section[line] = '#' | |||
| single_section_key.append(line) | |||
| continue | |||
| if line.startswith('\n'): | |||
| single_section_key.append('\n') | |||
| continue | |||
| if '=' not in line: | |||
| raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | |||
| key = line.split('=', maxsplit=1)[0].strip() | |||
| value = line.split('=', maxsplit=1)[1].strip() + '\n' | |||
| single_section[key] = value | |||
| single_section_key.append(key) | |||
| if sect_name is not None: | |||
| sect_list[sect_name] = single_section, single_section_key | |||
| sect_key_list.append(sect_name) | |||
| return sect_list, sect_key_list | |||
| def _write_section(self, sect_list, sect_key_list): | |||
| """ | |||
| This is the function to write config file with section list and name list. | |||
| :param sect_list: A list of ConfigSection() need to be writen into file. | |||
| :param sect_key_list: A list of name of sect_list. | |||
| :return: | |||
| """ | |||
| with open(self.file_path, 'w') as f: | |||
| for sect_key in sect_key_list: | |||
| single_section, single_section_key = sect_list[sect_key] | |||
| f.write('[' + sect_key + ']\n') | |||
| for key in single_section_key: | |||
| if key == '\n': | |||
| f.write('\n') | |||
| continue | |||
| if single_section[key] == '#': | |||
| f.write(key) | |||
| continue | |||
| f.write(key + ' = ' + single_section[key]) | |||
| f.write('\n') | |||
| def save_config_file(self, section_name, section): | |||
| """ | |||
| 这个方法可以用来修改并保存配置文件中单独的一个 section | |||
| :param str section_name: 需要保存的 section 的名字. | |||
| :param section: 你需要修改并保存的 section, :class:`~fastNLP.io.ConfigSaver` 类型 | |||
| """ | |||
| section_file = self._get_section(section_name) | |||
| if len(section_file.__dict__.keys()) == 0: # the section not in the file before | |||
| # append this section to config file | |||
| with open(self.file_path, 'a') as f: | |||
| f.write('[' + section_name + ']\n') | |||
| for k in section.__dict__.keys(): | |||
| f.write(k + ' = ') | |||
| if isinstance(section[k], str): | |||
| f.write('\"' + str(section[k]) + '\"\n\n') | |||
| else: | |||
| f.write(str(section[k]) + '\n\n') | |||
| else: | |||
| # the section exists | |||
| change_file = False | |||
| for k in section.__dict__.keys(): | |||
| if k not in section_file: | |||
| # find a new key in this section | |||
| change_file = True | |||
| break | |||
| if section_file[k] != section[k]: | |||
| change_file = True | |||
| break | |||
| if not change_file: | |||
| return | |||
| sect_list, sect_key_list = self._read_section() | |||
| if section_name not in sect_key_list: | |||
| raise AttributeError() | |||
| sect, sect_key = sect_list[section_name] | |||
| for k in section.__dict__.keys(): | |||
| if k not in sect_key: | |||
| if sect_key[-1] != '\n': | |||
| sect_key.append('\n') | |||
| sect_key.append(k) | |||
| sect[k] = str(section[k]) | |||
| if isinstance(section[k], str): | |||
| sect[k] = "\"" + sect[k] + "\"" | |||
| sect[k] = sect[k] + "\n" | |||
| sect_list[section_name] = sect, sect_key | |||
| self._write_section(sect_list, sect_key_list) | |||
| @@ -1,7 +1,5 @@ | |||
| __all__ = [ | |||
| "BaseLoader", | |||
| 'DataBundle', | |||
| 'DataSetLoader', | |||
| ] | |||
| import _pickle as pickle | |||
| @@ -1,11 +1,11 @@ | |||
| from ...core.dataset import DataSet | |||
| from ...core.instance import Instance | |||
| from ..base_loader import DataSetLoader | |||
| from ..data_bundle import DataSetLoader | |||
| from ..file_reader import _read_conll | |||
| from typing import Union, Dict | |||
| from ..utils import check_loader_paths | |||
| from ..base_loader import DataBundle | |||
| from ..data_bundle import DataBundle | |||
| class ConllLoader(DataSetLoader): | |||
| """ | |||
| @@ -2,7 +2,7 @@ | |||
| from typing import Union, Dict | |||
| from ..embed_loader import EmbeddingOption, EmbedLoader | |||
| from ..base_loader import DataSetLoader, DataBundle | |||
| from ..data_bundle import DataSetLoader, DataBundle | |||
| from ...core.vocabulary import VocabularyOption, Vocabulary | |||
| from ...core.dataset import DataSet | |||
| from ...core.instance import Instance | |||
| @@ -4,7 +4,7 @@ from typing import Union, Dict, List | |||
| from ...core.const import Const | |||
| from ...core.vocabulary import Vocabulary | |||
| from ..base_loader import DataBundle, DataSetLoader | |||
| from ..data_bundle import DataBundle, DataSetLoader | |||
| from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
| from ...modules.encoder.bert import BertTokenizer | |||
| @@ -1,7 +1,7 @@ | |||
| from typing import Union, Dict | |||
| from ..base_loader import DataBundle | |||
| from ..data_bundle import DataBundle | |||
| from ..dataset_loader import CSVLoader | |||
| from ...core.vocabulary import Vocabulary, VocabularyOption | |||
| from ...core.const import Const | |||
| @@ -1,5 +1,5 @@ | |||
| from ..base_loader import DataSetLoader | |||
| from ..data_bundle import DataSetLoader | |||
| from ...core.dataset import DataSet | |||
| from ...core.instance import Instance | |||
| from ...core.const import Const | |||
| @@ -2,7 +2,7 @@ | |||
| from typing import Union, Dict | |||
| from nltk import Tree | |||
| from ..base_loader import DataBundle, DataSetLoader | |||
| from ..data_bundle import DataBundle, DataSetLoader | |||
| from ..dataset_loader import CSVLoader | |||
| from ...core.vocabulary import VocabularyOption, Vocabulary | |||
| from ...core.dataset import DataSet | |||
| @@ -6,7 +6,7 @@ from ...core.const import Const | |||
| from ...core.dataset import DataSet | |||
| from ...core.instance import Instance | |||
| from ...core.vocabulary import VocabularyOption, Vocabulary | |||
| from ..base_loader import DataBundle, DataSetLoader | |||
| from ..data_bundle import DataBundle, DataSetLoader | |||
| from typing import Union, Dict | |||
| from ..utils import check_loader_paths, get_tokenizer | |||
| @@ -26,7 +26,7 @@ __all__ = [ | |||
| from ..core.dataset import DataSet | |||
| from ..core.instance import Instance | |||
| from .file_reader import _read_csv, _read_json | |||
| from .base_loader import DataSetLoader | |||
| from .data_bundle import DataSetLoader | |||
| class JsonLoader(DataSetLoader): | |||
| @@ -9,7 +9,7 @@ import warnings | |||
| import numpy as np | |||
| from ..core.vocabulary import Vocabulary | |||
| from .base_loader import BaseLoader | |||
| from .data_bundle import BaseLoader | |||
| from ..core.utils import Option | |||
| @@ -44,6 +44,8 @@ fastNLP 目前提供了如下的 Loader | |||
| """ | |||
| __all__ = [ | |||
| 'Loader', | |||
| 'YelpLoader', | |||
| 'YelpFullLoader', | |||
| 'YelpPolarityLoader', | |||
| @@ -57,7 +59,6 @@ __all__ = [ | |||
| 'OntoNotesNERLoader', | |||
| 'CTBLoader', | |||
| 'Loader', | |||
| 'CSVLoader', | |||
| 'JsonLoader', | |||
| @@ -7,6 +7,7 @@ import random | |||
| import shutil | |||
| import numpy as np | |||
| class YelpLoader(Loader): | |||
| """ | |||
| 别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader` | |||
| @@ -14,6 +15,7 @@ class YelpLoader(Loader): | |||
| 原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 | |||
| Example:: | |||
| "1","I got 'new' tires from the..." | |||
| "1","Don't waste your time..." | |||
| @@ -28,11 +30,11 @@ class YelpLoader(Loader): | |||
| "...", "..." | |||
| """ | |||
| def __init__(self): | |||
| super(YelpLoader, self).__init__() | |||
| def _load(self, path: str=None): | |||
| def _load(self, path: str = None): | |||
| ds = DataSet() | |||
| with open(path, 'r', encoding='utf-8') as f: | |||
| for line in f: | |||
| @@ -69,12 +71,12 @@ class YelpFullLoader(YelpLoader): | |||
| :param int seed: 划分dev时的随机数种子 | |||
| :return: str, 数据集的目录地址 | |||
| """ | |||
| dataset_name = 'yelp-review-full' | |||
| data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
| if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否需要重新下载 | |||
| re_download = True | |||
| if dev_ratio>0: | |||
| if dev_ratio > 0: | |||
| dev_line_count = 0 | |||
| tr_line_count = 0 | |||
| with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ | |||
| @@ -83,14 +85,14 @@ class YelpFullLoader(YelpLoader): | |||
| tr_line_count += 1 | |||
| for line in f2: | |||
| dev_line_count += 1 | |||
| if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||
| if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005): | |||
| re_download = True | |||
| else: | |||
| re_download = False | |||
| if re_download: | |||
| shutil.rmtree(data_dir) | |||
| data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
| if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | |||
| if dev_ratio > 0: | |||
| assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
| @@ -109,7 +111,7 @@ class YelpFullLoader(YelpLoader): | |||
| finally: | |||
| if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): | |||
| os.remove(os.path.join(data_dir, 'middle_file.csv')) | |||
| return data_dir | |||
| @@ -131,7 +133,7 @@ class YelpPolarityLoader(YelpLoader): | |||
| data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
| if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否符合比例要求 | |||
| re_download = True | |||
| if dev_ratio>0: | |||
| if dev_ratio > 0: | |||
| dev_line_count = 0 | |||
| tr_line_count = 0 | |||
| with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ | |||
| @@ -140,14 +142,14 @@ class YelpPolarityLoader(YelpLoader): | |||
| tr_line_count += 1 | |||
| for line in f2: | |||
| dev_line_count += 1 | |||
| if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||
| if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005): | |||
| re_download = True | |||
| else: | |||
| re_download = False | |||
| if re_download: | |||
| shutil.rmtree(data_dir) | |||
| data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
| if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | |||
| if dev_ratio > 0: | |||
| assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
| @@ -166,7 +168,7 @@ class YelpPolarityLoader(YelpLoader): | |||
| finally: | |||
| if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): | |||
| os.remove(os.path.join(data_dir, 'middle_file.csv')) | |||
| return data_dir | |||
| @@ -185,10 +187,10 @@ class IMDBLoader(Loader): | |||
| "...", "..." | |||
| """ | |||
| def __init__(self): | |||
| super(IMDBLoader, self).__init__() | |||
| def _load(self, path: str): | |||
| dataset = DataSet() | |||
| with open(path, 'r', encoding="utf-8") as f: | |||
| @@ -201,12 +203,12 @@ class IMDBLoader(Loader): | |||
| words = parts[1] | |||
| if words: | |||
| dataset.append(Instance(raw_words=words, target=target)) | |||
| if len(dataset) == 0: | |||
| raise RuntimeError(f"{path} has no valid data.") | |||
| return dataset | |||
| def download(self, dev_ratio: float = 0.1, seed: int = 0): | |||
| """ | |||
| 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
| @@ -221,9 +223,9 @@ class IMDBLoader(Loader): | |||
| """ | |||
| dataset_name = 'aclImdb' | |||
| data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
| if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求 | |||
| if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求 | |||
| re_download = True | |||
| if dev_ratio>0: | |||
| if dev_ratio > 0: | |||
| dev_line_count = 0 | |||
| tr_line_count = 0 | |||
| with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f1, \ | |||
| @@ -232,14 +234,14 @@ class IMDBLoader(Loader): | |||
| tr_line_count += 1 | |||
| for line in f2: | |||
| dev_line_count += 1 | |||
| if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||
| if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005): | |||
| re_download = True | |||
| else: | |||
| re_download = False | |||
| if re_download: | |||
| shutil.rmtree(data_dir) | |||
| data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
| if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | |||
| if dev_ratio > 0: | |||
| assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
| @@ -258,7 +260,7 @@ class IMDBLoader(Loader): | |||
| finally: | |||
| if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): | |||
| os.remove(os.path.join(data_dir, 'middle_file.txt')) | |||
| return data_dir | |||
| @@ -278,10 +280,10 @@ class SSTLoader(Loader): | |||
| raw_words列是str。 | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| def _load(self, path: str): | |||
| """ | |||
| 从path读取SST文件 | |||
| @@ -296,7 +298,7 @@ class SSTLoader(Loader): | |||
| if line: | |||
| ds.append(Instance(raw_words=line)) | |||
| return ds | |||
| def download(self): | |||
| """ | |||
| 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
| @@ -323,10 +325,10 @@ class SST2Loader(Loader): | |||
| test的DataSet没有target列。 | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| def _load(self, path: str): | |||
| """ | |||
| 从path读取SST2文件 | |||
| @@ -335,7 +337,7 @@ class SST2Loader(Loader): | |||
| :return: DataSet | |||
| """ | |||
| ds = DataSet() | |||
| with open(path, 'r', encoding='utf-8') as f: | |||
| f.readline() # 跳过header | |||
| if 'test' in os.path.split(path)[1]: | |||
| @@ -356,7 +358,7 @@ class SST2Loader(Loader): | |||
| if raw_words: | |||
| ds.append(Instance(raw_words=raw_words, target=target)) | |||
| return ds | |||
| def download(self): | |||
| """ | |||
| 自动下载数据集,如果你使用了该数据集,请引用以下的文章 | |||
| @@ -2,17 +2,21 @@ from ...core.dataset import DataSet | |||
| from .. import DataBundle | |||
| from ..utils import check_loader_paths | |||
| from typing import Union, Dict | |||
| import os | |||
| from ..file_utils import _get_dataset_url, get_cache_path, cached_path | |||
| class Loader: | |||
| """ | |||
| 各种数据 Loader 的基类,提供了 API 的参考. | |||
| """ | |||
| def __init__(self): | |||
| pass | |||
| def _load(self, path:str) -> DataSet: | |||
| def _load(self, path: str) -> DataSet: | |||
| raise NotImplementedError | |||
| def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: | |||
| def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: | |||
| """ | |||
| 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | |||
| @@ -22,31 +26,25 @@ class Loader: | |||
| (0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 | |||
| (1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 | |||
| 名包含'train'、 'dev'、 'test'则会报错 | |||
| Example:: | |||
| 名包含'train'、 'dev'、 'test'则会报错:: | |||
| data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、 | |||
| # dev、 test等有所变化,可以通过以下的方式取出DataSet | |||
| tr_data = data_bundle.datasets['train'] | |||
| te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 | |||
| (2) 传入文件路径 | |||
| Example:: | |||
| (2) 传入文件路径:: | |||
| data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' | |||
| tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet | |||
| (3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test | |||
| Example:: | |||
| (3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: | |||
| paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | |||
| data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" | |||
| dev_data = data_bundle.datasets['dev'] | |||
| :return: 返回的:class:`~fastNLP.io.DataBundle` | |||
| :return: 返回的 :class:`~fastNLP.io.DataBundle` | |||
| """ | |||
| if paths is None: | |||
| paths = self.download() | |||
| @@ -54,10 +52,10 @@ class Loader: | |||
| datasets = {name: self._load(path) for name, path in paths.items()} | |||
| data_bundle = DataBundle(datasets=datasets) | |||
| return data_bundle | |||
| def download(self): | |||
| raise NotImplementedError(f"{self.__class__} cannot download data automatically.") | |||
| def _get_dataset_path(self, dataset_name): | |||
| """ | |||
| 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 | |||
| @@ -65,11 +63,9 @@ class Loader: | |||
| :param str dataset_name: 数据集的名称 | |||
| :return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 | |||
| """ | |||
| default_cache_path = get_cache_path() | |||
| url = _get_dataset_url(dataset_name) | |||
| output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') | |||
| return output_dir | |||
| @@ -203,7 +203,8 @@ class QNLILoader(JsonLoader): | |||
| """ | |||
| 如果您的实验使用到了该数据,请引用 | |||
| TODO 补充 | |||
| .. todo:: | |||
| 补充 | |||
| :return: | |||
| """ | |||
| @@ -8,7 +8,7 @@ __all__ = [ | |||
| import torch | |||
| from .base_loader import BaseLoader | |||
| from .data_bundle import BaseLoader | |||
| class ModelLoader(BaseLoader): | |||
| @@ -1,6 +1,6 @@ | |||
| from nltk import Tree | |||
| from ..base_loader import DataBundle | |||
| from ..data_bundle import DataBundle | |||
| from ...core.vocabulary import Vocabulary | |||
| from ...core.const import Const | |||
| from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | |||
| @@ -1,188 +1,188 @@ | |||
| import pickle | |||
| import numpy as np | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.io.base_loader import DataBundle | |||
| from fastNLP.io.dataset_loader import JsonLoader | |||
| from fastNLP.core.const import Const | |||
| from tools.logger import * | |||
| WORD_PAD = "[PAD]" | |||
| WORD_UNK = "[UNK]" | |||
| DOMAIN_UNK = "X" | |||
| TAG_UNK = "X" | |||
| class SummarizationLoader(JsonLoader): | |||
| """ | |||
| 读取summarization数据集,读取的DataSet包含fields:: | |||
| text: list(str),document | |||
| summary: list(str), summary | |||
| text_wd: list(list(str)),tokenized document | |||
| summary_wd: list(list(str)), tokenized summary | |||
| labels: list(int), | |||
| flatten_label: list(int), 0 or 1, flatten labels | |||
| domain: str, optional | |||
| tag: list(str), optional | |||
| 数据来源: CNN_DailyMail Newsroom DUC | |||
| """ | |||
| def __init__(self): | |||
| super(SummarizationLoader, self).__init__() | |||
| def _load(self, path): | |||
| ds = super(SummarizationLoader, self)._load(path) | |||
| def _lower_text(text_list): | |||
| return [text.lower() for text in text_list] | |||
| def _split_list(text_list): | |||
| return [text.split() for text in text_list] | |||
| def _convert_label(label, sent_len): | |||
| np_label = np.zeros(sent_len, dtype=int) | |||
| if label != []: | |||
| np_label[np.array(label)] = 1 | |||
| return np_label.tolist() | |||
| ds.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||
| ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||
| ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd') | |||
| ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd') | |||
| ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label") | |||
| return ds | |||
| def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True): | |||
| """ | |||
| :param paths: dict path for each dataset | |||
| :param vocab_size: int max_size for vocab | |||
| :param vocab_path: str vocab path | |||
| :param sent_max_len: int max token number of the sentence | |||
| :param doc_max_timesteps: int max sentence number of the document | |||
| :param domain: bool build vocab for publication, use 'X' for unknown | |||
| :param tag: bool build vocab for tag, use 'X' for unknown | |||
| :param load_vocab_file: bool build vocab (False) or load vocab (True) | |||
| :return: DataBundle | |||
| datasets: dict keys correspond to the paths dict | |||
| vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | |||
| embeddings: optional | |||
| """ | |||
| def _pad_sent(text_wd): | |||
| pad_text_wd = [] | |||
| for sent_wd in text_wd: | |||
| if len(sent_wd) < sent_max_len: | |||
| pad_num = sent_max_len - len(sent_wd) | |||
| sent_wd.extend([WORD_PAD] * pad_num) | |||
| else: | |||
| sent_wd = sent_wd[:sent_max_len] | |||
| pad_text_wd.append(sent_wd) | |||
| return pad_text_wd | |||
| def _token_mask(text_wd): | |||
| token_mask_list = [] | |||
| for sent_wd in text_wd: | |||
| token_num = len(sent_wd) | |||
| if token_num < sent_max_len: | |||
| mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||
| else: | |||
| mask = [1] * sent_max_len | |||
| token_mask_list.append(mask) | |||
| return token_mask_list | |||
| def _pad_label(label): | |||
| text_len = len(label) | |||
| if text_len < doc_max_timesteps: | |||
| pad_label = label + [0] * (doc_max_timesteps - text_len) | |||
| else: | |||
| pad_label = label[:doc_max_timesteps] | |||
| return pad_label | |||
| def _pad_doc(text_wd): | |||
| text_len = len(text_wd) | |||
| if text_len < doc_max_timesteps: | |||
| padding = [WORD_PAD] * sent_max_len | |||
| pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||
| else: | |||
| pad_text = text_wd[:doc_max_timesteps] | |||
| return pad_text | |||
| def _sent_mask(text_wd): | |||
| text_len = len(text_wd) | |||
| if text_len < doc_max_timesteps: | |||
| sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||
| else: | |||
| sent_mask = [1] * doc_max_timesteps | |||
| return sent_mask | |||
| datasets = {} | |||
| train_ds = None | |||
| for key, value in paths.items(): | |||
| ds = self.load(value) | |||
| # pad sent | |||
| ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd") | |||
| ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask") | |||
| # pad document | |||
| ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text") | |||
| ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len") | |||
| ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label") | |||
| # rename field | |||
| ds.rename_field("pad_text", Const.INPUT) | |||
| ds.rename_field("seq_len", Const.INPUT_LEN) | |||
| ds.rename_field("pad_label", Const.TARGET) | |||
| # set input and target | |||
| ds.set_input(Const.INPUT, Const.INPUT_LEN) | |||
| ds.set_target(Const.TARGET, Const.INPUT_LEN) | |||
| datasets[key] = ds | |||
| if "train" in key: | |||
| train_ds = datasets[key] | |||
| vocab_dict = {} | |||
| if load_vocab_file == False: | |||
| logger.info("[INFO] Build new vocab from training dataset!") | |||
| if train_ds == None: | |||
| raise ValueError("Lack train file to build vocabulary!") | |||
| vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
| vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"]) | |||
| vocab_dict["vocab"] = vocabs | |||
| else: | |||
| logger.info("[INFO] Load existing vocab from %s!" % vocab_path) | |||
| word_list = [] | |||
| with open(vocab_path, 'r', encoding='utf8') as vocab_f: | |||
| cnt = 2 # pad and unk | |||
| for line in vocab_f: | |||
| pieces = line.split("\t") | |||
| word_list.append(pieces[0]) | |||
| cnt += 1 | |||
| if cnt > vocab_size: | |||
| break | |||
| vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
| vocabs.add_word_lst(word_list) | |||
| vocabs.build_vocab() | |||
| vocab_dict["vocab"] = vocabs | |||
| if domain == True: | |||
| domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||
| domaindict.from_dataset(train_ds, field_name="publication") | |||
| vocab_dict["domain"] = domaindict | |||
| if tag == True: | |||
| tagdict = Vocabulary(padding=None, unknown=TAG_UNK) | |||
| tagdict.from_dataset(train_ds, field_name="tag") | |||
| vocab_dict["tag"] = tagdict | |||
| for ds in datasets.values(): | |||
| vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
| return DataBundle(vocabs=vocab_dict, datasets=datasets) | |||
| import pickle | |||
| import numpy as np | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.io.data_bundle import DataBundle | |||
| from fastNLP.io.dataset_loader import JsonLoader | |||
| from fastNLP.core.const import Const | |||
| from tools.logger import * | |||
| WORD_PAD = "[PAD]" | |||
| WORD_UNK = "[UNK]" | |||
| DOMAIN_UNK = "X" | |||
| TAG_UNK = "X" | |||
| class SummarizationLoader(JsonLoader): | |||
| """ | |||
| 读取summarization数据集,读取的DataSet包含fields:: | |||
| text: list(str),document | |||
| summary: list(str), summary | |||
| text_wd: list(list(str)),tokenized document | |||
| summary_wd: list(list(str)), tokenized summary | |||
| labels: list(int), | |||
| flatten_label: list(int), 0 or 1, flatten labels | |||
| domain: str, optional | |||
| tag: list(str), optional | |||
| 数据来源: CNN_DailyMail Newsroom DUC | |||
| """ | |||
| def __init__(self): | |||
| super(SummarizationLoader, self).__init__() | |||
| def _load(self, path): | |||
| ds = super(SummarizationLoader, self)._load(path) | |||
| def _lower_text(text_list): | |||
| return [text.lower() for text in text_list] | |||
| def _split_list(text_list): | |||
| return [text.split() for text in text_list] | |||
| def _convert_label(label, sent_len): | |||
| np_label = np.zeros(sent_len, dtype=int) | |||
| if label != []: | |||
| np_label[np.array(label)] = 1 | |||
| return np_label.tolist() | |||
| ds.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||
| ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||
| ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd') | |||
| ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd') | |||
| ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label") | |||
| return ds | |||
| def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True): | |||
| """ | |||
| :param paths: dict path for each dataset | |||
| :param vocab_size: int max_size for vocab | |||
| :param vocab_path: str vocab path | |||
| :param sent_max_len: int max token number of the sentence | |||
| :param doc_max_timesteps: int max sentence number of the document | |||
| :param domain: bool build vocab for publication, use 'X' for unknown | |||
| :param tag: bool build vocab for tag, use 'X' for unknown | |||
| :param load_vocab_file: bool build vocab (False) or load vocab (True) | |||
| :return: DataBundle | |||
| datasets: dict keys correspond to the paths dict | |||
| vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | |||
| embeddings: optional | |||
| """ | |||
| def _pad_sent(text_wd): | |||
| pad_text_wd = [] | |||
| for sent_wd in text_wd: | |||
| if len(sent_wd) < sent_max_len: | |||
| pad_num = sent_max_len - len(sent_wd) | |||
| sent_wd.extend([WORD_PAD] * pad_num) | |||
| else: | |||
| sent_wd = sent_wd[:sent_max_len] | |||
| pad_text_wd.append(sent_wd) | |||
| return pad_text_wd | |||
| def _token_mask(text_wd): | |||
| token_mask_list = [] | |||
| for sent_wd in text_wd: | |||
| token_num = len(sent_wd) | |||
| if token_num < sent_max_len: | |||
| mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||
| else: | |||
| mask = [1] * sent_max_len | |||
| token_mask_list.append(mask) | |||
| return token_mask_list | |||
| def _pad_label(label): | |||
| text_len = len(label) | |||
| if text_len < doc_max_timesteps: | |||
| pad_label = label + [0] * (doc_max_timesteps - text_len) | |||
| else: | |||
| pad_label = label[:doc_max_timesteps] | |||
| return pad_label | |||
| def _pad_doc(text_wd): | |||
| text_len = len(text_wd) | |||
| if text_len < doc_max_timesteps: | |||
| padding = [WORD_PAD] * sent_max_len | |||
| pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||
| else: | |||
| pad_text = text_wd[:doc_max_timesteps] | |||
| return pad_text | |||
| def _sent_mask(text_wd): | |||
| text_len = len(text_wd) | |||
| if text_len < doc_max_timesteps: | |||
| sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||
| else: | |||
| sent_mask = [1] * doc_max_timesteps | |||
| return sent_mask | |||
| datasets = {} | |||
| train_ds = None | |||
| for key, value in paths.items(): | |||
| ds = self.load(value) | |||
| # pad sent | |||
| ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd") | |||
| ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask") | |||
| # pad document | |||
| ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text") | |||
| ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len") | |||
| ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label") | |||
| # rename field | |||
| ds.rename_field("pad_text", Const.INPUT) | |||
| ds.rename_field("seq_len", Const.INPUT_LEN) | |||
| ds.rename_field("pad_label", Const.TARGET) | |||
| # set input and target | |||
| ds.set_input(Const.INPUT, Const.INPUT_LEN) | |||
| ds.set_target(Const.TARGET, Const.INPUT_LEN) | |||
| datasets[key] = ds | |||
| if "train" in key: | |||
| train_ds = datasets[key] | |||
| vocab_dict = {} | |||
| if load_vocab_file == False: | |||
| logger.info("[INFO] Build new vocab from training dataset!") | |||
| if train_ds == None: | |||
| raise ValueError("Lack train file to build vocabulary!") | |||
| vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
| vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"]) | |||
| vocab_dict["vocab"] = vocabs | |||
| else: | |||
| logger.info("[INFO] Load existing vocab from %s!" % vocab_path) | |||
| word_list = [] | |||
| with open(vocab_path, 'r', encoding='utf8') as vocab_f: | |||
| cnt = 2 # pad and unk | |||
| for line in vocab_f: | |||
| pieces = line.split("\t") | |||
| word_list.append(pieces[0]) | |||
| cnt += 1 | |||
| if cnt > vocab_size: | |||
| break | |||
| vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
| vocabs.add_word_lst(word_list) | |||
| vocabs.build_vocab() | |||
| vocab_dict["vocab"] = vocabs | |||
| if domain == True: | |||
| domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||
| domaindict.from_dataset(train_ds, field_name="publication") | |||
| vocab_dict["domain"] = domaindict | |||
| if tag == True: | |||
| tagdict = Vocabulary(padding=None, unknown=TAG_UNK) | |||
| tagdict.from_dataset(train_ds, field_name="tag") | |||
| vocab_dict["tag"] = tagdict | |||
| for ds in datasets.values(): | |||
| vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
| return DataBundle(vocabs=vocab_dict, datasets=datasets) | |||
| @@ -3,7 +3,7 @@ from datetime import timedelta | |||
| from fastNLP.io.dataset_loader import JsonLoader | |||
| from fastNLP.modules.encoder._bert import BertTokenizer | |||
| from fastNLP.io.base_loader import DataBundle | |||
| from fastNLP.io.data_bundle import DataBundle | |||
| from fastNLP.core.const import Const | |||
| class BertData(JsonLoader): | |||
| @@ -1,7 +1,7 @@ | |||
| from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance | |||
| from fastNLP.io.file_reader import _read_json | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.io.base_loader import DataBundle | |||
| from fastNLP.io.data_bundle import DataBundle | |||
| from reproduction.coreference_resolution.model.config import Config | |||
| import reproduction.coreference_resolution.model.preprocess as preprocess | |||
| @@ -1,6 +1,6 @@ | |||
| from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
| from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
| from fastNLP.io.data_loader import ConllLoader | |||
| import numpy as np | |||
| @@ -9,7 +9,7 @@ from typing import Union, Dict | |||
| from fastNLP.core.const import Const | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.io.base_loader import DataBundle, DataSetLoader | |||
| from fastNLP.io.data_bundle import DataBundle, DataSetLoader | |||
| from fastNLP.io.dataset_loader import JsonLoader, CSVLoader | |||
| from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
| from fastNLP.modules.encoder._bert import BertTokenizer | |||
| @@ -1,6 +1,6 @@ | |||
| from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
| from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
| from fastNLP.io import ConllLoader | |||
| from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 | |||
| from fastNLP import Const | |||
| @@ -1,7 +1,7 @@ | |||
| from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||
| from fastNLP.core.vocabulary import VocabularyOption | |||
| from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
| from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
| from typing import Union, Dict, List, Iterator | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| @@ -1,6 +1,6 @@ | |||
| from fastNLP.core.vocabulary import VocabularyOption | |||
| from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
| from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
| from typing import Union, Dict | |||
| from fastNLP import Vocabulary | |||
| from fastNLP import Const | |||
| @@ -1,5 +1,5 @@ | |||
| from fastNLP.core.vocabulary import VocabularyOption | |||
| from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
| from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
| from typing import Union, Dict | |||
| from fastNLP import DataSet | |||
| from fastNLP import Vocabulary | |||
| @@ -1,6 +1,6 @@ | |||
| from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||
| from fastNLP.core.vocabulary import VocabularyOption | |||
| from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
| from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
| from typing import Union, Dict, List, Iterator | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| @@ -1,6 +1,6 @@ | |||
| from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||
| from fastNLP.core.vocabulary import VocabularyOption | |||
| from fastNLP.io.base_loader import DataSetLoader, DataBundle | |||
| from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
| from typing import Union, Dict, List, Iterator | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| @@ -1,6 +1,6 @@ | |||
| from typing import Iterable | |||
| from nltk import Tree | |||
| from fastNLP.io.base_loader import DataBundle, DataSetLoader | |||
| from fastNLP.io.data_bundle import DataBundle, DataSetLoader | |||
| from fastNLP.core.vocabulary import VocabularyOption, Vocabulary | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| @@ -4,7 +4,7 @@ from typing import Iterable | |||
| from fastNLP import DataSet, Instance, Vocabulary | |||
| from fastNLP.core.vocabulary import VocabularyOption | |||
| from fastNLP.io import JsonLoader | |||
| from fastNLP.io.base_loader import DataBundle,DataSetLoader | |||
| from fastNLP.io.data_bundle import DataBundle,DataSetLoader | |||
| from fastNLP.io.embed_loader import EmbeddingOption | |||
| from fastNLP.io.file_reader import _read_json | |||
| from typing import Union, Dict | |||