mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 10:27:41 +08:00
Added scale for RoPE
This commit is contained in:
parent
cd3616a463
commit
273d558117
@ -25,11 +25,12 @@ class RoPE(Module):
|
||||
each dimension in the positional encodings. Default: ``10000``
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
|
||||
def __init__(self, dims: int, traditional: bool = False, base: float = 10000, scale: float = 1.0):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
self.base = base
|
||||
self.scale = scale
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, traditional={self.traditional}"
|
||||
@ -67,7 +68,7 @@ class RoPE(Module):
|
||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, base=self.base, dtype=x.dtype
|
||||
N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
@ -79,10 +80,10 @@ class RoPE(Module):
|
||||
|
||||
@staticmethod
|
||||
def create_cos_sin_theta(
|
||||
N: int, D: int, offset: int = 0, base: float = 10000, dtype=mx.float32
|
||||
N: int, D: int, offset: int = 0, base: float = 10000, scale: float = 1.0, dtype=mx.float32
|
||||
):
|
||||
D = D // 2
|
||||
positions = mx.arange(offset, N, dtype=dtype)
|
||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||
return mx.cos(theta), mx.sin(theta)
|
||||
|
Loading…
Reference in New Issue
Block a user