Yuchan commited on
Commit
1e5945c
ยท
verified ยท
1 Parent(s): 7f390c3

Update Inference.py

Browse files
Files changed (1) hide show
  1. Inference.py +2 -30
Inference.py CHANGED
@@ -178,11 +178,7 @@ class LoU(layers.Layer):
178
  super().__init__()
179
  self.d_model = d_model
180
  self.clip_value = float(clip_value)
181
- self.eps = float(eps)
182
- self.Q = layers.Dense(d_model, dtype='float32')
183
- self.K = layers.Dense(d_model, dtype='float32')
184
- self.V = layers.Dense(d_model, dtype='float32')
185
- self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
186
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
187
 
188
  self.glu = SwiGLU(d_model, 320)
@@ -193,31 +189,7 @@ class LoU(layers.Layer):
193
  residual = x_f32
194
  x_f32 = self.norm1(x)
195
 
196
- q = self.Q(x_f32)
197
- k = self.K(x_f32)
198
- V = self.V(x_f32)
199
- g_q = (tf.nn.tanh(q) + 1.0) / 2.0
200
- g_k = (tf.nn.tanh(k) + 1.0) / 2.0
201
- score = g_q * g_k
202
-
203
- score = tf.cumsum(score, axis=1) # (B, L, D)
204
-
205
- # ๐Ÿ’ก ์ˆ˜์ •๋œ ๋ถ€๋ถ„: ํ˜„์žฌ ํ† ํฐ๊นŒ์ง€์˜ ๋ˆ„์ ํ•ฉ ํ‰๊ท ์œผ๋กœ ์ •๊ทœํ™”
206
- seq_len = tf.shape(score)[1]
207
- # [1, 2, 3, ..., L]์„ D_model ์ฐจ์›์œผ๋กœ ํ™•์žฅ
208
- count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
209
- count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
210
-
211
- # ๋ˆ„์ ํ•ฉ์„ ํ˜„์žฌ๊นŒ์ง€์˜ ํ† ํฐ ๊ฐœ์ˆ˜๋กœ ๋‚˜๋ˆ„์–ด ํ‰๊ท  ๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ (B, L, D)
212
- score_mean = score / count_for_mean
213
-
214
- # ์ •๊ทœํ™” ๋ถ„๋ชจ ์„ค์ •
215
- denom = tf.maximum(score_mean, self.eps)
216
- score_norm = score / denom
217
- # -----------------------------------------------
218
-
219
- score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
220
- x_comb = score_clipped * V
221
 
222
  out = self.norm(x_comb + residual)
223
  out = self.cross(out, z)
 
178
  super().__init__()
179
  self.d_model = d_model
180
  self.clip_value = float(clip_value)
181
+ self.mha = layers.MultiHeadAttention(8, 20)
 
 
 
 
182
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
183
 
184
  self.glu = SwiGLU(d_model, 320)
 
189
  residual = x_f32
190
  x_f32 = self.norm1(x)
191
 
192
+ x_comb = self.mha(x, x, x, use_causal_mask=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  out = self.norm(x_comb + residual)
195
  out = self.cross(out, z)