mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 01:51:18 +08:00
Fix cache key in RoPE (#561)
This commit is contained in:
parent
077c1ee64a
commit
874b739f3c
@ -104,9 +104,9 @@ class RoPE(Module):
|
||||
dtype=mx.float32,
|
||||
):
|
||||
if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key:
|
||||
D = D // 2
|
||||
half_D = D // 2
|
||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
|
||||
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D))
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||
cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype)
|
||||
cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta))
|
||||
|
Loading…
Reference in New Issue
Block a user