Added RoPE scaling test

This commit is contained in:
Hazem 2023-12-23 00:20:28 +02:00
parent a415bac4f7
commit 0e5ae1d1bf

View File

@ -556,7 +556,7 @@ class TestNN(mlx_tests.MLXTestCase):
)
def test_rope(self):
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
rope = nn.RoPE(4, **kwargs)
shape = (1, 3, 4)
x = mx.random.uniform(shape=shape)