diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 2c27d3587..ebc6f2b7a 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -571,16 +571,15 @@ class TestNN(mlx_tests.MLXTestCase): self.assertTrue(y.dtype, mx.float16) def test_alibi(self): - for kwargs in [{"num_heads": 8}]: - alibi = nn.ALibi(**kwargs) - shape = [1, 8, 20, 20] - x = mx.random.uniform(shape=shape) - y = alibi(x) - self.assertTrue(y.shape, shape) - self.assertTrue(y.dtype, mx.float32) + alibi = nn.ALiBi() + shape = [1, 8, 20, 20] + x = mx.random.uniform(shape=shape) + y = alibi(x) + self.assertTrue(y.shape, shape) + self.assertTrue(y.dtype, mx.float32) - y = alibi(x.astype(mx.float16)) - self.assertTrue(y.dtype, mx.float16) + y = alibi(x.astype(mx.float16)) + self.assertTrue(y.dtype, mx.float16) if __name__ == "__main__":