mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
Custom primitive + RoPE fat op (#676)
* extensions start * rope custom op * fix build * docs + rope benchmark * fix test * Add a Metal kernel for RoPE * Fix position of traditional * transform tests * Move rope computation to float and fix tests * Fix the test and a typo * change to fast * fix no metal build --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
158
python/tests/test_fast.py
Normal file
158
python/tests/test_fast.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
def rope_orig(x, dims, traditional, base, scale, offset):
|
||||
N = x.shape[1] + offset
|
||||
dtype = x.dtype
|
||||
half_D = dims // 2
|
||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D))
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||
costheta, sintheta = mx.cos(theta), mx.sin(theta)
|
||||
if traditional:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
||||
return mx.reshape(rx, x.shape)
|
||||
else:
|
||||
x1 = x[..., : dims // 2]
|
||||
x2 = x[..., dims // 2 : dims]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
if dims < x.shape[-1]:
|
||||
rx = mx.concatenate([rx1, rx2, x[..., dims:]], axis=-1)
|
||||
else:
|
||||
rx = mx.concatenate([rx1, rx2], axis=-1)
|
||||
return rx
|
||||
|
||||
|
||||
class TestFast(mlx_tests.MLXTestCase):
|
||||
def test_rope(self):
|
||||
T = 4
|
||||
|
||||
# Defaults: dims, dtype, base, scale, offset, traditional
|
||||
defaults = (8, mx.float32, 10000.0, 1.0, 0, False)
|
||||
|
||||
# Per dtype absolute tolerance
|
||||
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
||||
|
||||
# Test cases:
|
||||
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
||||
bases = [10000.0, 1000000.0]
|
||||
scales = [1.0, 2.0]
|
||||
offsets = [0, 3]
|
||||
traditional = [True, False]
|
||||
|
||||
for traditional in [True, False]:
|
||||
dims, dtype, _, scale, offset, _ = defaults
|
||||
for base in bases:
|
||||
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
dims, _, base, scale, offset, _ = defaults
|
||||
for dtype in dtypes:
|
||||
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||
ry = rope_orig(
|
||||
x.astype(mx.float32), dims, traditional, base, scale, offset
|
||||
)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
if dtype != mx.float32:
|
||||
self.assertLessEqual(
|
||||
mx.abs(ry - rx_fast).max(), mx.abs(ry - rx).max()
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
dims, dtype, base, scale, _, _ = defaults
|
||||
for offset in offsets:
|
||||
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
dims, dtype, base, _, offset, _ = defaults
|
||||
for scale in scales:
|
||||
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
def test_fast_transforms(self):
|
||||
x = mx.random.uniform(shape=(2, 2, 8))
|
||||
|
||||
defaults = (8, False, 10000.0, 1.0, 0)
|
||||
dims, traditional, base, scale, offset = defaults
|
||||
|
||||
# VJP
|
||||
_, vjp_out = mx.vjp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),))
|
||||
_, vjp_fast_out = mx.vjp(
|
||||
lambda x: mx.fast.rope(
|
||||
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
|
||||
),
|
||||
(x,),
|
||||
(mx.ones_like(x),),
|
||||
)
|
||||
self.assertTrue(mx.allclose(vjp_out[0], vjp_fast_out[0]))
|
||||
|
||||
# JVP
|
||||
_, jvp_out = mx.jvp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),))
|
||||
_, jvp_fast_out = mx.jvp(
|
||||
lambda x: mx.fast.rope(
|
||||
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
|
||||
),
|
||||
(x,),
|
||||
(mx.ones_like(x),),
|
||||
)
|
||||
self.assertTrue(mx.allclose(jvp_out[0], jvp_fast_out[0]))
|
||||
|
||||
# VMAP
|
||||
x = mx.random.uniform(shape=(2, 2, 2, 8))
|
||||
vmap_out = mx.vmap(lambda x: rope_orig(x, *defaults))(x)
|
||||
vmap_fast_out = mx.vmap(
|
||||
lambda x: mx.fast.rope(
|
||||
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
|
||||
)
|
||||
)(x)
|
||||
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user