mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 09:58:17 +08:00
Allow offset to be an mx.array for mx.fast.rope (#1724)
* allow offset for rope * comment
This commit is contained in:
@@ -8,6 +8,7 @@ import mlx_tests
|
||||
|
||||
|
||||
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
|
||||
offset = offset.item() if isinstance(offset, mx.array) else offset
|
||||
N = x.shape[-2] + offset
|
||||
dtype = x.dtype
|
||||
half_D = dims // 2
|
||||
@@ -76,7 +77,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
||||
bases = [10000.0, 1000000.0]
|
||||
scales = [1.0, 2.0]
|
||||
offsets = [0, 3]
|
||||
offsets = [0, 3, mx.array(3)]
|
||||
traditional = [True, False]
|
||||
|
||||
for traditional in [True, False]:
|
||||
|
||||
Reference in New Issue
Block a user