|
|
|
@@ -70,12 +70,12 @@ class MsIterableDataset(torch.utils.data.IterableDataset): |
|
|
|
for idx in range(iter_start, iter_end): |
|
|
|
item_dict = self.dataset[idx] |
|
|
|
res = { |
|
|
|
k: np.array(item_dict[k]) |
|
|
|
k: torch.tensor(item_dict[k]) |
|
|
|
for k in self.columns if k in self.retained_columns |
|
|
|
} |
|
|
|
for preprocessor in self.preprocessor_list: |
|
|
|
res.update({ |
|
|
|
k: np.array(v) |
|
|
|
k: torch.tensor(v) |
|
|
|
for k, v in preprocessor(item_dict).items() |
|
|
|
if k in self.retained_columns |
|
|
|
}) |
|
|
|
|