| | |
| |
|
| | from dataclasses import asdict, dataclass, field |
| | from typing import Dict, List, Optional, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import PretrainedConfig, PreTrainedModel |
| |
|
| |
|
| | @dataclass |
| | class RotaryEmbeddingConfig: |
| | """ |
| | Rotary Positional Embedding configuration |
| | max_seq_len: The number of positions to encode and cache. |
| | dim: Dimension of RoPE. |
| | theta: Rotation angle. |
| | """ |
| |
|
| | max_seq_len: int |
| | dim: int |
| | theta: float |
| |
|
| |
|
| | @dataclass |
| | class PerceiverResamplerConfig: |
| | """ |
| | Parameters to initialize an PerceiverResampler model. |
| | Args: |
| | emb_layer_norm_before: Whether to use layer norm before the first attention |
| | layer. |
| | attention_heads: Number of attention heads. |
| | key_size: The dimension of the query, key, and values within each attention |
| | head, if not specified, it is set to attention_heads//embed_dim. |
| | It can be useful to set a custom key size if we want to impose the size of |
| | the query, key and value tensor ( for example, tensors shaped with |
| | power of 2 are more efficiently handled on TPUs ). |
| | Note: Parametrizing the model with a custom key size has been done in : |
| | Brown, Tom, et al. "Language models are few-shot learners." |
| | Advances in neural information processing systems 33 (2020): 1877-1901. |
| | embed_dim: Embedding dimension. |
| | ffn_embed_dim: Feed forward embedding dimension. |
| | num_layers: Number of attention blocks. |
| | ffn_activation_name: Activation function to be used in FFN block. Supported |
| | names are "gelu", "relu", "swish". |
| | use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed |
| | Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg |
| | to True and use swish as ffn_activation_name. |
| | Same principle for a gated-relu. To keep the same number of parameters in |
| | the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU. |
| | See https://arxiv.org/pdf/2002.05202.pdf for more details. |
| | resampled_length: length of the resampled output of the module |
| | use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint |
| | gradients in the forward pass to reduce the computation in the backward). |
| | """ |
| |
|
| | |
| | emb_layer_norm_before: bool = False |
| | attention_heads: int = 20 |
| | key_size: Optional[int] = None |
| | embed_dim: int = 1280 |
| | ffn_embed_dim: int = 5120 |
| | num_layers: int = 24 |
| | add_bias_kv: bool = False |
| | add_bias_ffn: bool = True |
| | ffn_activation_name: str = "gelu-no-approx" |
| | use_glu_in_ffn: bool = False |
| | resampled_length: int = 64 |
| |
|
| | |
| | use_gradient_checkpointing: bool = False |
| |
|
| | def __post_init__(self) -> None: |
| | """ |
| | Checks that the given values are compatible. |
| | """ |
| |
|
| | if self.key_size is None: |
| | if not self.embed_dim % self.attention_heads == 0: |
| | raise ValueError( |
| | f"When no key size is provided, the embedding dimension should be " |
| | f"divisible by the number of heads, however provided embedding " |
| | f"dimension is {self.embed_dim} and the number of heads is " |
| | f"{self.attention_heads}." |
| | ) |
| | self.key_size = self.embed_dim // self.attention_heads |
| |
|
| |
|
| | @dataclass |
| | class GptConfig: |
| | """ |
| | Parameters to initialize a Gpt model. |
| | NOTE: the pad token is not defined |
| | Args: |
| | vocab_size: Token vocabulary. |
| | eos_token_id: used to stop sentence generation |
| | embed_dim: Embedding dimension. |
| | ffn_embed_dim: Feed forward embedding dimension. |
| | num_heads: Number of attention heads. |
| | num_kv_heads: Number of key and value heads to support Grouped-Query and |
| | Multi-Query Attention. If None, the number of key and value heads is |
| | equal to the number of attention heads. |
| | num_layers: Number of Decoder layer_stack |
| | rope_config: The configuration for the rotary positional embeddings |
| | add_bias_ffn: Add bias in feed forward network block. |
| | ffn_activation_name: Activation function to be used in FFN block. Supported |
| | names are "gelu", "gelu-no-approx", "relu", "swish". |
| | use_glu_in_ffn: whether to use Gated Linear Unit (GLU) in Feed |
| | Forward Network (FFN) block. |
| | example: To do a swiGLU (gated-swish) put this arg |
| | to True and use swish as ffn_activation_name. |
| | Same principle for a gated-relu. |
| | add_bias_lm_head: whether to use bias in the final LM layer |
| | norm_type: The type of norm used ( pre normalization scheme ) used. can be |
| | one of ["layer_norm", "RMS_norm"] |
| | parallel_attention_ff: Whether to do the attention and the MLP in parallel, |
| | and then sum up the results as it is done in Gpt-NeoX : |
| | Black, Sid, et al. "Gpt-neox-20b: An open-source autoregressive |
| | language model." arXiv preprint arXiv:2204.06745 (2022). |
| | It is said to improve the training time of 15% when compiling with JAX |
| | use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint |
| | gradients in the forward pass to reduce the computation in the backward). |
| | add_bias_attn: Add bias to the attention mechanism (key, query, value, and |
| | output projections). |
| | """ |
| |
|
| | |
| | vocab_size: int |
| | eos_token_id: int |
| |
|
| | |
| | embed_dim: int = 16 |
| | ffn_embed_dim: int = 64 |
| | num_heads: int = 2 |
| | num_kv_heads: Optional[int] = None |
| | num_layers: int = 2 |
| | rope_config: RotaryEmbeddingConfig = field( |
| | default_factory=lambda: RotaryEmbeddingConfig( |
| | max_seq_len=512, dim=8, theta=10000.0 |
| | ) |
| | ) |
| | add_bias_ffn: bool = False |
| | ffn_activation_name: str = "swish" |
| | use_glu_in_ffn: bool = True |
| | add_bias_lm_head: bool = False |
| | norm_type: str = "RMS_norm" |
| | rms_norm_eps: float = 1e-6 |
| | parallel_attention_ff: bool = True |
| |
|
| | |
| | use_gradient_checkpointing: bool = False |
| |
|
| | |
| | add_bias_attn: bool = False |
| |
|
| | def __post_init__(self) -> None: |
| | """ |
| | Checks that the given values are compatible. |
| | """ |
| | if not self.embed_dim % self.num_heads == 0: |
| | raise ValueError( |
| | f"The embedding dimension should be " |
| | f"divisible by the number of heads, however provided embedding " |
| | f"dimension is {self.embed_dim} and the number of heads is " |
| | f"{self.num_heads}." |
| | ) |
| |
|
| | if not self.embed_dim // self.num_heads > 1: |
| | raise ValueError( |
| | "embed_dim / num_heads must be higher than 2 to apply rotary embeddings" |
| | ) |
| |
|
| | if not self.embed_dim // self.num_heads >= self.rope_config.dim: |
| | raise ValueError( |
| | "embed_dim // num_heads must be higher than rope_config.dim " |
| | "to apply rotary embeddings" |
| | ) |
| |
|
| | def to_dict(self): |
| | output = asdict(self) |
| | output["rope_config"] = asdict(self.rope_config) |
| | return output |
| |
|
| |
|
| | @dataclass |
| | class NucleotideTransformerConfig: |
| | """ |
| | Parameters to initialize an NT model. |
| | Args: |
| | alphabet_size: Token vocabulary. |
| | pad_token_id: ID of pad token. |
| | mask_token_id: ID of mask token. |
| | max_positions: Maximum sequence length. |
| | embed_scale: Correction ratio applied to the embeddings to make up for the |
| | norm difference between the input during training and inference. |
| | emb_layer_norm_before: Whether to use layer norm before the first attention |
| | layer. |
| | attention_heads: Number of attention heads. |
| | key_size: The dimension of the query, key, and values within each attention |
| | head, if not specified, it is set to attention_heads//embed_dim. |
| | It can be useful to set a custom key size if we want to impose the size of |
| | the query, key and value tensor ( for example, tensors shaped with |
| | power of 2 are more efficiently handled on TPUs ). |
| | Note: Parametrizing the model with a custom key size has been done in : |
| | Brown, Tom, et al. "Language models are few-shot learners." |
| | Advances in neural information processing systems 33 (2020): 1877-1901. |
| | embed_dim: Embedding dimension. |
| | ffn_embed_dim: Feed forward embedding dimension. |
| | num_layers: Number of attention blocks. |
| | positional_embedding: Type of positional embedding to use before the first |
| | attention layer. Options: "learned", "learned_standard" "sinusoidal" or |
| | None. |
| | NOTE: "learned" is the positional embedding of ESM, and "learned_standard" |
| | is a more standard one, used for example in DNAbert. |
| | lm_head: type of language model head. Options: "simple", "roberta" or None. |
| | add_bias_kv: Add bias in attention layer. |
| | add_bias_ffn: Add bias in feed forward network block. |
| | use_rotary_embedding: Whether to use rotary embeddings. Requires: |
| | positional_embeddings = None. |
| | rescaling_factor: Scaling factor to use for rotary embeddings. |
| | ffn_activation_name: Activation function to be used in FFN block. Supported |
| | names are "gelu", "relu", "swish". |
| | use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed |
| | Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg |
| | to True and use swish as ffn_activation_name. |
| | Same principle for a gated-relu. To keep the same number of parameters in |
| | the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU. |
| | See https://arxiv.org/pdf/2002.05202.pdf for more details. |
| | mask_before_attention: Use mask before attention layers. |
| | layer_norm_eps: the eps factor in the different layer norms of the model (refer |
| | to layer norm implementation) |
| | token_dropout: Token dropout. |
| | masking_ratio: Masking ratio (used if token dropout is enabled). |
| | masking_prob: Masking probability (used if token dropout is enabled). |
| | use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint |
| | gradients in the forward pass to reduce the computation in the backward). |
| | """ |
| |
|
| | alphabet_size: int |
| | pad_token_id: int |
| | mask_token_id: int |
| |
|
| | max_positions: int = 1024 |
| | embed_scale: float = 1.0 |
| |
|
| | |
| | emb_layer_norm_before: bool = False |
| | attention_heads: int = 20 |
| | key_size: Optional[int] = None |
| | embed_dim: int = 1280 |
| | ffn_embed_dim: int = 5120 |
| | num_layers: int = 24 |
| | positional_embedding: Optional[str] = "learned" |
| | lm_head: Optional[str] = "simple" |
| | add_bias_kv: bool = False |
| | add_bias_ffn: bool = True |
| | use_rotary_embedding: bool = False |
| | rescaling_factor: Optional[float] = None |
| | ffn_activation_name: str = "gelu-no-approx" |
| | use_glu_in_ffn: bool = False |
| | mask_before_attention: bool = False |
| | layer_norm_eps: float = 1e-5 |
| | pre_layer_norm: bool = True |
| | bias_word_embedding: bool = False |
| |
|
| | |
| | token_dropout: bool = False |
| | masking_ratio: float = 0.1 |
| | masking_prob: float = 0.8 |
| |
|
| | |
| | use_gradient_checkpointing: bool = False |
| |
|
| | |
| | embeddings_layers_to_save: List[int] = field(default_factory=list) |
| | attention_maps_to_save: List[Tuple[int, int]] = field(default_factory=list) |
| |
|
| | def __post_init__(self) -> None: |
| | """ |
| | Checks that the given values are compatible. |
| | """ |
| |
|
| | if self.key_size is None: |
| | if not self.embed_dim % self.attention_heads == 0: |
| | raise ValueError( |
| | f"When no key size is provided, the embedding dimension should be " |
| | f"divisible by the number of heads, however provided embedding " |
| | f"dimension is {self.embed_dim} and the number of heads is " |
| | f"{self.attention_heads}." |
| | ) |
| | self.key_size = self.embed_dim // self.attention_heads |
| | if self.positional_embedding is not None: |
| | if type(self.positional_embedding) != str: |
| | raise TypeError |
| |
|
| | if self.positional_embedding not in [ |
| | "learned", |
| | "sinusoidal", |
| | "learned_standard", |
| | "alibi_dnabert_2", |
| | ]: |
| | raise ValueError( |
| | "The positional_embedding argument should either be None," |
| | "`learned`, `sinusoidal`, 'learned_standard' or 'alibi_dnabert_2'." |
| | ) |
| | if self.lm_head is not None: |
| | if type(self.lm_head) != str: |
| | raise TypeError |
| |
|
| | if self.lm_head not in ["simple", "roberta"]: |
| | raise ValueError( |
| | "The lm_head argument should either be None," |
| | "`simple` or `roberta`." |
| | ) |
| |
|
| | if self.use_rotary_embedding and self.positional_embedding is not None: |
| | raise ValueError( |
| | "When using rotary embedding, positional_embedding must be set to none" |
| | ) |
| |
|
| | if self.add_bias_kv and self.use_rotary_embedding: |
| | raise ValueError( |
| | "Biases on key and values are not compatible with Rotary embeddings." |
| | ) |
| |
|
| | if self.positional_embedding == "alibi_dnabert_2": |
| | assert not self.add_bias_kv |
| |
|
| |
|
| | @dataclass |
| | class ChatNTConfig(PretrainedConfig): |
| | model_type = "ChatNT" |
| |
|
| | def __init__(self, **kwargs): |
| | self.gpt_config: GptConfig = kwargs.get("gpt_config", GptConfig(32000, 3)) |
| | self.nt_config: NucleotideTransformerConfig = kwargs.get( |
| | "nt_config", NucleotideTransformerConfig(4000, 1, 4) |
| | ) |
| | self.perceiver_resampler_config: PerceiverResamplerConfig = kwargs.get( |
| | "perceiver_resampler_config", PerceiverResamplerConfig() |
| | ) |
| | self.seq_token_id: int = kwargs.get("seq_token_id", 32000) |
| | self.bio_pad_token_id: int = kwargs.get("bio_pad_token_id", 1) |
| | self.english_pad_token_id: int = kwargs.get("english_pad_token_id", 2) |
| | super().__init__(**kwargs) |
| |
|
| | def to_dict(self): |
| | output = super().to_dict() |
| |
|
| | def serialize(obj): |
| | return obj.to_dict() if hasattr(obj, "to_dict") else vars(obj) |
| |
|
| | output["gpt_config"] = serialize(self.gpt_config) |
| | output["nt_config"] = serialize(self.nt_config) |
| | output["perceiver_resampler_config"] = serialize( |
| | self.perceiver_resampler_config |
| | ) |
| | return output |
| |
|
| |
|
| | class TorchBioBrainDecoder(nn.Module): |
| | def __init__( |
| | self, |
| | gpt_config: GptConfig, |
| | seq_token_id: int, |
| | ): |
| | """ |
| | Initializes the BioBrain decoder, using a GPT model for text generation with |
| | bio embeddings. |
| | Args: |
| | gpt_config: Configuration for the GPT model |
| | seq_token_id: Index of the SEQ token |
| | """ |
| | super(TorchBioBrainDecoder, self).__init__() |
| | self.gpt_config = gpt_config |
| | self.seq_token_id = seq_token_id |
| |
|
| | |
| | self.gpt_model = TorchGptDecoder(self.gpt_config) |
| |
|
| | def forward( |
| | self, english_token_ids: torch.Tensor, projected_bio_embeddings: torch.Tensor |
| | ) -> torch.Tensor: |
| | """ |
| | Forward pass through the model. |
| | Args: |
| | english_token_ids: Tensor of English token IDs with shape |
| | (batch_size, num_english_tokens). |
| | projected_bio_embeddings: Optional tensor of bio embeddings with shape |
| | (batch_size, num_bio_sequences, ?, embed_dim). |
| | Returns: |
| | torch.Tensor: The logits from the GPT model, |
| | shaped (batch_size, num_english_tokens, vocab_size). |
| | """ |
| |
|
| | |
| | tokens_embeddings = self.gpt_model.token_embed(english_token_ids) |
| |
|
| | if projected_bio_embeddings is not None: |
| | ( |
| | batch_size, |
| | num_bio_sequences, |
| | _, |
| | bio_embed_dim, |
| | ) = projected_bio_embeddings.shape |
| |
|
| | |
| | processed_tokens_ids = english_token_ids.clone() |
| | for bio_seq_num in range(num_bio_sequences): |
| | tokens_embeddings, processed_tokens_ids = self.insert_embeddings( |
| | processed_tokens_ids, |
| | tokens_embeddings, |
| | projected_bio_embeddings[:, bio_seq_num, :, :], |
| | bio_seq_num=bio_seq_num, |
| | ) |
| |
|
| | |
| | embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings) |
| | embeddings = self.gpt_model.final_norm(embeddings) |
| |
|
| | |
| | logits = self.gpt_model.lm_head(embeddings) |
| |
|
| | if projected_bio_embeddings is not None: |
| | |
| | processed_tokens_ids = english_token_ids.clone() |
| | resampled_length = projected_bio_embeddings.shape[-2] |
| | for _ in range(num_bio_sequences): |
| | logits, processed_tokens_ids = self.cleanup_logits( |
| | tokens=processed_tokens_ids, |
| | logits=logits, |
| | resampled_length=resampled_length, |
| | ) |
| |
|
| | return logits |
| |
|
| | def insert_embeddings( |
| | self, |
| | tokens: torch.Tensor, |
| | input_embeddings: torch.Tensor, |
| | resampled_embeddings: torch.Tensor, |
| | bio_seq_num: int, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Inserts resampled embeddings in input_embeddings, starting at the SEQ token |
| | Args: |
| | tokens (torch.Tensor): Shape (batch_size, num_tokens) |
| | input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim) |
| | resampled_embeddings (torch.Tensor): |
| | Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim) |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor]: |
| | - input_embeddings with resampled_embeddings inserted at the SEQ token |
| | - tokens with the SEQ token set to -1 |
| | """ |
| |
|
| | def _insert( |
| | tokens_1d: torch.Tensor, |
| | input_embeddings_1d: torch.Tensor, |
| | resampled_embeddings_1d: torch.Tensor, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Args: |
| | tokens (torch.Tensor): Shape (num_tokens,) |
| | input_embeddings (torch.Tensor): Shape (num_tokens, embed_dim,) |
| | resampled_embeddings (torch.Tensor): |
| | Shape (bio_sequence_length, embed_dim,) |
| | """ |
| | indices = torch.where(tokens_1d == self.seq_token_id)[0] |
| | if indices.numel() > 0: |
| | idx = indices[0].item() |
| | insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num |
| | x = torch.cat( |
| | [ |
| | input_embeddings_1d[:insertion_pos, :], |
| | resampled_embeddings_1d, |
| | input_embeddings_1d[insertion_pos:, :], |
| | ], |
| | dim=0, |
| | )[: tokens_1d.shape[0] + 1, :] |
| | x = torch.roll(torch.roll(x, shifts=-idx, dims=0), shifts=idx, dims=0)[ |
| | :-1, : |
| | ] |
| | tokens_1d[idx] = -1 |
| | return x, tokens_1d |
| | else: |
| | return ( |
| | input_embeddings, |
| | tokens_1d, |
| | ) |
| |
|
| | tokens_acc = [] |
| | embeddings_acc = [] |
| |
|
| | for i in range(tokens.shape[0]): |
| | embeddings_out, tokens_out = _insert( |
| | tokens[i].clone(), |
| | input_embeddings[i].clone(), |
| | resampled_embeddings[i].clone(), |
| | ) |
| | tokens_acc.append(tokens_out) |
| | embeddings_acc.append(embeddings_out) |
| | tokens_acc = torch.stack(tokens_acc) |
| | embeddings_acc = torch.stack(embeddings_acc) |
| |
|
| | return embeddings_acc, tokens_acc |
| |
|
| | def cleanup_logits( |
| | self, tokens: torch.Tensor, logits: torch.Tensor, resampled_length: int |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Removes the logits corresponding to the unused embeddings. |
| | Args: |
| | tokens: Input english tokens. |
| | logits: Input logits. |
| | Returns: |
| | Cleaned logits, last values will be equal to 0. |
| | """ |
| |
|
| | def _clean( |
| | token: torch.Tensor, logit: torch.Tensor |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | indices = torch.where(token == self.seq_token_id)[0] |
| | if indices.numel() > 0: |
| | idx = indices[0].item() |
| |
|
| | mask_idx = ( |
| | torch.arange(logit.shape[0] - resampled_length, device=logit.device) |
| | > idx |
| | ) |
| | mask_idx = mask_idx.unsqueeze(1) |
| |
|
| | |
| | logit = ( |
| | logit[:-resampled_length] * (~mask_idx) |
| | + logit[resampled_length:] * mask_idx |
| | ) |
| |
|
| | |
| | logit = torch.cat( |
| | ( |
| | logit, |
| | torch.zeros( |
| | (resampled_length, logit.shape[1]), |
| | dtype=logit.dtype, |
| | device=logit.device, |
| | ), |
| | ) |
| | ) |
| |
|
| | |
| | token[idx] = -1 |
| |
|
| | return logit, token |
| |
|
| | else: |
| | return logit, token |
| |
|
| | tokens_acc = [] |
| | logits_acc = [] |
| |
|
| | for i in range(tokens.shape[0]): |
| | logits_out, tokens_out = _clean(tokens[i].clone(), logits[i].clone()) |
| | tokens_acc.append(tokens_out) |
| | logits_acc.append(logits_out) |
| | tokens_acc = torch.stack(tokens_acc) |
| | logits_acc = torch.stack(logits_acc) |
| |
|
| | return logits_acc, tokens_acc |
| |
|
| |
|
| | class TorchMultiOmicsModel(PreTrainedModel): |
| | config_class = ChatNTConfig |
| |
|
| | def __init__(self, config: ChatNTConfig) -> None: |
| | if isinstance(config, dict): |
| | |
| | |
| | config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig( |
| | **config["gpt_config"]["rope_config"] |
| | ) |
| | config["gpt_config"] = GptConfig(**config["gpt_config"]) |
| | config["nt_config"] = NucleotideTransformerConfig(**config["nt_config"]) |
| | config["perceiver_resampler_config"] = PerceiverResamplerConfig( |
| | **config["perceiver_resampler_config"] |
| | ) |
| | config = ChatNTConfig(**config) |
| |
|
| | else: |
| | if isinstance(config.gpt_config, dict): |
| | config.gpt_config["rope_config"] = RotaryEmbeddingConfig( |
| | **config.gpt_config["rope_config"] |
| | ) |
| | config.gpt_config = GptConfig(**config.gpt_config) |
| |
|
| | if isinstance(config.nt_config, dict): |
| | config.nt_config = NucleotideTransformerConfig(**config.nt_config) |
| |
|
| | if isinstance(config.perceiver_resampler_config, dict): |
| | config.perceiver_resampler_config = PerceiverResamplerConfig( |
| | **config.perceiver_resampler_config |
| | ) |
| |
|
| | super().__init__(config=config) |
| | self.gpt_config = config.gpt_config |
| | self.nt_config = config.nt_config |
| | self.perceiver_resampler_config = config.perceiver_resampler_config |
| | self.seq_token_id = config.seq_token_id |
| | self.bio_pad_token_id = config.bio_pad_token_id |
| | self.english_pad_token_id = config.english_pad_token_id |
| |
|
| | |
| | self.seq_token_id -= 1 |
| |
|
| | self.biobrain_encoder = TorchBioBrainEncoder(nt_config=self.nt_config) |
| | self.biobrain_decoder = TorchBioBrainDecoder( |
| | gpt_config=self.gpt_config, seq_token_id=self.seq_token_id |
| | ) |
| | self.projection_model = TorchMultiModalPerceiverResamplerProjection( |
| | perceiver_resampler_config=self.perceiver_resampler_config, |
| | input_embed_dim=self.nt_config.embed_dim, |
| | embed_dim=self.gpt_config.embed_dim, |
| | english_vocab_size=self.gpt_config.vocab_size, |
| | bio_pad_token_id=self.bio_pad_token_id, |
| | english_pad_token_id=self.english_pad_token_id, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor | None], |
| | projection_english_tokens_ids: torch.Tensor, |
| | projected_bio_embeddings: torch.Tensor = None, |
| | ) -> dict[str, torch.Tensor]: |
| | """ |
| | Args: |
| | multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]): |
| | english_tokens_ids: Represents the prompt tokens (english tokens) |
| | Shape (batch_size, num_english_tokens) |
| | bio_tokens_ids: Represents the bio sequences tokens |
| | Shape (batch_size, num_bio_sequences, num_bio_tokens) |
| | projection_english_tokens_ids (torch.Tensor): |
| | Shape (batch_size, num_english_tokens) |
| | projected_bio_embeddings (projected_bio_embeddings, optional): |
| | Shape (batch_size, num_bio_sequencse, ?, embed_dim). |
| | Defaults to None. |
| | Returns: |
| | dict[str, torch.Tensor] containing: |
| | - logits: |
| | Shape (batch_size, num_tokens, vocab_size) |
| | - projected_bio_embeddings: |
| | Shape (batch_size, num_bio_sequences, ?, embed_dim) |
| | """ |
| | english_token_ids, bio_token_ids = multi_omics_tokens_ids |
| | english_token_ids = english_token_ids.clone() |
| | projection_english_tokens_ids = projection_english_tokens_ids.clone() |
| | if bio_token_ids is not None: |
| | bio_token_ids = bio_token_ids.clone() |
| | if projected_bio_embeddings is not None: |
| | projected_bio_embeddings = projected_bio_embeddings.clone() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | vocab_size = self.gpt_config.vocab_size |
| | |
| | english_token_ids[english_token_ids == vocab_size - 1] = 0 |
| | projection_english_tokens_ids[ |
| | projection_english_tokens_ids == vocab_size - 1 |
| | ] = 0 |
| | english_token_ids[english_token_ids == vocab_size] = vocab_size - 1 |
| | projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = ( |
| | vocab_size - 1 |
| | ) |
| |
|
| | if bio_token_ids is None: |
| | projected_bio_embeddings = None |
| | else: |
| | num_bio_sequences = bio_token_ids.shape[1] |
| |
|
| | if projected_bio_embeddings is None: |
| | |
| | bio_embeddings_list = [ |
| | self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num]) |
| | for bio_seq_num in range(num_bio_sequences) |
| | ] |
| |
|
| | |
| | projected_bio_embeddings = [ |
| | self.projection_model( |
| | bio_token_ids=bio_token_ids[:, bio_seq_num], |
| | bio_embeddings=bio_embeddings, |
| | english_token_ids=projection_english_tokens_ids, |
| | ) |
| | for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list) |
| | ] |
| | projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1) |
| |
|
| | |
| | logits = self.biobrain_decoder( |
| | english_token_ids=english_token_ids, |
| | projected_bio_embeddings=projected_bio_embeddings, |
| | ) |
| | logits = logits.to(torch.float32) |
| |
|
| | outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings} |
| |
|
| | return outs |
| |
|
| |
|
| | class TorchRotaryEmbedding(torch.nn.Module): |
| | def __init__(self, config: RotaryEmbeddingConfig): |
| | super().__init__() |
| |
|
| | self.max_seq_len = config.max_seq_len |
| | self.dim = config.dim |
| | self.theta = config.theta |
| | self.sincos_cache = None |
| |
|
| | def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor: |
| | """ |
| | Create the sines and cosines for the RoPE. |
| | Returns: |
| | Sinusoidal positions of shape (self.max_seq_len, self.dim). |
| | """ |
| | |
| | inv_freq = 1.0 / ( |
| | self.theta |
| | ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim) |
| | ) |
| |
|
| | |
| | sinusoid_inp = torch.einsum( |
| | "i,j->ij", torch.arange(self.max_seq_len, device=device).float(), inv_freq |
| | ) |
| |
|
| | |
| | sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos() |
| |
|
| | |
| | sincos = torch.zeros( |
| | (self.max_seq_len, self.dim), dtype=torch.float32, device=device |
| | ) |
| |
|
| | |
| | sentinel = self.dim // 2 + self.dim % 2 |
| | sincos[:, :sentinel] = sin |
| | sincos[:, sentinel:] = cos |
| |
|
| | return sincos |
| |
|
| | def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Prepare a tensor to apply the RoPE mechanism. |
| | Args: |
| | x: Tensor of shape (batch_size, seq_len, num_heads, head_dim), |
| | typically this is the key or query tensor. |
| | Returns: |
| | The even indices in the last dimension have their sign flipped. |
| | Tensor of shape (batch_size, seq_len, num_heads, head_dim). |
| | """ |
| | |
| | rotate_half = torch.stack((-x[..., 1::2], x[..., ::2]), dim=-1) |
| |
|
| | |
| | rotate_half = rotate_half.view(rotate_half.shape[:-2] + (-1,)) |
| | return rotate_half |
| |
|
| | def _apply_rotary_pos_emb( |
| | self, x: torch.Tensor, sincos: torch.Tensor |
| | ) -> torch.Tensor: |
| | """ |
| | Applies rotary embeddings to x. |
| | Args: |
| | x: Tensor of shape (batch_size, seq_len, num_heads, head_dim), |
| | typically this is the key or query tensor. |
| | sincos: Tuple of sine and cosine tensors for position encoding. |
| | Returns: |
| | RoPE embeddings tensor. |
| | """ |
| | sin_pos, cos_pos = sincos |
| |
|
| | |
| | sin_pos = torch.repeat_interleave(sin_pos.unsqueeze(2), repeats=2, dim=-1) |
| | cos_pos = torch.repeat_interleave(cos_pos.unsqueeze(2), repeats=2, dim=-1) |
| |
|
| | |
| | return (x * cos_pos) + (self._rotate_every_two(x) * sin_pos) |
| |
|
| | def __call__( |
| | self, k: torch.Tensor, q: torch.Tensor, positions: Optional[torch.Tensor] = None |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Applies rotary embeddings to k and q. |
| | Args: |
| | k: key tensor of shape (batch_size, seq_len, num_heads, head_dim), |
| | q: value tensor of shape (batch_size, seq_len, num_heads, head_dim), |
| | positions: optional positions offset useful when caching, |
| | Returns: |
| | RoPE embeddings for the keys and values. |
| | """ |
| | if self.sincos_cache is None: |
| | device = k.device |
| | self.sincos_cache = self._create_sinusoidal_positions(device=device) |
| |
|
| | batch_size, seq_len, num_heads, head_dim = k.shape |
| |
|
| | |
| | position_ids = ( |
| | torch.arange(seq_len, device=k.device).unsqueeze(0).expand(batch_size, -1) |
| | ) |
| |
|
| | if positions is not None: |
| | position_ids += positions |
| |
|
| | |
| | sincos = self.sincos_cache[position_ids] |
| |
|
| | |
| | sincos = torch.chunk(sincos, 2, dim=-1) |
| |
|
| | |
| | k_rot = self._apply_rotary_pos_emb(k[..., : self.dim], sincos) |
| | k_pass = k[..., self.dim :] |
| |
|
| | q_rot = self._apply_rotary_pos_emb(q[..., : self.dim], sincos) |
| | q_pass = q[..., self.dim :] |
| |
|
| | |
| | keys = torch.cat([k_rot, k_pass], dim=-1) |
| | values = torch.cat([q_rot, q_pass], dim=-1) |
| |
|
| | return keys, values |
| |
|
| |
|
| | class TorchGptGroupedQueryAttention(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim: int, |
| | num_heads: int, |
| | rope_config: RotaryEmbeddingConfig, |
| | num_kv_heads: int = None, |
| | head_dim: int = None, |
| | add_bias_attn: bool = False, |
| | ) -> None: |
| | super().__init__() |
| | self.num_heads = num_heads |
| | self.num_kv_heads = num_kv_heads or num_heads |
| | self.embed_dim = embed_dim |
| | self.head_dim = head_dim or (embed_dim // num_heads) |
| | self.add_bias_attn = add_bias_attn |
| | self.rope = TorchRotaryEmbedding(rope_config) |
| |
|
| | self.query_linear = nn.Linear( |
| | embed_dim, self.num_heads * self.head_dim, bias=add_bias_attn |
| | ) |
| | self.key_linear = nn.Linear( |
| | embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn |
| | ) |
| | self.value_linear = nn.Linear( |
| | embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn |
| | ) |
| | self.out_linear = nn.Linear( |
| | self.num_heads * self.head_dim, embed_dim, bias=add_bias_attn |
| | ) |
| |
|
| | def forward( |
| | self, |
| | query_inputs: torch.Tensor, |
| | key_inputs: torch.Tensor, |
| | value_inputs: torch.Tensor, |
| | attention_mask: torch.Tensor = None, |
| | ) -> torch.Tensor: |
| | batch_size, seq_len, _ = query_inputs.shape |
| |
|
| | queries = self.query_linear(query_inputs).view( |
| | batch_size, seq_len, self.num_heads, self.head_dim |
| | ) |
| | keys = self.key_linear(key_inputs).view( |
| | batch_size, seq_len, self.num_kv_heads, self.head_dim |
| | ) |
| | values = self.value_linear(value_inputs).view( |
| | batch_size, seq_len, self.num_kv_heads, self.head_dim |
| | ) |
| |
|
| | keys, queries = self.rope(keys, queries) |
| |
|
| | n_rep = self.num_heads // self.num_kv_heads |
| | keys = keys.repeat_interleave(n_rep, dim=2) |
| | values = values.repeat_interleave(n_rep, dim=2) |
| |
|
| | attention_logits = torch.einsum("bthd,bThd->bhtT", queries, keys) / ( |
| | self.head_dim**0.5 |
| | ) |
| |
|
| | if attention_mask is not None: |
| | attention_logits = attention_logits.masked_fill( |
| | attention_mask == 0, float("-inf") |
| | ) |
| |
|
| | attention_weights = nn.functional.softmax(attention_logits, dim=-1) |
| | attention_weights = attention_weights.to(values.dtype) |
| |
|
| | values = torch.einsum("bhtT,bThd->bthd", attention_weights, values) |
| | values = values.contiguous().view(batch_size, seq_len, -1) |
| |
|
| | return self.out_linear(values) |
| |
|
| |
|
| | class TorchGptDecoder(nn.Module): |
| | def __init__(self, config: GptConfig, name: Optional[str] = None): |
| | super().__init__() |
| | self.config = config |
| |
|
| | self.token_embed = nn.Embedding(config.vocab_size, config.embed_dim) |
| |
|
| | if config.norm_type == "layer_norm": |
| | self.final_norm = nn.LayerNorm(config.embed_dim) |
| | elif config.norm_type == "RMS_norm": |
| | self.final_norm = TorchRMSNorm(config.embed_dim, eps=config.rms_norm_eps) |
| | else: |
| | raise ValueError(f"unrecognized norm_type in config {config.norm_type}") |
| |
|
| | self.layers = nn.ModuleList( |
| | [ |
| | TorchGptDecoderLayer( |
| | embed_dim=config.embed_dim, |
| | ffn_embed_dim=config.ffn_embed_dim, |
| | num_heads=config.num_heads, |
| | rope_config=config.rope_config, |
| | norm_type=config.norm_type, |
| | parallel_attention_ff=config.parallel_attention_ff, |
| | add_bias_ffn=config.add_bias_ffn, |
| | ffn_activation_name=config.ffn_activation_name, |
| | use_glu_in_ffn=config.use_glu_in_ffn, |
| | num_kv_heads=config.num_kv_heads, |
| | add_bias_attn=config.add_bias_attn, |
| | rms_norm_eps=config.rms_norm_eps, |
| | ) |
| | for _ in range(config.num_layers) |
| | ] |
| | ) |
| |
|
| | self.lm_head = TorchSimpleLMHead( |
| | embed_dim=config.embed_dim, |
| | alphabet_size=config.vocab_size, |
| | add_bias_lm_head=config.add_bias_lm_head, |
| | ) |
| |
|
| | def apply_transformer_layers( |
| | self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None |
| | ) -> torch.Tensor: |
| | if attention_mask is None: |
| | attention_mask = build_causal_attention_mask( |
| | 1, embeddings.shape[1], device=embeddings.device |
| | ) |
| | for layer in self.layers: |
| | embeddings = layer(embeddings, attention_mask) |
| |
|
| | return embeddings |
| |
|
| | def forward( |
| | self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None |
| | ) -> dict[str, torch.Tensor]: |
| | if attention_mask is None: |
| | attention_mask = build_causal_attention_mask( |
| | 1, token_ids.shape[1], device=token_ids.device |
| | ) |
| |
|
| | tokens_embeddings = self.token_embed(token_ids) |
| |
|
| | after_transformer_embeddings = self.apply_transformer_layers( |
| | tokens_embeddings, attention_mask=attention_mask |
| | ) |
| |
|
| | embeddings = self.final_norm(after_transformer_embeddings) |
| | logits = self.lm_head(embeddings) |
| | return {"embeddings": embeddings, "logits": logits} |
| |
|
| |
|
| | class TorchSimpleLMHead(nn.Module): |
| | def __init__( |
| | self, embed_dim: int, alphabet_size: int, add_bias_lm_head: bool = True |
| | ) -> None: |
| | super().__init__() |
| | self.fc = nn.Linear(embed_dim, alphabet_size, bias=add_bias_lm_head) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.fc(x) |
| |
|
| |
|
| | class TorchGptDecoderLayer(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim: int, |
| | ffn_embed_dim: int, |
| | num_heads: int, |
| | rope_config: RotaryEmbeddingConfig, |
| | norm_type: str, |
| | parallel_attention_ff: bool, |
| | add_bias_ffn: bool, |
| | ffn_activation_name: str, |
| | use_glu_in_ffn: bool, |
| | num_kv_heads: int, |
| | add_bias_attn: bool, |
| | rms_norm_eps: float = 1e-6, |
| | ) -> None: |
| | super().__init__() |
| | self.num_heads = num_heads |
| | self.parallel_attention_ff = parallel_attention_ff |
| | self.use_glu_in_ffn = use_glu_in_ffn |
| |
|
| | |
| | self.self_attn = TorchGptGroupedQueryAttention( |
| | embed_dim=embed_dim, |
| | num_heads=num_heads, |
| | num_kv_heads=num_kv_heads, |
| | rope_config=rope_config, |
| | add_bias_attn=add_bias_attn, |
| | ) |
| |
|
| | |
| | if norm_type == "layer_norm": |
| | self.attn_norm = nn.LayerNorm(embed_dim) |
| | if not self.parallel_attention_ff: |
| | self.ffn_norm = nn.LayerNorm(embed_dim) |
| | elif norm_type == "RMS_norm": |
| | self.attn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps) |
| | if not self.parallel_attention_ff: |
| | self.ffn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps) |
| | else: |
| | raise ValueError(f"unrecognized norm_type: {norm_type}") |
| |
|
| | |
| | self.activation = get_activation_fn(ffn_activation_name) |
| | ffn_hidden_dim = ffn_embed_dim * (2 if use_glu_in_ffn else 1) |
| | self.fc1 = nn.Linear(embed_dim, ffn_hidden_dim, bias=add_bias_ffn) |
| | self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_ffn) |
| |
|
| | def forward( |
| | self, embeddings: torch.Tensor, attention_mask: torch.Tensor |
| | ) -> torch.Tensor: |
| | residuals = embeddings |
| |
|
| | if self.parallel_attention_ff: |
| | |
| | embeddings_normed = self.attn_norm(embeddings) |
| |
|
| | attn_output, _ = self.self_attn( |
| | embeddings_normed, |
| | embeddings_normed, |
| | embeddings_normed, |
| | attn_mask=attention_mask, |
| | ) |
| | ffn_output = self.mlp(embeddings_normed) |
| |
|
| | return residuals + attn_output + ffn_output |
| | else: |
| | |
| | normed_embeddings = self.attn_norm(embeddings) |
| |
|
| | attn_output = embeddings + self.self_attn( |
| | normed_embeddings, |
| | normed_embeddings, |
| | normed_embeddings, |
| | attention_mask=attention_mask, |
| | ) |
| |
|
| | normed_embeddings2 = self.ffn_norm(attn_output) |
| | ffn_output = self.mlp(normed_embeddings2) |
| | return attn_output + ffn_output |
| |
|
| | def mlp(self, x: torch.Tensor) -> torch.Tensor: |
| | """Applies the feedforward network (MLP) with optional GLU.""" |
| | ffn_output = self.fc1(x) |
| |
|
| | if self.use_glu_in_ffn: |
| | ffn_output1, ffn_output2 = ffn_output.chunk(2, dim=-1) |
| | ffn_output = self.activation(ffn_output1) * ffn_output2 |
| | else: |
| | ffn_output = self.activation(ffn_output) |
| |
|
| | return self.fc2(ffn_output) |
| |
|
| |
|
| | class TorchRMSNorm(nn.Module): |
| | def __init__(self, dim: int, eps: float = 1e-6) -> None: |
| | super().__init__() |
| | self.eps = eps |
| | self.scale = nn.Parameter(torch.ones(dim)) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return ( |
| | x |
| | * self.scale |
| | / torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps) |
| | ) |
| |
|
| |
|
| | def get_activation_fn(activation_name: str): |
| | activations = { |
| | "gelu": nn.functional.gelu, |
| | "relu": nn.functional.relu, |
| | "swish": nn.functional.silu, |
| | "silu": nn.functional.silu, |
| | } |
| | return activations.get(activation_name, nn.functional.relu) |
| |
|
| |
|
| | def build_causal_attention_mask( |
| | batch_size: int, seq_len: int, device: torch.device |
| | ) -> torch.Tensor: |
| | """ |
| | Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed |
| | to an attention layer. |
| | Args: |
| | batch_size: Batch size. |
| | seq_len: Length of the sequences. |
| | Returns: |
| | Batch of causal masks. |
| | """ |
| | mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device) |
| | causal_mask = torch.tril(mask) |
| | return causal_mask |
| |
|
| |
|
| | @dataclass |
| | class RotaryEmbeddingConfigBis: |
| | """ |
| | Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows |
| | to adapt the rotary embeddings to larger lengths than what was used for training. |
| | One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa |
| | Args: |
| | """ |
| |
|
| | rescaling_factor: Optional[float] |
| |
|
| |
|
| | class RotaryEmbeddingBis(torch.nn.Module): |
| | """ |
| | Rotary position embeddings based on those in |
| | [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). |
| | Query and keys are transformed by rotation |
| | matrices which depend on their relative positions. |
| | """ |
| |
|
| | def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfigBis): |
| | super().__init__() |
| |
|
| | |
| | self.rescaling_factor = rotary_embedding_config.rescaling_factor |
| | self.upper_freq = 10000 |
| | self.dim = dim |
| |
|
| | self._seq_len_cached = None |
| | self._cos_cached = None |
| | self._sin_cached = None |
| |
|
| | def _apply_rotary_pos_emb( |
| | self, |
| | heads: torch.Tensor, |
| | cos: torch.Tensor, |
| | sin: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ """ |
| | x_first, x_second = ( |
| | heads[..., : heads.shape[-1] // 2], |
| | heads[..., heads.shape[-1] // 2 :], |
| | ) |
| |
|
| | first_part = x_first * cos - x_second * sin |
| | second_part = x_second * cos + x_first * sin |
| |
|
| | return torch.cat((first_part, second_part), dim=-1) |
| |
|
| | def _compute_cos_sin_tables( |
| | self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2 |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | seq_len = x.shape[seq_dimension] |
| | |
| | |
| | self._seq_len_cached = seq_len |
| | t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq) |
| | |
| | freqs = torch.einsum("i, j -> ij", t, inv_freq) |
| |
|
| | self._cos_cached = torch.cos(freqs)[None, :, None, :] |
| | self._sin_cached = torch.sin(freqs)[None, :, None, :] |
| | |
| |
|
| | |
| | |
| |
|
| | return self._cos_cached, self._sin_cached |
| |
|
| | def forward( |
| | self, q: torch.Tensor, k: torch.Tensor |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if self.rescaling_factor is None: |
| | inv_freq = 1.0 / ( |
| | self.upper_freq |
| | ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim) |
| | ) |
| | else: |
| | updated_base = self.upper_freq * ( |
| | self.rescaling_factor ** (self.dim / (self.dim - 2)) |
| | ) |
| | inv_freq = 1.0 / ( |
| | updated_base |
| | ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim) |
| | ) |
| |
|
| | self._cos_cached, self._sin_cached = self._compute_cos_sin_tables( |
| | q, |
| | inv_freq, |
| | seq_dimension=-3, |
| | ) |
| |
|
| | return ( |
| | self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), |
| | self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), |
| | ) |
| |
|
| |
|
| | class MultiHeadAttention(nn.Module): |
| | def __init__( |
| | self, |
| | num_heads: int, |
| | key_size: int, |
| | rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None, |
| | add_bias_kv: bool = False, |
| | value_size: Optional[int] = None, |
| | model_size: Optional[int] = None, |
| | name: Optional[str] = None, |
| | ): |
| | super().__init__() |
| | if not model_size: |
| | model_size = key_size * num_heads |
| | if not value_size: |
| | value_size = key_size |
| | self.model_size = model_size |
| | self.key_size = key_size |
| | self.value_size = value_size |
| | self.add_bias_kv = add_bias_kv |
| | self.name = name |
| | self.num_heads = num_heads |
| | self._rotary_embedding_config = rotary_embedding_config |
| |
|
| | self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size) |
| | self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size) |
| | self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size) |
| | self.output = nn.Linear(self.num_heads * self.value_size, self.model_size) |
| | if self._rotary_embedding_config: |
| | self._rotary_embedding = RotaryEmbeddingBis( |
| | self.key_size, self._rotary_embedding_config |
| | ) |
| |
|
| | def apply_rotary_embeddings( |
| | self, |
| | query: torch.Tensor, |
| | key: torch.Tensor, |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """ """ |
| | query, key = self._rotary_embedding(query, key) |
| | return query, key |
| |
|
| | def forward( |
| | self, |
| | query: torch.Tensor, |
| | key: torch.Tensor, |
| | value: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | attention_weight_bias: Optional[torch.Tensor] = None, |
| | ) -> dict[str, torch.Tensor]: |
| | """ |
| | Returns: |
| | dictionary containing attention weights |
| | and outputs. |
| | """ |
| | key_heads = self.w_k(key).reshape( |
| | (*key.shape[:-1], self.num_heads, self.key_size) |
| | ) |
| | query_heads = self.w_q(query).reshape( |
| | (*query.shape[:-1], self.num_heads, self.key_size) |
| | ) |
| | value_heads = self.w_v(value).reshape( |
| | (*value.shape[:-1], self.num_heads, self.value_size) |
| | ) |
| | if self._rotary_embedding_config: |
| | query_heads, key_heads = self.apply_rotary_embeddings( |
| | query_heads, key_heads |
| | ) |
| | attention_weights = torch.einsum( |
| | "...thd, ...Thd -> ...htT", query_heads, key_heads |
| | ) |
| | sqrt_key_size = np.sqrt(self.key_size) |
| | attention_weights = attention_weights / sqrt_key_size |
| | if attention_mask is not None: |
| | attention_weights = torch.where(attention_mask, attention_weights, -1e30) |
| |
|
| | attention_weights = attention_weights.to(value_heads.dtype) |
| |
|
| | if attention_weight_bias is not None: |
| | attention_weights = F.softmax( |
| | attention_weights + attention_weight_bias, dim=-1 |
| | ) |
| | else: |
| | attention_weights = F.softmax(attention_weights, dim=-1) |
| |
|
| | value_out = torch.einsum( |
| | "...htT, ...Thd->...thd", attention_weights, value_heads |
| | ) |
| | value_out = value_out.reshape((*value_out.shape[:-2], -1)) |
| | embeddings = self.output(value_out) |
| |
|
| | return {"attention_weights": attention_weights, "embeddings": embeddings} |
| |
|
| |
|
| | class SelfAttentionBlock(nn.Module): |
| | def __init__( |
| | self, |
| | num_heads: int, |
| | embed_dim: int, |
| | ffn_embed_dim: int, |
| | key_size: Optional[int] = None, |
| | add_bias_kv: bool = False, |
| | add_bias_fnn: bool = True, |
| | ffn_activation_name: str = "gelu-no-approx", |
| | use_glu_in_ffn: bool = False, |
| | layer_norm_eps: float = 1e-5, |
| | pre_layer_norm: bool = True, |
| | name: Optional[str] = None, |
| | rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None, |
| | ): |
| | super().__init__() |
| | if key_size is None: |
| | if embed_dim % num_heads != 0: |
| | raise ValueError( |
| | f"The embedding dimension should be divisible by the number of " |
| | f"heads, however provided embedding dimension is {embed_dim} and " |
| | f"the number of heads is {num_heads}." |
| | ) |
| | else: |
| | key_size = embed_dim // num_heads |
| |
|
| | |
| | self._pre_layer_norm = pre_layer_norm |
| | self._use_glu_in_fnn = use_glu_in_ffn |
| | |
| | if use_glu_in_ffn: |
| | |
| | |
| | |
| | |
| | self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn) |
| | else: |
| | self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn) |
| |
|
| | self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn) |
| |
|
| | self.layer_norm_self_attention = nn.LayerNorm( |
| | embed_dim, |
| | ) |
| | self.layer_norm_mlp = nn.LayerNorm(embed_dim) |
| | if ffn_activation_name == "swish": |
| | self._ffn_activation_fn = nn.SiLU() |
| | elif ffn_activation_name == "gelu-no-approx": |
| | self._ffn_activation_fn = nn.GELU(approximate="tanh") |
| | else: |
| | self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name) |
| |
|
| | self.mha = MultiHeadAttention( |
| | num_heads=num_heads, |
| | key_size=key_size, |
| | add_bias_kv=add_bias_kv, |
| | model_size=embed_dim, |
| | name="self_attention", |
| | rotary_embedding_config=rotary_embedding_config, |
| | ) |
| |
|
| | def mlp(self, embed: torch.Tensor) -> torch.Tensor: |
| |
|
| | if self._pre_layer_norm: |
| | x = self.layer_norm_mlp(embed) |
| | else: |
| | x = embed |
| |
|
| | if self._use_glu_in_fnn: |
| | x = self.fc1(x) |
| | x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1) |
| | x = self._ffn_activation_fn(x1) * x2 |
| | else: |
| | x = self._ffn_activation_fn(self.fc1(x)) |
| | x = self.fc2(x) |
| |
|
| | if not self._pre_layer_norm: |
| | x = self.layer_norm_mlp(x + embed) |
| | return x |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | attention_weight_bias: Optional[torch.Tensor] = None, |
| | ) -> dict[str, torch.Tensor]: |
| |
|
| | res = x |
| | if self._pre_layer_norm: |
| | x = self.layer_norm_self_attention(x) |
| |
|
| | output: dict[str, torch.Tensor] = self.mha( |
| | x, |
| | x, |
| | x, |
| | attention_mask=attention_mask, |
| | attention_weight_bias=attention_weight_bias, |
| | ) |
| |
|
| | if not self._pre_layer_norm: |
| | output["embeddings"] = self.layer_norm_self_attention( |
| | output["embeddings"] + res |
| | ) |
| |
|
| | x = output["embeddings"] |
| | else: |
| | x = output["embeddings"] |
| | x = res + x |
| |
|
| | |
| | if not self._pre_layer_norm: |
| | x = self.mlp(x) |
| | else: |
| | x = x + self.mlp(x) |
| |
|
| | output["embeddings"] = x |
| | return output |
| |
|
| |
|
| | class RobertaLMHead(nn.Module): |
| | """ |
| | Roberta Language Model head. Transforms final attention layer output into a |
| | distribution over tokens at each position. |
| | """ |
| |
|
| | def __init__(self, embed_dim: int, alphabet_size: int): |
| | """ |
| | Args: |
| | embed_dim: Embedding dimension. |
| | alphabet_size: Number of tokens in the alphabet. |
| | """ |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.alphabet_size = alphabet_size |
| |
|
| | |
| | self._first_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True) |
| | self._fc1 = nn.Linear(embed_dim, embed_dim) |
| | self._second_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True) |
| | self._final_fc = nn.Linear(embed_dim, alphabet_size) |
| |
|
| | def forward(self, x: torch.Tensor) -> dict: |
| | x = self._first_layer_norm(x) |
| | embeddings = x |
| | x = self._fc1(x) |
| | x = nn.functional.gelu(x) |
| | x = self._second_layer_norm(x) |
| | logits = self._final_fc(x) |
| | return {"embeddings": embeddings, "logits": logits} |
| |
|
| |
|
| | class TorchNucleotideTransformer(nn.Module): |
| | def __init__( |
| | self, |
| | nt_config: NucleotideTransformerConfig, |
| | ): |
| | super(TorchNucleotideTransformer, self).__init__() |
| | self.nt_config = nt_config |
| |
|
| | |
| | assert nt_config.positional_embedding is None |
| | assert nt_config.lm_head == "roberta" |
| | assert nt_config.use_rotary_embedding is True |
| | assert nt_config.token_dropout is False |
| | assert nt_config.emb_layer_norm_before is False |
| | assert nt_config.mask_before_attention is False |
| | assert nt_config.bias_word_embedding is False |
| | assert nt_config.use_gradient_checkpointing is False |
| |
|
| | self.embed_layer = nn.Embedding(nt_config.alphabet_size, nt_config.embed_dim) |
| |
|
| | self.lm_head = RobertaLMHead( |
| | embed_dim=nt_config.embed_dim, |
| | alphabet_size=nt_config.alphabet_size, |
| | ) |
| |
|
| | self.rotary_embedding_config = RotaryEmbeddingConfigBis( |
| | rescaling_factor=nt_config.rescaling_factor |
| | ) |
| |
|
| | self.attention_blocks = nn.ModuleList( |
| | [ |
| | SelfAttentionBlock( |
| | num_heads=nt_config.attention_heads, |
| | embed_dim=nt_config.embed_dim, |
| | key_size=nt_config.key_size, |
| | ffn_embed_dim=nt_config.ffn_embed_dim, |
| | add_bias_kv=nt_config.add_bias_kv, |
| | add_bias_fnn=nt_config.add_bias_ffn, |
| | ffn_activation_name=nt_config.ffn_activation_name, |
| | use_glu_in_ffn=nt_config.use_glu_in_ffn, |
| | rotary_embedding_config=self.rotary_embedding_config, |
| | layer_norm_eps=nt_config.layer_norm_eps, |
| | pre_layer_norm=nt_config.pre_layer_norm, |
| | ) |
| | for _ in range(nt_config.num_layers) |
| | ] |
| | ) |
| |
|
| | def forward( |
| | self, tokens: torch.Tensor, attention_mask: torch.Tensor = None |
| | ) -> torch.Tensor: |
| | """ |
| | Computes the embeddings based on the input tokens. |
| | Args: |
| | tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len). |
| | attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len). |
| | If no mask is provided, a mask by default which equals 1 over all non |
| | pad tokens and 0 over pad tokens is computed. |
| | Returns: |
| | Dictionary containing the final embeddings and logits. |
| | """ |
| | x = self.embed_layer(tokens) |
| |
|
| | |
| | x = self.nt_config.embed_scale * x |
| |
|
| | if attention_mask is None: |
| | attention_mask = build_padding_attention_mask( |
| | tokens=tokens, pad_token_id=self.nt_config.pad_token_id |
| | ) |
| |
|
| | for layer in self.attention_blocks: |
| | x = layer(x, attention_mask)["embeddings"] |
| |
|
| | assert self.nt_config.lm_head == "roberta" |
| | x = self.lm_head(x)["embeddings"] |
| |
|
| | return x |
| |
|
| |
|
| | def build_padding_attention_mask( |
| | tokens: torch.Tensor, pad_token_id: int |
| | ) -> torch.Tensor: |
| | """ |
| | Builds a padding mask from a sequence of tokens by masking <pad> in the attention. |
| | Args: |
| | tokens: Batch of sequences of shape (batch_size, seq_len). |
| | pad_token_id: Int corresponding to the <pad> token to mask. |
| | Returns: |
| | Batch of attention masks, masking out <pad> tokens. |
| | """ |
| | padding_mask = tokens != pad_token_id |
| | padding_mask = padding_mask.unsqueeze(1) |
| | padding_mask = torch.einsum("bhT, bht -> bhtT", padding_mask, padding_mask) |
| | return padding_mask |
| |
|
| |
|
| | class TorchBioBrainEncoder(nn.Module): |
| | def __init__( |
| | self, |
| | nt_config: NucleotideTransformerConfig, |
| | ): |
| | super(TorchBioBrainEncoder, self).__init__() |
| | self.nt_config = nt_config |
| | self.nt_model = TorchNucleotideTransformer(self.nt_config) |
| |
|
| | def forward( |
| | self, |
| | bio_token_ids: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | bio_token_ids (torch.Tensor): |
| | Shape (batch_size, num_bio_tokens) |
| | Returns: |
| | torch.Tensor: |
| | Shape (batch_size, num_bio_tokens, embed_dim) |
| | """ |
| | bio_embeddings = self.nt_model(tokens=bio_token_ids) |
| |
|
| | return bio_embeddings |
| |
|
| |
|
| | class TorchMultiModalPerceiverResamplerBlock(nn.Module): |
| | def __init__( |
| | self, |
| | num_heads: int, |
| | embed_dim: int, |
| | ffn_embed_dim: int, |
| | key_size: Optional[int] = None, |
| | add_bias_kv: bool = False, |
| | add_bias_ffn: bool = True, |
| | ffn_activation_name: str = "gelu", |
| | use_glu_in_ffn: bool = False, |
| | ): |
| | super().__init__() |
| |
|
| | if key_size is None: |
| | if embed_dim % num_heads != 0: |
| | raise ValueError( |
| | f"Embedding dimension {embed_dim} should be divisible by " |
| | f"num_heads {num_heads}." |
| | ) |
| | key_size = embed_dim // num_heads |
| |
|
| | self.num_heads = num_heads |
| | self.embed_dim = embed_dim |
| | self.ffn_embed_dim = ffn_embed_dim * 2 if use_glu_in_ffn else ffn_embed_dim |
| | self.use_glu_in_ffn = use_glu_in_ffn |
| |
|
| | self.cross_attention_1 = MultiHeadAttention( |
| | num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv |
| | ) |
| | self.cross_attention_2 = MultiHeadAttention( |
| | num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv |
| | ) |
| |
|
| | self.norm_cross_attention_1 = nn.LayerNorm(embed_dim) |
| | self.norm_cross_attention_2 = nn.LayerNorm(embed_dim) |
| | self.norm_mlp = nn.LayerNorm(embed_dim) |
| |
|
| | self.fc1 = nn.Linear(embed_dim, self.ffn_embed_dim, bias=add_bias_ffn) |
| | self.fc2 = nn.Linear(self.ffn_embed_dim, embed_dim, bias=add_bias_ffn) |
| |
|
| | self.activation_fn = getattr( |
| | nn.functional, ffn_activation_name, nn.functional.gelu |
| | ) |
| |
|
| | def mlp(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.norm_mlp(x) |
| | if self.use_glu_in_ffn: |
| | x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1) |
| | x = self.activation_fn(x1) * x2 |
| | else: |
| | x = self.activation_fn(self.fc1(x)) |
| | return self.fc2(x) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | cross_attention_embeddings_1: torch.Tensor, |
| | cross_attention_embeddings_2: torch.Tensor, |
| | attention_mask_1: Optional[torch.Tensor] = None, |
| | attention_mask_2: Optional[torch.Tensor] = None, |
| | ) -> Dict[str, torch.Tensor]: |
| | res = x |
| | x = self.norm_cross_attention_1(x) |
| |
|
| | attn_output = self.cross_attention_1( |
| | query=x, |
| | key=cross_attention_embeddings_1, |
| | value=cross_attention_embeddings_1, |
| | attention_mask=attention_mask_1, |
| | )["embeddings"] |
| | x = res + attn_output |
| |
|
| | res = x |
| | x = self.norm_cross_attention_2(x) |
| | attn_output = self.cross_attention_2( |
| | query=x, |
| | key=cross_attention_embeddings_2, |
| | value=cross_attention_embeddings_2, |
| | attention_mask=attention_mask_2, |
| | )["embeddings"] |
| | x = res + attn_output |
| |
|
| | x = x + self.mlp(x) |
| |
|
| | return {"embeddings": x} |
| |
|
| |
|
| | class TorchMultiModalPerceiverResampler(nn.Module): |
| | """ |
| | Perceiver Resampler model, made of successive PerceiverResamplerBlocks. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | config: PerceiverResamplerConfig, |
| | name: Optional[str] = None, |
| | ): |
| | """ |
| | Initialize a Perceiver Resampler model. |
| | Args: |
| | config: Dataclass containing model hyperparameters. |
| | name: Name for module (custom will break weight loading). |
| | """ |
| | super().__init__() |
| | self.config = config |
| | self.name = name |
| | self.layers = nn.ModuleList( |
| | [ |
| | TorchMultiModalPerceiverResamplerBlock( |
| | num_heads=self.config.attention_heads, |
| | embed_dim=self.config.embed_dim, |
| | key_size=self.config.key_size, |
| | ffn_embed_dim=self.config.ffn_embed_dim, |
| | add_bias_kv=self.config.add_bias_kv, |
| | add_bias_ffn=self.config.add_bias_ffn, |
| | ffn_activation_name=self.config.ffn_activation_name, |
| | use_glu_in_ffn=self.config.use_glu_in_ffn, |
| | ) |
| | for _ in range(self.config.num_layers) |
| | ] |
| | ) |
| |
|
| | self.latent_queries = torch.nn.Parameter( |
| | torch.randn(self.config.resampled_length, self.config.embed_dim) |
| | * ( |
| | 1.0 |
| | / torch.sqrt(torch.tensor(self.config.embed_dim, dtype=torch.float32)) |
| | ) |
| | ) |
| |
|
| | def apply_attention_blocks( |
| | self, |
| | x: torch.Tensor, |
| | xf_1: torch.Tensor, |
| | xf_2: torch.Tensor, |
| | outs: Dict[str, torch.Tensor], |
| | attention_mask_1: Optional[torch.Tensor] = None, |
| | attention_mask_2: Optional[torch.Tensor] = None, |
| | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| | """ |
| | Create the blocks of attention layers and applies them. |
| | """ |
| | for layer in self.layers: |
| | concat_input_1 = torch.cat([xf_1, x], dim=1) |
| | concat_input_2 = torch.cat([xf_2, x], dim=1) |
| |
|
| | output = layer( |
| | x=x, |
| | cross_attention_embeddings_1=concat_input_1, |
| | cross_attention_embeddings_2=concat_input_2, |
| | attention_mask_1=attention_mask_1, |
| | attention_mask_2=attention_mask_2, |
| | ) |
| | x = output["embeddings"] |
| |
|
| | return x, outs |
| |
|
| | def forward( |
| | self, |
| | input_embeddings_1: torch.Tensor, |
| | input_embeddings_2: torch.Tensor, |
| | attention_mask_1: Optional[torch.Tensor] = None, |
| | attention_mask_2: Optional[torch.Tensor] = None, |
| | ) -> Dict[str, torch.Tensor]: |
| | """ |
| | Computes the embeddings based on the input tokens. |
| | """ |
| | assert ( |
| | input_embeddings_1.shape[-1] == self.config.embed_dim |
| | ), "The input embedding dim should match the model embed dim" |
| | assert ( |
| | input_embeddings_2.shape[-1] == self.config.embed_dim |
| | ), "The input embedding dim should match the model embed dim" |
| |
|
| | batch_size = input_embeddings_1.shape[0] |
| |
|
| | latent_queries = self.latent_queries.unsqueeze(0).repeat(batch_size, 1, 1) |
| |
|
| | outs: Dict[str, torch.Tensor] = {} |
| | x = latent_queries |
| |
|
| | x, outs = self.apply_attention_blocks( |
| | x=x, |
| | xf_1=input_embeddings_1, |
| | xf_2=input_embeddings_2, |
| | outs=outs, |
| | attention_mask_1=attention_mask_1, |
| | attention_mask_2=attention_mask_2, |
| | ) |
| |
|
| | outs["embeddings"] = x |
| |
|
| | return outs |
| |
|
| |
|
| | class TorchMultiModalPerceiverResamplerProjection(nn.Module): |
| | def __init__( |
| | self, |
| | perceiver_resampler_config: PerceiverResamplerConfig, |
| | input_embed_dim: int, |
| | embed_dim: int, |
| | bio_pad_token_id: int, |
| | english_pad_token_id: int, |
| | english_vocab_size: int, |
| | ): |
| | super().__init__() |
| | self.config = perceiver_resampler_config |
| | self.input_embed_dim = input_embed_dim |
| | self.embed_dim = embed_dim |
| | self.bio_pad_token_id = bio_pad_token_id |
| | self.english_pad_token_id = english_pad_token_id |
| | self.english_vocab_size = english_vocab_size |
| |
|
| | self.bio_projection = nn.Linear(input_embed_dim, embed_dim) |
| | self.token_embedding = nn.Embedding(english_vocab_size, embed_dim) |
| | self.perceiver_resampler = TorchMultiModalPerceiverResampler(config=self.config) |
| |
|
| | def forward( |
| | self, |
| | bio_token_ids: torch.Tensor, |
| | bio_embeddings: torch.Tensor, |
| | english_token_ids: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | bio_token_ids (torch.Tensor): |
| | Shape (batch_size, num_bio_tokens) |
| | bio_embeddings (torch.Tensor): |
| | Shape (batch_size, num_bio_tokens, embed_dim) |
| | english_token_ids (torch.Tensor): |
| | Shape (batch_size, num_english_tokens) |
| | """ |
| | projected_bio_embeddings = self.bio_projection(bio_embeddings) |
| | english_embeddings = self.token_embedding(english_token_ids) |
| |
|
| | bio_attention_mask = build_perceiver_padding_attention_mask( |
| | bio_token_ids, self.config.resampled_length, self.bio_pad_token_id |
| | ) |
| | english_attention_mask = build_perceiver_padding_attention_mask( |
| | english_token_ids, self.config.resampled_length, self.english_pad_token_id |
| | ) |
| |
|
| | projected_embeddings = self.perceiver_resampler( |
| | input_embeddings_1=projected_bio_embeddings, |
| | attention_mask_1=bio_attention_mask, |
| | input_embeddings_2=english_embeddings, |
| | attention_mask_2=english_attention_mask, |
| | )["embeddings"] |
| |
|
| | return projected_embeddings |
| |
|
| |
|
| | def build_perceiver_padding_attention_mask( |
| | tokens: torch.Tensor, resampled_length: int, pad_token_id: int |
| | ) -> torch.Tensor: |
| | batch_size, seq_len = tokens.shape |
| | padding_mask = tokens != pad_token_id |
| |
|
| | padding_mask = torch.cat( |
| | [ |
| | padding_mask, |
| | torch.ones( |
| | (batch_size, resampled_length), dtype=torch.bool, device=tokens.device |
| | ), |
| | ], |
| | dim=1, |
| | ) |
| |
|
| | padding_mask = padding_mask[:, None, None, :] |
| | padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) |
| | return padding_mask |
| | |