lhallee commited on
Commit
0747bfd
·
verified ·
1 Parent(s): 0d3c647

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +119 -79
modeling_esm_plusplus.py CHANGED
@@ -391,23 +391,38 @@ except ImportError:
391
  flex_attention = None
392
 
393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
- def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
396
- assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
397
- token_valid = attention_mask_2d.bool()
398
- batch_size, seq_len = token_valid.shape
399
-
400
- def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
401
- return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
402
-
403
- return create_block_mask(
404
- mask_mod,
405
- batch_size,
406
- 1,
407
- seq_len,
408
- seq_len,
409
- device=attention_mask_2d.device,
410
- )
411
 
412
 
413
  class ESMplusplusConfig(PretrainedConfig):
@@ -702,14 +717,15 @@ class MultiHeadAttention(nn.Module):
702
  def forward(
703
  self,
704
  x: torch.Tensor,
705
- attention_mask: Optional[torch.Tensor] = None,
706
- flex_block_mask: Optional[object] = None,
707
  output_attentions: bool = False,
708
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
709
  """
710
  Args:
711
  x: Input tensor
712
- attention_mask: Optional attention mask
 
713
  output_attentions: Whether to return attention weights
714
 
715
  Returns:
@@ -727,24 +743,15 @@ class MultiHeadAttention(nn.Module):
727
  scale = 1 / math.sqrt(self.d_head)
728
 
729
  if output_attentions: # Manual attention computation
730
- b, h, l, _ = query_BHLD.shape
731
- attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
732
- if attention_mask is not None:
733
- attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
734
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
735
- attn_weights += attn_bias
736
  attn_weights = F.softmax(attn_weights, dim=-1)
737
  context_BHLD = torch.matmul(attn_weights, value_BHLD)
738
  else:
739
  if self.attn_backend == "flex":
740
  assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
741
- assert query_BHLD.dtype in (torch.float16, torch.bfloat16), (
742
- f"Flex attention backend requires float16 or bfloat16, got {query_BHLD.dtype}."
743
- )
744
- if attention_mask is not None:
745
- assert flex_block_mask is not None, (
746
- "Flex attention backend requires a block mask when attention_mask is provided."
747
- )
748
  context_BHLD = flex_attention(
749
  query_BHLD,
750
  key_BHLD,
@@ -753,15 +760,11 @@ class MultiHeadAttention(nn.Module):
753
  scale=scale,
754
  )
755
  else:
756
- sdpa_mask = None
757
- if attention_mask is not None:
758
- sdpa_mask = torch.zeros_like(attention_mask, dtype=query_BHLD.dtype)
759
- sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
760
  context_BHLD = F.scaled_dot_product_attention(
761
  query_BHLD,
762
  key_BHLD,
763
  value_BHLD,
764
- attn_mask=sdpa_mask,
765
  scale=scale,
766
  )
767
 
@@ -820,14 +823,15 @@ class UnifiedTransformerBlock(nn.Module):
820
  def forward(
821
  self,
822
  x: torch.Tensor,
823
- attention_mask: Optional[torch.Tensor] = None,
824
- flex_block_mask: Optional[object] = None,
825
  output_attentions: bool = False,
826
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
827
  """
828
  Args:
829
  x: Input tensor
830
- attention_mask: Optional attention mask
 
831
  output_attentions: Whether to return attention weights
832
 
833
  Returns:
@@ -902,13 +906,13 @@ class TransformerStack(nn.Module):
902
  self,
903
  x: torch.Tensor,
904
  attention_mask: Optional[torch.Tensor] = None,
905
- output_hidden_states: bool = False,
906
- output_attentions: bool = False,
907
  ) -> TransformerOutput:
