|
|
import random |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
from src.PeptiVerse.inference import PeptiVersePredictor |
|
|
from src.utils.model_utils import _print |
|
|
|
|
|
|
|
|
class MadSBMSampler: |
|
|
def __init__(self, model, config, device, guidance=None): |
|
|
self.config = config |
|
|
self.device = device |
|
|
self.model = model |
|
|
self.tokenizer = model.tokenizer |
|
|
self.mask_id = self.tokenizer.mask_token_id |
|
|
self.eps = config.time_embed.min_time |
|
|
self.seed_everything(seed=42) |
|
|
|
|
|
if guidance: |
|
|
self.guidance = guidance |
|
|
self.peptiverse = PeptiVersePredictor( |
|
|
manifest_path="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse/best_models.txt", |
|
|
classifier_weight_root="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse", |
|
|
device=self.device |
|
|
) |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def sample(self, xt, num_steps, tracer, target_toks=None, guidance=None): |
|
|
xt = xt.clone() |
|
|
B, L = xt.shape |
|
|
assert B == 1, "Do only 1 sequence at a time" |
|
|
|
|
|
t_max = 1.0 - self.eps |
|
|
dt = 1.0 / num_steps |
|
|
attn_mask = torch.ones_like(xt, device=self.device) |
|
|
|
|
|
action_traj = {} |
|
|
tot_action = 0.0 |
|
|
|
|
|
tracer.log_step(xt=xt, step_idx=0) |
|
|
|
|
|
converge_idx = num_steps |
|
|
converged = False |
|
|
|
|
|
for k in range(num_steps): |
|
|
|
|
|
prog = (k + 1) / float(num_steps) |
|
|
t_val = t_max - (t_max - self.eps) * prog |
|
|
t = torch.full((B,), fill_value=float(t_val), device=self.device) |
|
|
|
|
|
|
|
|
outs = self.model(input_ids=xt, attention_mask=attn_mask, t=t) |
|
|
|
|
|
u_tilt = outs['dit'] |
|
|
total_logits = outs['madsbm'] |
|
|
esm_logits = outs['esm'] |
|
|
|
|
|
if self.config.model.ablate: |
|
|
actional = self.compute_action(u_tilt, esm_logits=None) |
|
|
else: |
|
|
actional = self.compute_action(u_tilt, esm_logits=esm_logits) |
|
|
|
|
|
action_traj[f"action_step_{k+1}"] = actional |
|
|
tot_action += (actional * dt) |
|
|
|
|
|
|
|
|
|
|
|
r_theta = torch.exp(u_tilt * self.config.sampling.rate_scale) |
|
|
R_tot = r_theta.sum(dim=-1) |
|
|
rate = (- R_tot * self.config.sampling.jump_scale * dt).clamp(min=-40.0, max=0.0) |
|
|
jump_prob = 1.0 - torch.exp(rate) |
|
|
|
|
|
|
|
|
logits = total_logits.clone() |
|
|
logits /= self.config.sampling.tau |
|
|
logits = self.top_p_filter(logits, self.config.sampling.top_p) |
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
probs = probs.view(-1, probs.size(-1)) |
|
|
sample = torch.multinomial(probs, 1) |
|
|
candidate_toks = sample.view(B, L) |
|
|
|
|
|
|
|
|
rand = torch.rand(B, L, device=self.device) |
|
|
can_jump = (rand < jump_prob) |
|
|
updatable = can_jump & self.is_masked(xt) |
|
|
|
|
|
|
|
|
if guidance: |
|
|
chosen_candidate = self.binding_guidance(probs, target_toks, B, L) |
|
|
xt[updatable] = chosen_candidate[updatable] |
|
|
else: |
|
|
xt[updatable] = candidate_toks[updatable] |
|
|
|
|
|
tracer.log_step(xt=xt, step_idx = k+1) |
|
|
|
|
|
if k == num_steps-1: |
|
|
final_logits = total_logits |
|
|
still_masked = self.is_masked(xt) |
|
|
|
|
|
if not converged and not self.is_masked(xt).any(): |
|
|
converge_idx = k + 1 |
|
|
converged = True |
|
|
|
|
|
|
|
|
if still_masked.any(): |
|
|
final_toks = final_logits.argmax(dim=-1) |
|
|
xt[still_masked] = final_toks[still_masked] |
|
|
|
|
|
tracer.log_step(xt, num_steps + 1) |
|
|
|
|
|
binding_affin = self.peptiverse.predict_binding_affinity( |
|
|
mode = 'wt', |
|
|
target_ids = target_toks, |
|
|
binder_ids = xt |
|
|
)['affinity'] |
|
|
|
|
|
return xt, binding_affin |
|
|
|
|
|
|
|
|
def binding_guidance(self, probs, target_toks, B, L): |
|
|
M = self.config.sampling.M |
|
|
candidate_toks = [] |
|
|
affinities = [] |
|
|
|
|
|
for _ in range(M): |
|
|
ith_sample = torch.multinomial(probs, 1).view(B, L) |
|
|
candidate_toks.append(ith_sample) |
|
|
|
|
|
for toks in candidate_toks: |
|
|
pred = self.peptiverse.predict_binding_affinity( |
|
|
mode = 'wt', |
|
|
target_ids = target_toks, |
|
|
binder_ids = toks.detach() |
|
|
)['affinity'] |
|
|
affinities.append(pred) |
|
|
|
|
|
affinities = torch.tensor(affinities, dtype=torch.float32) |
|
|
weights = F.softmax(affinities / self.config.sampling.tau, dim=0) |
|
|
chosen_idx = torch.multinomial(weights, 1).item() |
|
|
|
|
|
return candidate_toks[chosen_idx] |
|
|
|
|
|
|
|
|
def compute_action(self, u_tilt, esm_logits=None): |
|
|
""" Computes the action functional for evals """ |
|
|
if esm_logits is not None: |
|
|
R0 = torch.softmax(esm_logits, dim=-1) |
|
|
else: |
|
|
R0 = 1.0 / self.tokenizer.vocab_size |
|
|
|
|
|
psi_u = torch.exp(u_tilt) - u_tilt - 1.0 |
|
|
action_per_tok = (R0 * psi_u).sum(dim=-1) |
|
|
|
|
|
return action_per_tok.mean().item() |
|
|
|
|
|
|
|
|
def top_p_filter(self, logits, p_val): |
|
|
""" |
|
|
Implementation of nucleus / top-p sampling |
|
|
Masks out tokens that contribute to the bottom (1 - p) cumulative probability |
|
|
""" |
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
|
|
|
sorted_idx_to_remove = cum_probs > p_val |
|
|
|
|
|
|
|
|
sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone() |
|
|
sorted_idx_to_remove[..., 0] = 0 |
|
|
|
|
|
idx_to_remove = sorted_idx_to_remove.scatter(-1, sorted_indices, sorted_idx_to_remove) |
|
|
logits[idx_to_remove] = float('-inf') |
|
|
return logits |
|
|
|
|
|
|
|
|
def is_masked(self, xt): |
|
|
return (xt == self.mask_id) |
|
|
|
|
|
|
|
|
def seed_everything(self, seed): |
|
|
if seed is None: |
|
|
return |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|