| @@ -0,0 +1,86 @@ | |||||
| from collections import defaultdict | |||||
| import torch | |||||
| class Batch(object): | |||||
| def __init__(self, dataset, sampler, batch_size): | |||||
| self.dataset = dataset | |||||
| self.sampler = sampler | |||||
| self.batch_size = batch_size | |||||
| self.idx_list = None | |||||
| self.curidx = 0 | |||||
| def __iter__(self): | |||||
| self.idx_list = self.sampler(self.dataset) | |||||
| self.curidx = 0 | |||||
| self.lengths = self.dataset.get_length() | |||||
| return self | |||||
| def __next__(self): | |||||
| if self.curidx >= len(self.idx_list): | |||||
| raise StopIteration | |||||
| else: | |||||
| endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||||
| padding_length = {field_name : max(field_length[self.curidx: endidx]) | |||||
| for field_name, field_length in self.lengths.items()} | |||||
| batch_x, batch_y = defaultdict(list), defaultdict(list) | |||||
| for idx in range(self.curidx, endidx): | |||||
| x, y = self.dataset.to_tensor(idx, padding_length) | |||||
| for name, tensor in x.items(): | |||||
| batch_x[name].append(tensor) | |||||
| for name, tensor in y.items(): | |||||
| batch_y[name].append(tensor) | |||||
| for batch in (batch_x, batch_y): | |||||
| for name, tensor_list in batch.items(): | |||||
| print(name, " ", tensor_list) | |||||
| batch[name] = torch.stack(tensor_list, dim=0) | |||||
| self.curidx += endidx | |||||
| return batch_x, batch_y | |||||
| if __name__ == "__main__": | |||||
| """simple running example | |||||
| """ | |||||
| from field import TextField, LabelField | |||||
| from instance import Instance | |||||
| from dataset import DataSet | |||||
| texts = ["i am a cat", | |||||
| "this is a test of new batch", | |||||
| "haha" | |||||
| ] | |||||
| labels = [0, 1, 0] | |||||
| # prepare vocabulary | |||||
| vocab = {} | |||||
| for text in texts: | |||||
| for tokens in text.split(): | |||||
| if tokens not in vocab: | |||||
| vocab[tokens] = len(vocab) | |||||
| # prepare input dataset | |||||
| data = DataSet() | |||||
| for text, label in zip(texts, labels): | |||||
| x = TextField(text.split(), False) | |||||
| y = LabelField(label, is_target=True) | |||||
| ins = Instance(text=x, label=y) | |||||
| data.append(ins) | |||||
| # use vocabulary to index data | |||||
| data.index_field("text", vocab) | |||||
| # define naive sampler for batch class | |||||
| class SeqSampler: | |||||
| def __call__(self, dataset): | |||||
| return list(range(len(dataset))) | |||||
| # use bacth to iterate dataset | |||||
| batcher = Batch(data, SeqSampler(), 2) | |||||
| for epoch in range(3): | |||||
| for batch_x, batch_y in batcher: | |||||
| print(batch_x) | |||||
| print(batch_y) | |||||
| # do stuff | |||||
| @@ -0,0 +1,29 @@ | |||||
| from collections import defaultdict | |||||
| class DataSet(list): | |||||
| def __init__(self, name="", instances=None): | |||||
| list.__init__([]) | |||||
| self.name = name | |||||
| if instances is not None: | |||||
| self.extend(instances) | |||||
| def index_all(self, vocab): | |||||
| for ins in self: | |||||
| ins.index_all(vocab) | |||||
| def index_field(self, field_name, vocab): | |||||
| for ins in self: | |||||
| ins.index_field(field_name, vocab) | |||||
| def to_tensor(self, idx: int, padding_length: dict): | |||||
| ins = self[idx] | |||||
| return ins.to_tensor(padding_length) | |||||
| def get_length(self): | |||||
| lengths = defaultdict(list) | |||||
| for ins in self: | |||||
| for field_name, field_length in ins.get_length().items(): | |||||
| lengths[field_name].append(field_length) | |||||
| return lengths | |||||
| @@ -0,0 +1,70 @@ | |||||
| import torch | |||||
| class Field(object): | |||||
| def __init__(self, is_target: bool): | |||||
| self.is_target = is_target | |||||
| def index(self, vocab): | |||||
| pass | |||||
| def get_length(self): | |||||
| pass | |||||
| def to_tensor(self, padding_length): | |||||
| pass | |||||
| class TextField(Field): | |||||
| def __init__(self, text: list, is_target): | |||||
| """ | |||||
| :param list text: | |||||
| """ | |||||
| super(TextField, self).__init__(is_target) | |||||
| self.text = text | |||||
| self._index = None | |||||
| def index(self, vocab): | |||||
| if self._index is None: | |||||
| self._index = [vocab[c] for c in self.text] | |||||
| else: | |||||
| print('error') | |||||
| return self._index | |||||
| def get_length(self): | |||||
| return len(self.text) | |||||
| def to_tensor(self, padding_length: int): | |||||
| pads = [] | |||||
| if self._index is None: | |||||
| print('error') | |||||
| if padding_length > self.get_length(): | |||||
| pads = [0 for i in range(padding_length - self.get_length())] | |||||
| # (length, ) | |||||
| return torch.LongTensor(self._index + pads) | |||||
| class LabelField(Field): | |||||
| def __init__(self, label, is_target=True): | |||||
| super(LabelField, self).__init__(is_target) | |||||
| self.label = label | |||||
| self._index = None | |||||
| def get_length(self): | |||||
| return 1 | |||||
| def index(self, vocab): | |||||
| if self._index is None: | |||||
| self._index = vocab[self.label] | |||||
| else: | |||||
| pass | |||||
| return self._index | |||||
| def to_tensor(self, padding_length): | |||||
| if self._index is None: | |||||
| return torch.LongTensor([self.label]) | |||||
| else: | |||||
| return torch.LongTensor([self._index]) | |||||
| if __name__ == "__main__": | |||||
| tf = TextField("test the code".split()) | |||||
| @@ -0,0 +1,38 @@ | |||||
| class Instance(object): | |||||
| def __init__(self, **fields): | |||||
| self.fields = fields | |||||
| self.has_index = False | |||||
| self.indexes = {} | |||||
| def add_field(self, field_name, field): | |||||
| self.fields[field_name] = field | |||||
| def get_length(self): | |||||
| length = {name : field.get_length() for name, field in self.fields.items()} | |||||
| return length | |||||
| def index_field(self, field_name, vocab): | |||||
| """use `vocab` to index certain field | |||||
| """ | |||||
| self.indexes[field_name] = self.fields[field_name].index(vocab) | |||||
| def index_all(self, vocab): | |||||
| """use `vocab` to index all fields | |||||
| """ | |||||
| if self.has_index: | |||||
| print("error") | |||||
| return self.indexes | |||||
| indexes = {name : field.index(vocab) for name, field in self.fields.items()} | |||||
| self.indexes = indexes | |||||
| return indexes | |||||
| def to_tensor(self, padding_length: dict): | |||||
| tensorX = {} | |||||
| tensorY = {} | |||||
| for name, field in self.fields.items(): | |||||
| if field.is_target: | |||||
| tensorY[name] = field.to_tensor(padding_length[name]) | |||||
| else: | |||||
| tensorX[name] = field.to_tensor(padding_length[name]) | |||||
| return tensorX, tensorY | |||||