mirror of
https://gitcode.com/gh_mirrors/eas/EasyFace.git
synced 2025-12-30 13:02:29 +00:00
38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
from collections.abc import Mapping
|
|
|
|
import torch
|
|
|
|
from modelscope.outputs import ModelOutputBase
|
|
|
|
|
|
def to_device(batch, device, non_blocking=False):
|
|
"""Put the data to the target cuda device just before the forward function.
|
|
Args:
|
|
batch: The batch data out of the dataloader.
|
|
device: (str | torch.device): The target device for the data.
|
|
|
|
Returns: The data to the target device.
|
|
|
|
"""
|
|
if isinstance(batch, ModelOutputBase):
|
|
for idx in range(len(batch)):
|
|
batch[idx] = to_device(batch[idx], device)
|
|
return batch
|
|
elif isinstance(batch, dict) or isinstance(batch, Mapping):
|
|
if hasattr(batch, '__setitem__'):
|
|
# Reuse mini-batch to keep attributes for prediction.
|
|
for k, v in batch.items():
|
|
batch[k] = to_device(v, device)
|
|
return batch
|
|
else:
|
|
return type(batch)(
|
|
{k: to_device(v, device)
|
|
for k, v in batch.items()})
|
|
elif isinstance(batch, (tuple, list)):
|
|
return type(batch)(to_device(v, device) for v in batch)
|
|
elif isinstance(batch, torch.Tensor):
|
|
return batch.to(device, non_blocking=non_blocking)
|
|
else:
|
|
return batch
|