""" Cognitive Proxy - Brain-Steered Language Model Hugging Face Spaces deployment Author: Sandro Andric """ import gradio as gr import torch import torch.nn as nn import numpy as np import pickle import os from pathlib import Path from sklearn.decomposition import PCA from transformers import AutoTokenizer, AutoModelForCausalLM import plotly.graph_objects as go import plotly.express as px import spaces # For ZeroGPU on Hugging Face # --- CONFIG --- import os from pathlib import Path # Get the directory of this script SCRIPT_DIR = Path(__file__).parent if __file__ else Path.cwd() # Try multiple possible locations for the model files if (SCRIPT_DIR / "results" / "final_atlas_256_vocab.pkl").exists(): ATLAS_PATH = str(SCRIPT_DIR / "results" / "final_atlas_256_vocab.pkl") ADAPTER_PATH = str(SCRIPT_DIR / "results" / "tinyllama_adapter_direct.pt") elif (SCRIPT_DIR / "final_atlas_256_vocab.pkl").exists(): ATLAS_PATH = str(SCRIPT_DIR / "final_atlas_256_vocab.pkl") ADAPTER_PATH = str(SCRIPT_DIR / "tinyllama_adapter_direct.pt") else: # Fallback to expected location ATLAS_PATH = "results/final_atlas_256_vocab.pkl" ADAPTER_PATH = "results/tinyllama_adapter_direct.pt" print(f"Atlas path: {ATLAS_PATH}") print(f"Adapter path: {ADAPTER_PATH}") MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # --- ADAPTER CLASS --- class TinyLlamaAdapterDirect(nn.Module): def __init__(self, input_dim=2048, hidden_dim=1024, output_dim=65536): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(0.1), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(0.1), nn.Linear(hidden_dim, hidden_dim // 2), nn.LayerNorm(hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, output_dim), ) def forward(self, x): return self.net(x) # Global system cache system = None def load_system(): global system if system is not None: return system device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) tokenizer.pad_token = tokenizer.eos_token # Use float32 for CPU, float16 for GPU dtype = torch.float16 if torch.cuda.is_available() else torch.float32 try: # Try new parameter name first model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=dtype).to(device) except TypeError: # Fall back to old parameter name model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=dtype).to(device) model.eval() adapter = TinyLlamaAdapterDirect().to(device).to(dtype) if os.path.exists(ADAPTER_PATH): adapter.load_state_dict(torch.load(ADAPTER_PATH, map_location=device, weights_only=True)) adapter.eval() if os.path.exists(ATLAS_PATH): print(f"Loading atlas from {ATLAS_PATH}") with open(ATLAS_PATH, 'rb') as f: data = pickle.load(f) if isinstance(data, dict): print(f"Atlas data keys: {list(data.keys())[:5]}") if 'means' in data: atlas = data['means'] print(f"Using 'means' key, got {len(atlas) if isinstance(atlas, dict) else 'not a dict'} items") else: atlas = data print(f"Using data directly, got {len(atlas) if isinstance(atlas, dict) else 'not a dict'} items") else: atlas = data print(f"Atlas is not a dict, type: {type(data)}") else: print(f"Atlas file not found at {ATLAS_PATH}") atlas = {} # Ensure atlas is valid if not atlas or not isinstance(atlas, dict): print(f"Warning: Atlas is empty or invalid, using fallback") atlas = {'word1': np.random.randn(256, 256), 'word2': np.random.randn(256, 256)} words = list(atlas.keys()) print(f"Loaded atlas with {len(words)} words") if len(words) < 2: print(f"Warning: Not enough words in atlas ({len(words)}), using fallback") atlas = {'word1': np.random.randn(256, 256), 'word2': np.random.randn(256, 256)} words = list(atlas.keys()) # Handle both 256x256 and flat arrays first_val = np.array(atlas[words[0]]) if first_val.shape == (256, 256): plv_matrix = np.array([np.array(atlas[w]).flatten() for w in words]) else: plv_matrix = np.array([np.array(atlas[w]) for w in words]) # Ensure matrix is 2D if len(plv_matrix.shape) == 1 or plv_matrix.shape[0] < 2: print(f"Warning: Invalid PLV matrix shape {plv_matrix.shape}, using fallback") plv_matrix = np.random.randn(10, 65536) pca = PCA(n_components=min(10, plv_matrix.shape[0] - 1)) pca.fit(plv_matrix) pc1_axis = pca.components_[0] pc1_axis = pc1_axis / np.linalg.norm(pc1_axis) global_mean = plv_matrix.mean(axis=0) system = { 'model': model, 'tokenizer': tokenizer, 'adapter': adapter, 'axis': torch.tensor(pc1_axis, dtype=torch.float32).to(device), 'global_mean': torch.tensor(global_mean, dtype=torch.float32).to(device), 'device': device } return system @spaces.GPU(duration=60) def generate_variants(prompt, scenario, max_tokens): """Generate all three variants""" sys = load_system() if scenario == "Educational": prompt_formatted = f"<|user|>\n{prompt}\n<|assistant|>\n" alpha_strength = 5.0 elif scenario == "Technical writing": prompt_formatted = f"<|user|>\n{prompt}\n<|assistant|>\n" alpha_strength = 5.0 else: prompt_formatted = prompt alpha_strength = 3.0 outputs = [] for alpha in [-alpha_strength, 0, alpha_strength]: inputs = sys['tokenizer'](prompt_formatted, return_tensors='pt').to(sys['device']) generated_ids = inputs.input_ids.clone() for _ in range(max_tokens): outputs_model = sys['model'](generated_ids, output_hidden_states=True) hidden = outputs_model.hidden_states[-1][:, -1, :] # Ensure proper dtype for adapter adapter_dtype = next(sys['adapter'].parameters()).dtype hidden = hidden.to(adapter_dtype) if alpha != 0: hidden = hidden.detach().requires_grad_(True) plv_pred = sys['adapter'](hidden) score = torch.sum(plv_pred * sys['axis'].to(adapter_dtype)) grad = torch.autograd.grad(score, hidden, retain_graph=False)[0] grad = grad / (grad.norm() + 1e-8) hidden = hidden.detach() + alpha * grad.detach() with torch.no_grad(): logits = sys['model'].lm_head(sys['model'].model.norm(hidden)) probs = torch.softmax(logits / 0.8, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_ids = torch.cat([generated_ids, next_token], dim=-1) if next_token.item() == sys['tokenizer'].eos_token_id: break text = sys['tokenizer'].decode(generated_ids[0], skip_special_tokens=True) if "<|assistant|>" in text: text = text.split("<|assistant|>")[-1].strip() outputs.append(text) return outputs[0], outputs[1], outputs[2] @spaces.GPU(duration=30) def analyze_text(text): """Analyze text and return score with visualization""" sys = load_system() with torch.no_grad(): inputs = sys['tokenizer'](text, return_tensors='pt').to(sys['device']) out = sys['model'](**inputs, output_hidden_states=True) last_hidden = out.hidden_states[-1][0, -1, :] # Ensure proper dtype for adapter adapter_dtype = next(sys['adapter'].parameters()).dtype last_hidden = last_hidden.to(adapter_dtype) plv_pred = sys['adapter'](last_hidden.unsqueeze(0)) plv_flat = plv_pred[0] plv_centered = plv_flat - sys['global_mean'].to(adapter_dtype) score = (plv_centered * sys['axis'].to(adapter_dtype)).sum().item() # Create minimal gauge like Streamlit gauge_min = min(-300, score - 50) gauge_max = max(300, score + 50) fig = go.Figure(go.Indicator( mode="number+gauge", value=score, gauge={ 'shape': "angular", 'axis': {'range': [gauge_min, gauge_max], 'tickwidth': 0.5, 'tickcolor': '#ccc'}, 'bar': {'color': "#333", 'thickness': 0.15}, 'bgcolor': "white", 'borderwidth': 1, 'bordercolor': "#e0e0e0", 'steps': [ {'range': [gauge_min, -5], 'color': "#e8f5e9"}, {'range': [-5, 5], 'color': "#fafafa"}, {'range': [5, gauge_max], 'color': "#fff3e0"} ], }, number={'font': {'size': 36, 'color': '#000'}} )) fig.update_layout( height=300, width=400, margin={'l': 30, 'r': 30, 't': 50, 'b': 30}, paper_bgcolor='white', font={'color': '#666'} ) if score > 5: interpretation = "**Syntactic dominance** \nText patterns match brain activity during grammatical processing" elif score < -5: interpretation = "**Semantic dominance** \nText patterns match brain activity during meaning comprehension" else: interpretation = "**Balanced** \nMixed patterns - both structure and meaning equally present" # Create PLV matrix heatmap (reshape to 256x256) plv_np = plv_pred[0].cpu().numpy() plv_matrix = plv_np[:65536].reshape(256, 256) fig_plv = px.imshow( plv_matrix, color_continuous_scale='Viridis', aspect='auto' ) fig_plv.update_layout( coloraxis_showscale=True, coloraxis=dict( colorbar=dict( thickness=10, len=0.7, title=dict(text="Synchrony", side="right"), tickfont=dict(size=10) ) ), margin={'l': 0, 'r': 40, 't': 10, 'b': 0}, height=300 ) fig_plv.update_xaxes(visible=False) fig_plv.update_yaxes(visible=False) return fig, interpretation, score, fig_plv @spaces.GPU(duration=60) def generate_steered(prompt, alpha, max_tokens): """Generate with custom steering""" sys = load_system() inputs = sys['tokenizer'](prompt, return_tensors='pt').to(sys['device']) generated_ids = inputs.input_ids.clone() for _ in range(max_tokens): outputs_model = sys['model'](generated_ids, output_hidden_states=True) hidden = outputs_model.hidden_states[-1][:, -1, :] # Ensure proper dtype for adapter adapter_dtype = next(sys['adapter'].parameters()).dtype hidden = hidden.to(adapter_dtype) if alpha != 0: hidden = hidden.detach().requires_grad_(True) plv_pred = sys['adapter'](hidden) score = torch.sum(plv_pred * sys['axis'].to(adapter_dtype)) grad = torch.autograd.grad(score, hidden, retain_graph=False)[0] grad = grad / (grad.norm() + 1e-8) hidden = hidden.detach() + alpha * grad.detach() with torch.no_grad(): logits = sys['model'].lm_head(sys['model'].model.norm(hidden)) probs = torch.softmax(logits / 0.8, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_ids = torch.cat([generated_ids, next_token], dim=-1) if next_token.item() == sys['tokenizer'].eos_token_id: break return sys['tokenizer'].decode(generated_ids[0], skip_special_tokens=True) # Custom CSS to match Streamlit minimal design custom_css = """ """ # Create interface DEFAULT_PROMPTS = { "Technical writing": "Draft a short SMS to the customer informing them their payment has failed.", "Educational": "Explain in 2 sentences what the butterfly effect is.", "Free form": "Brainstorm creative uses of brain-steered language models in five bullet points." } SCENARIO_AXIS_TEXT = { "Technical writing": { "left_label": "Semantic / Content (meaning-heavy, concrete) [empathetic/actionable tone]", "baseline_label": "Baseline", "right_label": "Syntactic / Function (structure-heavy, abstract) [formal/policy tone]", "left_caption": "*Steered toward meaning (brain semantic side)*", "baseline_caption": "*No brain steering*", "right_caption": "*Steered toward structure (brain syntactic side)*", }, "Educational": { "left_label": "Semantic / Content (meaning-heavy, concrete) [analogy/concrete style]", "baseline_label": "Baseline", "right_label": "Syntactic / Function (structure-heavy, abstract) [definition/logical style]", "left_caption": "*Steered toward meaning (brain semantic side)*", "baseline_caption": "*No brain steering*", "right_caption": "*Steered toward structure (brain syntactic side)*", }, "Free form": { "left_label": "Semantic / Content (meaning-heavy, concrete)", "baseline_label": "Baseline", "right_label": "Syntactic / Function (structure-heavy, abstract)", "left_caption": "*Steered toward meaning (brain semantic side)*", "baseline_caption": "*No brain steering*", "right_caption": "*Steered toward structure (brain syntactic side)*", }, } with gr.Blocks( title="Cognitive Proxy", theme=gr.themes.Base( primary_hue="gray", neutral_hue="gray", text_size="md", spacing_size="lg", radius_size="none", ), css=custom_css ) as demo: # Header gr.HTML("""