Update modeling_alinlight.py
Browse files- modeling_alinlight.py +25 -3
modeling_alinlight.py
CHANGED
|
@@ -72,6 +72,23 @@ class AlinlightRMSNorm(nn.Module):
|
|
| 72 |
return self.weight * x.to(input_dtype)
|
| 73 |
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
class AlinlightRotaryEmbedding(nn.Module):
|
| 76 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 77 |
super().__init__()
|
|
@@ -134,7 +151,11 @@ class AlinlightMLP(nn.Module):
|
|
| 134 |
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 135 |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 136 |
self.act_fn = nn.SiLU()
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
# Tag for specialized initialization
|
| 140 |
self.down_proj._is_residual_projection = True
|
|
@@ -171,8 +192,9 @@ class AlinlightAttention(nn.Module):
|
|
| 171 |
|
| 172 |
self.use_qk_norm = getattr(config, "use_qk_norm", True)
|
| 173 |
if self.use_qk_norm:
|
| 174 |
-
|
| 175 |
-
self.
|
|
|
|
| 176 |
|
| 177 |
self.attn_logit_softcapping = getattr(config, 'attn_logit_softcapping', None)
|
| 178 |
|
|
|
|
| 72 |
return self.weight * x.to(input_dtype)
|
| 73 |
|
| 74 |
|
| 75 |
+
class GatedNorm(nn.Module):
|
| 76 |
+
"""
|
| 77 |
+
Gated Normalization wrapper.
|
| 78 |
+
Allows the model to learn to skip normalization via a learnable gate.
|
| 79 |
+
"""
|
| 80 |
+
def __init__(self, original_norm, initial_gate_value=-1.0):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.norm = original_norm
|
| 83 |
+
# Initialize gate to -1.0 (sigmoid(-1) ≈ 0.27) to start conservatively
|
| 84 |
+
self.gate = nn.Parameter(torch.tensor(initial_gate_value))
|
| 85 |
+
|
| 86 |
+
def forward(self, x, *args, **kwargs):
|
| 87 |
+
normed = self.norm(x, *args, **kwargs)
|
| 88 |
+
g = torch.sigmoid(self.gate)
|
| 89 |
+
return (1.0 - g) * x + g * normed
|
| 90 |
+
|
| 91 |
+
|
| 92 |
class AlinlightRotaryEmbedding(nn.Module):
|
| 93 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 94 |
super().__init__()
|
|
|
|
| 151 |
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 152 |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 153 |
self.act_fn = nn.SiLU()
|
| 154 |
+
|
| 155 |
+
# Use GatedNorm for the inner normalization
|
| 156 |
+
self.pre_down_norm = GatedNorm(
|
| 157 |
+
AlinlightRMSNorm(self.intermediate_size, eps=config.rms_norm_eps)
|
| 158 |
+
)
|
| 159 |
|
| 160 |
# Tag for specialized initialization
|
| 161 |
self.down_proj._is_residual_projection = True
|
|
|
|
| 192 |
|
| 193 |
self.use_qk_norm = getattr(config, "use_qk_norm", True)
|
| 194 |
if self.use_qk_norm:
|
| 195 |
+
# Use GatedNorm for QK Normalization
|
| 196 |
+
self.q_norm = GatedNorm(AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps))
|
| 197 |
+
self.k_norm = GatedNorm(AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps))
|
| 198 |
|
| 199 |
self.attn_logit_softcapping = getattr(config, 'attn_logit_softcapping', None)
|
| 200 |
|