908
  """
909
  Args:
910
  x: Input tensor
911
- attention_mask: Optional attention mask
912
  output_hidden_states: Whether to return all hidden states
913
  output_attentions: Whether to return attention weights
914
 
@@ -918,33 +922,31 @@ class TransformerStack(nn.Module):
918
  hidden_states = () if output_hidden_states else None
919
  attentions = () if output_attentions else None
920
 
921
- if attention_mask is not None:
922
- assert attention_mask.ndim == 2, f"Expected 2D token attention mask, got shape {attention_mask.shape}."
923
- token_attention_mask = attention_mask.bool()
924
- if self.attn_backend == "flex" and not output_attentions:
925
- assert create_block_mask is not None, (
926
- "Flex attention backend requested but torch.create_block_mask is unavailable."
927
- )
928
- flex_block_mask = _create_pad_block_mask(token_attention_mask)
929
- attention_mask = None
930
- else:
931
- pairwise_attention_mask = token_attention_mask.unsqueeze(-1) & token_attention_mask.unsqueeze(-2)
932
- attention_mask = pairwise_attention_mask.unsqueeze(1)
933
- flex_block_mask = None
934
- else:
935
- flex_block_mask = None
936
 
937
  for block in self.blocks:
938
  if self.gradient_checkpointing and self.training:
939
  x, attn_weights = self._gradient_checkpointing_func(
940
  block.__call__,
941
- x,
942
- attention_mask,
943
- flex_block_mask,
944
- output_attentions,
945
  )
946
  else:
947
- x, attn_weights = block(x, attention_mask, flex_block_mask, output_attentions)
 
 
 
 
 
948
 
949
  if attentions is not None:
950
  attentions += (attn_weights,)
@@ -952,9 +954,13 @@ class TransformerStack(nn.Module):
952
  if output_hidden_states:
953
  assert hidden_states is not None
954
  hidden_states += (x,)
 
 
 
 
955
 
956
  return TransformerOutput(
957
- last_hidden_state=self.norm(x),
958
  hidden_states=hidden_states,
959
  attentions=attentions
960
  )
@@ -1048,7 +1054,12 @@ class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
1048
 
1049
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1050
  x = self.embed(input_ids)
1051
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
 
 
 
 
 
1052
 
1053
  def forward(
1054
  self,
@@ -1072,11 +1083,20 @@ class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
1072
  Returns:
1073
  TransformerOutput containing last hidden state and optionally all hidden states and attention weights
1074
  """
 
 
 
1075
  if inputs_embeds is None:
1076
  x = self.embed(input_ids)
1077
  else:
1078
  x = inputs_embeds
1079
- return self.transformer(x, attention_mask, output_hidden_states, output_attentions)
 
 
 
 
 
 
1080
 
1081
 
1082
  class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
@@ -1116,7 +1136,12 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
1116
 
1117
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1118
  x = self.embed(input_ids)
1119
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
 
 
 
 
 
1120
 
1121
  def forward(
1122
  self,
@@ -1146,16 +1171,24 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
1146
  x = self.embed(input_ids)
1147
  else:
1148
  x = inputs_embeds
1149
- output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
1150
- x = output.last_hidden_state
1151
- logits = self.sequence_head(x)
 
 
 
 
 
 
 
1152
  loss = None
1153
  if labels is not None:
1154
  loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
 
1155
  return ESMplusplusOutput(
1156
  loss=loss,
1157
  logits=logits,
1158
- last_hidden_state=x,
1159
  hidden_states=output.hidden_states,
1160
  attentions=output.attentions,
1161
  )
@@ -1185,7 +1218,12 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
1185
 
1186
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1187
  x = self.embed(input_ids)
1188
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
 
 
 
 
 
1189
 
1190
  def forward(
1191
  self,
@@ -1219,9 +1257,11 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
1219
  output_attentions=output_attentions,
1220
  output_hidden_states=output_hidden_states
1221
  )
1222
- x = output.last_hidden_state
1223
- features = self.pooler(x, attention_mask)
 
1224
  logits = self.classifier(features)
 
1225
  loss = None
1226
  if labels is not None:
1227
  labels = labels.to(logits.device)
@@ -1246,7 +1286,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
1246
  return ESMplusplusOutput(
1247
  loss=loss,
1248
  logits=logits,
1249
- last_hidden_state=x,
1250
  hidden_states=output.hidden_states,
1251
  attentions=output.attentions,
1252
  )
@@ -1302,15 +1342,17 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
1302
  output_attentions=output_attentions,
1303
  output_hidden_states=output_hidden_states
1304
  )
1305
- x = output.last_hidden_state
1306
- logits = self.classifier(x)
 
1307
  loss = None
1308
  if labels is not None:
1309
  loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
 
1310
  return ESMplusplusOutput(
1311
  loss=loss,
1312
  logits=logits,
1313
- last_hidden_state=x,
1314
  hidden_states=output.hidden_states,
1315
  attentions=output.attentions,
1316
  )
@@ -1487,5 +1529,3 @@ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
1487
  @property
1488
  def special_token_ids(self):
1489
  return self.all_special_ids
1490
-
1491
-
 
391
  flex_attention = None
392
 
393
 
