Browse Source

[to #42322933] fix bug: run failed in tensor_utils.py

1. 修复default data collator的输入类型为tuple时运行会失败的问题
2. 修复default data collator的输入类型为dict时不兼容BatchEncoding的问题
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9403517

    * fix bug: 1. run failed when datatype is tuple 2. change type checking from dict to Mapping to fit transformers.datasets.BatchEncoding
master
yuze.zyz 3 years ago
parent
commit
8d0d6252ca
1 changed files with 4 additions and 3 deletions
  1. +4
    -3
      modelscope/utils/tensor_utils.py

+ 4
- 3
modelscope/utils/tensor_utils.py View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from huggingface/transformers. # Part of the implementation is borrowed from huggingface/transformers.
from collections.abc import Mapping




def torch_nested_numpify(tensors): def torch_nested_numpify(tensors):
@@ -31,7 +32,7 @@ def torch_default_data_collator(features):
# features = [vars(f) for f in features] # features = [vars(f) for f in features]
first = features[0] first = features[0]


if isinstance(first, dict):
if isinstance(first, Mapping):
batch = {} batch = {}
# Special handling for labels. # Special handling for labels.
# Ensure that tensor is created with the correct type # Ensure that tensor is created with the correct type
@@ -65,9 +66,9 @@ def torch_default_data_collator(features):
batch = [] batch = []
for idx in range(len(first)): for idx in range(len(first)):
if isinstance(first[idx], torch.Tensor): if isinstance(first[idx], torch.Tensor):
batch.append(torch.stack([f[k] for f in features]))
batch.append(torch.stack([f[idx] for f in features]))
else: else:
batch.append(torch.tensor([f[k] for f in features]))
batch.append(torch.tensor([f[idx] for f in features]))
else: else:
if isinstance(first, torch.Tensor): if isinstance(first, torch.Tensor):
batch = torch.stack(features) batch = torch.stack(features)


Loading…
Cancel
Save