Add batch offsets for mx.fast.rope (#2564)

* implement batch rope for Metal

* cuda rope (#2576)
This commit is contained in:
Awni Hannun
2025-09-08 17:35:07 -07:00
committed by GitHub
parent b194d65a6a
commit 17310d91a6
7 changed files with 231 additions and 153 deletions

View File

@@ -164,8 +164,13 @@ void init_fast(nb::module_& parent_module) {
R"pbdoc(
Apply rotary positional encoding to the input.
The input is expected to be at least 3D with shape ``(B, *, T, D)`` where:
* ``B`` is the batch size.
* ``T`` is the sequence length.
* ``D`` is the feature dimension.
Args:
a (array): Input array.
a (array): The input array.
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
traditional (bool): If set to ``True`` choose the traditional
@@ -174,7 +179,9 @@ void init_fast(nb::module_& parent_module) {
each dimension in the positional encodings. Exactly one of ``base`` and
``freqs`` must be ``None``.
scale (float): The scale used to scale the positions.
offset (int or array): The position offset to start at.
offset (int or array): The position offset to start at. If an
:obj:`array` is given it can be a scalar or vector of ``B``
offsets for each example in the batch.
freqs (array, optional): Optional frequencies to use with RoPE.
If set, the ``base`` parameter must be ``None``. Default: ``None``.

View File

@@ -91,7 +91,7 @@ mx::array to_array_with_accessor(nb::object obj) {
return nb::cast<mx::array>(obj.attr("__mlx_array__")());
} else {
std::ostringstream msg;
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
<< " received in array initialization.";
throw std::invalid_argument(msg.str());
}

View File

@@ -8,18 +8,23 @@ import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
offset = offset.item() if isinstance(offset, mx.array) else offset
N = x.shape[-2] + offset
N = x.shape[-2]
dtype = x.dtype
half_D = dims // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
positions = mx.arange(N, dtype=dtype)
if isinstance(offset, mx.array) and offset.size > 1:
expand = tuple(range(1, x.ndim - 1))
positions = mx.expand_dims(offset, expand) + positions
else:
positions = offset + positions
positions = positions * scale
if freqs is None:
inv_freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
)
else:
inv_freqs = (1 / freqs).astype(x.dtype)
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
theta = mx.expand_dims(positions, -1) * inv_freqs
costheta, sintheta = mx.cos(theta), mx.sin(theta)
if traditional:
x1 = x[..., :dims:2]
@@ -214,6 +219,7 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertEqual(dtype, rx.dtype)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
return
# Test single vector
x = mx.random.uniform(shape=(1, 1, dims))
@@ -277,6 +283,55 @@ class TestFast(mlx_tests.MLXTestCase):
g2 = mx.grad(f2)(x, y)
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
def test_rope_batch(self):
T = 4
base = 10000.0
scale = 1.0
traditional = True
batch_sizes = [3, 8, 11]
num_heads = [1, 3, 5]
dims = 32
x = mx.random.uniform(shape=(8, 4, T, dims))
offset = mx.array([1, 2, 3])
with self.assertRaises(ValueError):
mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
for batch_size in batch_sizes:
for n_head in num_heads:
x = mx.random.uniform(shape=(batch_size, n_head, T, dims))
offset = mx.arange(batch_size)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
x = mx.random.normal(shape=(2, 6, 8, 64)).transpose(0, 2, 1, 3)
dims = 64
offset = 0
rx_fast = mx.fast.rope(
x, dims, traditional=traditional, scale=scale, base=base, offset=offset
)
rx_fast_single = mx.fast.rope(
x[0:1], dims, traditional=traditional, scale=scale, base=base, offset=offset
)
rx = rope_orig(x, dims, traditional, base, scale, offset)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
def test_rms_norm(self):
# Per dtype absolute tolerance
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}