| | import gradio as gr |
| | import pandas as pd |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForMaskedLM |
| | import torch.nn.functional as F |
| | import logging |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | from io import BytesIO |
| | from PIL import Image |
| | from contextlib import contextmanager |
| | import warnings |
| | import sys |
| | import os |
| | import zipfile |
| |
|
| | logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| |
|
| | |
| | model_name = "ChatterjeeLab/FusOn-pLM" |
| | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| | model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True) |
| | model.to(device) |
| | model.eval() |
| |
|
| | @contextmanager |
| | def suppress_output(): |
| | with open(os.devnull, 'w') as devnull: |
| | old_stdout = sys.stdout |
| | sys.stdout = devnull |
| | try: |
| | yield |
| | finally: |
| | sys.stdout = old_stdout |
| |
|
| | def process_sequence(sequence, domain_bounds, n): |
| | AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C'] |
| | AAs_tokens_indices = {'L' : 4, 'A' : 5, 'G' : 6, 'V': 7, 'S' : 8, 'E' : 9, 'R' : 10, 'T' : 11, 'I': 12, 'D' : 13, 'P' : 14, |
| | 'K' : 15, 'Q' : 16, 'N' : 17, 'F' : 18, 'Y' : 19, 'M' : 20, 'H' : 21, 'W' : 22, 'C' : 23} |
| | |
| | if not sequence.strip(): |
| | raise gr.Error("Error: The sequence input is empty. Please enter a valid protein sequence.") |
| | return None, None, None |
| | if any(char not in AAs_tokens for char in sequence): |
| | raise gr.Error("Error: The sequence input contains non-amino acid characters. Please enter a valid protein sequence.") |
| | return None, None, None |
| |
|
| | |
| | try: |
| | start = int(domain_bounds['start'][0]) |
| | end = int(domain_bounds['end'][0]) |
| | except ValueError: |
| | raise gr.Error("Error: Start and end indices must be integers.") |
| | return None, None, None |
| | if start >= end: |
| | raise gr.Error("Start index must be smaller than end index.") |
| | return None, None, None |
| | if start == 0 and end != 0: |
| | raise gr.Error("Indexing starts at 1. Please enter valid domain bounds.") |
| | return None, None, None |
| | if start <= 0 or end <= 0: |
| | raise gr.Error("Domain bounds must be positive integers. Please enter valid domain bounds.") |
| | return None, None, None |
| | if start > len(sequence) or end > len(sequence): |
| | raise gr.Error("Domain bounds exceed sequence length.") |
| | return None, None, None |
| |
|
| | |
| | if n == None: |
| | raise gr.Error("Choose Top N Tokens from the dropdown menu.") |
| | return None, None, None |
| |
|
| | start_index = int(domain_bounds['start'][0]) - 1 |
| | end_index = int(domain_bounds['end'][0]) |
| |
|
| | top_n_mutations = {} |
| | all_logits = [] |
| |
|
| | |
| | originals_logits = [] |
| | conservation_likelihoods = {} |
| |
|
| | for i in range(len(sequence)): |
| | |
| | if start_index <= i <= (end_index - 1): |
| | original_residue = sequence[i] |
| | original_residue_index = AAs_tokens_indices[original_residue] |
| | masked_seq = sequence[:i] + '<mask>' + sequence[i+1:] |
| | inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000) |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| | with torch.no_grad(): |
| | logits = model(**inputs).logits |
| | mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
| | mask_token_logits = logits[0, mask_token_index, :] |
| |
|
| | |
| | all_tokens_logits = mask_token_logits.squeeze(0) |
| | top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True) |
| | top_tokens_logits = all_tokens_logits[top_tokens_indices] |
| | mutation = [] |
| | |
| | for token_index in top_tokens_indices: |
| | decoded_token = tokenizer.decode([token_index.item()]) |
| | |
| | if decoded_token in AAs_tokens: |
| | mutation.append(decoded_token) |
| | if len(mutation) == n: |
| | break |
| | top_n_mutations[(sequence[i], i)] = mutation |
| |
|
| | |
| | logits_array = mask_token_logits.cpu().numpy() |
| | |
| | filtered_indices = list(range(4, 23 + 1)) |
| | filtered_logits = logits_array[:, filtered_indices] |
| | all_logits.append(filtered_logits) |
| |
|
| | |
| | normalized_mask_token_logits = F.softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).numpy() |
| | normalized_mask_token_logits = np.squeeze(normalized_mask_token_logits) |
| | originals_logit = normalized_mask_token_logits[original_residue_index] |
| | originals_logits.append(originals_logit) |
| |
|
| | if originals_logit > 0.7: |
| | conservation_likelihoods[(original_residue, i)] = 1 |
| | else: |
| | conservation_likelihoods[(original_residue, i)] = 0 |
| |
|
| |
|
| |
|
| | |
| | domain_len = end - start |
| | if 500 > domain_len > 100: |
| | step_size = 50 |
| | elif 500 <= domain_len: |
| | step_size = 100 |
| | elif domain_len < 10: |
| | step_size = 1 |
| | else: |
| | step_size = 10 |
| | x_tick_positions = np.arange(start_index, end_index, step_size) |
| | x_tick_labels = [str(pos + 1) for pos in x_tick_positions] |
| |
|
| | all_logits_array = np.vstack(originals_logits) |
| | transposed_logits_array = all_logits_array.T |
| | conservation_likelihoods_array = np.array(list(conservation_likelihoods.values())).reshape(1, -1) |
| | |
| | combined_array = np.vstack((transposed_logits_array, conservation_likelihoods_array)) |
| |
|
| | plt.figure(figsize=(15, 5)) |
| | plt.rcParams.update({'font.size': 16.5}) |
| | sns.heatmap(combined_array, cmap='viridis', xticklabels=x_tick_labels, yticklabels=['Residue \nLogits', 'Residue \nConservation'], cbar=True) |
| | plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0) |
| | plt.title('Original Residue Probability and Conservation') |
| | plt.xlabel('Residue Index') |
| | plt.show() |
| | buf = BytesIO() |
| | plt.savefig(buf, format='png', dpi=300) |
| | buf.seek(0) |
| | plt.close() |
| | img_2 = Image.open(buf) |
| |
|
| |
|
| | |
| | token_indices = torch.arange(logits.size(-1)) |
| | tokens = [tokenizer.decode([idx]) for idx in token_indices] |
| | filtered_tokens = [tokens[i] for i in filtered_indices] |
| | all_logits_array = np.vstack(all_logits) |
| | normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy() |
| | transposed_logits_array = normalized_logits_array.T |
| |
|
| |
|
| | plt.figure(figsize=(15, 8)) |
| | plt.rcParams.update({'font.size': 16.5}) |
| | sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens) |
| | plt.title('Token Probability') |
| | plt.ylabel('Amino Acid') |
| | plt.xlabel('Residue Index') |
| | plt.yticks(rotation=0) |
| | plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0) |
| |
|
| | buf = BytesIO() |
| | plt.savefig(buf, format='png', dpi = 300) |
| | buf.seek(0) |
| | plt.close() |
| |
|
| | img_1 = Image.open(buf) |
| |
|
| | |
| | original_residues = [] |
| | mutations = [] |
| | positions = [] |
| |
|
| | for key, value in top_n_mutations.items(): |
| | original_residue, position = key |
| | original_residues.append(original_residue) |
| | mutations.append(value) |
| | positions.append(position + 1) |
| |
|
| | df = pd.DataFrame({ |
| | 'Original Residue': original_residues, |
| | 'Predicted Residues': mutations, |
| | 'Position': positions |
| | }) |
| | df.to_csv("predicted_tokens.csv", index=False) |
| | img_1.save("heatmap.png", dpi=(300, 300)) |
| | img_2.save("heatmap_2.png", dpi=(300, 300)) |
| | zip_path = "outputs.zip" |
| | with zipfile.ZipFile(zip_path, 'w') as zipf: |
| | zipf.write("predicted_tokens.csv") |
| | zipf.write("heatmap.png") |
| | zipf.write("heatmap_2.png") |
| |
|
| | return df, img_1, img_2, zip_path |
| |
|
| | |
| | demo = gr.Interface( |
| | fn=process_sequence, |
| | inputs=[ |
| | gr.Textbox(label="Sequence", placeholder="Enter the protein sequence here"), |
| | gr.Dataframe( |
| | value = [[1, 1]], |
| | headers=["start", "end"], |
| | datatype=["number", "number"], |
| | row_count=(1, "fixed"), |
| | col_count=(2, "fixed"), |
| | label="Domain Bounds" |
| | ), |
| | gr.Dropdown([i for i in range(1, 21)], label="Top N Tokens"), |
| | ], |
| | outputs=[ |
| | gr.Dataframe(label="Predicted Tokens (in order of decreasing likelihood)"), |
| | gr.Image(type="pil", label="Probability Distribution for All Tokens"), |
| | gr.Image(type="pil", label="Residue Conservation"), |
| | gr.File(label="Download Outputs"), |
| | ], |
| | ) |
| | if __name__ == "__main__": |
| | with suppress_output(): |
| | demo.launch() |