Fix rope fallback to not upcast (#1797)

* fix rope fallback to not upcast

* Update mlx/fast.cpp

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2025-01-26 19:07:21 -08:00
committed by GitHub
parent 0cea88bcc5
commit 121d9a0702
2 changed files with 18 additions and 14 deletions

View File

@@ -392,8 +392,8 @@ array rope(
s);
};
auto inv_freqs =
inputs.size() == 3 ? reciprocal(inputs[2], s) : default_inv_freqs();
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
: default_inv_freqs();
auto theta =
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
auto coss = cos(theta, s);