Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +15 -11
modeling_esm_plusplus.py
CHANGED
|
@@ -406,17 +406,21 @@ def get_attention_mask(
|
|
| 406 |
if attn_backend == "flex":
|
| 407 |
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
|
| 408 |
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
mask_mod,
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
extended_attention_mask = None
|
| 421 |
else:
|
| 422 |
flex_block_mask = None
|
|
|
|
| 406 |
if attn_backend == "flex":
|
| 407 |
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
|
| 408 |
|
| 409 |
+
if attention_mask is None:
|
| 410 |
+
flex_block_mask = None
|
| 411 |
+
else:
|
| 412 |
+
sequence_ids = torch.where(token_attention_mask, 1, -1)
|
| 413 |
+
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| 414 |
+
return (sequence_ids[batch_idx, q_idx] == sequence_ids[batch_idx, kv_idx]) & (sequence_ids[batch_idx, q_idx] != -1)
|
| 415 |
+
|
| 416 |
+
flex_block_mask = create_block_mask(
|
| 417 |
+
mask_mod,
|
| 418 |
+
batch_size,
|
| 419 |
+
1,
|
| 420 |
+
seq_len,
|
| 421 |
+
seq_len,
|
| 422 |
+
device=device,
|
| 423 |
+
)
|
| 424 |
extended_attention_mask = None
|
| 425 |
else:
|
| 426 |
flex_block_mask = None
|