|
|
|
|
|
""" |
|
|
Misc functions, including distributed helpers. |
|
|
""" |
|
|
|
|
|
import collections |
|
|
import re |
|
|
|
|
|
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass |
|
|
from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
MyTensor = Union[torch.Tensor, List[Any]] |
|
|
|
|
|
|
|
|
def interpolate( |
|
|
input, size=None, scale_factor=None, mode="nearest", align_corners=None |
|
|
): |
|
|
|
|
|
""" |
|
|
Equivalent to nn.functional.interpolate, but with support for empty channel sizes. |
|
|
""" |
|
|
if input.numel() > 0: |
|
|
return torch.nn.functional.interpolate( |
|
|
input, size, scale_factor, mode, align_corners |
|
|
) |
|
|
|
|
|
assert ( |
|
|
input.shape[0] != 0 or input.shape[1] != 0 |
|
|
), "At least one of the two first dimensions must be non zero" |
|
|
|
|
|
if input.shape[1] == 0: |
|
|
|
|
|
return torch.nn.functional.interpolate( |
|
|
input.transpose(0, 1), size, scale_factor, mode, align_corners |
|
|
).transpose(0, 1) |
|
|
|
|
|
|
|
|
return torch.nn.functional.interpolate( |
|
|
input, size, scale_factor, mode, align_corners |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BatchedPointer: |
|
|
stage_ids: MyTensor |
|
|
stage_ids__type = torch.long |
|
|
query_ids: MyTensor |
|
|
query_ids__type = torch.long |
|
|
object_ids: MyTensor |
|
|
object_ids__type = torch.long |
|
|
ptr_mask: MyTensor |
|
|
ptr_mask__type = torch.bool |
|
|
ptr_types: MyTensor |
|
|
ptr_types__type = torch.long |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FindStage: |
|
|
img_ids: MyTensor |
|
|
img_ids__type = torch.long |
|
|
text_ids: MyTensor |
|
|
text_ids__type = torch.long |
|
|
|
|
|
input_boxes: MyTensor |
|
|
input_boxes__type = torch.float |
|
|
input_boxes_mask: MyTensor |
|
|
input_boxes_mask__type = torch.bool |
|
|
input_boxes_label: MyTensor |
|
|
input_boxes_label__type = torch.long |
|
|
|
|
|
input_points: MyTensor |
|
|
input_points__type = torch.float |
|
|
input_points_mask: MyTensor |
|
|
input_points_mask__type = torch.bool |
|
|
|
|
|
|
|
|
|
|
|
object_ids: Optional[List[List]] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BatchedFindTarget: |
|
|
|
|
|
num_boxes: MyTensor |
|
|
num_boxes__type = torch.long |
|
|
|
|
|
|
|
|
boxes: MyTensor |
|
|
boxes__type = torch.float |
|
|
|
|
|
|
|
|
boxes_padded: MyTensor |
|
|
boxes_padded__type = torch.float |
|
|
|
|
|
|
|
|
repeated_boxes: MyTensor |
|
|
repeated_boxes__type = torch.float |
|
|
|
|
|
|
|
|
segments: Optional[MyTensor] |
|
|
segments__type = torch.bool |
|
|
|
|
|
|
|
|
semantic_segments: Optional[MyTensor] |
|
|
semantic_segments__type = torch.bool |
|
|
|
|
|
is_valid_segment: Optional[MyTensor] |
|
|
is_valid_segment__type = torch.bool |
|
|
|
|
|
|
|
|
is_exhaustive: MyTensor |
|
|
is_exhaustive__type = torch.bool |
|
|
|
|
|
|
|
|
object_ids: MyTensor |
|
|
object_ids__type = torch.long |
|
|
object_ids_padded: MyTensor |
|
|
object_ids_padded__type = torch.long |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BatchedInferenceMetadata: |
|
|
"""All metadata required to post-process a find stage""" |
|
|
|
|
|
|
|
|
coco_image_id: MyTensor |
|
|
coco_image_id__type = torch.long |
|
|
|
|
|
|
|
|
original_image_id: MyTensor |
|
|
original_image_id__type = torch.long |
|
|
|
|
|
|
|
|
original_category_id: MyTensor |
|
|
original_category_id__type = torch.int |
|
|
|
|
|
|
|
|
original_size: MyTensor |
|
|
original_size__type = torch.long |
|
|
|
|
|
|
|
|
object_id: MyTensor |
|
|
object_id__type = torch.long |
|
|
|
|
|
|
|
|
frame_index: MyTensor |
|
|
frame_index__type = torch.long |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_conditioning_only: List[Optional[bool]] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BatchedDatapoint: |
|
|
img_batch: torch.Tensor |
|
|
find_text_batch: List[str] |
|
|
find_inputs: List[FindStage] |
|
|
find_targets: List[BatchedFindTarget] |
|
|
find_metadatas: List[BatchedInferenceMetadata] |
|
|
raw_images: Optional[List[Any]] = None |
|
|
|
|
|
|
|
|
def convert_my_tensors(obj): |
|
|
def is_optional_field(field) -> bool: |
|
|
return get_origin(field) is Union and type(None) in get_args(field) |
|
|
|
|
|
for field in fields(obj): |
|
|
if is_dataclass(getattr(obj, field.name)): |
|
|
convert_my_tensors(getattr(obj, field.name)) |
|
|
continue |
|
|
|
|
|
field_type = field.type |
|
|
if is_optional_field(field.type): |
|
|
field_type = Union[get_args(field.type)[:-1]] |
|
|
|
|
|
if field_type != MyTensor or getattr(obj, field.name) is None: |
|
|
continue |
|
|
|
|
|
elif len(getattr(obj, field.name)) and isinstance( |
|
|
getattr(obj, field.name)[0], torch.Tensor |
|
|
): |
|
|
stack_dim = 0 |
|
|
if field.name in [ |
|
|
"input_boxes", |
|
|
"input_boxes_label", |
|
|
]: |
|
|
stack_dim = 1 |
|
|
setattr( |
|
|
obj, |
|
|
field.name, |
|
|
torch.stack(getattr(obj, field.name), dim=stack_dim).to( |
|
|
getattr(obj, field.name + "__type") |
|
|
), |
|
|
) |
|
|
else: |
|
|
setattr( |
|
|
obj, |
|
|
field.name, |
|
|
torch.as_tensor( |
|
|
getattr(obj, field.name), dtype=getattr(obj, field.name + "__type") |
|
|
), |
|
|
) |
|
|
return obj |
|
|
|