mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
add base kwarg to rope (#186)
This commit is contained in:
@@ -18,15 +18,18 @@ class RoPE(Module):
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimensions to be rotated. If the input feature
|
||||
is larger than dims then the rest is left unchanged.
|
||||
traditional (bool): If set to True choose the traditional
|
||||
implementation which is slightly less efficient.
|
||||
is larger than dims then the rest is left unchanged.
|
||||
traditional (bool, optional): If set to True choose the traditional
|
||||
implementation which is slightly less efficient. Default: ``False``
|
||||
base (float, optional): The base used to compute angular frequency for
|
||||
each dimension in the positional encodings. Default: ``10000``
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, traditional: bool = False):
|
||||
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
self.base = base
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, traditional={self.traditional}"
|
||||
@@ -64,7 +67,7 @@ class RoPE(Module):
|
||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, dtype=x.dtype
|
||||
N, self.dims, offset=offset, base=self.base, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
@@ -82,10 +85,7 @@ class RoPE(Module):
|
||||
positions = mx.arange(offset, N, dtype=dtype)
|
||||
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||
costheta = mx.cos(theta)
|
||||
sintheta = mx.sin(theta)
|
||||
|
||||
return costheta, sintheta
|
||||
return mx.cos(theta), mx.sin(theta)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(Module):
|
||||
|
Reference in New Issue
Block a user