From fde4b4dc4249cacd531098413bbfe4fcba5a02eb Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 27 Nov 2024 12:15:45 -0800 Subject: [PATCH] add olmo2 --- llms/mlx_lm/models/olmo2.py | 4 +++- llms/tests/test_models.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/olmo2.py b/llms/mlx_lm/models/olmo2.py index f8efda14..a28fdcc1 100644 --- a/llms/mlx_lm/models/olmo2.py +++ b/llms/mlx_lm/models/olmo2.py @@ -231,7 +231,9 @@ class TransformerBlock(nn.Module): self.post_attention_layernorm = nn.RMSNorm( args.hidden_size, eps=args.rms_norm_eps ) - self.post_feedforward_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_feedforward_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) self.args = args def __call__( diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 2d5ed502..edb594d7 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -813,6 +813,5 @@ class TestModels(unittest.TestCase): ) - if __name__ == "__main__": unittest.main()