| @@ -34,7 +34,8 @@ class MatchingLoader(DataSetLoader): | |||
| def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, | |||
| to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | |||
| cut_text: int = None, get_index=True, set_input: Union[list, str, bool]=True, | |||
| cut_text: int = None, get_index=True, auto_pad_length: int=None, | |||
| auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, | |||
| set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: | |||
| """ | |||
| :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | |||
| @@ -49,6 +50,8 @@ class MatchingLoader(DataSetLoader): | |||
| :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 | |||
| :param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 | |||
| :param bool get_index: 是否需要根据词表将文本转为index | |||
| :param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad | |||
| :param str auto_pad_token: 自动pad的内容 | |||
| :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False | |||
| 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, | |||
| 于此同时其他field不会被设置为input。默认值为True。 | |||
| @@ -169,6 +172,9 @@ class MatchingLoader(DataSetLoader): | |||
| data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | |||
| new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | |||
| if auto_pad_length is not None: | |||
| cut_text = min(auto_pad_length, cut_text if cut_text is not None else 0) | |||
| if cut_text is not None: | |||
| for data_name, data_set in data_info.datasets.items(): | |||
| for fields in data_set.get_field_names(): | |||
| @@ -180,7 +186,7 @@ class MatchingLoader(DataSetLoader): | |||
| assert len(data_set_list) > 0, f'There are NO data sets in data info!' | |||
| if bert_tokenizer is None: | |||
| words_vocab = Vocabulary() | |||
| words_vocab = Vocabulary(padding=auto_pad_token) | |||
| words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||
| field_name=[n for n in data_set_list[0].get_field_names() | |||
| if (Const.INPUT in n)], | |||
| @@ -202,6 +208,17 @@ class MatchingLoader(DataSetLoader): | |||
| data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||
| is_input=auto_set_input, is_target=auto_set_target) | |||
| if auto_pad_length is not None: | |||
| for data_name, data_set in data_info.datasets.items(): | |||
| for fields in data_set.get_field_names(): | |||
| if Const.INPUT in fields: | |||
| data_set.apply(lambda x: x[fields] + [words_vocab.padding] * (auto_pad_length - len(x[fields])), | |||
| new_field_name=fields, is_input=auto_set_input) | |||
| elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): | |||
| data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * | |||
| (auto_pad_length - len(x[fields])), new_field_name=fields, | |||
| is_input=auto_set_input) | |||
| for data_name, data_set in data_info.datasets.items(): | |||
| if isinstance(set_input, list): | |||
| data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) | |||