diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 8a511a1e..d4df3839 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -24,6 +24,15 @@ class ModelArgs(BaseModelArgs): rope_theta: float = 10000 rope_traditional: bool = False model_type: str = None + mlp_ratio: int = 4 + weight_tying: bool = False + + def __post_init__(self): + self.mlp_hidden_size = ( + self.mlp_hidden_size + if self.mlp_hidden_size is not None + else self.mlp_ratio * self.d_model + ) class LayerNorm(nn.LayerNorm): @@ -99,6 +108,7 @@ class TransformerBlock(nn.Module): h = x + r x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1) + out = h + self.ff_out(nn.silu(x2) * x1) return out, cache @@ -107,9 +117,12 @@ class Transformer(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.n_layers = args.n_layers + self.weight_tying = args.weight_tying + self.wte = nn.Embedding(args.embedding_size, args.d_model) self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)] - self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False) + if not self.weight_tying: + self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False) self.norm = LayerNorm(args.d_model, affine=False) def __call__( @@ -130,7 +143,12 @@ class Transformer(nn.Module): for e, block in enumerate(self.blocks): h, cache[e] = block(h, mask, cache[e]) - return self.ff_out(self.norm(h)), cache + h = self.norm(h) + + if self.weight_tying: + return h @ self.wte.weight.T, cache + + return self.ff_out(h), cache class OlmoModel(nn.Module):