From a7d139f484cc613afafde0cb99302518cfcaa516 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Wed, 7 Feb 2024 00:27:05 +1100
Subject: [PATCH] fix(mlx-lm): olmo 1b model (#417)
---
llms/mlx_lm/models/olmo.py | 22 ++++++++++++++++++++--
1 file changed, 20 insertions(+), 2 deletions(-)
diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py
index 8a511a1e..d4df3839 100644
--- a/llms/mlx_lm/models/olmo.py
+++ b/llms/mlx_lm/models/olmo.py
@@ -24,6 +24,15 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
+ mlp_ratio: int = 4
+ weight_tying: bool = False
+
+ def __post_init__(self):
+ self.mlp_hidden_size = (
+ self.mlp_hidden_size
+ if self.mlp_hidden_size is not None
+ else self.mlp_ratio * self.d_model
+ )
class LayerNorm(nn.LayerNorm):
@@ -99,6 +108,7 @@ class TransformerBlock(nn.Module):
h = x + r
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
+
out = h + self.ff_out(nn.silu(x2) * x1)
return out, cache
@@ -107,9 +117,12 @@ class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_layers = args.n_layers
+ self.weight_tying = args.weight_tying
+
self.wte = nn.Embedding(args.embedding_size, args.d_model)
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
- self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
+ if not self.weight_tying:
+ self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
self.norm = LayerNorm(args.d_model, affine=False)
def __call__(
@@ -130,7 +143,12 @@ class Transformer(nn.Module):
for e, block in enumerate(self.blocks):
h, cache[e] = block(h, mask, cache[e])
- return self.ff_out(self.norm(h)), cache
+ h = self.norm(h)
+
+ if self.weight_tying:
+ return h @ self.wte.weight.T, cache
+
+ return self.ff_out(h), cache
class OlmoModel(nn.Module):