formatting

This commit is contained in:
Prince Canuma 2025-03-12 10:35:12 +01:00
parent 3c15130f39
commit 822546dc91

View File

@ -5,14 +5,14 @@ from typing import Any, Dict, List, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .cache import KVCache, RotatingKVCache
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask
from .cache import KVCache, RotatingKVCache
@dataclass @dataclass
class ModelArgs(BaseModelArgs): class ModelArgs(BaseModelArgs):
model_type: str model_type: str
hidden_size: int =1152 hidden_size: int = 1152
num_hidden_layers: int = 26 num_hidden_layers: int = 26
intermediate_size: int = 6912 intermediate_size: int = 6912
num_attention_heads: int = 4 num_attention_heads: int = 4
@ -102,13 +102,13 @@ class Attention(nn.Module):
if mask.shape[-1] != key_len: if mask.shape[-1] != key_len:
mask = mask[..., -key_len:] mask = mask[..., -key_len:]
output = mx.fast.scaled_dot_product_attention( output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5): def __init__(self, dims: int, eps: float = 1e-5):
super().__init__() super().__init__()
@ -118,6 +118,7 @@ class RMSNorm(nn.Module):
def __call__(self, x): def __call__(self, x):
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps) return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, dim, hidden_dim): def __init__(self, dim, hidden_dim):
super().__init__() super().__init__()
@ -138,9 +139,7 @@ class TransformerBlock(nn.Module):
self.self_attn = Attention(args, layer_idx) self.self_attn = Attention(args, layer_idx)
self.mlp = MLP(args.hidden_size, args.intermediate_size) self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
args.hidden_size, eps=args.rms_norm_eps
)
self.pre_feedforward_layernorm = RMSNorm( self.pre_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps args.hidden_size, eps=args.rms_norm_eps
) )
@ -189,7 +188,6 @@ class Gemma3Model(nn.Module):
cache = [None] * len(self.layers) cache = [None] * len(self.layers)
if mask is None: if mask is None:
# Sliding window
j = self.args.sliding_window_pattern j = self.args.sliding_window_pattern
mask = create_attention_mask(h, cache[j - 1 : j]) mask = create_attention_mask(h, cache[j - 1 : j])
@ -219,9 +217,7 @@ class Model(nn.Module):
def sanitize(self, weights): def sanitize(self, weights):
if "lm_head.weight" not in weights: if "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights[ weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
"model.embed_tokens.weight"
]
return { return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
} }
@ -245,9 +241,7 @@ class Model(nn.Module):
i % self.args.sliding_window_pattern i % self.args.sliding_window_pattern
== self.args.sliding_window_pattern - 1 == self.args.sliding_window_pattern - 1
): ):
caches.append( caches.append(KVCache())
KVCache()
)
else: else:
caches.append( caches.append(
RotatingKVCache(max_size=self.args.sliding_window, keep=0) RotatingKVCache(max_size=self.args.sliding_window, keep=0)