MadSBM / src /sampling /madsbm_sampler.py
Shrey Goel
initial commit
94c2704
import random
import torch
import numpy as np
import torch.nn.functional as F
from src.PeptiVerse.inference import PeptiVersePredictor
from src.utils.model_utils import _print
class MadSBMSampler:
def __init__(self, model, config, device, guidance=None):
self.config = config
self.device = device
self.model = model
self.tokenizer = model.tokenizer
self.mask_id = self.tokenizer.mask_token_id
self.eps = config.time_embed.min_time
self.seed_everything(seed=42)
if guidance:
self.guidance = guidance
self.peptiverse = PeptiVersePredictor(
manifest_path="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse/best_models.txt",
classifier_weight_root="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse",
device=self.device
)
@torch.inference_mode()
def sample(self, xt, num_steps, tracer, target_toks=None, guidance=None):
xt = xt.clone()
B, L = xt.shape
assert B == 1, "Do only 1 sequence at a time"
t_max = 1.0 - self.eps
dt = 1.0 / num_steps
attn_mask = torch.ones_like(xt, device=self.device)
action_traj = {}
tot_action = 0.0
tracer.log_step(xt=xt, step_idx=0)
converge_idx = num_steps
converged = False
for k in range(num_steps):
# t decreases from 1 --> 0 as our model was trained that t=1 --> noise and t=0 --> clean
prog = (k + 1) / float(num_steps)
t_val = t_max - (t_max - self.eps) * prog
t = torch.full((B,), fill_value=float(t_val), device=self.device) # B = 1 during sampling
# predicted control field --> B, L, V
outs = self.model(input_ids=xt, attention_mask=attn_mask, t=t)
u_tilt = outs['dit']
total_logits = outs['madsbm']
esm_logits = outs['esm']
if self.config.model.ablate:
actional = self.compute_action(u_tilt, esm_logits=None)
else:
actional = self.compute_action(u_tilt, esm_logits=esm_logits)
action_traj[f"action_step_{k+1}"] = actional
tot_action += (actional * dt)
# Compute jump rates and jump probs
# P(jump) = 1 - exp(-rate * dt)
r_theta = torch.exp(u_tilt * self.config.sampling.rate_scale)
R_tot = r_theta.sum(dim=-1) # 1, L
rate = (- R_tot * self.config.sampling.jump_scale * dt).clamp(min=-40.0, max=0.0)
jump_prob = 1.0 - torch.exp(rate)
# Scale and filter logits with nucleus sampling
logits = total_logits.clone()
logits /= self.config.sampling.tau
logits = self.top_p_filter(logits, self.config.sampling.top_p)
# Sample new tokens
probs = F.softmax(logits, dim=-1)
probs = probs.view(-1, probs.size(-1))
sample = torch.multinomial(probs, 1)
candidate_toks = sample.view(B, L)
# determine tokens we can change
rand = torch.rand(B, L, device=self.device)
can_jump = (rand < jump_prob)
updatable = can_jump & self.is_masked(xt)
# Update the sequence
if guidance:
chosen_candidate = self.binding_guidance(probs, target_toks, B, L)
xt[updatable] = chosen_candidate[updatable]
else:
xt[updatable] = candidate_toks[updatable]
tracer.log_step(xt=xt, step_idx = k+1)
if k == num_steps-1:
final_logits = total_logits
still_masked = self.is_masked(xt)
if not converged and not self.is_masked(xt).any():
converge_idx = k + 1
converged = True
# Copy over remaining tokens
if still_masked.any():
final_toks = final_logits.argmax(dim=-1)
xt[still_masked] = final_toks[still_masked]
tracer.log_step(xt, num_steps + 1)
binding_affin = self.peptiverse.predict_binding_affinity(
mode = 'wt',
target_ids = target_toks,
binder_ids = xt
)['affinity']
return xt, binding_affin
def binding_guidance(self, probs, target_toks, B, L):
M = self.config.sampling.M
candidate_toks = []
affinities = []
for _ in range(M):
ith_sample = torch.multinomial(probs, 1).view(B, L)
candidate_toks.append(ith_sample)
for toks in candidate_toks:
pred = self.peptiverse.predict_binding_affinity(
mode = 'wt',
target_ids = target_toks,
binder_ids = toks.detach()
)['affinity']
affinities.append(pred)
affinities = torch.tensor(affinities, dtype=torch.float32)
weights = F.softmax(affinities / self.config.sampling.tau, dim=0)
chosen_idx = torch.multinomial(weights, 1).item()
return candidate_toks[chosen_idx]
def compute_action(self, u_tilt, esm_logits=None):
""" Computes the action functional for evals """
if esm_logits is not None:
R0 = torch.softmax(esm_logits, dim=-1)
else:
R0 = 1.0 / self.tokenizer.vocab_size
psi_u = torch.exp(u_tilt) - u_tilt - 1.0
action_per_tok = (R0 * psi_u).sum(dim=-1) # R0 goes to 1 in both cases
return action_per_tok.mean().item()
def top_p_filter(self, logits, p_val):
"""
Implementation of nucleus / top-p sampling
Masks out tokens that contribute to the bottom (1 - p) cumulative probability
"""
# Sort logits and get cumulative probabilities
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cum prob > p-val thresh
sorted_idx_to_remove = cum_probs > p_val
# Shift the indices to the right to keep also the first token above the threshold
sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone()
sorted_idx_to_remove[..., 0] = 0
idx_to_remove = sorted_idx_to_remove.scatter(-1, sorted_indices, sorted_idx_to_remove)
logits[idx_to_remove] = float('-inf')
return logits
def is_masked(self, xt):
return (xt == self.mask_id)
def seed_everything(self, seed):
if seed is None:
return
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if using multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False