bellmake's picture
SAM3 Video Segmentation - Clean deployment
14114e8
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from collections import defaultdict
from dataclasses import fields, is_dataclass
from typing import Any, Mapping, Protocol, runtime_checkable
import torch
def _is_named_tuple(x) -> bool:
return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields")
@runtime_checkable
class _CopyableData(Protocol):
def to(self, device: torch.device, *args: Any, **kwargs: Any):
"""Copy data to the specified device"""
...
def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any):
"""Function that recursively copies data to a torch.device.
Args:
data: The data to copy to device
device: The device to which the data should be copied
args: positional arguments that will be passed to the `to` call
kwargs: keyword arguments that will be passed to the `to` call
Returns:
The data on the correct device
"""
if _is_named_tuple(data):
return type(data)(
**copy_data_to_device(data._asdict(), device, *args, **kwargs)
)
elif isinstance(data, (list, tuple)):
return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data)
elif isinstance(data, defaultdict):
return type(data)(
data.default_factory,
{
k: copy_data_to_device(v, device, *args, **kwargs)
for k, v in data.items()
},
)
elif isinstance(data, Mapping):
return type(data)(
{
k: copy_data_to_device(v, device, *args, **kwargs)
for k, v in data.items()
}
)
elif is_dataclass(data) and not isinstance(data, type):
new_data_class = type(data)(
**{
field.name: copy_data_to_device(
getattr(data, field.name), device, *args, **kwargs
)
for field in fields(data)
if field.init
}
)
for field in fields(data):
if not field.init:
setattr(
new_data_class,
field.name,
copy_data_to_device(
getattr(data, field.name), device, *args, **kwargs
),
)
return new_data_class
elif isinstance(data, _CopyableData):
return data.to(device, *args, **kwargs)
return data