| """ |
| Spatial Context Networks (SCN) |
| Geometric Semantic Routing in Neural Architectures |
| |
| Author: Furkan Nar |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
|
|
| class GeometricActivation(nn.Module): |
| """ |
| Geometric activation function based on normalized Euclidean distance. |
| |
| Each neuron acts as a point-mass with a learnable centroid in d-dimensional space. |
| Activation is inversely proportional to the normalized distance from the centroid: |
| |
| f(v) = 1 / (||v - mu||_2 / sqrt(d) + epsilon) |
| |
| Args: |
| n_neurons (int): Number of neurons (centroids) in this layer. |
| dim (int): Dimensionality of the input semantic space. |
| stability_factor (float): SF in the paper; epsilon = 1/SF. Default: 10.0 |
| """ |
|
|
| def __init__(self, n_neurons: int, dim: int, stability_factor: float = 10.0): |
| super().__init__() |
| self.n_neurons = n_neurons |
| self.dim = dim |
| self.epsilon = 1.0 / stability_factor |
|
|
| |
| self.centroids = nn.Parameter(torch.randn(n_neurons, dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: Input tensor of shape (batch_size, dim) |
| Returns: |
| activations: Tensor of shape (batch_size, n_neurons) |
| """ |
| |
| |
| diff = x.unsqueeze(1) - self.centroids.unsqueeze(0) |
| dist = torch.norm(diff, dim=-1) |
| normalized_dist = dist / math.sqrt(self.dim) |
| activations = 1.0 / (normalized_dist + self.epsilon) |
| return activations |
|
|
|
|
| class SemanticRoutingLayer(nn.Module): |
| """ |
| Semantic routing layer that selectively activates neurons based on |
| geometric affinity to the input. |
| |
| Active set: S = { n_i | f_i(q) > tau } |
| Binary mask: M_ij = I[ f_j(v_i) > tau ] |
| |
| Args: |
| n_neurons (int): Number of neurons. |
| dim (int): Input dimensionality. |
| routing_threshold (float): Activation threshold tau. Default: 0.5 |
| stability_factor (float): Passed to GeometricActivation. Default: 10.0 |
| """ |
|
|
| def __init__( |
| self, |
| n_neurons: int, |
| dim: int, |
| routing_threshold: float = 0.5, |
| stability_factor: float = 10.0, |
| ): |
| super().__init__() |
| self.routing_threshold = routing_threshold |
| self.geo_activation = GeometricActivation(n_neurons, dim, stability_factor) |
|
|
| def forward(self, x: torch.Tensor): |
| """ |
| Args: |
| x: Input tensor of shape (batch_size, dim) |
| Returns: |
| activations: Raw activations, shape (batch_size, n_neurons) |
| mask: Binary routing mask, shape (batch_size, n_neurons) |
| """ |
| activations = self.geo_activation(x) |
| mask = (activations > self.routing_threshold).float() |
| return activations, mask |
|
|
|
|
| class ConnectionDensityLayer(nn.Module): |
| """ |
| Connection density weighting with adaptive scaling and explosion control. |
| |
| C = sum_{i in S} w_i / (alpha / z) |
| |
| where alpha = total neurons, z = |S| (active neurons). |
| When C > tau_exp, square-root damping is applied: C_stable = sqrt(C). |
| |
| Args: |
| n_neurons (int): Total number of neurons (alpha). |
| explosion_threshold (float): tau_exp. Default: 2.0 |
| """ |
|
|
| def __init__(self, n_neurons: int, explosion_threshold: float = 2.0): |
| super().__init__() |
| self.n_neurons = n_neurons |
| self.explosion_threshold = explosion_threshold |
|
|
| |
| self.connection_weights = nn.Parameter(torch.randn(n_neurons)) |
|
|
| def forward(self, activations: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| activations: Shape (batch_size, n_neurons) |
| mask: Binary mask, shape (batch_size, n_neurons) |
| Returns: |
| context: Scalar context score per sample, shape (batch_size, 1) |
| """ |
| z = mask.sum(dim=-1, keepdim=True).clamp(min=1.0) |
| alpha = float(self.n_neurons) |
|
|
| |
| weighted = activations * mask * self.connection_weights.unsqueeze(0) |
| context = weighted.sum(dim=-1, keepdim=True) / (alpha / z) |
|
|
| |
| context = torch.where( |
| context > self.explosion_threshold, |
| torch.sqrt(context.abs() + 1e-8) * context.sign(), |
| context, |
| ) |
| return context |
|
|
|
|
| class SpatialContextNetwork(nn.Module): |
| """ |
| Spatial Context Network (SCN). |
| |
| Full architecture: |
| 1. SemanticRoutingLayer — geometric activation + binary routing mask |
| 2. ConnectionDensityLayer — adaptive normalization + explosion control |
| 3. Linear projection — map context score to output space |
| 4. Pattern distribution — element-wise multiply by softmax(pattern_weights) |
| |
| Args: |
| input_dim (int): Dimensionality of input features. |
| n_neurons (int): Number of hidden geometric neurons. Default: 32 |
| output_dim (int): Number of output classes/dimensions. Default: 4 |
| routing_threshold (float): Routing threshold tau. Default: 0.5 |
| stability_factor (float): Controls epsilon = 1/SF. Default: 10.0 |
| explosion_threshold (float): Threshold for sqrt damping. Default: 2.0 |
| |
| Example:: |
| |
| model = SpatialContextNetwork(input_dim=10, n_neurons=32, output_dim=4) |
| x = torch.randn(8, 10) |
| output = model(x) # (8, 4) |
| """ |
|
|
| def __init__( |
| self, |
| input_dim: int = 10, |
| n_neurons: int = 32, |
| output_dim: int = 4, |
| routing_threshold: float = 0.5, |
| stability_factor: float = 10.0, |
| explosion_threshold: float = 2.0, |
| ): |
| super().__init__() |
| self.input_dim = input_dim |
| self.n_neurons = n_neurons |
| self.output_dim = output_dim |
|
|
| self.routing = SemanticRoutingLayer( |
| n_neurons, input_dim, routing_threshold, stability_factor |
| ) |
| self.density = ConnectionDensityLayer(n_neurons, explosion_threshold) |
| self.projection = nn.Linear(1, output_dim) |
|
|
| |
| self.pattern_weights = nn.Parameter(torch.zeros(output_dim)) |
|
|
| |
| |
| with torch.no_grad(): |
| prior = torch.tensor([0.38, 0.25, 0.22, 0.15]) |
| if output_dim == 4: |
| self.pattern_weights.copy_(torch.log(prior + 1e-8)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: Input tensor of shape (batch_size, input_dim) |
| Returns: |
| output: Tensor of shape (batch_size, output_dim) |
| """ |
| activations, mask = self.routing(x) |
| context = self.density(activations, mask) |
| hidden = self.projection(context) |
| output = hidden * F.softmax(self.pattern_weights, dim=-1) |
| return output |
|
|
| def get_network_stats(self, x: torch.Tensor) -> dict: |
| """ |
| Returns diagnostic statistics for a batch of inputs. |
| |
| Returns: |
| dict with keys: mean_active_neurons, network_efficiency, |
| mean_context_score, activations, mask |
| """ |
| with torch.no_grad(): |
| activations, mask = self.routing(x) |
| context = self.density(activations, mask) |
| active = mask.sum(dim=-1) |
| return { |
| "mean_active_neurons": active.mean().item(), |
| "network_efficiency": (active / self.n_neurons).mean().item(), |
| "mean_context_score": context.mean().item(), |
| "activations": activations, |
| "mask": mask, |
| } |