add base kwarg to rope (#186)

This commit is contained in:
Awni Hannun 2023-12-15 16:47:59 -08:00 committed by GitHub
parent 83f266c44c
commit 2e02acdc83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 10 deletions

View File

@ -154,8 +154,8 @@ class array {
}; };
private: private:
int idx;
const array& arr; const array& arr;
int idx;
}; };
ArrayIterator begin() const { ArrayIterator begin() const {

View File

@ -18,15 +18,18 @@ class RoPE(Module):
Args: Args:
dims (int): The feature dimensions to be rotated. If the input feature dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged. is larger than dims then the rest is left unchanged.
traditional (bool): If set to True choose the traditional traditional (bool, optional): If set to True choose the traditional
implementation which is slightly less efficient. implementation which is slightly less efficient. Default: ``False``
base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Default: ``10000``
""" """
def __init__(self, dims: int, traditional: bool = False): def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
super().__init__() super().__init__()
self.dims = dims self.dims = dims
self.traditional = traditional self.traditional = traditional
self.base = base
def _extra_repr(self): def _extra_repr(self):
return f"{self.dims}, traditional={self.traditional}" return f"{self.dims}, traditional={self.traditional}"
@ -64,7 +67,7 @@ class RoPE(Module):
x = mx.reshape(x, (-1, shape[-2], shape[-1])) x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta( costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, dtype=x.dtype N, self.dims, offset=offset, base=self.base, dtype=x.dtype
) )
rope = ( rope = (
@ -82,10 +85,7 @@ class RoPE(Module):
positions = mx.arange(offset, N, dtype=dtype) positions = mx.arange(offset, N, dtype=dtype)
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D)) freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
costheta = mx.cos(theta) return mx.cos(theta), mx.sin(theta)
sintheta = mx.sin(theta)
return costheta, sintheta
class SinusoidalPositionalEncoding(Module): class SinusoidalPositionalEncoding(Module):

View File

@ -463,6 +463,21 @@ class TestNN(mlx_tests.MLXTestCase):
mx.array([0.8651, -0.3034, 0.0000, 0.3752]), mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
) )
def test_rope(self):
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
rope = nn.RoPE(4, **kwargs)
shape = (1, 3, 4)
x = mx.random.uniform(shape=shape)
y = rope(x)
self.assertTrue(y.shape, shape)
self.assertTrue(y.dtype, mx.float32)
y = rope(x, offset=3)
self.assertTrue(y.shape, shape)
y = rope(x.astype(mx.float16))
self.assertTrue(y.dtype, mx.float16)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()