394
+ def get_attention_mask(
395
+ attn_backend: str,
396
+ batch_size: int,
397
+ seq_len: int,
398
+ device: torch.device,
399
+ attention_mask: Optional[torch.Tensor] = None
400
+ ) -> torch.Tensor:
401
+ if attention_mask is None:
402
+ token_attention_mask = torch.ones((batch_size, seq_len), device=device).bool()
403
+ else:
404
+ token_attention_mask = attention_mask.bool()
405
+
406
+ if attn_backend == "flex":
407
+ assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
408
+
409
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
410
+ return token_attention_mask[batch_idx, q_idx] & token_attention_mask[batch_idx, kv_idx]
411
+
412
+ flex_block_mask = create_block_mask(
413
+ mask_mod,
414
+ batch_size,
415
+ 1,
416
+ seq_len,
417
+ seq_len,
418
+ device=device,
419
+ )
420
+ extended_attention_mask = None
421
+ else:
422
+ flex_block_mask = None
423
+ extended_attention_mask = token_attention_mask[:, None, :, None] & token_attention_mask[:, None, None, :]
424
 
425
+ return extended_attention_mask, flex_block_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
 
427
 
428
  class ESMplusplusConfig(PretrainedConfig):
 
717
  def forward(
718
  self,
719
  x: torch.Tensor,
720
+ attention_mask: torch.Tensor,
721
+ flex_block_mask: object,
722
  output_attentions: bool = False,
723
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
724
  """
725
  Args:
726
  x: Input tensor
727
+ attention_mask: 4D attention mask
728
+ flex_block_mask: Flex attention block mask
729
  output_attentions: Whether to return attention weights
730
 
731
  Returns:
 
743
  scale = 1 / math.sqrt(self.d_head)
744
 
745
  if output_attentions: # Manual attention computation
 
 
 
 
746
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
747
+ attn_weights = attn_weights.masked_fill(attention_mask.logical_not(), float('-inf'))
748
  attn_weights = F.softmax(attn_weights, dim=-1)
749
  context_BHLD = torch.matmul(attn_weights, value_BHLD)
750
  else:
751
  if self.attn_backend == "flex":
752
  assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
753
+ assert query_BHLD.dtype in (torch.float16, torch.bfloat16), f"Flex attention backend requires float16 or bfloat16, got {query_BHLD.dtype}."
754
+ assert flex_block_mask is not None, "Flex attention backend requires a block mask when attention_mask is provided."
 
 
 
 
 
755
  context_BHLD = flex_attention(
756
  query_BHLD,
757
  key_BHLD,
 
760
  scale=scale,
761
  )
762
  else:
 
 
 
 
763
  context_BHLD = F.scaled_dot_product_attention(
764
  query_BHLD,
765
  key_BHLD,
766
  value_BHLD,
767
+ attn_mask=attention_mask,
768
  scale=scale,
769
  )
770
 
 
823
  def forward(
824
  self,
825
  x: torch.Tensor,
826
+ attention_mask: torch.Tensor,
827
+ flex_block_mask: object,
828
  output_attentions: bool = False,
829
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
830
  """
831
  Args:
832
  x: Input tensor
833
+ attention_mask: 4D attention mask
834
+ flex_block_mask: Flex attention block mask
835
  output_attentions: Whether to return attention weights
836
 
837
  Returns:
 
906
  self,
907
  x: torch.Tensor,
908
  attention_mask: Optional[torch.Tensor] = None,
909
+ output_hidden_states: Optional[bool] = False,
910
+ output_attentions: Optional[bool] = False,
911
  ) -> TransformerOutput:
912
  """
913
  Args:
914
  x: Input tensor
915
+ attention_mask: Optional 2D attention mask
916
  output_hidden_states: Whether to return all hidden states
917
  output_attentions: Whether to return attention weights
918
 
 
922
  hidden_states = () if output_hidden_states else None
923
  attentions = () if output_attentions else None
924
 
925
+ # move to 4D attention mask or flex block mask
926
+ attention_mask, flex_block_mask = get_attention_mask(
927
+ attn_backend=self.attn_backend,
928
+ batch_size=x.shape[0],
929
+ seq_len=x.shape[1],
930
+ device=x.device,
931
+ attention_mask=attention_mask,
932
+ )
 
 
 
 
 
 
 
933
 
934
  for block in self.blocks:
935
  if self.gradient_checkpointing and self.training:
936
  x, attn_weights = self._gradient_checkpointing_func(
937
  block.__call__,
938
+ x=x,
939
+ attention_mask=attention_mask,
940
+ flex_block_mask=flex_block_mask,
941
+ output_attentions=output_attentions,
942
  )
943
  else:
944
+ x, attn_weights = block(
945
+ x=x,
946
+ attention_mask=attention_mask,
947
+ flex_block_mask=flex_block_mask,
948
+ output_attentions=output_attentions,
949
+ )
950
 
951
  if attentions is not None:
952
  attentions += (attn_weights,)
 
954
  if output_hidden_states:
955
  assert hidden_states is not None
956
  hidden_states += (x,)
957
+
958
+ last_hidden_state = self.norm(x)
959
+ if output_hidden_states:
960
+ hidden_states += (last_hidden_state,)
961
 
962
  return TransformerOutput(
963
+ last_hidden_state=last_hidden_state,
964
  hidden_states=hidden_states,
965
  attentions=attentions
966
  )
 
1054
 
1055
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1056
  x = self.embed(input_ids)
1057
+ return self.transformer(
1058
+ x=x,
1059
+ attention_mask=attention_mask,
1060
+ output_hidden_states=False,
1061
+ output_attentions=False,
1062
+ ).last_hidden_state
1063
 
1064
  def forward(
1065
  self,
 
1083
  Returns:
1084
  TransformerOutput containing last hidden state and optionally all hidden states and attention weights
1085
  """
