File size: 1,312 Bytes
94c2704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

import torch
import torch.nn.functional as F


class ProbabilityPathTracer:
    def __init__(self, oracle_model, tokenizer, device):
        self.oracle = oracle_model
        self.tokenizer = tokenizer
        self.device = device
        self.mask_id = tokenizer.mask_token_id
        self.history = {}  # {nth_step: prob_score}

    @torch.inference_mode()
    def compute_loglikeli(self, xt):
        is_revealed = (xt != self.mask_id)
        
        if not is_revealed.any():
            return 0.0 

        # esm forward pass
        logits = self.oracle(
            input_ids=xt,
            attention_mask=torch.ones_like(xt, device=xt.device)
        ).logits
        
        # Calculate CE loss only on unmasked tokens
        nll = F.cross_entropy(
            logits.view(-1, logits.size(-1)), 
            xt.view(-1), 
            reduction='none'
        )
        
        nll = nll.view(xt.shape)
        
        # Lower NLL = better --> higher LL = better
        avg_ll = -(nll * is_revealed.float()).sum(dim=1) / is_revealed.float().sum(dim=1).clamp(min=1)
        
        return avg_ll.item()

    def log_step(self, xt, step_idx):
        score = self.compute_loglikeli(xt)
        self.history[f"trace_step_{step_idx}"] = score

    def get_trace(self):
        return self.history