| @@ -864,9 +864,13 @@ class DataSet: | |||||
| results = [ins for ins in self if not func(ins)] | results = [ins for ins in self if not func(ins)] | ||||
| if len(results) != 0: | if len(results) != 0: | ||||
| dataset = DataSet(results) | dataset = DataSet(results) | ||||
| return dataset | |||||
| else: | else: | ||||
| return DataSet() | |||||
| dataset = DataSet() | |||||
| for name in self.field_arrays.keys(): | |||||
| empty_field = FieldArray(name, [None]) | |||||
| empty_field.content = [] | |||||
| dataset.field_arrays[name] = empty_field | |||||
| return dataset | |||||
| def split(self, ratio: float, shuffle=True): | def split(self, ratio: float, shuffle=True): | ||||
| r""" | r""" | ||||
| @@ -47,6 +47,10 @@ class FieldArray: | |||||
| """ | """ | ||||
| self.content.pop(index) | self.content.pop(index) | ||||
| def __iter__(self): | |||||
| for idx in range(len(self)): | |||||
| yield self[idx] | |||||
| def __getitem__(self, indices: Union[int, List[int]]): | def __getitem__(self, indices: Union[int, List[int]]): | ||||
| return self.get(indices) | return self.get(indices) | ||||
| @@ -354,6 +354,40 @@ class DataBundle: | |||||
| progress_bar=progress_bar, progress_desc=progress_desc) | progress_bar=progress_bar, progress_desc=progress_desc) | ||||
| return res | return res | ||||
| def add_seq_len(self, field_name: str, new_field_name='seq_len', ignore_miss_dataset: bool = True): | |||||
| r""" | |||||
| 将使用 :func:`len` 直接对每个 dataset 的 ``field_name`` 中每个元素作用,将其结果作为 sequence length, 并放入 | |||||
| ``new_field_name`` 这个 field。 | |||||
| :param field_name: 需要处理的 field_name | |||||
| :param new_field_name: 新的 field_name | |||||
| :param ignore_miss_dataset: 如果为 ``True`` ,则当 ``field_name`` 在某个 dataset 内不存在时,直接忽略该 dataset, | |||||
| 如果为 ``False`` 则会报错。 | |||||
| :return: | |||||
| """ | |||||
| return self.apply_field(len, field_name, new_field_name=new_field_name, ignore_miss_dataset=ignore_miss_dataset) | |||||
| def drop(self, func: Callable, inplace=True): | |||||
| r""" | |||||
| 删除某些 Instance。 需要注意的是 ``func`` 接受一个 Instance ,返回 bool 值。返回值为 ``True`` 时, | |||||
| 该 Instance 会被移除或者不会包含在返回的 DataBundle 中。 | |||||
| :param func: 接受一个 Instance 作为参数,返回 bool 值。为 ``True`` 时删除该 instance | |||||
| :param inplace: 是否在当前 DataBundle 中直接删除 instance;如果为 False,将返回一个新的 DataBundle。 | |||||
| :return: DataSet | |||||
| """ | |||||
| if inplace: | |||||
| for name, dataset in self.datasets.items(): | |||||
| dataset.drop(func, inplace) | |||||
| return self | |||||
| else: | |||||
| data_bundle = DataBundle(vocabs=self.vocabs) | |||||
| for name, dataset in self.datasets.items(): | |||||
| res = dataset.drop(func, inplace) | |||||
| data_bundle.set_dataset(res, name) | |||||
| return data_bundle | |||||
| def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": | def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": | ||||
| """ | """ | ||||
| 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | ||||
| @@ -0,0 +1,64 @@ | |||||
| import pytest | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.io.data_bundle import DataBundle | |||||
| def test_add_seq_len(): | |||||
| dataset1 = DataSet({ | |||||
| "x": [[0,1,2], [5,3,2,3], [5,21,5,10], [3,6,8,1]] | |||||
| }) | |||||
| dataset2 = DataSet({ | |||||
| "x": [[0,1,2,3,4], [5,3,2,3], [5,20,45,1,98], [3,6,8,3,6,31]] | |||||
| }) | |||||
| dataset3 = DataSet({ | |||||
| "x": [[0,1,2,7,5,2], [5,3], [0], [3,6,8]] | |||||
| }) | |||||
| data_bundle = DataBundle(datasets={ | |||||
| "dataset1": dataset1, | |||||
| "dataset2": dataset2, | |||||
| "dataset3": dataset3 | |||||
| }) | |||||
| data_bundle.add_seq_len("x") | |||||
| print(data_bundle.get_dataset("dataset1")) | |||||
| for i, data in enumerate(data_bundle.get_dataset("dataset1")): | |||||
| print(data["seq_len"], dataset1["x"][i]) | |||||
| assert data["seq_len"] == len(dataset1["x"][i]) | |||||
| for i, data in enumerate(data_bundle.get_dataset("dataset2")): | |||||
| assert data["seq_len"] == len(dataset2["x"][i]) | |||||
| for i, data in enumerate(data_bundle.get_dataset("dataset3")): | |||||
| assert data["seq_len"] == len(dataset3["x"][i]) | |||||
| @pytest.mark.parametrize("inplace", [True, False]) | |||||
| def test_drop(inplace): | |||||
| dataset1 = DataSet({ | |||||
| "x": [0, 1, 1, 4, 2, 1, 0, 1, 1, 6, 7, 1] | |||||
| }) | |||||
| dataset2 = DataSet({ | |||||
| "x": [0, 0, 0, 0, 0] | |||||
| }) | |||||
| dataset3 = DataSet({ | |||||
| "x": [1, 1, 1, 1, 1, 2, 3, 4] | |||||
| }) | |||||
| data_bundle = DataBundle(datasets={ | |||||
| "dataset1": dataset1, | |||||
| "dataset2": dataset2, | |||||
| "dataset3": dataset3 | |||||
| }) | |||||
| res = data_bundle.drop(lambda x: x["x"] == 0, inplace) | |||||
| if inplace: | |||||
| assert res is data_bundle | |||||
| else: | |||||
| assert not (res is data_bundle) | |||||
| assert data_bundle.get_dataset("dataset1")["x"] == dataset1["x"] | |||||
| assert data_bundle.get_dataset("dataset2")["x"] == dataset2["x"] | |||||
| assert data_bundle.get_dataset("dataset3")["x"] == dataset3["x"] | |||||
| dataset1_drop = [1, 1, 4, 2, 1, 1, 1, 6, 7, 1] | |||||
| for i, data in enumerate(res.get_dataset("dataset1")["x"]): | |||||
| assert data == dataset1_drop[i] | |||||
| dataset2_drop = [] | |||||
| for i, data in enumerate(res.get_dataset("dataset2")["x"]): | |||||
| assert data == dataset2_drop[i] | |||||
| dataset3_drop = [1, 1, 1, 1, 1, 2, 3, 4] | |||||
| for i, data in enumerate(res.get_dataset("dataset3")["x"]): | |||||
| assert data == dataset3_drop[i] | |||||