| | import mlx.core as mx |
| | import mlx.nn as nn |
| | import json |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| |
|
| |
|
| | @dataclass |
| | class ModelArgs: |
| | hidden_size: int |
| | num_attention_heads: int |
| | num_hidden_layers: int |
| | vocab_size: int |
| | intermediate_size: int |
| | intermediate_size_mlp: int = None |
| | num_key_value_heads: int = 0 |
| | rms_norm_eps: float = 1e-5 |
| | rope_theta: float = 10000.0 |
| | head_dim: int = None |
| | use_dual_mlp: bool = False |
| | tie_word_embeddings: bool = True |
| | use_qk_norm: bool = False |
| | attn_scale: float = 1.0 |
| | no_rope_layers: list | None = None |
| | attention_chunk_size: int | None = None |
| | attn_temperature_tuning: bool = False |
| |
|
| | @classmethod |
| | def from_dict(cls, params): |
| | return cls( |
| | hidden_size=params["hidden_size"], |
| | num_attention_heads=params["num_attention_heads"], |
| | num_hidden_layers=params["num_hidden_layers"], |
| | vocab_size=params["vocab_size"], |
| | intermediate_size=params["intermediate_size"], |
| | intermediate_size_mlp=params.get("intermediate_size_mlp"), |
| | num_key_value_heads=params.get("num_key_value_heads", 0), |
| | rms_norm_eps=params.get("rms_norm_eps", 1e-5), |
| | rope_theta=params.get("rope_theta", 10000.0), |
| | head_dim=params.get("head_dim"), |
| | |
| | use_dual_mlp=False, |
| | tie_word_embeddings=params.get("tie_word_embeddings", True), |
| | use_qk_norm=params.get("use_qk_norm", False), |
| | attn_scale=params.get("attn_scale", 1.0), |
| | no_rope_layers=params.get("no_rope_layers"), |
| | attention_chunk_size=params.get("attention_chunk_size"), |
| | attn_temperature_tuning=params.get("attn_temperature_tuning", False), |
| | ) |
| |
|
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, dims: int, eps: float = 1e-5): |
| | super().__init__() |
| | self.weight = mx.ones((dims,)) |
| | self.eps = eps |
| |
|
| | def _norm(self, x): |
| | return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) |
| |
|
| | def __call__(self, x): |
| | output = self._norm(x.astype(mx.float32)).astype(x.dtype) |
| | return self.weight * output |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__(self, args: ModelArgs): |
| | super().__init__() |
| | self.args = args |
| | self.n_heads = args.num_attention_heads |
| | self.n_kv_heads = ( |
| | args.num_key_value_heads |
| | if args.num_key_value_heads > 0 |
| | else args.num_attention_heads |
| | ) |
| | self.head_dim = ( |
| | args.head_dim |
| | if getattr(args, "head_dim", None) is not None |
| | else (args.hidden_size // self.n_heads) |
| | ) |
| | |
| | |
| | self.scale = self.head_dim**-0.5 |
| |
|
| | self.q_proj = nn.Linear( |
| | args.hidden_size, self.n_heads * self.head_dim, bias=False |
| | ) |
| | self.k_proj = nn.Linear( |
| | args.hidden_size, self.n_kv_heads * self.head_dim, bias=False |
| | ) |
| | self.v_proj = nn.Linear( |
| | args.hidden_size, self.n_kv_heads * self.head_dim, bias=False |
| | ) |
| | self.o_proj = nn.Linear( |
| | self.n_heads * self.head_dim, args.hidden_size, bias=False |
| | ) |
| | self.q_norm = ( |
| | RMSNorm(self.head_dim, eps=args.rms_norm_eps) |
| | if getattr(args, "use_qk_norm", False) |
| | else None |
| | ) |
| | self.k_norm = ( |
| | RMSNorm(self.head_dim, eps=args.rms_norm_eps) |
| | if getattr(args, "use_qk_norm", False) |
| | else None |
| | ) |
| | |
| | self.rope = nn.RoPE(self.head_dim, traditional=True, base=args.rope_theta) |
| |
|
| | def __call__( |
| | self, |
| | x, |
| | mask=None, |
| | cache=None, |
| | apply_rope: bool = True, |
| | attn_temp: float | None = None, |
| | ): |
| | B, L, D = x.shape |
| | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) |
| |
|
| | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) |
| | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) |
| | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) |
| |
|
| | if self.q_norm is not None: |
| | queries = self.q_norm(queries) |
| | keys = self.k_norm(keys) |
| |
|
| | |
| | if apply_rope: |
| | if cache is not None: |
| | queries = self.rope(queries, offset=cache.offset) |
| | keys = self.rope(keys, offset=cache.offset) |
| | keys, values = cache.update_and_fetch(keys, values) |
| | else: |
| | queries = self.rope(queries) |
| | keys = self.rope(keys) |
| | else: |
| | if cache is not None: |
| | keys, values = cache.update_and_fetch(keys, values) |
| |
|
| | if self.n_kv_heads != self.n_heads: |
| | repeat = self.n_heads // self.n_kv_heads |
| | keys = mx.repeat(keys, repeat, axis=1) |
| | values = mx.repeat(values, repeat, axis=1) |
| |
|
| | |
| | scale = self.scale if attn_temp is None else (self.scale * attn_temp) |
| | output = mx.fast.scaled_dot_product_attention( |
| | queries, keys, values, scale=scale, mask=mask |
| | ) |
| | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) |
| | return self.o_proj(output) |
| |
|
| |
|
| | class SwiGLUMLP(nn.Module): |
| | """Standard LLaMA-style gated MLP (SwiGLU).""" |
| |
|
| | def __init__(self, dim, intermediate_size, activation=nn.silu): |
| | super().__init__() |
| | self.gate_proj = nn.Linear(dim, intermediate_size, bias=False) |
| | self.up_proj = nn.Linear(dim, intermediate_size, bias=False) |
| | self.down_proj = nn.Linear(intermediate_size, dim, bias=False) |
| |
|
| | |
| |
|
| | def __call__(self, x): |
| | |
| | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) |
| |
|
| |
|
| | class DualMLP(nn.Module): |
| | """Dense dual-branch MLP: gated + plain.""" |
| |
|
| | def __init__(self, dim, intermediate_gated, intermediate_plain, activation=nn.silu): |
| | super().__init__() |
| | self.g_up = nn.Linear(dim, intermediate_gated, bias=False) |
| | self.g_gate = nn.Linear(dim, intermediate_gated, bias=False) |
| | self.g_down = nn.Linear(intermediate_gated, dim, bias=False) |
| |
|
| | self.p_up = nn.Linear(dim, intermediate_plain, bias=False) |
| | self.p_down = nn.Linear(intermediate_plain, dim, bias=False) |
| |
|
| | |
| |
|
| | def __call__(self, x): |
| | |
| | |
| | gated_out = self.g_down(nn.silu(self.g_gate(x)) * self.g_up(x)) |
| | plain_out = self.p_down(nn.silu(self.p_up(x))) |
| |
|
| | return gated_out + plain_out |
| |
|
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self, args: ModelArgs, layer_idx: int): |
| | super().__init__() |
| | self.attention = Attention(args) |
| | self.layer_idx = layer_idx |
| | |
| | |
| | |
| | |
| | if ( |
| | isinstance(args.no_rope_layers, list) |
| | and len(args.no_rope_layers) > layer_idx |
| | ): |
| | all_marked = all(bool(v) for v in args.no_rope_layers) |
| | if all_marked: |
| | disable_rope = False |
| | else: |
| | disable_rope = bool(args.no_rope_layers[layer_idx]) |
| | else: |
| | disable_rope = False |
| | self.apply_rope = not disable_rope |
| | self.layer_idx = layer_idx |
| |
|
| | if args.use_dual_mlp and args.intermediate_size_mlp: |
| | self.feed_forward = DualMLP( |
| | args.hidden_size, |
| | args.intermediate_size, |
| | args.intermediate_size_mlp, |
| | ) |
| | else: |
| | self.feed_forward = SwiGLUMLP( |
| | args.hidden_size, |
| | args.intermediate_size_mlp, |
| | ) |
| |
|
| | self.attention_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) |
| | self.ffn_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) |
| |
|
| | def __call__(self, x, mask=None, cache=None): |
| | L = x.shape[1] |
| | |
| | attn_mask = ( |
| | None |
| | if L <= 1 |
| | else nn.MultiHeadAttention.create_additive_causal_mask(L).astype(x.dtype) |
| | ) |
| | args = self.attention.args |
| | apply_rope = self.apply_rope |
| | attn_temp = 1.0 if getattr(args, "attn_temperature_tuning", False) else None |
| |
|
| | r = self.attention( |
| | self.attention_norm(x), |
| | attn_mask, |
| | cache, |
| | apply_rope=apply_rope, |
| | attn_temp=attn_temp, |
| | ) |
| | h = x + r |
| | r = self.feed_forward(self.ffn_norm(h)) |
| | return h + r |
| |
|
| |
|
| | class Model(nn.Module): |
| | def __init__(self, args: ModelArgs): |
| | super().__init__() |
| | self.args = args |
| | self.vocab_size = args.vocab_size |
| | self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size) |
| | |
| | self.layers = [ |
| | TransformerBlock(args=args, layer_idx=i) |
| | for i in range(args.num_hidden_layers) |
| | ] |
| | self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) |
| |
|
| | if not self.args.tie_word_embeddings: |
| | self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False) |
| |
|
| | def __call__(self, inputs, cache=None): |
| | h = self.tok_embeddings(inputs) |
| |
|
| | if cache is None: |
| | cache = [None] * len(self.layers) |
| |
|
| | for layer, c in zip(self.layers, cache): |
| | h = layer(h, None, c) |
| |
|
| | h = self.norm(h) |
| |
|
| | if self.args.tie_word_embeddings: |
| | return h @ self.tok_embeddings.weight.T |
| | else: |
| | return self.output(h) |
| |
|
| |
|
| | def load_model(model_path: str): |
| | model_path = Path(model_path) |
| | with open(model_path / "config.json", "r") as f: |
| | config = json.load(f) |
| |
|
| | from safetensors import safe_open |
| | from mlx.utils import tree_unflatten |
| |
|
| | |
| | with safe_open(model_path / "model.safetensors", framework="mlx") as f: |
| | keys = list(f.keys()) |
| | has_dual = any( |
| | (".feed_forward.g_up.weight" in k) |
| | or (".mlp.g_up.weight" in k) |
| | or (".feed_forward.p_up.weight" in k) |
| | or (".mlp.p_up.weight" in k) |
| | for k in keys |
| | ) |
| |
|
| | args = ModelArgs.from_dict(config) |
| | args.use_dual_mlp = bool(has_dual) |
| | model = Model(args) |
| |
|
| | weights = {} |
| | with safe_open(model_path / "model.safetensors", framework="mlx") as f: |
| | for k in f.keys(): |
| | v = f.get_tensor(k) |
| | |
| | |
| | k = k.replace("model.embed_tokens", "tok_embeddings") |
| | k = k.replace("model.layers", "layers") |
| | k = k.replace("self_attn", "attention") |
| | k = k.replace("input_layernorm", "attention_norm") |
| | k = k.replace("post_attention_layernorm", "ffn_norm") |
| | k = k.replace("mlp.", "feed_forward.") |
| | k = k.replace("model.norm", "norm") |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | weights[k] = v |
| |
|
| | |
| | if config.get("tie_word_embeddings", True): |
| | weights.pop("output.weight", None) |
| |
|
| | model.update(tree_unflatten(list(weights.items()))) |
| | return model |
| |
|