| | import os |
| | import math |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from src.utils.model_utils import _print |
| | from src.guidance.solubility_module import SolubilityClassifier |
| | from src.sampling.unconditional_sampler import UnconditionalSampler |
| |
|
| |
|
| | class GuidedSampler: |
| | def __init__(self, config, esm_model, tokenizer, diffusion, device): |
| | self.config = config |
| | self.device = device |
| |
|
| | self.esm = esm_model |
| | self.memdlm = diffusion |
| | self.tokenizer = tokenizer |
| | self.uncond_generator = UnconditionalSampler(self.tokenizer, self.memdlm) |
| |
|
| | ckpt_path = os.path.join(f"/home/a03-sgoel/MeMDLM_v2/checkpoints/{config.wandb.name}/best_model.ckpt") |
| | self.classifier_model = SolubilityClassifier(config) |
| | state_dict = self.classifier_model.get_state_dict(ckpt_path) |
| | self.classifier_model.load_state_dict(state_dict) |
| | self.classifier_model.eval().to(self.device) |
| |
|
| | self.top_p = self.config.guidance.top_p |
| | self.alpha = self.config.guidance.alpha |
| | self.gamma = self.config.guidance.gamma |
| | self.saliency_eps = self.config.guidance.saliency_eps |
| | self.saliency_t = self.config.guidance.saliency_t |
| | self.sampling_t = self.config.guidance.sampling_t |
| | self.boltzmann_t = self.config.guidance.boltzmann_t |
| | |
| |
|
| | def embed_sequence(self, input_ids, attention_masks): |
| | with torch.no_grad(): |
| | outs = self.esm( |
| | input_ids=input_ids, |
| | attention_mask=attention_masks, |
| | output_hidden_states=True, |
| | output_attentions=True |
| | ) |
| | embeds = outs.hidden_states[-1] |
| | attn_matrix = outs.attentions |
| | return embeds, attn_matrix |
| |
|
| |
|
| | def sample_from_categorical(self, logits, temperature, noise_scale=1.0): |
| | gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8) |
| | logits = (logits / temperature) + (noise_scale * gumbel_noise) |
| | |
| | log_probs = F.log_softmax(logits, dim=-1) |
| | _, tokens = log_probs.max(dim=-1) |
| | |
| | return tokens, log_probs |
| | |
| |
|
| | def denoise_sequence(self, input_ids, attn_masks): |
| | """ |
| | Compute the current and prior sequences' log prob distribution. |
| | """ |
| | has_masks = (input_ids == self.tokenizer.mask_token_id).any() |
| |
|
| | |
| | if has_masks: |
| | xt_prior, logits_prior = self.uncond_generator.sample_unconditional( |
| | xt=input_ids, |
| | num_steps=self.config.guidance.n_steps, |
| | tau=self.sampling_t, |
| | return_logits=True |
| | ) |
| | else: |
| | xt_prior = input_ids |
| | logits_prior = self.memdlm(input_ids=input_ids, attention_mask=attn_masks) |
| |
|
| | |
| | _, logits = self.uncond_generator.sample_unconditional( |
| | xt=xt_prior, |
| | num_steps=1, |
| | tau=self.sampling_t, |
| | return_logits=True |
| | ) |
| |
|
| | |
| | x0, logp_lm = self.sample_from_categorical(logits, temperature=self.sampling_t) |
| |
|
| | return x0.squeeze(), logp_lm.squeeze(), logits_prior |
| |
|
| |
|
| | def get_prior(self, logits_prior, solubility_logits): |
| | if self.config.guidance.prior == "boltzmann": |
| | hydrophilic = ["D","E","K","R","N","Q","H","S","T","Y"] |
| | hydrophobic = ["L","I","V","F","W","M","A","C","G","P"] |
| | amino_acids = hydrophilic + hydrophobic |
| | |
| | tokens = list(self.tokenizer.get_vocab().keys()) |
| | other = [tok for tok in tokens if tok not in amino_acids] |
| |
|
| | hydrophilic_idxs = [self.tokenizer.convert_tokens_to_ids(aa) for aa in hydrophilic] |
| | hydrophobic_idxs = [self.tokenizer.convert_tokens_to_ids(aa) for aa in hydrophobic] |
| | other_idxs = [self.tokenizer.convert_tokens_to_ids(tok) for tok in other] |
| |
|
| | bias = torch.zeros(len(tokens), device=self.device) |
| | bias[hydrophilic_idxs] = 1.0 |
| | bias[hydrophobic_idxs] = -1.0 |
| | bias[other_idxs] = 0.0 |
| |
|
| | sol_scores = torch.sigmoid(solubility_logits) |
| | token_bias = sol_scores.unsqueeze(-1) * bias |
| |
|
| | lm_probs = F.softmax(logits_prior / self.sampling_t, dim=-1) |
| | boltz_weight = torch.exp(token_bias / self.boltzmann_t) |
| |
|
| | p_prior = lm_probs * boltz_weight |
| | p_prior = p_prior / p_prior.sum(dim=-1, keepdim=True) |
| | logp_prior = torch.log(p_prior) |
| |
|
| | elif self.config.guidance.prior == "lm_probs": |
| | _, logp_prior = self.sample_from_categorical(logits_prior, temperature=self.sampling_t) |
| |
|
| | return logp_prior.squeeze() |
| |
|
| |
|
| | def compute_saliency_map(self, embeds, solubility_logits): |
| | """ |
| | Compute a saliency map as in LaMBO-2 (https://arxiv.org/abs/2305.20009) Eq. 5 |
| | """ |
| | |
| | solubility_logits.sum().backward(retain_graph=True) |
| | grads = embeds.grad.abs().sum(dim=-1) |
| | saliency = grads.pow(1.0 / self.saliency_t).clamp(min=self.saliency_eps).to(self.device) |
| | saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-6) |
| | return saliency.squeeze() |
| |
|
| |
|
| | def determine_edit_positions(self, saliency_map, soluble_indices, solubility_logits): |
| | """ |
| | Fix the insoluble residues and additional TM residues to |
| | maintain membrane-like protein structure. |
| | """ |
| | seq_len = saliency_map.shape[0] |
| |
|
| | |
| | edit_mask = torch.ones(seq_len, dtype=torch.bool, device=self.device) |
| |
|
| | |
| | if len(soluble_indices) > 0: |
| | edit_mask[soluble_indices] = False |
| | elif soluble_indices is None or len(soluble_indices) == 0: |
| | solubility_preds = F.sigmoid(solubility_logits) |
| | edit_mask[solubility_preds > 0.5] = False |
| |
|
| | |
| | num_conserved = max(1, int(0.1 * edit_mask.sum())) |
| | _, topk_idxs = torch.topk(saliency_map, num_conserved) |
| | edit_mask[topk_idxs] = False |
| |
|
| | edit_idxs = edit_mask.nonzero(as_tuple=True)[0] |
| | return edit_idxs |
| |
|
| |
|
| | def create_neighborhood(self, edit_pos, attn_matrix, top_p): |
| | """ |
| | Select a dynamic "neighborhood" of tokens for edit position via top-p sampling. |
| | Attention scores find relevant tokens, avoding blind updates of the individual token |
| | """ |
| | |
| | row = attn_matrix[edit_pos].clone().squeeze() |
| | row = row.index_fill( |
| | dim=0, |
| | index=torch.tensor([0, edit_pos, row.size(0)-1], device=row.device), |
| | value=float('-inf') |
| | ) |
| | |
| | |
| | temp = 1.0 / math.log(row.size(0)) |
| | attn_probs = F.softmax(row / temp, dim=0) |
| | sorted_probs, sorted_idxs = torch.sort(attn_probs, descending=True) |
| | cum_probs = sorted_probs.cumsum(dim=0) |
| | cutoff = (cum_probs <= top_p).nonzero(as_tuple=True)[0] |
| | |
| | |
| | final_idx = cutoff[-1].item() + 1 if cutoff.numel() > 0 else 1 |
| | neighborhood = sorted_idxs[:final_idx] |
| | return neighborhood |
| | |
| |
|
| | def compute_saliency_weight(self, edit_pos, attn_mat, saliency_map, neighborhood): |
| | """ |
| | Blend the saliency of the neighborhood's tokens and the token at the edit position. |
| | """ |
| | neighborhood_attns = attn_mat[edit_pos, neighborhood] |
| | neighborhood_attns /= neighborhood_attns.sum() |
| |
|
| | neighborhood_saliencies = saliency_map[neighborhood] |
| | |
| | neighborhood_weight = torch.sum(neighborhood_attns * neighborhood_saliencies) |
| | ctxt_aware_saliency = saliency_map[edit_pos] + (self.gamma * neighborhood_weight) |
| |
|
| | return ctxt_aware_saliency |
| |
|
| |
|
| | def compute_guidance_dist(self, logp_lm, logp_prior, saliency_weight): |
| | """ |
| | Define a guidance distribution between a prior and the current LM probs. |
| | Compute the log probs of the "new" (optimized) token. |
| | """ |
| | w = torch.sigmoid(saliency_weight * self.alpha) |
| | p_lm = torch.exp(logp_lm) |
| | p_prior = torch.exp(logp_prior) |
| | mixed_probs = (1 - w) * p_lm + w * p_prior |
| | guidance_dist = torch.log(mixed_probs + 1e-12) |
| | return guidance_dist |
| | |
| |
|
| | def check_scaffold(self, seq1, seq2, idxs): |
| | changed = (seq1[idxs] != seq2[idxs]) |
| | if changed.any(): |
| | _print('soluble residues changed') |
| | else: |
| | _print('no soluble residue changes') |
| |
|
| |
|
| | def optimize_sequence(self, input_ids, attn_masks, soluble_indices): |
| | _print(f'soluble idx: {soluble_indices}') |
| |
|
| | |
| | x0, logp_lm, logits_prior = self.denoise_sequence(input_ids, attn_masks) |
| | _print(f'og tokens: {x0}') |
| | _print(f'og tokens: {x0.shape}') |
| | _print(f'og log probs: {logp_lm.shape}') |
| | |
| | |
| | embeds, attn_mats = self.embed_sequence(x0.unsqueeze(0), attn_masks) |
| | embeds = embeds.detach().clone().requires_grad_(True) |
| | attn_matrix = attn_mats[-1].mean(dim=1)[0].squeeze(0) |
| |
|
| | |
| | batch = {"embeds": embeds, "attention_mask": attn_masks} |
| | solubility_logits = self.classifier_model(batch) |
| |
|
| | |
| | saliency_map = self.compute_saliency_map(embeds, solubility_logits) |
| | _print(f'saliency map: {saliency_map}') |
| | edit_positions = self.determine_edit_positions(saliency_map, soluble_indices, solubility_logits) |
| | _print(f'edit positions: {edit_positions}') |
| |
|
| | |
| | logp_prior = self.get_prior(logits_prior, solubility_logits) |
| | _print(f'prior log probs: {logp_prior.shape}') |
| | |
| | |
| | for edit_pos in edit_positions.tolist(): |
| | neighborhood = self.create_neighborhood( |
| | edit_pos, |
| | attn_matrix, |
| | self.top_p |
| | ) |
| | _print(f'neighborhood: {neighborhood}') |
| | |
| | ctxt_aware_saliency = self.compute_saliency_weight( |
| | edit_pos, |
| | attn_matrix, |
| | saliency_map, |
| | neighborhood |
| | ) |
| | _print(f'ctx aware saliency: {ctxt_aware_saliency}') |
| | |
| | logp_lm_prime = self.compute_guidance_dist( |
| | logp_lm[edit_pos], |
| | logp_prior[edit_pos], |
| | ctxt_aware_saliency |
| | ) |
| | logp_lm[edit_pos] = logp_lm_prime |
| |
|
| | tot = torch.exp(logp_lm_prime).sum() |
| | one = torch.tensor(1.0, dtype=tot.dtype, device=tot.device) |
| | assert torch.isclose(tot, one, atol=1e-4), f"Invalid prob distribution. Sum = {tot:5f}" |
| |
|
| | |
| | x0_prime = torch.distributions.Categorical(logits=logp_lm).sample() |
| | |
| | |
| | self.check_scaffold(x0, x0_prime, soluble_indices) |
| |
|
| | |
| | x0_prime[soluble_indices] = x0[soluble_indices] |
| | self.check_scaffold(x0, x0_prime, soluble_indices) |
| |
|
| | return x0_prime |