mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix rope fallback to not upcast (#1797)
* fix rope fallback to not upcast * Update mlx/fast.cpp Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
0cea88bcc5
commit
121d9a0702
@ -392,8 +392,8 @@ array rope(
|
|||||||
s);
|
s);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto inv_freqs =
|
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
|
||||||
inputs.size() == 3 ? reciprocal(inputs[2], s) : default_inv_freqs();
|
: default_inv_freqs();
|
||||||
auto theta =
|
auto theta =
|
||||||
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
|
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
|
||||||
auto coss = cos(theta, s);
|
auto coss = cos(theta, s);
|
||||||
|
@ -18,7 +18,7 @@ def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
|
|||||||
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
|
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inv_freqs = 1 / freqs
|
inv_freqs = (1 / freqs).astype(x.dtype)
|
||||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
|
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
|
||||||
costheta, sintheta = mx.cos(theta), mx.sin(theta)
|
costheta, sintheta = mx.cos(theta), mx.sin(theta)
|
||||||
if traditional:
|
if traditional:
|
||||||
@ -189,9 +189,12 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
freqs = mx.random.uniform(shape=(dims // 2,))
|
freqs = mx.random.uniform(shape=(dims // 2,))
|
||||||
|
|
||||||
rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)
|
tolerances = {mx.float32: 1e-5, mx.float16: 1e-2}
|
||||||
|
for dtype in [mx.float32, mx.float16]:
|
||||||
|
x_ = x.astype(dtype)
|
||||||
|
rx = rope_orig(x_, dims, False, None, 1.0, 0, freqs)
|
||||||
rx_fast = mx.fast.rope(
|
rx_fast = mx.fast.rope(
|
||||||
x,
|
x_,
|
||||||
dims,
|
dims,
|
||||||
traditional=False,
|
traditional=False,
|
||||||
base=None,
|
base=None,
|
||||||
@ -199,7 +202,8 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
offset=0,
|
offset=0,
|
||||||
freqs=freqs,
|
freqs=freqs,
|
||||||
)
|
)
|
||||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
self.assertEqual(dtype, rx.dtype)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
|
||||||
# Test single vector
|
# Test single vector
|
||||||
x = mx.random.uniform(shape=(1, 1, dims))
|
x = mx.random.uniform(shape=(1, 1, dims))
|
||||||
|
Loading…
Reference in New Issue
Block a user