| | from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, Union, overload) |
| | import torch |
| | from torch.func import functional_call |
| |
|
| | @overload |
| | def flatten_bn(x: torch.Tensor) -> torch.Tensor: |
| | ... |
| |
|
| |
|
| | @overload |
| | def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]: |
| | ... |
| |
|
| |
|
| | @overload |
| | def flatten_bn( |
| | x: Union[List[torch.Tensor], torch.Tensor], |
| | *, |
| | concat: Literal[True], |
| | ) -> torch.Tensor: |
| | ... |
| |
|
| |
|
| | @overload |
| | def flatten_bn( |
| | x: Union[List[torch.Tensor], torch.Tensor], |
| | *, |
| | concat: bool = False, |
| | ) -> Union[List[torch.Tensor], torch.Tensor]: |
| | ... |
| |
|
| |
|
| | def flatten_bn( |
| | x: Union[List[torch.Tensor], torch.Tensor], |
| | *, |
| | concat: bool = False, |
| | ) -> Union[List[torch.Tensor], torch.Tensor]: |
| | """ |
| | Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs. |
| | |
| | The input tensor should have shape ``(B, N, ...)```. |
| | """ |
| | if isinstance(x, torch.Tensor): |
| | return x.flatten(0, 1) |
| |
|
| | if concat: |
| | return torch.cat(x) |
| |
|
| | return [x_n for x_b in x for x_n in x_b] |
| |
|
| | def _flatten_embeddings(embeddings: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Recursively flattens and concatenates NestedTensors on all but the last |
| | dimension. |
| | """ |
| |
|
| | if isinstance(embeddings, torch.Tensor): |
| | |
| | return embeddings.flatten(0, -2) |
| |
|
| | return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) |
| |
|
| | def _embedding_count_expression(embeddings: torch.Tensor) -> str: |
| | """ |
| | Constructs a debugging representation of the number of embeddings in the |
| | Tensors. |
| | """ |
| |
|
| | if isinstance(embeddings, torch.Tensor): |
| | return " x ".join([str(dim) for dim in embeddings.shape[:-1]]) |
| |
|
| | return " + ".join( |
| | _embedding_count_expression(inner) for inner in embeddings) |
| |
|
| | def _merge_multimodal_embeddings( |
| | inputs_embeds: torch.Tensor, |
| | is_multimodal: torch.Tensor, |
| | multimodal_embeddings: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the |
| | positions in ``inputs_embeds`` corresponding to placeholder tokens in |
| | ``input_ids``. |
| | |
| | Note: |
| | This updates ``inputs_embeds`` in place. |
| | """ |
| | num_expected_tokens = is_multimodal.sum().item() |
| | assert isinstance(num_expected_tokens, int) |
| | |
| | flattened = _flatten_embeddings(multimodal_embeddings) |
| | if flattened.shape[0] != num_expected_tokens: |
| | expr = _embedding_count_expression(multimodal_embeddings) |
| | raise ValueError( |
| | f"Attempted to assign {expr} = {flattened.shape[0]} " |
| | f"multimodal tokens to {num_expected_tokens} placeholders") |
| |
|
| | inputs_embeds[is_multimodal] = flattened |
| | return inputs_embeds |
| |
|
| | def merge_multimodal_embeddings( |
| | input_ids: torch.Tensor, |
| | inputs_embeds: torch.Tensor, |
| | multimodal_embeddings: torch.Tensor, |
| | placeholder_token_id: Union[int, List[int]], |
| | ) -> torch.Tensor: |
| | """ |
| | Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the |
| | positions in ``inputs_embeds`` corresponding to placeholder tokens in |
| | ``input_ids``. |
| | |
| | ``placeholder_token_id`` can be a list of token ids (e.g, token ids |
| | of img_start, img_break, and img_end tokens) when needed: This means |
| | the order of these tokens in the ``input_ids`` MUST MATCH the order of |
| | their embeddings in ``multimodal_embeddings`` since we need to |
| | slice-merge instead of individually scattering. |
| | |
| | For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where |
| | - T is text token |
| | - S is image start token |
| | - I is image embedding token |
| | - B is image break token |
| | - E is image end token. |
| | |
| | Then the image embeddings (that correspond to I's) from vision encoder |
| | must be padded with embeddings of S, B, and E in the same order of |
| | input_ids for a correct embedding merge. |
| | |
| | Note: |
| | This updates ``inputs_embeds`` in place. |
| | """ |
| | if isinstance(placeholder_token_id, list): |
| | placeholder_token_id = torch.tensor(placeholder_token_id, |
| | device=input_ids.device) |
| | return _merge_multimodal_embeddings( |
| | inputs_embeds, |
| | torch.isin(input_ids, placeholder_token_id), |
| | multimodal_embeddings, |
| | ) |
| | return _merge_multimodal_embeddings( |
| | inputs_embeds, |
| | (input_ids == placeholder_token_id), |
| | multimodal_embeddings, |
| | ) |
| |
|