Added scale for RoPE

This commit is contained in:
Hazem 2023-12-23 00:05:55 +02:00
parent cd3616a463
commit 273d558117

View File

@ -25,11 +25,12 @@ class RoPE(Module):
each dimension in the positional encodings. Default: ``10000`` each dimension in the positional encodings. Default: ``10000``
""" """
def __init__(self, dims: int, traditional: bool = False, base: float = 10000): def __init__(self, dims: int, traditional: bool = False, base: float = 10000, scale: float = 1.0):
super().__init__() super().__init__()
self.dims = dims self.dims = dims
self.traditional = traditional self.traditional = traditional
self.base = base self.base = base
self.scale = scale
def _extra_repr(self): def _extra_repr(self):
return f"{self.dims}, traditional={self.traditional}" return f"{self.dims}, traditional={self.traditional}"
@ -67,7 +68,7 @@ class RoPE(Module):
x = mx.reshape(x, (-1, shape[-2], shape[-1])) x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta( costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, base=self.base, dtype=x.dtype N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype
) )
rope = ( rope = (
@ -79,10 +80,10 @@ class RoPE(Module):
@staticmethod @staticmethod
def create_cos_sin_theta( def create_cos_sin_theta(
N: int, D: int, offset: int = 0, base: float = 10000, dtype=mx.float32 N: int, D: int, offset: int = 0, base: float = 10000, scale: float = 1.0, dtype=mx.float32
): ):
D = D // 2 D = D // 2
positions = mx.arange(offset, N, dtype=dtype) positions = mx.arange(offset, N, dtype=dtype) * scale
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D)) freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
return mx.cos(theta), mx.sin(theta) return mx.cos(theta), mx.sin(theta)