| | |
| | |
| | |
| | |
| |
|
| | from collections import OrderedDict |
| | import math |
| | import requests |
| | from io import BytesIO |
| | from functools import partial |
| | import pickle |
| | from typing import Callable, Optional, Sequence, Tuple, List |
| | import numpy as np |
| | import os |
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| | from torch.nn.init import trunc_normal_ |
| | from torchvision import transforms |
| | from torchvision.transforms import InterpolationMode |
| |
|
| | class GLU(nn.Module): |
| | def __init__(self,hidden_size): |
| | super().__init__() |
| | self.linear_proj = nn.Linear(hidden_size,hidden_size,bias=False) |
| | self.norm1 = nn.LayerNorm(hidden_size) |
| | self.act1 = nn.GELU() |
| | self.act2 = nn.functional.silu |
| | self.dense_h_to_4h = nn.Linear(hidden_size,hidden_size*4,bias=False) |
| | self.gate_proj = nn.Linear(hidden_size,hidden_size*4,bias=False) |
| | self.dense_4h_to_h = nn.Linear(hidden_size*4,hidden_size,bias=False) |
| |
|
| | def forward(self,x): |
| | x = self.linear_proj(x) |
| | x = self.act1(self.norm1(x)) |
| | x = self.act2(self.gate_proj(x))*self.dense_h_to_4h(x) |
| | x = self.dense_4h_to_h(x) |
| | return x |
| | def swiglu(x): |
| | x = torch.chunk(x, 2, dim=-1) |
| | return nn.functional.silu(x[0]) * x[1] |
| |
|
| | class GLU_new(nn.Module): |
| | def __init__(self,hidden_size, dropout=0.1): |
| | super().__init__() |
| | intermediate_size = int((4 * hidden_size * 2 / 3) / 64) * 64 |
| | intermediate_size = 1280 |
| |
|
| | self.act = swiglu |
| | self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size * 2, bias=False) |
| | self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size, bias=False) |
| | self.dropout = nn.Dropout(p=dropout) |
| |
|
| | def forward(self,x): |
| | x = self.dense_h_to_4h(x) |
| | x = self.act(x) |
| | x = self.dense_4h_to_h(x) |
| | x = self.dropout(x) |
| | return x |
| |
|
| |
|
| | n_queries = 32 |
| | def get_abs_pos(abs_pos, tgt_size): |
| | |
| | |
| | |
| | src_size = int(math.sqrt(abs_pos.size(0))) |
| | tgt_size = int(math.sqrt(tgt_size)) |
| | dtype = abs_pos.dtype |
| |
|
| | if src_size != tgt_size: |
| | return F.interpolate( |
| | abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), |
| | size=(tgt_size, tgt_size), |
| | mode="bicubic", |
| | align_corners=False, |
| | ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) |
| | else: |
| | return abs_pos |
| |
|
| | from einops import rearrange, repeat |
| |
|
| | def get_1d_sincos_pos_embed(embed_dim, pos): |
| | """ |
| | embed_dim: output dimension for each position |
| | pos: a list of positions to be encoded: size (M,) |
| | out: (M, D) |
| | """ |
| | assert embed_dim % 2 == 0 |
| | omega = np.arange(embed_dim // 2, dtype=np.float32) |
| | omega /= embed_dim / 2. |
| | omega = 1. / 10000**omega |
| |
|
| | pos = pos.reshape(-1) |
| | out = np.einsum('m,d->md', pos, omega) |
| |
|
| | emb_sin = np.sin(out) |
| | emb_cos = np.cos(out) |
| |
|
| | emb = np.concatenate([emb_sin, emb_cos], axis=1) |
| | return emb |
| |
|
| | class Resampler(nn.Module): |
| | def __init__( |
| | self, |
| | kv_dim, |
| | embed_dim, |
| | num_heads=8, |
| | n_queries=64, |
| | max_seqlen=1024, |
| | perceiver_resampler_positional_emb=True, |
| | use_GLU=False, |
| | bos_init=False, |
| | dropout=0.0 |
| | ): |
| | super().__init__() |
| | self.perceiver_resampler_positional_emb = perceiver_resampler_positional_emb |
| |
|
| | if self.perceiver_resampler_positional_emb: |
| | assert n_queries <= max_seqlen |
| | self.stride = max_seqlen // n_queries |
| | |
| | |
| | pos = np.arange(max_seqlen, dtype=np.float32) |
| | self.register_buffer( |
| | "pos_embed", |
| | torch.from_numpy(get_1d_sincos_pos_embed(embed_dim, pos)).float() |
| | ) |
| | self.latents = nn.Parameter(torch.randn(n_queries, embed_dim)) |
| | if bos_init: |
| | self.latents.load('') |
| | else: |
| | nn.init.trunc_normal_(self.latents, std=1e-3) |
| |
|
| | self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) |
| | self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout) |
| | self.ln_q = nn.LayerNorm(embed_dim) |
| | self.ln_kv = nn.LayerNorm(embed_dim) |
| | self.ln_post = nn.LayerNorm(embed_dim) |
| | if use_GLU: |
| | print('GLU *********************************') |
| | self.proj = GLU_new(embed_dim, dropout=dropout) |
| | else: |
| | self.proj = nn.Linear(embed_dim, embed_dim, bias=False) |
| |
|
| | self.apply(self._init_weights) |
| | |
| | def _init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | nn.init.trunc_normal_(m.weight, std=1e-3) |
| | if isinstance(m, nn.Linear) and m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.LayerNorm): |
| | nn.init.constant_(m.bias, 0) |
| | nn.init.constant_(m.weight, 1.0) |
| |
|
| | def forward(self, struc_x): |
| | """ |
| | Args: |
| | x (torch.Tensor): protein structure features |
| | shape (B, L, C) |
| | Returns: |
| | shape (B, n, C) where n is self.num_latents |
| | """ |
| | x = struc_x["encoder_out"] |
| | mask = struc_x["encoder_padding_mask"] |
| |
|
| |
|
| | nan_mask = torch.isnan(x) |
| | if nan_mask.any(): |
| | x = x.masked_fill(nan_mask, 0.0) |
| | |
| | |
| |
|
| | x = self.kv_proj(x) |
| | x = self.ln_kv(x) |
| |
|
| | b, seqlen = x.shape[:2] |
| |
|
| | latents = self.ln_q(self.latents) |
| | if self.perceiver_resampler_positional_emb: |
| | |
| | latents = latents + self.pos_embed[::self.stride].contiguous() |
| | pos_emb = self.pos_embed[:seqlen].unsqueeze(0) |
| | x = x + pos_emb.contiguous() |
| | |
| | |
| | latents = repeat(latents, "n d -> b n d", b=b) |
| | out = self.attn(latents, x, x, key_padding_mask=~mask)[0] |
| |
|
| | out = self.ln_post(out) |
| | out = self.proj(out) |
| |
|
| | return out |
| |
|
| | class StructureTransformer(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | width: int = 640, |
| | n_queries: int = 32, |
| | output_dim: int = 4096, |
| | embedding_keys=set(["mpnn_emb"]), |
| | max_seqlen: int=1024, |
| | num_heads: int=8, |
| | structure_emb_path_prefix='structure_emb', |
| | **kwargs |
| | ): |
| | super().__init__() |
| |
|
| | self.structure_emb_path_prefix = structure_emb_path_prefix |
| | |
| | self.embedding_keys = embedding_keys |
| | self.max_seqlen = max_seqlen |
| | self.width = width |
| | self.n_queries = n_queries |
| |
|
| | self.attn_pool = Resampler( |
| | embed_dim=output_dim, |
| | kv_dim=width, |
| | n_queries=n_queries, |
| | max_seqlen=max_seqlen, |
| | num_heads=num_heads, |
| | **kwargs |
| | ) |
| |
|
| | def prepare_structure(self, sample): |
| | emb_pad = torch.zeros((self.max_seqlen, self.width)) |
| | emb_mask = torch.zeros((self.max_seqlen), dtype=bool) |
| | |
| | if "pifold_emb" in self.embedding_keys and "pifold_mask" in sample: |
| | mask = sample["pifold_mask"] |
| | pifold_emb = sample["pifold_emb"] |
| | new_pifold_emb = pifold_emb.new_zeros(mask.shape[0], pifold_emb.shape[1]).fill_(float("nan")) |
| | new_pifold_emb[mask > 0] = pifold_emb |
| | sample["pifold_emb"] = new_pifold_emb |
| | |
| | |
| | emb = [] |
| | for ek in self.embedding_keys: |
| | if ek in sample: |
| | if isinstance( sample[ek], List): |
| | emb.append(torch.cat(sample[ek])) |
| | else: |
| | emb.append(sample[ek]) |
| | |
| | emb = torch.cat(emb, dim=-1) |
| | |
| | emb_pad[:len(emb)] = emb |
| | emb_mask[:len(emb)] = 1 |
| | return emb_pad, emb_mask |
| |
|
| | def forward(self, x): |
| |
|
| | |
| | x = self.attn_pool(x) |
| |
|
| | return x |
| |
|
| | def encode(self, structure_paths: List[str]): |
| | structure_embs = [] |
| | structure_mask = [] |
| |
|
| | for structure_path in structure_paths: |
| | structure_path = [chr(s) for s in structure_path[:self.n_queries].tolist() if s > 0] |
| | structure_path = os.path.join(self.structure_emb_path_prefix, ''.join(structure_path)) |
| | if not os.path.exists(structure_path): |
| | print('no structure found') |
| | return None |
| | |
| | with open(structure_path, 'rb') as f: |
| | structure, struc_mask = self.prepare_structure(pickle.load(f)) |
| | |
| |
|
| | structure_embs.append(structure) |
| | structure_mask.append(struc_mask) |
| |
|
| | structure_embs = torch.stack(structure_embs, dim=0).to( |
| | device=next(self.attn_pool.parameters()).device, |
| | dtype=next(self.attn_pool.parameters()).dtype) |
| | structure_mask = torch.stack(structure_mask, dim=0).to( |
| | device=next(self.attn_pool.parameters()).device) |
| |
|
| | return self({ |
| | 'encoder_out': structure_embs, |
| | 'encoder_padding_mask': structure_mask |
| | }) |
| |
|