File size: 4,254 Bytes
94c2704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

import sys
import torch
import math

import numpy as np
import torch.nn.functional as F

from collections import Counter
from omegaconf import OmegaConf


config = OmegaConf.load("/scratch/pranamlab/sgoel/MeMDLM_v2/src/configs/lm.yaml")


# -------# Masking #-------- #
def mask_for_de_novo(sequence_length):
    return "<mask>" * sequence_length

def mask_for_scaffold(sequence, generate_type, mask_token):
    if generate_type == "uppercase":
        sequence = ''.join([mask_token if residue.isupper() else residue.upper() for residue in sequence])
    elif generate_type == "lowercase":
        sequence = ''.join([mask_token if residue.islower() else residue for residue in sequence])   
    return sequence


# -------# Generation #-------- #
def evodiff_infill(motif_seq, tokenizer, model, device, batch_size=1):
    """
    Following the given evodiff example
    https://github.com/microsoft/evodiff/blob/main/examples/evodiff.ipynb
    """    
    # Manual masking of infilling sequence
    motif_seq = ''.join(["#" if aa.islower() else aa for aa in motif_seq])  # Mask token is "#" in evodiff tokenizer
    tkns = tokenizer.tokenize([motif_seq])
    sample = torch.as_tensor(tkns).to(device)

    # Create input motif + scaffold
    loc = torch.arange(0, len(motif_seq)).to(device)[sample==tokenizer.mask_id].cpu().numpy()
    np.random.shuffle(loc)
    
    sample = sample.to(device).unsqueeze(0)
    # og_sample = sample.clone()
    
    with torch.no_grad():
        for i in loc:
            timestep = torch.tensor([0] * batch_size).to(device)  # placeholder but not called in model
            timestep = timestep.to(device)
            prediction = model(sample, timestep)
            p = prediction[:, i, :len(tokenizer.all_aas) - 6]  # only canonical
            p = F.softmax(p, dim=1)  # softmax over logits
            p_sample = torch.multinomial(p, num_samples=1) # sample from categorical distribution
            sample[:, i] = p_sample.squeeze()
    output = [tokenizer.untokenize(s) for s in sample]
    return output[0] #if batch_size==1 else output, og_sample, loc


def dplm_infill(masked_seq, tokenizer, model, device):
    from src.lm.dplm.diffusion_module import DPLM
    from src.lm.dplm.unconditional_sampler import UnconditionalSampler as DPLMUnconditionalSampler
    
    generator = DPLMUnconditionalSampler(tokenizer, model)
    xt = tokenizer(masked_seq, return_tensors='pt')['input_ids'].to(model.device)
    denoised_tokens = generator.sample_unconditional(xt, config.sampling.n_steps)[0].squeeze()
    generated_sequence = tokenizer.decode(denoised_tokens).replace(" ", "")[5:-5]
    return generated_sequence


# -------# Metrics #-------- #
def calc_progen_ppl(model, tokenizer, target, device, fp16=True):
    """Compute causal LM cross-entropy loss for a given sequence."""
    with torch.no_grad():
        with torch.cuda.amp.autocast(enabled=fp16):
            logits = model(
                input_ids = target,
                attention_mask = torch.ones_like(target)
            ).logits
            # Shift
            logits = logits[:-1, ...]
            target = target[1:]
            loss = torch.nn.functional.cross_entropy(
                input=logits,
                target=target,
                reduction='mean'
            )
            return torch.exp(loss).item()


def calc_ppl(model, tokenizer, generated_sequence, mask_token_indices, model_type):
    total_loss = 0.0
    tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)

    for i in mask_token_indices:
        masked_input = tensor_input.clone()
        masked_input[0, i] = tokenizer.mask_token_id
    
        labels = torch.full(tensor_input.shape, -100).to(model.device)
        labels[0, i] = tensor_input[0, i]

        with torch.no_grad():
            loss = model(masked_input, labels=labels).loss.item()
            total_loss += loss
    
    avg_loss = total_loss / len(generated_sequence)
    perplexity = math.exp(avg_loss)

    return perplexity


def calc_entropy(seq):
    counts = Counter(seq)
    total_len = len(seq)
    entropy = 0.0
    for count in counts.values():
        prob = count / total_len
        entropy -= prob * math.log2(prob)
    return entropy