* add test

* fix rope

* fix test
This commit is contained in:
Awni Hannun
2024-08-20 17:37:52 -07:00
committed by GitHub
parent bb1b76d9dc
commit d40e76809f
4 changed files with 26 additions and 17 deletions

View File

@@ -8,7 +8,7 @@ import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
N = x.shape[1] + offset
N = x.shape[-2] + offset
dtype = x.dtype
half_D = dims // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
@@ -143,6 +143,20 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
# Test transpose into rope
dims, _, base, scale, offset, traditional = defaults
x = mx.random.uniform(shape=(1, 1, 4, dims)).swapaxes(1, 2)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
1.0 * x, # multiply here to allow donation
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[mx.float32])
def test_rope_with_freqs(self):
# Check throws
T = 4