From e365023862995b921f74d902a69667933fa58060 Mon Sep 17 00:00:00 2001 From: "feiwu.yfw" Date: Mon, 5 Sep 2022 19:36:46 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dprocessor=E8=BE=93=E5=87=BAto?= =?UTF-8?q?rch.tensor=E6=97=B6=E8=A2=AB=E8=BD=AC=E4=B8=BAnumpy=E7=9A=84?= =?UTF-8?q?=E5=BC=82=E5=B8=B8=20=20=20=20=20=20=20=20=20Link:=20https://co?= =?UTF-8?q?de.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10021802?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix to_torch_dataset --- modelscope/msdatasets/ms_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index 28a95643..691db4fe 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -70,12 +70,12 @@ class MsIterableDataset(torch.utils.data.IterableDataset): for idx in range(iter_start, iter_end): item_dict = self.dataset[idx] res = { - k: np.array(item_dict[k]) + k: torch.tensor(item_dict[k]) for k in self.columns if k in self.retained_columns } for preprocessor in self.preprocessor_list: res.update({ - k: np.array(v) + k: torch.tensor(v) for k, v in preprocessor(item_dict).items() if k in self.retained_columns })