mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
more accurate rope fallback (#2792)
This commit is contained in:
20
mlx/fast.cpp
20
mlx/fast.cpp
@@ -416,23 +416,25 @@ array rope(
|
|||||||
if (offset.size() > 1) {
|
if (offset.size() > 1) {
|
||||||
offset = expand_dims(offset, {-1, -2}, s);
|
offset = expand_dims(offset, {-1, -2}, s);
|
||||||
}
|
}
|
||||||
auto positions =
|
auto positions = multiply(
|
||||||
multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s);
|
add(arange(x.shape(2), float32, s), offset, s),
|
||||||
|
array(scale, float32),
|
||||||
|
s);
|
||||||
|
|
||||||
auto default_inv_freqs = [&s, &t, base, half_dims]() {
|
auto default_inv_freqs = [&s, base, half_dims]() {
|
||||||
return exp(
|
return exp(
|
||||||
multiply(
|
multiply(
|
||||||
arange(0, -half_dims, -1, t, s),
|
arange(0, -half_dims, -1, float32, s),
|
||||||
array(std::log(base) / half_dims, t),
|
array(std::log(base) / half_dims, float32),
|
||||||
s),
|
s),
|
||||||
s);
|
s);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
|
auto inv_freqs =
|
||||||
: default_inv_freqs();
|
inputs.size() == 3 ? reciprocal(inputs[2], s) : default_inv_freqs();
|
||||||
auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);
|
auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);
|
||||||
auto coss = cos(theta, s);
|
auto coss = astype(cos(theta, s), t, s);
|
||||||
auto sins = sin(theta, s);
|
auto sins = astype(sin(theta, s), t, s);
|
||||||
|
|
||||||
auto apply_rope = [forward, s](
|
auto apply_rope = [forward, s](
|
||||||
const array& x1,
|
const array& x1,
|
||||||
|
|||||||
@@ -332,6 +332,26 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
||||||
|
|
||||||
|
def test_rope_with_large_offset(self):
|
||||||
|
x = mx.random.normal(shape=(1, 1, 1024, 32))
|
||||||
|
rx_fp32 = mx.fast.rope(
|
||||||
|
x,
|
||||||
|
32,
|
||||||
|
traditional=False,
|
||||||
|
scale=1.0,
|
||||||
|
base=10000,
|
||||||
|
offset=4000,
|
||||||
|
)
|
||||||
|
rx_bf16 = mx.fast.rope(
|
||||||
|
x.astype(mx.bfloat16),
|
||||||
|
32,
|
||||||
|
traditional=False,
|
||||||
|
scale=1.0,
|
||||||
|
base=10000,
|
||||||
|
offset=4000,
|
||||||
|
)
|
||||||
|
self.assertLess((rx_fp32 - rx_bf16).abs().max(), 1e-1)
|
||||||
|
|
||||||
def test_rms_norm(self):
|
def test_rms_norm(self):
|
||||||
# Per dtype absolute tolerance
|
# Per dtype absolute tolerance
|
||||||
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
||||||
|
|||||||
Reference in New Issue
Block a user