model-prototype / Test.py
OpenLab-NLP's picture
Update Test.py
7f8fd1d verified
raw
history blame
9.4 kB
import os, json, random, numpy as np, tensorflow as tf
from tensorflow.keras import layers, Model
import sentencepiece as spm
import requests
# ===============================
# 0๏ธโƒฃ ํ™˜๊ฒฝ ์„ค์ •
# ===============================
TOKENIZER_PATH = "bpe.model"
DATA_PATH = "corpus.txt" # 36M ๋ฌธ์žฅ ํ…์ŠคํŠธ ํŒŒ์ผ
MAX_LEN = 128
EMBED_DIM = 384
LATENT_DIM = 384
BATCH_SIZE = 400
NEGATIVE_RATIO = 1 # negative sample ์ˆ˜
def download_file(url, save_path):
if not os.path.exists(save_path):
print(f"Downloading {save_path} ...")
r = requests.get(url, stream=True)
r.raise_for_status()
with open(save_path, "wb") as f:
for chunk in r.iter_content(8192*2):
f.write(chunk)
print(f"โœ… {save_path} saved")
download_file("https://huggingface.co/datasets/OpenLab-NLP/ko-corpus/resolve/main/bpe.model?download=true", TOKENIZER_PATH)
download_file("https://huggingface.co/datasets/OpenLab-NLP/ko-corpus/resolve/main/shuffled_corpus%20(1).txt?download=true", DATA_PATH)
# ===============================
# 2๏ธโƒฃ ํ† ํฌ๋‚˜์ด์ € ์ค€๋น„
# ===============================
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
vocab_size = sp.get_piece_size()
def encode_sentence(sentence, max_len=MAX_LEN):
return sp.encode(sentence, out_type=int)[:max_len]
def pad_sentence(tokens):
return tokens + [pad_id]*(MAX_LEN - len(tokens))
def gen_pairs_streaming(txt_path=DATA_PATH, negative_ratio=NEGATIVE_RATIO):
with open(txt_path, "r", encoding="utf-8") as f:
sentences = [line.strip() for line in f if line.strip()]
while True:
for s1 in sentences:
# positive pair (์ž๊ธฐ ์ž์‹ )
x1 = pad_sentence(encode_sentence(s1))
yield (x1, x1), 1.0
# negative pairs (์ž๊ธฐ ์ž์‹  ์ œ์™ธ)
for _ in range(negative_ratio):
s2 = s1
while s2 == s1:
s2 = random.choice(sentences)
x2 = pad_sentence(encode_sentence(s2))
yield (x1, x2), 0.0
dataset = tf.data.Dataset.from_generator(
lambda: gen_pairs_streaming(),
output_types=((tf.int32, tf.int32), tf.float32),
output_shapes=(((MAX_LEN,), (MAX_LEN,)), ())
).shuffle(1024).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
class EncoderBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim=EMBED_DIM, ff_dim=1152, seq_len=MAX_LEN):
super().__init__()
self.embed_dim = embed_dim
self.seq_len = seq_len
self.fc1 = layers.Dense(ff_dim)
self.fc2 = layers.Dense(embed_dim)
self.fc3 = layers.Dense(ff_dim)
self.fc4 = layers.Dense(embed_dim)
# (seq_len, embed_dim)๋กœ ์ •์˜ โ€” (L -> D) ํˆฌ์‚ฌ์šฉ
self.w_proj = self.add_weight(
name="w_proj_L_to_D",
shape=(seq_len, embed_dim),
initializer="glorot_uniform",
trainable=True
)
self.alpha2 = layers.Dense(1)
self.ln = layers.LayerNormalization(epsilon=1e-5)
self.ln1 = layers.LayerNormalization(epsilon=1e-5)
self.ln2 = layers.LayerNormalization(epsilon=1e-5)
def call(self, x):
# x: (B, L, D)
x_norm = self.ln(x)
h = self.fc1(x_norm) # (B, L, ff_dim)
g, v = tf.split(h, 2, axis=-1) # (B, L, ff_dim/2) ๊ฐ
h = tf.nn.silu(g) * v
h = self.fc2(h) # (B, L, D)
# --- matmul -> (B, L, L) ---
sim = tf.matmul(h, h, transpose_b=True) # (B, L, L)
# (์˜ต์…˜) ์ •๊ทœํ™”/์Šค์ผ€์ผ๋ง ์›ํ•˜๋ฉด ์ถ”๊ฐ€
sim = tf.nn.softmax(sim, axis=-1) # (B, L, L)
# --- (B, L, L) -> (B, L, D) : tensordot axes ๋งž์ถฐ์„œ ํˆฌ์‚ฌ ---
# w_proj: (L, D), sim last axis matches w_proj first axis
h2 = tf.tensordot(sim, self.w_proj, axes=[[2], [0]]) # (B, L, D)
# ์ด์ œ shape ๋งž์Œ โ€” v์™€ element-wise ๊ณฑ ๊ฐ€๋Šฅ
v_gate = tf.nn.softmax(self.alpha2(v), axis=1) # (B, L, 1)
v = v_gate * h2 # (B, L, D)
x_norm = x_norm + self.ln2(v)
z = self.fc3(x_norm)
g, v = tf.split(z, 2, axis=-1)
z = tf.nn.silu(g) * v
z = self.fc4(z)
return x_norm + self.ln1(z)
class L2NormLayer(layers.Layer):
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
super().__init__(**kwargs)
self.axis = axis
self.epsilon = epsilon
def call(self, inputs):
return tf.math.l2_normalize(inputs, axis=self.axis, epsilon=self.epsilon)
def get_config(self):
return {"axis": self.axis, "epsilon": self.epsilon, **super().get_config()}
class SentenceEncoder(tf.keras.Model):
def __init__(self, vocab_size, embed_dim=384, latent_dim=384, max_len=128, pad_id=pad_id):
super().__init__()
self.pad_id = pad_id
self.embed = layers.Embedding(vocab_size, embed_dim)
self.pos_embed = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
self.blocks = [EncoderBlock() for _ in range(1)]
self.attn_pool = layers.Dense(1)
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
self.latent = layers.Dense(latent_dim, activation=None) # tanh ์ œ๊ฑฐ
self.l2norm = L2NormLayer() # ์ถ”๊ฐ€
def call(self, x):
positions = tf.range(tf.shape(x)[1])[tf.newaxis, :]
x_embed = self.embed(x) + self.pos_embed(positions)
mask = tf.cast(tf.not_equal(x, self.pad_id), tf.float32)
x = x_embed
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
scores = self.attn_pool(x)
scores = tf.where(tf.equal(mask[..., tf.newaxis], 0), -1e9, scores)
scores = tf.nn.softmax(scores, axis=1)
pooled = tf.reduce_sum(x * scores, axis=1)
latent = self.latent(pooled)
return self.l2norm(latent) # L2 ์ •๊ทœํ™” ํ›„ ๋ฐ˜ํ™˜
# ===============================
# 5๏ธโƒฃ Cosine similarity layer + Contrastive Loss
# ===============================
class CosineSimilarityLayer(layers.Layer):
def call(self, inputs):
v1, v2 = inputs
return tf.reduce_sum(v1 * v2, axis=-1) # ์ด๋ฏธ L2 ์ •๊ทœํ™”๋ผ์„œ dot product = cosine similarity
def contrastive_loss(margin=0.5):
def loss(y_true, y_pred):
y_true = tf.cast(y_true, tf.float32)
dist = 1 - y_pred
pos_loss = y_true * tf.square(dist)
neg_loss = (1 - y_true) * tf.square(tf.maximum(margin - dist, 0))
return tf.reduce_mean(pos_loss + neg_loss)
return loss
encoder = SentenceEncoder(vocab_size=vocab_size)
# ===============================
# 6๏ธโƒฃ ์‹œ์•” ๋ชจ๋ธ ์ •์˜
# ===============================
input1 = tf.keras.Input(shape=(MAX_LEN,), dtype=tf.int32)
input2 = tf.keras.Input(shape=(MAX_LEN,), dtype=tf.int32)
v1 = encoder(input1)
v2 = encoder(input2)
cos_sim = CosineSimilarityLayer()([v1, v2])
siamese_model = tf.keras.Model([input1, input2], cos_sim)
siamese_model.compile(optimizer=tf.keras.optimizers.Adam(1e-5), loss=contrastive_loss(margin=0.5))
siamese_model.summary()
# ===============================
# 7๏ธโƒฃ ํ•™์Šต
# ===============================
#steps_per_epoch = 36757266 // 400
steps_per_epoch = 1000000 // 400
# generator ๊ธฐ๋ฐ˜ streaming ํ•™์Šต
siamese_model.fit(dataset, epochs=1, steps_per_epoch=steps_per_epoch) # steps_per_epoch๋Š” ํ•„์š”์— ๋”ฐ๋ผ ์กฐ์ ˆ
encoder.save_weights("encoder.weights.h5")
siamese_model.save_weights("siamese_model.weights.h5")
# ===============================
# 8๏ธโƒฃ corpus ๋ฒกํ„ฐ ์ƒ์„ฑ + ์บ์‹ฑ (์•ˆ์ „ํ•˜๊ฒŒ ์ƒˆ๋กœ ์ƒ์„ฑ)
# ===============================
LIMIT = 1000 # ๊ฒ€์ƒ‰์šฉ corpus ๋ฌธ์žฅ ์ˆ˜
prompts = []
# prompts ๋จผ์ € ์ฝ๊ธฐ
with open(DATA_PATH, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
if i >= LIMIT:
break
line = line.strip()
if line:
prompts.append(line)
def get_sentence_vector(sentence):
tokens = pad_sentence(encode_sentence(sentence))
return encoder(np.array([tokens])).numpy()[0]
# corpus_vectors ํ•ญ์ƒ ์ƒˆ๋กœ ์ƒ์„ฑ (๊ธฐ์กด npy ๋ฌด์‹œ)
corpus_vectors = np.stack([get_sentence_vector(p) for p in prompts]).astype(np.float16)
np.save("corpus_vectors.npy", corpus_vectors)
# norms ๊ณ„์‚ฐ
corpus_norms = np.linalg.norm(corpus_vectors, axis=1)
# ===============================
# 9๏ธโƒฃ ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
# ===============================
def search(query, top_k=3):
q_vec = get_sentence_vector(query).astype(np.float16)
sims = corpus_vectors @ q_vec
sims /= (corpus_norms * np.linalg.norm(q_vec) + 1e-8)
# top_k ์•ˆ์ „ ์ฒ˜๋ฆฌ
top_k = min(top_k, len(prompts))
top_idx = np.argsort(sims)[::-1][:top_k]
return [(prompts[i], float(sims[i])) for i in top_idx]
# ===============================
# ๐Ÿ”Ÿ ํ…Œ์ŠคํŠธ
# ===============================
query = "์šฐ๋ฆฌ๊ฐ€ ํ•ธ๋“œํฐ, ๋ฐฐ๋ฅผ ์„ธ๊ณ„์—์„œ ์ œ์ผ ์ž˜ ๋งŒ๋“œ๋Š” ๊ฒƒ ์ด์ƒ์œผ๋กœ ์‚ฌ๋ž‘์„ ์ œ์ผ ์ž˜ ์‹ค์ฒœํ•  ์ˆ˜ ์žˆ๋Š” ๋Šฅ๋ ฅ, ์ž์งˆ, ์ €๋ ฅ์ด ์šฐ๋ฆฌ์—๊ฒŒ ์žˆ๋‹ค."
results = search(query)
for p, s in results:
print(f"Prompt: {p}\n์œ ์‚ฌ๋„: {s:.3f}\n---")
query = "์•ˆ๋…•ํ•˜์„ธ์š”! ์˜ค๋Š˜ ๋‚ ์”จ ์–ด๋–ค๊ฐ€์š”?"
results = search(query)
for p, s in results:
print(f"Prompt: {p}\n์œ ์‚ฌ๋„: {s:.3f}\n---")