add base kwarg to rope (#186)

This commit is contained in:
Awni Hannun
2023-12-15 16:47:59 -08:00
committed by GitHub
parent 83f266c44c
commit 2e02acdc83
3 changed files with 25 additions and 10 deletions

View File

@@ -463,6 +463,21 @@ class TestNN(mlx_tests.MLXTestCase):
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
)
def test_rope(self):
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
rope = nn.RoPE(4, **kwargs)
shape = (1, 3, 4)
x = mx.random.uniform(shape=shape)
y = rope(x)
self.assertTrue(y.shape, shape)
self.assertTrue(y.dtype, mx.float32)
y = rope(x, offset=3)
self.assertTrue(y.shape, shape)
y = rope(x.astype(mx.float16))
self.assertTrue(y.dtype, mx.float16)
if __name__ == "__main__":
unittest.main()