mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
@@ -8,7 +8,7 @@ import mlx_tests
|
||||
|
||||
|
||||
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
|
||||
N = x.shape[1] + offset
|
||||
N = x.shape[-2] + offset
|
||||
dtype = x.dtype
|
||||
half_D = dims // 2
|
||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||
@@ -143,6 +143,20 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
# Test transpose into rope
|
||||
dims, _, base, scale, offset, traditional = defaults
|
||||
x = mx.random.uniform(shape=(1, 1, 4, dims)).swapaxes(1, 2)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
1.0 * x, # multiply here to allow donation
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[mx.float32])
|
||||
|
||||
def test_rope_with_freqs(self):
|
||||
# Check throws
|
||||
T = 4
|
||||
|
Reference in New Issue
Block a user