From 1d053e0d1d3653a3310ca7779a03a820c9f10121 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 21 Dec 2023 14:59:25 -0800 Subject: [PATCH] Fix the alibi test that was left unchanged (#252) --- python/tests/test_nn.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 2c27d3587..ebc6f2b7a 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -571,16 +571,15 @@ class TestNN(mlx_tests.MLXTestCase): self.assertTrue(y.dtype, mx.float16) def test_alibi(self): - for kwargs in [{"num_heads": 8}]: - alibi = nn.ALibi(**kwargs) - shape = [1, 8, 20, 20] - x = mx.random.uniform(shape=shape) - y = alibi(x) - self.assertTrue(y.shape, shape) - self.assertTrue(y.dtype, mx.float32) + alibi = nn.ALiBi() + shape = [1, 8, 20, 20] + x = mx.random.uniform(shape=shape) + y = alibi(x) + self.assertTrue(y.shape, shape) + self.assertTrue(y.dtype, mx.float32) - y = alibi(x.astype(mx.float16)) - self.assertTrue(y.dtype, mx.float16) + y = alibi(x.astype(mx.float16)) + self.assertTrue(y.dtype, mx.float16) if __name__ == "__main__":