Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved | |
| from collections import defaultdict | |
| from dataclasses import fields, is_dataclass | |
| from typing import Any, Mapping, Protocol, runtime_checkable | |
| import torch | |
| def _is_named_tuple(x) -> bool: | |
| return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields") | |
| class _CopyableData(Protocol): | |
| def to(self, device: torch.device, *args: Any, **kwargs: Any): | |
| """Copy data to the specified device""" | |
| ... | |
| def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any): | |
| """Function that recursively copies data to a torch.device. | |
| Args: | |
| data: The data to copy to device | |
| device: The device to which the data should be copied | |
| args: positional arguments that will be passed to the `to` call | |
| kwargs: keyword arguments that will be passed to the `to` call | |
| Returns: | |
| The data on the correct device | |
| """ | |
| if _is_named_tuple(data): | |
| return type(data)( | |
| **copy_data_to_device(data._asdict(), device, *args, **kwargs) | |
| ) | |
| elif isinstance(data, (list, tuple)): | |
| return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data) | |
| elif isinstance(data, defaultdict): | |
| return type(data)( | |
| data.default_factory, | |
| { | |
| k: copy_data_to_device(v, device, *args, **kwargs) | |
| for k, v in data.items() | |
| }, | |
| ) | |
| elif isinstance(data, Mapping): | |
| return type(data)( | |
| { | |
| k: copy_data_to_device(v, device, *args, **kwargs) | |
| for k, v in data.items() | |
| } | |
| ) | |
| elif is_dataclass(data) and not isinstance(data, type): | |
| new_data_class = type(data)( | |
| **{ | |
| field.name: copy_data_to_device( | |
| getattr(data, field.name), device, *args, **kwargs | |
| ) | |
| for field in fields(data) | |
| if field.init | |
| } | |
| ) | |
| for field in fields(data): | |
| if not field.init: | |
| setattr( | |
| new_data_class, | |
| field.name, | |
| copy_data_to_device( | |
| getattr(data, field.name), device, *args, **kwargs | |
| ), | |
| ) | |
| return new_data_class | |
| elif isinstance(data, _CopyableData): | |
| return data.to(device, *args, **kwargs) | |
| return data | |