Allow offset to be an mx.array for mx.fast.rope (#1724)

* allow offset for rope

* comment
This commit is contained in:
Awni Hannun
2024-12-19 15:51:44 -08:00
committed by GitHub
parent c3628eea49
commit 0308e9af71
8 changed files with 97 additions and 52 deletions

View File

@@ -8,6 +8,7 @@ import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
offset = offset.item() if isinstance(offset, mx.array) else offset
N = x.shape[-2] + offset
dtype = x.dtype
half_D = dims // 2
@@ -76,7 +77,7 @@ class TestFast(mlx_tests.MLXTestCase):
dtypes = [mx.float32, mx.float16, mx.bfloat16]
bases = [10000.0, 1000000.0]
scales = [1.0, 2.0]
offsets = [0, 3]
offsets = [0, 3, mx.array(3)]
traditional = [True, False]
for traditional in [True, False]: