diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 0d1c8b2ff..c56a2fc9e 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -556,7 +556,7 @@ class TestNN(mlx_tests.MLXTestCase): ) def test_rope(self): - for kwargs in [{}, {"traditional": False}, {"base": 10000}]: + for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]: rope = nn.RoPE(4, **kwargs) shape = (1, 3, 4) x = mx.random.uniform(shape=shape)