Custom primitive + RoPE fat op (#676)

* extensions start

* rope custom op

* fix build

* docs + rope benchmark

* fix test

* Add a Metal kernel for RoPE

* Fix position of traditional

* transform tests

* Move rope computation to float and fix tests

* Fix the test and a typo

* change to fast

* fix no metal build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-02-14 14:04:25 -08:00
committed by GitHub
parent 1a48713d32
commit ccf1645995
18 changed files with 624 additions and 70 deletions

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import math
from typing import Optional
@@ -20,20 +20,13 @@ class RoPE(Module):
Args:
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
traditional (bool, optional): If set to True choose the traditional
traditional (bool, optional): If set to ``True`` choose the traditional
implementation which is slightly less efficient. Default: ``False``.
base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Default: ``10000``.
scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
Attributes:
_cos_sin_theta_key (tuple): Cached key for the precomputed cosine and sine values.
_cos_sin_theta_value (tuple): Cached cosine and sine values.
"""
_cos_sin_theta_key = None
_cos_sin_theta_value = None
def __init__(
self,
dims: int,
@@ -50,69 +43,18 @@ class RoPE(Module):
def _extra_repr(self):
return f"{self.dims}, traditional={self.traditional}"
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = mx.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype
x = mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=self.base,
scale=self.scale,
offset=offset,
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return mx.reshape(rx, shape)
@classmethod
def create_cos_sin_theta(
cls,
N: int,
D: int,
offset: int = 0,
base: float = 10000,
scale: float = 1.0,
dtype=mx.float32,
):
if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key:
half_D = D // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
)
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype)
cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta))
return cls._cos_sin_theta_value
return mx.reshape(x, shape)
class SinusoidalPositionalEncoding(Module):