lhallee commited on
Commit
f7c04a4
·
verified ·
1 Parent(s): e250ef1

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +15 -11
modeling_esm_plusplus.py CHANGED
@@ -406,17 +406,21 @@ def get_attention_mask(
406
  if attn_backend == "flex":
407
  assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
408
 
409
- def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
410
- return token_attention_mask[batch_idx, q_idx] & token_attention_mask[batch_idx, kv_idx]
411
-
412
- flex_block_mask = create_block_mask(
413
- mask_mod,
414
- batch_size,
415
- 1,
416
- seq_len,
417
- seq_len,
418
- device=device,
419
- )
 
 
 
 
420
  extended_attention_mask = None
421
  else:
422
  flex_block_mask = None
 
406
  if attn_backend == "flex":
407
  assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
408
 
409
+ if attention_mask is None:
410
+ flex_block_mask = None
411
+ else:
412
+ sequence_ids = torch.where(token_attention_mask, 1, -1)
413
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
414
+ return (sequence_ids[batch_idx, q_idx] == sequence_ids[batch_idx, kv_idx]) & (sequence_ids[batch_idx, q_idx] != -1)
415
+
416
+ flex_block_mask = create_block_mask(
417
+ mask_mod,
418
+ batch_size,
419
+ 1,
420
+ seq_len,
421
+ seq_len,
422
+ device=device,
423
+ )
424
  extended_attention_mask = None
425
  else:
426
  flex_block_mask = None