diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index 28a95643..691db4fe 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -70,12 +70,12 @@ class MsIterableDataset(torch.utils.data.IterableDataset): for idx in range(iter_start, iter_end): item_dict = self.dataset[idx] res = { - k: np.array(item_dict[k]) + k: torch.tensor(item_dict[k]) for k in self.columns if k in self.retained_columns } for preprocessor in self.preprocessor_list: res.update({ - k: np.array(v) + k: torch.tensor(v) for k, v in preprocessor(item_dict).items() if k in self.retained_columns })