From ddf7b7a3e1ead6fe15b00b2e8040ca7b19df86e6 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Sat, 11 Nov 2023 00:37:41 +0800 Subject: [PATCH] [ENH] create ListData --- abl/structures/list_data.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/abl/structures/list_data.py b/abl/structures/list_data.py index 0feed6c..97660b4 100644 --- a/abl/structures/list_data.py +++ b/abl/structures/list_data.py @@ -6,6 +6,8 @@ from typing import Any, List, Union import numpy as np import torch +from ..utils import flatten as flatten_list +from ..utils import to_hashable from .base_data_element import BaseDataElement BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] @@ -293,6 +295,22 @@ class ListData(BaseDataElement): new_data[k] = new_values return new_data # type:ignore + def flatten(self, item: IndexType) -> List: + """Flatten self[item]. + + Returns: + list: Flattened data fields. + """ + return flatten_list(self[item]) + + def elements_num(self, item: IndexType) -> int: + """int: The number of elements in self[item].""" + return len(self.flatten(item)) + + def to_tuple(self, item: IndexType) -> tuple: + """tuple: The data fields in self[item] converted to tuple.""" + return to_hashable(self[item]) + def __len__(self) -> int: """int: The length of ListData.""" if len(self._data_fields) > 0: