Files
EasyFace/modelscope/utils/data_utils.py
2023-03-02 11:17:26 +08:00

38 lines
1.2 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from collections.abc import Mapping
import torch
from modelscope.outputs import ModelOutputBase
def to_device(batch, device, non_blocking=False):
"""Put the data to the target cuda device just before the forward function.
Args:
batch: The batch data out of the dataloader.
device: (str | torch.device): The target device for the data.
Returns: The data to the target device.
"""
if isinstance(batch, ModelOutputBase):
for idx in range(len(batch)):
batch[idx] = to_device(batch[idx], device)
return batch
elif isinstance(batch, dict) or isinstance(batch, Mapping):
if hasattr(batch, '__setitem__'):
# Reuse mini-batch to keep attributes for prediction.
for k, v in batch.items():
batch[k] = to_device(v, device)
return batch
else:
return type(batch)(
{k: to_device(v, device)
for k, v in batch.items()})
elif isinstance(batch, (tuple, list)):
return type(batch)(to_device(v, device) for v in batch)
elif isinstance(batch, torch.Tensor):
return batch.to(device, non_blocking=non_blocking)
else:
return batch