| @@ -126,7 +126,8 @@ class Callback: | |||
| :param trainer: `fastNLP.Trainer` | |||
| :param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。 | |||
| :param list[int] indices: 当前的 batch 是 dataset 中的哪些数据 | |||
| :param list[int] indices: 当前的 batch 是 dataset 中的哪些数据。仅在 DataLoader 支持得到当前 batch index 的时候有值, | |||
| 其它时候为 None 。 | |||
| """ | |||
| pass | |||
| @@ -94,20 +94,21 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
| else: | |||
| self.buffer.seek(0) | |||
| trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||
| self._delete_after_after(trainer) | |||
| def _delete_after_after(self, trainer): | |||
| trainer.driver.barrier() | |||
| if self.delete_after_after: | |||
| if self.real_save_folder: | |||
| logger.info(f"Deleting {self.real_save_folder}...") | |||
| shutil.rmtree(self.real_save_folder, ignore_errors=True) | |||
| try: | |||
| # 如果是 emtpy 的,就会被删除掉 | |||
| os.rmdir(self.save_folder) | |||
| except: | |||
| pass | |||
| elif hasattr(self, 'buffer'): | |||
| self.buffer.close() | |||
| del self.buffer | |||
| trainer.driver.barrier() | |||
| self._delete_folder() | |||
| trainer.driver.barrier() | |||
| def _delete_folder(self): | |||
| if self.real_save_folder: | |||
| logger.info(f"Deleting {self.real_save_folder}...") | |||
| shutil.rmtree(self.real_save_folder, ignore_errors=True) | |||
| try: | |||
| # 如果是 emtpy 的,就会被删除掉 | |||
| os.rmdir(self.save_folder) | |||
| logger.debug(f"Since {self.save_folder} is an empty folder, it has been removed.") | |||
| except: | |||
| pass | |||
| elif hasattr(self, 'buffer'): | |||
| self.buffer.close() | |||
| del self.buffer | |||
| @@ -6,7 +6,7 @@ from .padders.get_padder import get_padder | |||
| import re | |||
| from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \ | |||
| pack_batch_sequence, NESTED_DICT_SEPARATOR | |||
| pack_batch_sequence | |||
| sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | |||
| SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None] | |||
| @@ -16,10 +16,11 @@ class Collator: | |||
| def __init__(self, backend='torch'): | |||
| """ | |||
| 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | |||
| 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。 | |||
| 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 | |||
| 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | |||
| :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], | |||
| 若为 None ,则不进行 padding 。 | |||
| :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None]。 | |||
| 若为 None ,则不进行 padding 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad 的数据返回一定是 list 。 | |||
| """ | |||
| self.unpack_batch_func = None | |||
| self.pack_batch_func = None | |||
| @@ -54,22 +55,25 @@ class Collator: | |||
| else: | |||
| self.batch_data_type = 's' | |||
| logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type " | |||
| f"is {self.batch_data_type}") | |||
| f"is `{self.batch_data_type}`.") | |||
| if self.batch_data_type == 's': | |||
| self.unpack_batch_func = lambda x:{'_single': x} # 不需要做任何调整 | |||
| self.pack_batch_func = lambda x:x['_single'] | |||
| self.unpack_batch_func = lambda batch, ignore_fields: {'_single': batch} # 不需要做任何调整 | |||
| self.pack_batch_func = lambda x: x['_single'] | |||
| elif self.batch_data_type == 'l': | |||
| self.unpack_batch_func = unpack_batch_sequence | |||
| self.pack_batch_func = pack_batch_sequence | |||
| elif self.batch_data_type == 'd': | |||
| if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{'a@@b': value} | |||
| if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value} | |||
| self.unpack_batch_func = unpack_batch_nested_mapping | |||
| self.pack_batch_func = pack_batch_nested_mapping | |||
| else: | |||
| self.unpack_batch_func = unpack_batch_mapping | |||
| self.pack_batch_func = lambda x:x | |||
| unpack_batch:Dict = self.unpack_batch_func(batch) # 将各自 field 组成 batch 形式。 | |||
| if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸 | |||
| unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys())) | |||
| else: | |||
| unpack_batch:Dict = self.unpack_batch_func(batch, self.ignore_fields) # 将各自 field 组成 batch 形式。 | |||
| pad_batch = {} | |||
| if len(self.padders)==0: # 第一次运行,准备 padder | |||
| @@ -96,13 +100,13 @@ class Collator: | |||
| return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 | |||
| def set_pad(self, field_name:str, pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||
| def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||
| pad_fn:Callable=None) -> "Collator": | |||
| """ | |||
| 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
| :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
| field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b; | |||
| field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||
| 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||
| 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||
| :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
| @@ -126,11 +130,11 @@ class Collator: | |||
| f"index, but other field is set as dict mode." | |||
| elif self.batch_data_type == 'l': | |||
| assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \ | |||
| f"field name is {field_name}" | |||
| f"field name is {field_name}." | |||
| if field_name == '_single': | |||
| self.batch_data_type = 's' | |||
| elif sequence_idx_str.match(field_name): | |||
| elif isinstance(field_name, str) and sequence_idx_str.match(field_name): | |||
| self.batch_data_type = 'l' | |||
| else: | |||
| self.batch_data_type = 'd' | |||
| @@ -165,8 +169,8 @@ class Collator: | |||
| collator.set_ignore('field1', 'field2') | |||
| :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
| field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b; | |||
| 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
| field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||
| __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
| :return: 返回 Collator 自身 | |||
| """ | |||
| for field_name in field_names: | |||
| @@ -149,6 +149,7 @@ def is_number(dtype): | |||
| if dtype in (float, int, complex, bool) and not is_numpy_generic_class(dtype) \ | |||
| and not is_numpy_number_dtype(dtype): | |||
| return True | |||
| return False | |||
| except: | |||
| return False | |||
| @@ -161,6 +162,7 @@ if __name__ == '__main__': | |||
| # print(type(b[0])) | |||
| # print(b) | |||
| # import torch | |||
| print(is_number(type('a'))) | |||
| print(is_number_or_numpy_number(type(3))) # True | |||
| print(is_number_or_numpy_number(type(3.1))) # True | |||
| print(is_number_or_numpy_number(type('3'))) # False | |||
| @@ -2,54 +2,58 @@ from collections import defaultdict | |||
| from functools import reduce | |||
| from typing import Sequence, Mapping, Dict | |||
| NESTED_DICT_SEPARATOR = '@@' | |||
| def unpack_batch_mapping(batch:Sequence[Mapping])->Dict: | |||
| def unpack_batch_mapping(batch:Sequence[Mapping], ignore_fields:set)->Dict: | |||
| """ | |||
| 将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]} | |||
| :param batch: | |||
| :param ignore_fields: | |||
| :return: | |||
| """ | |||
| dict_batch = defaultdict(list) | |||
| for sample in batch: | |||
| for key, value in sample.items(): | |||
| if key in ignore_fields: | |||
| continue | |||
| dict_batch[key].append(value) | |||
| return dict_batch | |||
| def unpack_batch_nested_mapping(batch:Sequence[Mapping], _parent='')->Dict: | |||
| def unpack_batch_nested_mapping(batch:Sequence[Mapping], ignore_fields:set, stop_deep_fields:set)->Dict: | |||
| """ | |||
| 将 nested 的 dict 中的内容展开到一个 flat dict 中 | |||
| :param batch: | |||
| :param _parent: 内部使用 | |||
| :param ignore_fields: 需要忽略的 field 。 | |||
| :param stop_deep_fields: 不需要继续往下衍射的 | |||
| :return: | |||
| """ | |||
| dict_batch = defaultdict(list) | |||
| if _parent != '': | |||
| _parent += NESTED_DICT_SEPARATOR | |||
| for sample in batch: | |||
| for key, value in sample.items(): | |||
| if isinstance(value, Mapping): | |||
| _dict_batch = _unpack_batch_nested_mapping(value, _parent=_parent + key) | |||
| if key in ignore_fields: | |||
| continue | |||
| if isinstance(value, Mapping) and key not in stop_deep_fields: | |||
| _dict_batch = _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent=(key,)) | |||
| for key, value in _dict_batch.items(): | |||
| dict_batch[key].append(value) | |||
| else: | |||
| dict_batch[_parent + key].append(value) | |||
| dict_batch[key].append(value) | |||
| return dict_batch | |||
| def _unpack_batch_nested_mapping(value, _parent)->Dict: | |||
| def _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent)->Dict: | |||
| _dict = {} | |||
| _parent += NESTED_DICT_SEPARATOR | |||
| for k, v in value.items(): | |||
| if isinstance(v, Mapping): | |||
| __dict = _unpack_batch_nested_mapping(v, _parent=_parent + k) | |||
| _k = _parent + (k,) | |||
| if _k in ignore_fields: | |||
| continue | |||
| if isinstance(v, Mapping) and _k not in stop_deep_fields: | |||
| __dict = _unpack_batch_nested_mapping(v, ignore_fields, stop_deep_fields, _parent=_k) | |||
| _dict.update(__dict) | |||
| else: | |||
| _dict[_parent + k] = v | |||
| _dict[_k] = v | |||
| return _dict | |||
| @@ -63,10 +67,11 @@ def pack_batch_nested_mapping(batch:Mapping) -> Dict: | |||
| dicts = [] | |||
| for key, value in batch.items(): | |||
| keys = key.split(NESTED_DICT_SEPARATOR) | |||
| d = {keys[-1]: value} | |||
| for key in keys[:-1:][::-1]: | |||
| d = {key: d} | |||
| if not isinstance(key, tuple): | |||
| key = [key] | |||
| d = {key[-1]: value} | |||
| for k in key[:-1:][::-1]: | |||
| d = {k: d} | |||
| dicts.append(d) | |||
| return reduce(_merge_dict, dicts) | |||
| @@ -85,17 +90,21 @@ def _merge_dict(a, b, path=None): | |||
| return a | |||
| def unpack_batch_sequence(batch:Sequence[Sequence])->Dict: | |||
| def unpack_batch_sequence(batch:Sequence[Sequence], ignore_fields)->Dict: | |||
| """ | |||
| 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} | |||
| :param batch: | |||
| :param ignore_fields: 需要忽略的field | |||
| :return: | |||
| """ | |||
| dict_batch = defaultdict(list) | |||
| for sample in batch: | |||
| for i, content in enumerate(sample): | |||
| dict_batch[f'_{i}'].append(content) | |||
| field_name = f'_{i}' | |||
| if field_name in ignore_fields: | |||
| continue | |||
| dict_batch[field_name].append(content) | |||
| return dict_batch | |||
| @@ -1,7 +0,0 @@ | |||
| __all__ = [ | |||
| 'FDataLoader' | |||
| ] | |||
| class FDataLoader: | |||
| pass | |||
| @@ -17,7 +17,7 @@ if _NEED_IMPORT_TORCH: | |||
| from torch.utils.data import DataLoader, Sampler | |||
| from torch.utils.data._utils.collate import default_collate | |||
| else: | |||
| from ..fdataloader import FDataLoader as DataLoader | |||
| from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | |||
| class _FDataSet: | |||
| @@ -10,7 +10,7 @@ class TestNumpyNumberPadder: | |||
| def test_run(self): | |||
| padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
| a = [1, 2, 3] | |||
| assert isinstance(a, np.ndarray) | |||
| assert isinstance(padder(a), np.ndarray) | |||
| assert (padder(a) == np.array(a)).sum() == 3 | |||
| @@ -158,7 +158,7 @@ class TestCollator: | |||
| # 测试 ignore | |||
| collator = Collator(backend='raw') | |||
| collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a') | |||
| collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a')) | |||
| raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||
| findDictDiff(raw_pad_batch, collator(dict_batch)) | |||
| @@ -171,7 +171,7 @@ class TestCollator: | |||
| # 测试设置 pad 值 | |||
| collator = Collator(backend='raw') | |||
| collator.set_pad('nest_lst_int', pad_val=100) | |||
| collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a') | |||
| collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a')) | |||
| raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], | |||
| 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||
| findDictDiff(raw_pad_batch, collator(dict_batch)) | |||
| @@ -217,6 +217,72 @@ class TestCollator: | |||
| collator.set_pad('_single') | |||
| findListDiff(list_batch, collator(list_batch)) | |||
| def test_nest_ignore(self): | |||
| dict_batch = [{ | |||
| 'str': '1', | |||
| 'lst_str': ['1'], | |||
| 'int': 1, | |||
| 'lst_int': [1], | |||
| 'nest_lst_int': [[1]], | |||
| 'float': 1.1, | |||
| 'lst_float': [1.1], | |||
| 'bool': True, | |||
| 'numpy': np.ones(1), | |||
| 'dict': {'1': '1'}, | |||
| 'set': {'1'}, | |||
| 'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}} | |||
| }, | |||
| { | |||
| 'str': '2', | |||
| 'lst_str': ['2', '2'], | |||
| 'int': 2, | |||
| 'lst_int': [1, 2], | |||
| 'nest_lst_int': [[1], [1, 2]], | |||
| 'float': 2.1, | |||
| 'lst_float': [2.1], | |||
| 'bool': False, | |||
| 'numpy': np.zeros(1), | |||
| 'dict': {'1': '2'}, | |||
| 'set': {'2'}, | |||
| 'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}} | |||
| } | |||
| ] | |||
| # 测试 ignore | |||
| collator = Collator(backend='raw') | |||
| collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int')) | |||
| raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||
| 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], | |||
| 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, | |||
| 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], | |||
| 'c': {'int':[1, 1]}}} | |||
| findDictDiff(raw_pad_batch, collator(dict_batch)) | |||
| collator = Collator(backend='raw') | |||
| collator.set_pad(('nested_dict', 'c'), pad_val=None) | |||
| collator.set_ignore('str', 'int', 'lst_int') | |||
| raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||
| 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], | |||
| 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, | |||
| 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], | |||
| 'c': [{'int':1}, {'int':1}]}} | |||
| pad_batch = collator(dict_batch) | |||
| findDictDiff(raw_pad_batch, pad_batch) | |||
| collator = Collator(backend='raw') | |||
| collator.set_pad(('nested_dict', 'c'), pad_val=1) | |||
| with pytest.raises(BaseException): | |||
| collator(dict_batch) | |||
| collator = Collator(backend='raw') | |||
| collator.set_ignore('str', 'int', 'lst_int') | |||
| collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x]) | |||
| pad_batch = collator(dict_batch) | |||
| raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||
| 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], | |||
| 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, | |||
| 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], | |||
| 'c': [1, 1]}} | |||
| findDictDiff(raw_pad_batch, pad_batch) | |||
| @@ -4,25 +4,25 @@ from fastNLP.core.collators.utils import * | |||
| def test_unpack_batch_mapping(): | |||
| batch = [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] | |||
| assert unpack_batch_mapping(batch)=={'a': [[1, 2], [3]], 'b': [1, 2]} | |||
| assert unpack_batch_mapping(batch, {})=={'a': [[1, 2], [3]], 'b': [1, 2]} | |||
| def test_unpack_batch_nested_mapping(): | |||
| batch = [{'a': [1, 2], 'b': 1, 'c': {'c': 1}}, {'a': [3], 'b': 2, 'c': {'c': 2}}] | |||
| assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c': [1, 2]} | |||
| assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c','c'): [1, 2]} | |||
| batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1}}}, {'a': [3], 'b': 2, 'c': {'c': {'c': 2}}}] | |||
| assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2]} | |||
| assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2]} | |||
| batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1, 'd':[1, 1]}, 'd': [1]}}, | |||
| {'a': [3], 'b': 2, 'c': {'c': {'c': 2, 'd': [2, 2]}, 'd': [2, 2]}}] | |||
| assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2], | |||
| 'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]} | |||
| assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2], | |||
| ('c','c', 'd'):[[1, 1], [2, 2]], ('c', 'd'): [[1], [2, 2]]} | |||
| def test_pack_batch_nested_mapping(): | |||
| batch = {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2], | |||
| 'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]} | |||
| batch = {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2], | |||
| ('c', 'c', 'd'):[[1, 1], [2, 2]], ('c', 'd'): [[1], [2, 2]]} | |||
| new_batch = pack_batch_nested_mapping(batch) | |||
| assert new_batch == {'a': [[1, 2], [3]], 'b': [1, 2], | |||
| 'c': {'c':{'c': [1, 2], 'd': [[1, 1], [2, 2]]}, 'd':[[1], [2, 2]]}} | |||
| @@ -30,7 +30,7 @@ def test_pack_batch_nested_mapping(): | |||
| def test_unpack_batch_sequence(): | |||
| batch = [[1, 2, 3], [2, 4, 6]] | |||
| new_batch = unpack_batch_sequence(batch) | |||
| new_batch = unpack_batch_sequence(batch, {}) | |||
| assert new_batch == {'_0': [1, 2], '_1': [2, 4], '_2': [3, 6]} | |||