formatting

This commit is contained in:
Prince Canuma 2025-03-12 10:35:12 +01:00
parent 3c15130f39
commit 822546dc91

View File

@ -5,8 +5,8 @@ from typing import Any, Dict, List, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .cache import KVCache, RotatingKVCache
from .base import BaseModelArgs, create_attention_mask
from .cache import KVCache, RotatingKVCache
@dataclass
@ -102,13 +102,13 @@ class Attention(nn.Module):
if mask.shape[-1] != key_len:
mask = mask[..., -key_len:]
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
@ -118,6 +118,7 @@ class RMSNorm(nn.Module):
def __call__(self, x):
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
@ -138,9 +139,7 @@ class TransformerBlock(nn.Module):
self.self_attn = Attention(args, layer_idx)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.pre_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
@ -189,7 +188,6 @@ class Gemma3Model(nn.Module):
cache = [None] * len(self.layers)
if mask is None:
# Sliding window
j = self.args.sliding_window_pattern
mask = create_attention_mask(h, cache[j - 1 : j])
@ -219,9 +217,7 @@ class Model(nn.Module):
def sanitize(self, weights):
if "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights[
"model.embed_tokens.weight"
]
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@ -245,9 +241,7 @@ class Model(nn.Module):
i % self.args.sliding_window_pattern
== self.args.sliding_window_pattern - 1
):
caches.append(
KVCache()
)
caches.append(KVCache())
else:
caches.append(
RotatingKVCache(max_size=self.args.sliding_window, keep=0)