| | import torch |
| | import math |
| | import sys |
| |
|
| | import torch.nn.functional as F |
| | import pandas as pd |
| | import numpy as np |
| |
|
| | from omegaconf import OmegaConf |
| | from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer |
| |
|
| | from src.lm.memdlm.diffusion_module import MembraneFlow |
| | from src.lm.dplm.diffusion_module import DPLM |
| | from src.utils.model_utils import get_latents, _print |
| | from src.sampling.unconditional_sampler import UnconditionalSampler |
| | from src.lm.dplm.unconditional_sampler import UnconditionalSampler as DPLMUnconditionalSampler |
| |
|
| | config = OmegaConf.load("/home/a03-sgoel/MeMDLM_v2/src/configs/lm.yaml") |
| |
|
| | |
| | def mask_for_de_novo(sequence_length): |
| | return "<mask>" * sequence_length |
| |
|
| | def mask_for_scaffold(sequence, generate_type, mask_token): |
| | if generate_type == "uppercase": |
| | sequence = ''.join([mask_token if residue.isupper() else residue.upper() for residue in sequence]) |
| | elif generate_type == "lowercase": |
| | sequence = ''.join([mask_token if residue.islower() else residue for residue in sequence]) |
| | return sequence |
| |
|
| |
|
| | |
| | def memflow_infill_uncond(masked_seq, tokenizer, model: MembraneFlow): |
| | generator = UnconditionalSampler(tokenizer, model) |
| | xt = tokenizer(masked_seq, return_tensors='pt')['input_ids'].to(model.device) |
| | denoised_tokens = generator.sample_unconditional(xt, config.sampling.n_steps)[0].squeeze() |
| | generated_sequence = tokenizer.decode(denoised_tokens).replace(" ", "")[5:-5] |
| | return generated_sequence |
| |
|
| |
|
| | def evodiff_infill(motif_seq, tokenizer, model, device, batch_size=1): |
| | """ |
| | Following the given evodiff example |
| | https://github.com/microsoft/evodiff/blob/main/examples/evodiff.ipynb |
| | """ |
| | |
| | motif_seq = ''.join(["#" if aa.islower() else aa for aa in motif_seq]) |
| | tkns = tokenizer.tokenize([motif_seq]) |
| | sample = torch.as_tensor(tkns).to(device) |
| |
|
| | |
| | loc = torch.arange(0, len(motif_seq)).to(device)[sample==tokenizer.mask_id].cpu().numpy() |
| | np.random.shuffle(loc) |
| | |
| | sample = sample.to(device).unsqueeze(0) |
| | |
| | |
| | with torch.no_grad(): |
| | for i in loc: |
| | timestep = torch.tensor([0] * batch_size).to(device) |
| | timestep = timestep.to(device) |
| | prediction = model(sample, timestep) |
| | p = prediction[:, i, :len(tokenizer.all_aas) - 6] |
| | p = F.softmax(p, dim=1) |
| | p_sample = torch.multinomial(p, num_samples=1) |
| | sample[:, i] = p_sample.squeeze() |
| | output = [tokenizer.untokenize(s) for s in sample] |
| | return output[0] |
| |
|
| |
|
| | def dplm_infill(masked_seq, tokenizer, model: DPLM, device): |
| | generator = DPLMUnconditionalSampler(tokenizer, model) |
| | xt = tokenizer(masked_seq, return_tensors='pt')['input_ids'].to(model.device) |
| | denoised_tokens = generator.sample_unconditional(xt, config.sampling.n_steps)[0].squeeze() |
| | generated_sequence = tokenizer.decode(denoised_tokens).replace(" ", "")[5:-5] |
| | return generated_sequence |
| |
|
| |
|
| | |
| | def calc_progen_ppl(model, tokenizer, target, device, fp16=True): |
| | """Compute causal LM cross-entropy loss for a given sequence.""" |
| | with torch.no_grad(): |
| | with torch.cuda.amp.autocast(enabled=fp16): |
| | logits = model( |
| | input_ids = target, |
| | attention_mask = torch.ones_like(target) |
| | ).logits |
| | |
| | logits = logits[:-1, ...] |
| | target = target[1:] |
| | loss = torch.nn.functional.cross_entropy( |
| | input=logits, |
| | target=target, |
| | reduction='mean' |
| | ) |
| | return torch.exp(loss).item() |
| |
|
| |
|
| | def calc_ppl(model, tokenizer, generated_sequence, mask_token_indices, model_type): |
| | total_loss = 0.0 |
| | tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device) |
| | attn_mask = torch.ones_like(tensor_input).to(model.device) |
| |
|
| | for i in mask_token_indices: |
| | masked_input = tensor_input.clone() |
| | masked_input[0, i] = tokenizer.mask_token_id |
| | |
| | labels = torch.full(tensor_input.shape, -100).to(model.device) |
| | labels[0, i] = tensor_input[0, i] |
| |
|
| | with torch.no_grad(): |
| | if model_type == 'esm': |
| | loss = model(masked_input, labels=labels).loss.item() |
| | elif model_type == 'flow': |
| | logits = model.forward(masked_input, attention_mask=attn_mask) |
| | loss = F.cross_entropy( |
| | logits.view(-1, logits.size(-1)), |
| | labels.view(-1), |
| | reduction='none', |
| | ignore_index=-100, |
| | )[i].item() |
| | |
| | total_loss += loss |
| | |
| | avg_loss = total_loss / len(generated_sequence) |
| | perplexity = math.exp(avg_loss) |
| |
|
| | return perplexity |
| |
|
| |
|
| | def calc_blosum_score(og_seq, gen_seq, indices): |
| | import blosum as bl |
| | mat = bl.BLOSUM(62) |
| | tot_score = 0 |
| | for i in indices: |
| | og_res, gen_res = og_seq[i], gen_seq[i] |
| | try: |
| | val = mat[og_res][gen_res] |
| | tot_score += val |
| | except KeyError: |
| | |
| | tot_score += -4 |
| | return tot_score / len(indices) if indices else 0 |
| |
|
| |
|
| | def calc_cos_sim(original_sequence, generated_sequence, tokenizer, esm_model, device): |
| | og_embeddings = get_latents(esm_model, tokenizer, original_sequence.upper(), device) |
| | new_embeddings = get_latents(esm_model, tokenizer, generated_sequence, device) |
| | cosine_sim = torch.nn.functional.cosine_similarity(og_embeddings, new_embeddings, dim=-1) |
| | cosine_sim = torch.mean(cosine_sim).item() |
| | return cosine_sim |