Browse Source

修复processor输出torch.tensor时被转为numpy的异常

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10021802

    * fix to_torch_dataset
master
feiwu.yfw 3 years ago
parent
commit
e365023862
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      modelscope/msdatasets/ms_dataset.py

+ 2
- 2
modelscope/msdatasets/ms_dataset.py View File

@@ -70,12 +70,12 @@ class MsIterableDataset(torch.utils.data.IterableDataset):
for idx in range(iter_start, iter_end):
item_dict = self.dataset[idx]
res = {
k: np.array(item_dict[k])
k: torch.tensor(item_dict[k])
for k in self.columns if k in self.retained_columns
}
for preprocessor in self.preprocessor_list:
res.update({
k: np.array(v)
k: torch.tensor(v)
for k, v in preprocessor(item_dict).items()
if k in self.retained_columns
})


Loading…
Cancel
Save