mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
more accurate rope fallback (#2792)
This commit is contained in:
@@ -332,6 +332,26 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
||||
|
||||
def test_rope_with_large_offset(self):
|
||||
x = mx.random.normal(shape=(1, 1, 1024, 32))
|
||||
rx_fp32 = mx.fast.rope(
|
||||
x,
|
||||
32,
|
||||
traditional=False,
|
||||
scale=1.0,
|
||||
base=10000,
|
||||
offset=4000,
|
||||
)
|
||||
rx_bf16 = mx.fast.rope(
|
||||
x.astype(mx.bfloat16),
|
||||
32,
|
||||
traditional=False,
|
||||
scale=1.0,
|
||||
base=10000,
|
||||
offset=4000,
|
||||
)
|
||||
self.assertLess((rx_fp32 - rx_bf16).abs().max(), 1e-1)
|
||||
|
||||
def test_rms_norm(self):
|
||||
# Per dtype absolute tolerance
|
||||
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
||||
|
||||
Reference in New Issue
Block a user