FunctionGemma 270M Sparse Autoencoders

Sparse Autoencoders (SAEs) trained on all 18 layers of google/functiongemma-270m-it.

Architecture

  • Base Model: google/functiongemma-270m-it
  • Layers: 18 (decoder-only)
  • Hidden Size: 640
  • SAE Dimension: 4096 (6.4x expansion)
  • Hook Point: self_attn.o_proj (output projection of self-attention)

Training

  • Epochs: 5 per layer
  • Batch Size: 1
  • Learning Rate: 1e-4
  • Optimizer: AdamW
  • Loss: MSE + 0.01 * L1 regularization
  • Activation Clipping: [-10, 10]
  • Gradient Clipping: max_norm=1.0

Checkpoints

Each checkpoint contains:

{
    "model_name": "google/functiongemma-270m-it",
    "layer_idx": int,
    "d_in": 640,
    "d_sae": 4096,
    "W_enc": torch.Tensor,  # (640, 4096)
    "b_enc": torch.Tensor,  # (4096,)
    "W_dec": torch.Tensor,  # (4096, 640)
    "b_dec": torch.Tensor,  # (640,)
    "history": {
        "loss": [...],
        "mse": [...],
        "l0": [...]
    }
}

Usage

import torch
from huggingface_hub import hf_hub_download

# Load SAE for a specific layer
layer_idx = 0
ckpt_path = hf_hub_download(
    "mindchain/functiongemma-270m-sae",
    f"sae_layer_{layer_idx:02d}.pt"
)
sae = torch.load(ckpt_path, map_location="cpu")

# Use SAE
class JumpReLUSAE(torch.nn.Module):
    def __init__(self, W_enc, b_enc, W_dec, b_dec):
        super().__init__()
        self.W_enc = torch.nn.Parameter(W_enc)
        self.b_enc = torch.nn.Parameter(b_enc)
        self.W_dec = torch.nn.Parameter(W_dec)
        self.b_dec = torch.nn.Parameter(b_dec)

    def forward(self, x):
        batch, seq, d_in = x.shape
        x_flat = x.view(-1, d_in)
        pre_act = x_flat @ self.W_enc + self.b_enc
        features = torch.relu(pre_act)
        recon = features @ self.W_dec + self.b_dec
        return recon.view(batch, seq, d_in), features.view(batch, seq, -1)

sae_model = JumpReLUSAE(
    sae["W_enc"], sae["b_enc"],
    sae["W_dec"], sae["b_dec"]
)

# Get activations from FunctionGemma and encode
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "google/functiongemma-270m-it",
    torch_dtype=torch.bfloat16,
    device_map="cuda"
)
tokenizer = AutoTokenizer.from_pretrained("google/functiongemma-270m-it")

inputs = tokenizer("What's the weather?", return_tensors="pt").to(model.device)

# Hook to get activations
acts = []
def hook(module, inp, out):
    acts.append(out[0].detach().float())
handle = model.model.layers[layer_idx].self_attn.o_proj.register_forward_hook(hook)
with torch.no_grad():
    _ = model(**inputs)
handle.remove()

# Run through SAE
recon, features = sae_model(acts[0])
print(f"Active features: {(features > 0).sum().item()}")

Training Results

Layer Final Loss Final MSE L0
0 3.4457 3.1244 1225
1 2.0052 1.9042 1386
2 0.1182 0.0759 1546
3 0.1182 0.0758 3096
4 0.0361 0.0170 1635
5 0.0414 0.0351 399
6 0.0318 0.0138 1807
7 0.0877 0.0661 1120
8 0.0733 0.0445 1379
9 0.0561 0.0317 1569
10 0.0997 0.0852 591
11 0.0252 0.0097 3658
12 0.0565 0.0395 962
13 0.0924 0.0619 1403
14 0.2711 0.2504 709
15 0.1501 0.1062 1576
16 0.1670 0.1426 870
17 0.0385 0.0218 1470

License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Collection including mindchain/functiongemma-270m-sae