RoPE with frequencies as optional input (#1337)

* start rope with freq input

* rope with frequencies

* nits

* fix bug

* fix bug + test

* cleanup

* optional base
This commit is contained in:
Awni Hannun
2024-08-19 18:30:50 -07:00
committed by GitHub
parent 9d26441224
commit bb1b76d9dc
6 changed files with 319 additions and 69 deletions

View File

@@ -79,26 +79,29 @@ void init_fast(nb::module_& parent_module) {
"dims"_a,
nb::kw_only(),
"traditional"_a,
"base"_a,
"base"_a.none(),
"scale"_a,
"offset"_a,
"freqs"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: int, freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Apply rotary positional encoding to the input.
Args:
a (array): Input array.
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
is larger than dims then the rest is left unchanged.
traditional (bool): If set to ``True`` choose the traditional
implementation which rotates consecutive dimensions.
base (float): The base used to compute angular frequency for
each dimension in the positional encodings.
implementation which rotates consecutive dimensions.
base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Exactly one of ``base`` and
``freqs`` must be ``None``.
scale (float): The scale used to scale the positions.
offset (int): The position offset to start at.
freqs (array, optional): Optional frequencies to use with RoPE.
If set, the ``base`` parameter must be ``None``. ``Default: None``.
Returns:
array: The output array.
)pbdoc");
@@ -115,7 +118,7 @@ void init_fast(nb::module_& parent_module) {
"memory_efficient_threshold"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.

View File

@@ -7,13 +7,18 @@ import mlx.core as mx
import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset):
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
N = x.shape[1] + offset
dtype = x.dtype
half_D = dims // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
if freqs is None:
inv_freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
)
else:
inv_freqs = 1 / freqs
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
costheta, sintheta = mx.cos(theta), mx.sin(theta)
if traditional:
x1 = x[..., :dims:2]
@@ -138,6 +143,84 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_rope_with_freqs(self):
# Check throws
T = 4
dims = 8
x = mx.random.uniform(shape=(2, T, dims))
with self.assertRaises(ValueError):
freqs = mx.random.uniform(shape=(dims - 1,))
mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
with self.assertRaises(ValueError):
freqs = mx.random.uniform(shape=(1, dims))
mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
freqs = mx.random.uniform(shape=(dims // 2,))
rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)
rx_fast = mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
# Test single vector
x = mx.random.uniform(shape=(1, 1, dims))
rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)
rx_fast = mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
# Test grad with freqs
f1 = lambda x, y: (rope_orig(x, dims, False, None, 1.0, 0, freqs) * y).sum()
f2 = lambda x, y: (
mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
* y
).sum()
x = mx.random.uniform(shape=(2, 4, dims))
y = mx.random.uniform(shape=(2, 4, dims))
g1 = mx.grad(f1)(x, y)
g2 = mx.grad(f2)(x, y)
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
def test_rope_grad(self):
D = 32
defaults = (D, 10000.0, 1.0, 0, False)