add base kwarg to rope (#186)

This commit is contained in:
Awni Hannun 2023-12-15 16:47:59 -08:00 committed by GitHub
parent 83f266c44c
commit 2e02acdc83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 10 deletions

View File

@ -154,8 +154,8 @@ class array {
};
private:
int idx;
const array& arr;
int idx;
};
ArrayIterator begin() const {

View File

@ -19,14 +19,17 @@ class RoPE(Module):
Args:
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
traditional (bool): If set to True choose the traditional
implementation which is slightly less efficient.
traditional (bool, optional): If set to True choose the traditional
implementation which is slightly less efficient. Default: ``False``
base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Default: ``10000``
"""
def __init__(self, dims: int, traditional: bool = False):
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
super().__init__()
self.dims = dims
self.traditional = traditional
self.base = base
def _extra_repr(self):
return f"{self.dims}, traditional={self.traditional}"
@ -64,7 +67,7 @@ class RoPE(Module):
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, dtype=x.dtype
N, self.dims, offset=offset, base=self.base, dtype=x.dtype
)
rope = (
@ -82,10 +85,7 @@ class RoPE(Module):
positions = mx.arange(offset, N, dtype=dtype)
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
costheta = mx.cos(theta)
sintheta = mx.sin(theta)
return costheta, sintheta
return mx.cos(theta), mx.sin(theta)
class SinusoidalPositionalEncoding(Module):

View File

@ -463,6 +463,21 @@ class TestNN(mlx_tests.MLXTestCase):
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
)
def test_rope(self):
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
rope = nn.RoPE(4, **kwargs)
shape = (1, 3, 4)
x = mx.random.uniform(shape=shape)
y = rope(x)
self.assertTrue(y.shape, shape)
self.assertTrue(y.dtype, mx.float32)
y = rope(x, offset=3)
self.assertTrue(y.shape, shape)
y = rope(x.astype(mx.float16))
self.assertTrue(y.dtype, mx.float16)
if __name__ == "__main__":
unittest.main()