mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
more accurate rope fallback (#2792)
This commit is contained in:
20
mlx/fast.cpp
20
mlx/fast.cpp
@@ -416,23 +416,25 @@ array rope(
|
||||
if (offset.size() > 1) {
|
||||
offset = expand_dims(offset, {-1, -2}, s);
|
||||
}
|
||||
auto positions =
|
||||
multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s);
|
||||
auto positions = multiply(
|
||||
add(arange(x.shape(2), float32, s), offset, s),
|
||||
array(scale, float32),
|
||||
s);
|
||||
|
||||
auto default_inv_freqs = [&s, &t, base, half_dims]() {
|
||||
auto default_inv_freqs = [&s, base, half_dims]() {
|
||||
return exp(
|
||||
multiply(
|
||||
arange(0, -half_dims, -1, t, s),
|
||||
array(std::log(base) / half_dims, t),
|
||||
arange(0, -half_dims, -1, float32, s),
|
||||
array(std::log(base) / half_dims, float32),
|
||||
s),
|
||||
s);
|
||||
};
|
||||
|
||||
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
|
||||
: default_inv_freqs();
|
||||
auto inv_freqs =
|
||||
inputs.size() == 3 ? reciprocal(inputs[2], s) : default_inv_freqs();
|
||||
auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);
|
||||
auto coss = cos(theta, s);
|
||||
auto sins = sin(theta, s);
|
||||
auto coss = astype(cos(theta, s), t, s);
|
||||
auto sins = astype(sin(theta, s), t, s);
|
||||
|
||||
auto apply_rope = [forward, s](
|
||||
const array& x1,
|
||||
|
||||
Reference in New Issue
Block a user