From 273d558117e19f29febb4177b617ce1db45d047f Mon Sep 17 00:00:00 2001 From: Hazem Date: Sat, 23 Dec 2023 00:05:55 +0200 Subject: [PATCH] Added scale for RoPE --- python/mlx/nn/layers/positional_encoding.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py index db436f407..38c91f29a 100644 --- a/python/mlx/nn/layers/positional_encoding.py +++ b/python/mlx/nn/layers/positional_encoding.py @@ -25,11 +25,12 @@ class RoPE(Module): each dimension in the positional encodings. Default: ``10000`` """ - def __init__(self, dims: int, traditional: bool = False, base: float = 10000): + def __init__(self, dims: int, traditional: bool = False, base: float = 10000, scale: float = 1.0): super().__init__() self.dims = dims self.traditional = traditional self.base = base + self.scale = scale def _extra_repr(self): return f"{self.dims}, traditional={self.traditional}" @@ -67,7 +68,7 @@ class RoPE(Module): x = mx.reshape(x, (-1, shape[-2], shape[-1])) N = x.shape[1] + offset costheta, sintheta = RoPE.create_cos_sin_theta( - N, self.dims, offset=offset, base=self.base, dtype=x.dtype + N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype ) rope = ( @@ -79,10 +80,10 @@ class RoPE(Module): @staticmethod def create_cos_sin_theta( - N: int, D: int, offset: int = 0, base: float = 10000, dtype=mx.float32 + N: int, D: int, offset: int = 0, base: float = 10000, scale: float = 1.0, dtype=mx.float32 ): D = D // 2 - positions = mx.arange(offset, N, dtype=dtype) + positions = mx.arange(offset, N, dtype=dtype) * scale freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D)) theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) return mx.cos(theta), mx.sin(theta)