|
|
|
@@ -1,5 +1,6 @@ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
# Part of the implementation is borrowed from huggingface/transformers. |
|
|
|
from collections.abc import Mapping |
|
|
|
|
|
|
|
|
|
|
|
def torch_nested_numpify(tensors): |
|
|
|
@@ -31,7 +32,7 @@ def torch_default_data_collator(features): |
|
|
|
# features = [vars(f) for f in features] |
|
|
|
first = features[0] |
|
|
|
|
|
|
|
if isinstance(first, dict): |
|
|
|
if isinstance(first, Mapping): |
|
|
|
batch = {} |
|
|
|
# Special handling for labels. |
|
|
|
# Ensure that tensor is created with the correct type |
|
|
|
@@ -65,9 +66,9 @@ def torch_default_data_collator(features): |
|
|
|
batch = [] |
|
|
|
for idx in range(len(first)): |
|
|
|
if isinstance(first[idx], torch.Tensor): |
|
|
|
batch.append(torch.stack([f[k] for f in features])) |
|
|
|
batch.append(torch.stack([f[idx] for f in features])) |
|
|
|
else: |
|
|
|
batch.append(torch.tensor([f[k] for f in features])) |
|
|
|
batch.append(torch.tensor([f[idx] for f in features])) |
|
|
|
else: |
|
|
|
if isinstance(first, torch.Tensor): |
|
|
|
batch = torch.stack(features) |
|
|
|
|