1086
+ assert input_ids is not None or inputs_embeds is not None, "You have to specify either input_ids or inputs_embeds"
1087
+ assert not (input_ids is not None and inputs_embeds is not None), "You cannot specify both input_ids and inputs_embeds at the same time"
1088
+
1089
  if inputs_embeds is None:
1090
  x = self.embed(input_ids)
1091
  else:
1092
  x = inputs_embeds
1093
+
1094
+ return self.transformer(
1095
+ x=x,
1096
+ attention_mask=attention_mask,
1097
+ output_hidden_states=output_hidden_states,
1098
+ output_attentions=output_attentions,
1099
+ ).last_hidden_state
1100
 
1101
 
1102
  class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
 
1136
 
1137
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1138
  x = self.embed(input_ids)
1139
+ return self.transformer(
1140
+ x=x,
1141
+ attention_mask=attention_mask,
1142
+ output_hidden_states=False,
1143
+ output_attentions=False,
1144
+ ).last_hidden_state
1145
 
1146
  def forward(
1147
  self,
 
1171
  x = self.embed(input_ids)
1172
  else:
1173
  x = inputs_embeds
1174
+
1175
+ output = self.transformer(
1176
+ x=x,
1177
+ attention_mask=attention_mask,
1178
+ output_hidden_states=output_hidden_states,
1179
+ output_attentions=output_attentions,
1180
+ )
1181
+
1182
+ last_hidden_state = output.last_hidden_state
1183
+ logits = self.sequence_head(last_hidden_state)
1184
  loss = None
1185
  if labels is not None:
1186
  loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
1187
+
1188
  return ESMplusplusOutput(
1189
  loss=loss,
1190
  logits=logits,
1191
+ last_hidden_state=last_hidden_state,
1192
  hidden_states=output.hidden_states,
1193
  attentions=output.attentions,
1194
  )
 
1218
 
1219
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1220
  x = self.embed(input_ids)
1221
+ return self.transformer(
1222
+ x=x,
1223
+ attention_mask=attention_mask,
1224
+ output_hidden_states=False,
1225
+ output_attentions=False,
1226
+ ).last_hidden_state
1227
 
1228
  def forward(
1229
  self,
 
1257
  output_attentions=output_attentions,
1258
  output_hidden_states=output_hidden_states
1259
  )
1260
+
1261
+ last_hidden_state = output.last_hidden_state
1262
+ features = self.pooler(last_hidden_state, attention_mask) # pooler expects 2d attention mask
1263
  logits = self.classifier(features)
1264
+
1265
  loss = None
1266
  if labels is not None:
1267
  labels = labels.to(logits.device)
 
1286
  return ESMplusplusOutput(
1287
  loss=loss,
1288
  logits=logits,
1289
+ last_hidden_state=last_hidden_state,
1290
  hidden_states=output.hidden_states,
1291
  attentions=output.attentions,
1292
  )
 
1342
  output_attentions=output_attentions,
1343
  output_hidden_states=output_hidden_states
1344
  )
1345
+
1346
+ last_hidden_state = output.last_hidden_state
1347
+ logits = self.classifier(last_hidden_state)
1348
  loss = None
1349
  if labels is not None:
1350
  loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1351
+
1352
  return ESMplusplusOutput(
1353
  loss=loss,
1354
  logits=logits,
1355
+ last_hidden_state=last_hidden_state,
1356
  hidden_states=output.hidden_states,
1357
  attentions=output.attentions,
1358
  )
 
1529
  @property
1530
  def special_token_ids(self):
1531
  return self.all_special_ids