Added implementation for Scaled RoPE. (#261)

* Added scale for RoPE

* Ran pre-commit

* Added RoPE scaling test

* Added docstring for scale parameter

* Modified docstrings
This commit is contained in:
Hazem Essam 2023-12-31 16:06:01 +02:00 committed by GitHub
parent a020a2d49d
commit e3b8da2a49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 5 deletions

View File

@ -24,13 +24,21 @@ class RoPE(Module):
implementation which is slightly less efficient. Default: ``False``.
base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Default: ``10000``.
scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
"""
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
def __init__(
self,
dims: int,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
):
super().__init__()
self.dims = dims
self.traditional = traditional
self.base = base
self.scale = scale
def _extra_repr(self):
return f"{self.dims}, traditional={self.traditional}"
@ -68,7 +76,7 @@ class RoPE(Module):
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, base=self.base, dtype=x.dtype
N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype
)
rope = (
@ -80,10 +88,15 @@ class RoPE(Module):
@staticmethod
def create_cos_sin_theta(
N: int, D: int, offset: int = 0, base: float = 10000, dtype=mx.float32
N: int,
D: int,
offset: int = 0,
base: float = 10000,
scale: float = 1.0,
dtype=mx.float32,
):
D = D // 2
positions = mx.arange(offset, N, dtype=dtype)
positions = mx.arange(offset, N, dtype=dtype) * scale
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
return mx.cos(theta), mx.sin(theta)

View File

@ -749,7 +749,7 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertEqual(y.dtype, mx.float32)
def test_rope(self):
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
rope = nn.RoPE(4, **kwargs)
shape = (1, 3, 4)
x = mx.random.uniform(shape=shape)