muverqqw commited on
Commit
18dfd23
·
1 Parent(s): 09a68a0

Update modeling_alinlight.py

Browse files
Files changed (1) hide show
  1. modeling_alinlight.py +25 -3
modeling_alinlight.py CHANGED
@@ -72,6 +72,23 @@ class AlinlightRMSNorm(nn.Module):
72
  return self.weight * x.to(input_dtype)
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  class AlinlightRotaryEmbedding(nn.Module):
76
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
77
  super().__init__()
@@ -134,7 +151,11 @@ class AlinlightMLP(nn.Module):
134
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
135
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
136
  self.act_fn = nn.SiLU()
137
- self.pre_down_norm = AlinlightRMSNorm(self.intermediate_size, eps=config.rms_norm_eps)
 
 
 
 
138
 
139
  # Tag for specialized initialization
140
  self.down_proj._is_residual_projection = True
@@ -171,8 +192,9 @@ class AlinlightAttention(nn.Module):
171
 
172
  self.use_qk_norm = getattr(config, "use_qk_norm", True)
173
  if self.use_qk_norm:
174
- self.q_norm = AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps)
175
- self.k_norm = AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps)
 
176
 
177
  self.attn_logit_softcapping = getattr(config, 'attn_logit_softcapping', None)
178
 
 
72
  return self.weight * x.to(input_dtype)
73
 
74
 
75
+ class GatedNorm(nn.Module):
76
+ """
77
+ Gated Normalization wrapper.
78
+ Allows the model to learn to skip normalization via a learnable gate.
79
+ """
80
+ def __init__(self, original_norm, initial_gate_value=-1.0):
81
+ super().__init__()
82
+ self.norm = original_norm
83
+ # Initialize gate to -1.0 (sigmoid(-1) ≈ 0.27) to start conservatively
84
+ self.gate = nn.Parameter(torch.tensor(initial_gate_value))
85
+
86
+ def forward(self, x, *args, **kwargs):
87
+ normed = self.norm(x, *args, **kwargs)
88
+ g = torch.sigmoid(self.gate)
89
+ return (1.0 - g) * x + g * normed
90
+
91
+
92
  class AlinlightRotaryEmbedding(nn.Module):
93
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
94
  super().__init__()
 
151
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
152
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
153
  self.act_fn = nn.SiLU()
154
+
155
+ # Use GatedNorm for the inner normalization
156
+ self.pre_down_norm = GatedNorm(
157
+ AlinlightRMSNorm(self.intermediate_size, eps=config.rms_norm_eps)
158
+ )
159
 
160
  # Tag for specialized initialization
161
  self.down_proj._is_residual_projection = True
 
192
 
193
  self.use_qk_norm = getattr(config, "use_qk_norm", True)
194
  if self.use_qk_norm:
195
+ # Use GatedNorm for QK Normalization
196
+ self.q_norm = GatedNorm(AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps))
197
+ self.k_norm = GatedNorm(AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps))
198
 
199
  self.attn_logit_softcapping = getattr(config, 'attn_logit_softcapping', None)
200