From 874b739f3cd3ff141d8cb8d8f0751d98475e65d9 Mon Sep 17 00:00:00 2001 From: David Koski <46639364+davidkoski@users.noreply.github.com> Date: Fri, 26 Jan 2024 13:10:02 -0800 Subject: [PATCH] Fix cache key in RoPE (#561) --- python/mlx/nn/layers/positional_encoding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py index 1c586693f..1935792a4 100644 --- a/python/mlx/nn/layers/positional_encoding.py +++ b/python/mlx/nn/layers/positional_encoding.py @@ -104,9 +104,9 @@ class RoPE(Module): dtype=mx.float32, ): if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key: - D = D // 2 + half_D = D // 2 positions = mx.arange(offset, N, dtype=dtype) * scale - freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D)) + freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)) theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype) cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta))