bellmake's picture
SAM3 Video Segmentation - Clean deployment
14114e8
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import inspect
from functools import wraps
from typing import Callable, TypeVar, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from torch.utils._pytree import tree_map_only
# Type variables for better type hinting
T = TypeVar("T")
Module = TypeVar("Module", bound=nn.Module)
def activation_ckpt_wrapper(module: Union[nn.Module, Callable]) -> Callable:
"""
Wraps a given module to enable or disable activation checkpointing.
Activation checkpointing (gradient checkpointing) trades compute for memory by
recomputing intermediate activations during the backward pass instead of storing
them in memory during the forward pass.
When activation checkpointing is enabled, the wrapper expects only keyword arguments,
and it maps these to positional arguments based on the module's signature.
Args:
module: The module or function to wrap with activation checkpointing
Returns:
A wrapped callable that supports activation checkpointing
Usage:
The returned wrapper function can be called with the same arguments as the
original module, with an additional `act_ckpt_enable` keyword argument to control
activation checkpointing and optional `use_reentrant` parameter.
Example:
```python
wrapped_module = activation_ckpt_wrapper(my_module)
output = wrapped_module(x=input_tensor, y=another_tensor, act_ckpt_enable=True)
```
"""
@wraps(module)
def act_ckpt_wrapper(
*args, act_ckpt_enable: bool = True, use_reentrant: bool = False, **kwargs
):
if act_ckpt_enable:
if len(args) > 0:
raise ValueError(
"This wrapper expects keyword arguments only when `act_ckpt_enable=True`"
)
# Get the signature of the target function/module
callable_fn = module.forward if isinstance(module, nn.Module) else module
sig = inspect.signature(callable_fn)
# Create a mapping of parameter names to their default values
param_defaults = {
name: param.default for name, param in sig.parameters.items()
}
args = []
for p_name in param_defaults.keys():
if p_name in kwargs:
args.append(kwargs.pop(p_name))
elif param_defaults[p_name] is not inspect.Parameter.empty:
# Set arg to default value if it's not in kwargs. Useful for primitive types or args that default to None
args.append(param_defaults[p_name])
elif (
sig.parameters[p_name].kind is not inspect.Parameter.VAR_KEYWORD
): # Skip **kwargs parameter
raise ValueError(f"Missing positional argument: {p_name}")
# Scan remaining kwargs for torch.Tensor
remaining_keys = list(kwargs.keys())
for key in remaining_keys:
if isinstance(kwargs[key], torch.Tensor):
# Remove the tensor from kwargs, assuming it's not required by the module.
# If it is required, the module's signature should be modified to accept it as a positional or keyword argument.
kwargs[key] = "_REMOVED_BY_ACT_CKPT_WRAPPER_"
ret = checkpoint.checkpoint(
module, *args, use_reentrant=use_reentrant, **kwargs
)
else:
ret = module(*args, **kwargs)
return ret
return act_ckpt_wrapper
def clone_output_wrapper(f: Callable[..., T]) -> Callable[..., T]:
"""
Clone the CUDA output tensors of a function to avoid in-place operations.
This wrapper is useful when working with torch.compile to prevent errors
related to in-place operations on tensors.
Args:
f: The function whose CUDA tensor outputs should be cloned
Returns:
A wrapped function that clones any CUDA tensor outputs
"""
@wraps(f)
def wrapped(*args, **kwargs):
outputs = f(*args, **kwargs)
return tree_map_only(
torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs
)
return wrapped