import random import torch import numpy as np import torch.nn.functional as F from src.PeptiVerse.inference import PeptiVersePredictor from src.utils.model_utils import _print class MadSBMSampler: def __init__(self, model, config, device, guidance=None): self.config = config self.device = device self.model = model self.tokenizer = model.tokenizer self.mask_id = self.tokenizer.mask_token_id self.eps = config.time_embed.min_time self.seed_everything(seed=42) if guidance: self.guidance = guidance self.peptiverse = PeptiVersePredictor( manifest_path="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse/best_models.txt", classifier_weight_root="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse", device=self.device ) @torch.inference_mode() def sample(self, xt, num_steps, tracer, target_toks=None, guidance=None): xt = xt.clone() B, L = xt.shape assert B == 1, "Do only 1 sequence at a time" t_max = 1.0 - self.eps dt = 1.0 / num_steps attn_mask = torch.ones_like(xt, device=self.device) action_traj = {} tot_action = 0.0 tracer.log_step(xt=xt, step_idx=0) converge_idx = num_steps converged = False for k in range(num_steps): # t decreases from 1 --> 0 as our model was trained that t=1 --> noise and t=0 --> clean prog = (k + 1) / float(num_steps) t_val = t_max - (t_max - self.eps) * prog t = torch.full((B,), fill_value=float(t_val), device=self.device) # B = 1 during sampling # predicted control field --> B, L, V outs = self.model(input_ids=xt, attention_mask=attn_mask, t=t) u_tilt = outs['dit'] total_logits = outs['madsbm'] esm_logits = outs['esm'] if self.config.model.ablate: actional = self.compute_action(u_tilt, esm_logits=None) else: actional = self.compute_action(u_tilt, esm_logits=esm_logits) action_traj[f"action_step_{k+1}"] = actional tot_action += (actional * dt) # Compute jump rates and jump probs # P(jump) = 1 - exp(-rate * dt) r_theta = torch.exp(u_tilt * self.config.sampling.rate_scale) R_tot = r_theta.sum(dim=-1) # 1, L rate = (- R_tot * self.config.sampling.jump_scale * dt).clamp(min=-40.0, max=0.0) jump_prob = 1.0 - torch.exp(rate) # Scale and filter logits with nucleus sampling logits = total_logits.clone() logits /= self.config.sampling.tau logits = self.top_p_filter(logits, self.config.sampling.top_p) # Sample new tokens probs = F.softmax(logits, dim=-1) probs = probs.view(-1, probs.size(-1)) sample = torch.multinomial(probs, 1) candidate_toks = sample.view(B, L) # determine tokens we can change rand = torch.rand(B, L, device=self.device) can_jump = (rand < jump_prob) updatable = can_jump & self.is_masked(xt) # Update the sequence if guidance: chosen_candidate = self.binding_guidance(probs, target_toks, B, L) xt[updatable] = chosen_candidate[updatable] else: xt[updatable] = candidate_toks[updatable] tracer.log_step(xt=xt, step_idx = k+1) if k == num_steps-1: final_logits = total_logits still_masked = self.is_masked(xt) if not converged and not self.is_masked(xt).any(): converge_idx = k + 1 converged = True # Copy over remaining tokens if still_masked.any(): final_toks = final_logits.argmax(dim=-1) xt[still_masked] = final_toks[still_masked] tracer.log_step(xt, num_steps + 1) binding_affin = self.peptiverse.predict_binding_affinity( mode = 'wt', target_ids = target_toks, binder_ids = xt )['affinity'] return xt, binding_affin def binding_guidance(self, probs, target_toks, B, L): M = self.config.sampling.M candidate_toks = [] affinities = [] for _ in range(M): ith_sample = torch.multinomial(probs, 1).view(B, L) candidate_toks.append(ith_sample) for toks in candidate_toks: pred = self.peptiverse.predict_binding_affinity( mode = 'wt', target_ids = target_toks, binder_ids = toks.detach() )['affinity'] affinities.append(pred) affinities = torch.tensor(affinities, dtype=torch.float32) weights = F.softmax(affinities / self.config.sampling.tau, dim=0) chosen_idx = torch.multinomial(weights, 1).item() return candidate_toks[chosen_idx] def compute_action(self, u_tilt, esm_logits=None): """ Computes the action functional for evals """ if esm_logits is not None: R0 = torch.softmax(esm_logits, dim=-1) else: R0 = 1.0 / self.tokenizer.vocab_size psi_u = torch.exp(u_tilt) - u_tilt - 1.0 action_per_tok = (R0 * psi_u).sum(dim=-1) # R0 goes to 1 in both cases return action_per_tok.mean().item() def top_p_filter(self, logits, p_val): """ Implementation of nucleus / top-p sampling Masks out tokens that contribute to the bottom (1 - p) cumulative probability """ # Sort logits and get cumulative probabilities sorted_logits, sorted_indices = torch.sort(logits, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cum prob > p-val thresh sorted_idx_to_remove = cum_probs > p_val # Shift the indices to the right to keep also the first token above the threshold sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone() sorted_idx_to_remove[..., 0] = 0 idx_to_remove = sorted_idx_to_remove.scatter(-1, sorted_indices, sorted_idx_to_remove) logits[idx_to_remove] = float('-inf') return logits def is_masked(self, xt): return (xt == self.mask_id) def seed_everything(self, seed): if seed is None: return random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if using multi-GPU torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False