mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
add base kwarg to rope (#186)
This commit is contained in:
parent
83f266c44c
commit
2e02acdc83
@ -154,8 +154,8 @@ class array {
|
|||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int idx;
|
|
||||||
const array& arr;
|
const array& arr;
|
||||||
|
int idx;
|
||||||
};
|
};
|
||||||
|
|
||||||
ArrayIterator begin() const {
|
ArrayIterator begin() const {
|
||||||
|
@ -18,15 +18,18 @@ class RoPE(Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
dims (int): The feature dimensions to be rotated. If the input feature
|
dims (int): The feature dimensions to be rotated. If the input feature
|
||||||
is larger than dims then the rest is left unchanged.
|
is larger than dims then the rest is left unchanged.
|
||||||
traditional (bool): If set to True choose the traditional
|
traditional (bool, optional): If set to True choose the traditional
|
||||||
implementation which is slightly less efficient.
|
implementation which is slightly less efficient. Default: ``False``
|
||||||
|
base (float, optional): The base used to compute angular frequency for
|
||||||
|
each dimension in the positional encodings. Default: ``10000``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dims: int, traditional: bool = False):
|
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.traditional = traditional
|
self.traditional = traditional
|
||||||
|
self.base = base
|
||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self):
|
||||||
return f"{self.dims}, traditional={self.traditional}"
|
return f"{self.dims}, traditional={self.traditional}"
|
||||||
@ -64,7 +67,7 @@ class RoPE(Module):
|
|||||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||||
N = x.shape[1] + offset
|
N = x.shape[1] + offset
|
||||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||||
N, self.dims, offset=offset, dtype=x.dtype
|
N, self.dims, offset=offset, base=self.base, dtype=x.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
rope = (
|
rope = (
|
||||||
@ -82,10 +85,7 @@ class RoPE(Module):
|
|||||||
positions = mx.arange(offset, N, dtype=dtype)
|
positions = mx.arange(offset, N, dtype=dtype)
|
||||||
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
|
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
|
||||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||||
costheta = mx.cos(theta)
|
return mx.cos(theta), mx.sin(theta)
|
||||||
sintheta = mx.sin(theta)
|
|
||||||
|
|
||||||
return costheta, sintheta
|
|
||||||
|
|
||||||
|
|
||||||
class SinusoidalPositionalEncoding(Module):
|
class SinusoidalPositionalEncoding(Module):
|
||||||
|
@ -463,6 +463,21 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
|
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_rope(self):
|
||||||
|
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
|
||||||
|
rope = nn.RoPE(4, **kwargs)
|
||||||
|
shape = (1, 3, 4)
|
||||||
|
x = mx.random.uniform(shape=shape)
|
||||||
|
y = rope(x)
|
||||||
|
self.assertTrue(y.shape, shape)
|
||||||
|
self.assertTrue(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
y = rope(x, offset=3)
|
||||||
|
self.assertTrue(y.shape, shape)
|
||||||
|
|
||||||
|
y = rope(x.astype(mx.float16))
|
||||||
|
self.assertTrue(y.dtype, mx.float16)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user