mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Fix rope fallback to not upcast (#1797)
* fix rope fallback to not upcast * Update mlx/fast.cpp Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -392,8 +392,8 @@ array rope(
|
||||
s);
|
||||
};
|
||||
|
||||
auto inv_freqs =
|
||||
inputs.size() == 3 ? reciprocal(inputs[2], s) : default_inv_freqs();
|
||||
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
|
||||
: default_inv_freqs();
|
||||
auto theta =
|
||||
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
|
||||
auto coss = cos(theta, s);
|
||||
|
Reference in New Issue
Block a user