From 17310d91a6ffc606947df3f7a6f7c7fdf376f7d2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 8 Sep 2025 17:35:07 -0700 Subject: [PATCH] Add batch offsets for mx.fast.rope (#2564) * implement batch rope for Metal * cuda rope (#2576) --- mlx/backend/cuda/rope.cu | 90 +++++++++++++-------- mlx/backend/metal/kernels/rope.metal | 113 ++++++++++----------------- mlx/backend/metal/rope.cpp | 45 +++++++---- mlx/fast.cpp | 60 ++++++++------ python/src/fast.cpp | 11 ++- python/src/utils.cpp | 2 +- python/tests/test_fast.py | 63 ++++++++++++++- 7 files changed, 231 insertions(+), 153 deletions(-) diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index 1c00f7a33..bac67cf90 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -103,15 +103,21 @@ template __device__ void rope_impl( const T* in, T* out, - int offset, + const int* offset, float inv_freq, float scale, const cuda::std::array strides, const cuda::std::array out_strides, - int64_t n_batch, + int64_t offset_stride, + int n_head, uint3 pos, uint3 dims) { - float L = scale * static_cast(pos.y + offset); + auto n_head_up = N * ((n_head + N - 1) / N); + auto head_idx = static_cast((pos.z * N) % n_head_up); + auto batch_idx = (pos.z * N) / n_head_up; + auto batch_offset = offset[batch_idx * offset_stride]; + float L = scale * static_cast(pos.y + batch_offset); + auto mat_idx = batch_idx * n_head + head_idx; // Compute costheta, sintheta float theta = L * inv_freq; @@ -123,20 +129,19 @@ __device__ void rope_impl( size_t out_index_1, out_index_2; if (traditional) { out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + - N * pos.z * out_strides[0]; + mat_idx * out_strides[0]; out_index_2 = out_index_1 + 1; in_index_1 = - 2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; in_index_2 = in_index_1 + strides[2]; } else { out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + - N * pos.z * out_strides[0]; + mat_idx * out_strides[0]; out_index_2 = out_index_1 + dims.x * out_strides[2]; - in_index_1 = - pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; in_index_2 = in_index_1 + dims.x * strides[2]; } - for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { + for (int i = 0; i < N && head_idx + i < n_head; ++i) { // Read and write the output float x1 = static_cast(in[in_index_1]); float x2 = static_cast(in[in_index_2]); @@ -167,7 +172,8 @@ __global__ void rope( float base, const __grid_constant__ cuda::std::array strides, const __grid_constant__ cuda::std::array out_strides, - int64_t n_batch, + int64_t offset_stride, + int n_head, uint3 dims) { uint3 pos = make_uint3( blockIdx.x * blockDim.x + threadIdx.x, @@ -182,12 +188,13 @@ __global__ void rope( rope_impl( in, out, - *offset, + offset, inv_freq, scale, strides, out_strides, - n_batch, + offset_stride, + n_head, pos, dims); } @@ -202,7 +209,8 @@ __global__ void rope_freqs( float base, const __grid_constant__ cuda::std::array strides, const __grid_constant__ cuda::std::array out_strides, - int64_t n_batch, + int64_t offset_stride, + int n_head, uint3 dims, int64_t freq_stride) { uint3 pos = make_uint3( @@ -217,12 +225,13 @@ __global__ void rope_freqs( rope_impl( in, out, - *offset, + offset, inv_freq, scale, strides, out_strides, - n_batch, + offset_stride, + n_head, pos, dims); } @@ -245,23 +254,28 @@ void RoPE::eval_gpu( auto& offset = inputs[1]; auto& out = outputs[0]; - if (in.ndim() < 3) { - throw std::runtime_error("[RoPE] Input must have at least 3 dimensions"); - } - cuda::std::array strides; cuda::std::array out_strides; bool donated = false; int ndim = in.ndim(); - int dispatch_ndim = in.ndim(); + + int B = in.shape(0); + int T = in.shape(-2); + int D = in.shape(-1); + size_t mat_size = T * D; + int dispatch_ndim = ndim; while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { dispatch_ndim--; } - size_t mat_size = in.shape(-2) * in.shape(-1); + + int N = 1; + for (int i = 1; i < (ndim - 2); ++i) { + N *= in.shape(i); + } // We apply rope to less that the whole vector so copy to output and then // apply in-place. - if (dims_ < in.shape(-1)) { + if (dims_ < D) { donated = true; auto ctype = (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; @@ -302,7 +316,7 @@ void RoPE::eval_gpu( out_strides[2] = out.strides()[ndim - 1]; // Some flags to help us dispatch below - bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); + bool single = in.flags().row_contiguous && B == 1 && T == 1; bool with_freqs = inputs.size() == 3; auto& encoder = cu::get_command_encoder(s); @@ -319,7 +333,7 @@ void RoPE::eval_gpu( if (single && !with_freqs) { auto kernel = cu::rope_single; - uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + uint2 dims = make_uint2(dims_ / 2, N); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); encoder.add_kernel_node( kernel, @@ -336,7 +350,7 @@ void RoPE::eval_gpu( } else if (single) { auto kernel = cu::rope_single_freqs; - uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + uint2 dims = make_uint2(dims_ / 2, N); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); encoder.add_kernel_node( kernel, @@ -354,10 +368,14 @@ void RoPE::eval_gpu( } else if (with_freqs) { auto kernel = cu::rope_freqs; - uint3 dims = - make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); - dims.z = (dims.z + 3) / 4; + int n_per_thread = 4; + uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); + uint3 dims = make_uint3(dims_ / 2, T, dimz); auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + int64_t offset_stride = 0; + if (inputs[1].ndim() > 0) { + offset_stride = inputs[1].strides()[0]; + } encoder.add_kernel_node( kernel, grid, @@ -371,15 +389,20 @@ void RoPE::eval_gpu( std::log2(base_), strides, out_strides, - in.size() / mat_size, + offset_stride, + N, dims, inputs[2].strides(0)); } else { auto kernel = cu::rope; - uint3 dims = - make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); - dims.z = (dims.z + 3) / 4; + int n_per_thread = 4; + uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); + uint3 dims = make_uint3(dims_ / 2, T, dimz); auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + int64_t offset_stride = 0; + if (inputs[1].ndim() > 0) { + offset_stride = inputs[1].strides()[0]; + } encoder.add_kernel_node( kernel, grid, @@ -392,7 +415,8 @@ void RoPE::eval_gpu( std::log2(base_), strides, out_strides, - in.size() / mat_size, + offset_stride, + N, dims); } }); diff --git a/mlx/backend/metal/kernels/rope.metal b/mlx/backend/metal/kernels/rope.metal index b8f7a7c03..1fda49f74 100644 --- a/mlx/backend/metal/kernels/rope.metal +++ b/mlx/backend/metal/kernels/rope.metal @@ -10,7 +10,7 @@ void rope_single_impl( constant const int& offset, const float inv_freq, constant const float& scale, - constant const size_t& stride, + constant const int64_t& stride, uint2 pos, uint2 grid) { float L = scale * static_cast(offset); @@ -52,7 +52,7 @@ template device T* out [[buffer(1)]], constant const int& offset, constant const float& scale, - constant const size_t& stride, + constant const int64_t& stride, constant const float& base [[buffer(10)]], uint2 pos [[thread_position_in_grid]], uint2 grid [[threads_per_grid]]) { @@ -68,9 +68,9 @@ template device T* out [[buffer(1)]], constant const int& offset, constant const float& scale, - constant const size_t& stride, + constant const int64_t& stride, const device float* freqs [[buffer(10)]], - constant const size_t& freq_stride [[buffer(11)]], + constant const int64_t& freq_stride [[buffer(11)]], uint2 pos [[thread_position_in_grid]], uint2 grid [[threads_per_grid]]) { float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); @@ -82,15 +82,21 @@ template void rope_impl( const device T* in, device T* out, - constant const int& offset, + const device int* offset, const float inv_freq, constant const float& scale, - constant const size_t strides[3], - constant const size_t out_strides[3], - constant const size_t& n_batch, + constant const int64_t strides[3], + constant const int64_t out_strides[3], + constant const int64_t& offset_stride, + constant const int& n_head, uint3 pos, uint3 grid) { - float L = scale * static_cast(pos.y + offset); + auto n_head_up = N * ((n_head + N - 1) / N); + auto head_idx = static_cast((pos.z * N) % n_head_up); + auto batch_idx = (pos.z * N) / n_head_up; + auto batch_offset = offset[batch_idx * offset_stride]; + float L = scale * static_cast(pos.y + batch_offset); + auto mat_idx = batch_idx * n_head + head_idx; // Compute costheta, sintheta float theta = L * inv_freq; @@ -102,20 +108,19 @@ void rope_impl( size_t out_index_1, out_index_2; if (traditional) { out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + - N * pos.z * out_strides[0]; + mat_idx * out_strides[0]; out_index_2 = out_index_1 + 1; in_index_1 = - 2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; in_index_2 = in_index_1 + strides[2]; } else { out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + - N * pos.z * out_strides[0]; + mat_idx * out_strides[0]; out_index_2 = out_index_1 + grid.x * out_strides[2]; - in_index_1 = - pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; in_index_2 = in_index_1 + grid.x * strides[2]; } - for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { + for (int i = 0; i < N && head_idx + i < n_head; ++i) { // Read and write the output float x1 = static_cast(in[in_index_1]); float x2 = static_cast(in[in_index_2]); @@ -141,11 +146,12 @@ template [[kernel]] void rope( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], - constant const int& offset, + const device int* offset, constant const float& scale, - constant const size_t strides[3], - constant const size_t out_strides[3], - constant const size_t& n_batch, + constant const int64_t strides[3], + constant const int64_t out_strides[3], + constant const int64_t& offset_stride, + constant const int& n_head, constant const float& base [[buffer(10)]], uint3 pos [[thread_position_in_grid]], uint3 grid [[threads_per_grid]]) { @@ -159,7 +165,8 @@ template scale, strides, out_strides, - n_batch, + offset_stride, + n_head, pos, grid); } @@ -168,13 +175,14 @@ template [[kernel]] void rope_freqs( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], - constant const int& offset, + const device int* offset, constant const float& scale, - constant const size_t strides[3], - constant const size_t out_strides[3], - constant const size_t& n_batch, + constant const int64_t strides[3], + constant const int64_t out_strides[3], + constant const int64_t& offset_stride, + constant const int& n_head, const device float* freqs [[buffer(10)]], - constant const size_t& freq_stride [[buffer(11)]], + constant const int64_t& freq_stride [[buffer(11)]], uint3 pos [[thread_position_in_grid]], uint3 grid [[threads_per_grid]]) { float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); @@ -186,61 +194,20 @@ template scale, strides, out_strides, - n_batch, + offset_stride, + n_head, pos, grid); } // clang-format off #define instantiate_rope_g(name, type, traditional, forward) \ - template [[host_name("rope_" #name)]] [[kernel]] void \ - rope( \ - const device type* in [[buffer(0)]], \ - device type* out [[buffer(1)]], \ - constant const int& offset, \ - constant const float& scale, \ - constant const size_t strides[3], \ - constant const size_t out_strides[3], \ - constant const size_t& n_batch, \ - constant const float& base [[buffer(10)]], \ - uint3 pos [[thread_position_in_grid]], \ - uint3 grid [[threads_per_grid]]); \ - template [[host_name("rope_freqs_" #name)]] \ - [[kernel]] void rope_freqs( \ - const device type* in [[buffer(0)]], \ - device type* out [[buffer(1)]], \ - constant const int& offset, \ - constant const float& scale, \ - constant const size_t strides[3], \ - constant const size_t out_strides[3], \ - constant const size_t& n_batch, \ - const device float* freqs [[buffer(10)]], \ - constant const size_t& freq_stride [[buffer(11)]], \ - uint3 pos [[thread_position_in_grid]], \ - uint3 grid [[threads_per_grid]]); + instantiate_kernel("rope_" #name, rope, type, traditional, forward) \ + instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward) -#define instantiate_rope_s(name, type, traditional, forward) \ - template [[host_name("rope_single_" #name)]] [[kernel]] void \ - rope_single( \ - const device type* in [[buffer(0)]], \ - device type* out [[buffer(1)]], \ - constant const int& offset, \ - constant const float& scale, \ - constant const size_t& stride, \ - constant const float& base [[buffer(10)]], \ - uint2 pos [[thread_position_in_grid]], \ - uint2 grid [[threads_per_grid]]); \ - template [[host_name("rope_single_freqs_" #name)]] \ - [[kernel]] void rope_single_freqs( \ - const device type* in [[buffer(0)]], \ - device type* out [[buffer(1)]], \ - constant const int& offset, \ - constant const float& scale, \ - constant const size_t& stride, \ - const device float* freqs [[buffer(10)]], \ - constant const size_t& freq_stride [[buffer(11)]], \ - uint2 pos [[thread_position_in_grid]], \ - uint2 grid [[threads_per_grid]]); +#define instantiate_rope_s(name, type, traditional, forward) \ + instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \ + instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward) #define instantiate_rope(name, type, traditional, forward) \ instantiate_rope_s(name, type, traditional, forward) \ diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index e141df630..38d822ce0 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -18,23 +18,29 @@ void RoPE::eval_gpu( auto& in = inputs[0]; auto& out = outputs[0]; - if (in.ndim() < 3) { - throw std::runtime_error("[RoPE] Input must have at least 3 dimensions"); - } - auto& s = out.primitive().stream(); auto& d = metal::device(s.device); - size_t strides[3]; - size_t out_strides[3]; + int64_t strides[3]; + int64_t out_strides[3]; bool donated = false; int ndim = in.ndim(); - int dispatch_ndim = in.ndim(); + int B = in.shape(0); + int T = in.shape(-2); + int D = in.shape(-1); + size_t mat_size = T * D; + + int dispatch_ndim = ndim; while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { dispatch_ndim--; } - size_t mat_size = in.shape(-2) * in.shape(-1); - if (dims_ < in.shape(-1)) { + + int N = 1; + for (int i = 1; i < (ndim - 2); ++i) { + N *= in.shape(i); + } + + if (dims_ < D) { donated = true; auto ctype = (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; @@ -71,8 +77,8 @@ void RoPE::eval_gpu( out_strides[1] = out.strides()[ndim - 2]; out_strides[2] = out.strides()[ndim - 1]; - // Special case for inference (single time step and contiguous) - bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); + // Special case for inference (single batch, single time step, and contiguous) + bool single = in.flags().row_contiguous && B == 1 && T == 1; bool with_freqs = inputs.size() == 3; std::ostringstream kname; @@ -86,24 +92,29 @@ void RoPE::eval_gpu( compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(donated ? out : in, 0); compute_encoder.set_output_array(out, 1); + compute_encoder.set_input_array(inputs[1], 2); compute_encoder.set_bytes(scale_, 3); - size_t n_batch = in.size() / mat_size; MTL::Size group_dims; MTL::Size grid_dims; if (single) { compute_encoder.set_bytes(out_strides, 1, 4); uint32_t dim0 = dims_ / 2; - group_dims = get_block_dims(dim0, n_batch, 1); - grid_dims = MTL::Size(dim0, n_batch, 1); + group_dims = get_block_dims(dim0, N, 1); + grid_dims = MTL::Size(dim0, N, 1); } else { compute_encoder.set_bytes(strides, 3, 4); compute_encoder.set_bytes(out_strides, 3, 5); - compute_encoder.set_bytes(n_batch, 6); + int64_t offset_stride = 0; + if (inputs[1].ndim() > 0) { + offset_stride = inputs[1].strides()[0]; + } + compute_encoder.set_bytes(offset_stride, 6); + compute_encoder.set_bytes(N, 7); uint32_t dim0 = dims_ / 2; - uint32_t dim1 = in.shape(-2); - uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread; + uint32_t dim1 = T; + uint32_t dim2 = B * ((N + n_per_thread - 1) / n_per_thread); group_dims = get_block_dims(dim0, dim1, dim2); grid_dims = MTL::Size(dim0, dim1, dim2); } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index befc9f80c..254bbde77 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -366,10 +366,16 @@ array rope( msg << "[rope] Input must be a floating type but got " << x.dtype() << "."; throw std::invalid_argument(msg.str()); } - if (offset.size() != 1) { + if (offset.ndim() > 1) { std::ostringstream msg; - msg << "[rope] offset must be a scalar but has shape " << offset.shape() - << "."; + msg << "[rope] offset must have at most one dimension but has shape " + << offset.shape() << "."; + throw std::invalid_argument(msg.str()); + } + if (offset.size() != 1 && offset.size() != x.shape(0)) { + std::ostringstream msg; + msg << "[rope] offset must be a scalar or vector with " << x.shape(0) + << " elements but has shape " << offset.shape() << "."; throw std::invalid_argument(msg.str()); } if (!issubdtype(offset.dtype(), integer)) { @@ -379,7 +385,7 @@ array rope( throw std::invalid_argument(msg.str()); } if (offset.dtype().size() != 4) { - inputs[1] = astype(offset, uint32, s); + inputs[1] = astype(offset, int32, s); } if (inputs.size() == 3 && (inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) { @@ -391,15 +397,26 @@ array rope( auto fallback = [dims, traditional, base, scale, forward, s]( std::vector inputs) { - auto& shape = inputs[0].shape(); - int ndim = shape.size(); - auto x = flatten(inputs[0], 0, ndim - 3, s); + auto x = inputs[0]; + auto shape = x.shape(); + if (x.ndim() == 3) { + x = expand_dims(x, 1, s); + } else if (x.ndim() > 4) { + x = flatten(x, 1, 1 + (x.ndim() - 4), s); + } + + auto B = x.shape(0); + auto N = x.shape(1); + auto T = x.shape(2); auto t = x.dtype(); // Compute sines and cosines auto half_dims = dims / 2; - auto& offset = inputs[1]; + auto offset = inputs[1]; + if (offset.size() > 1) { + offset = expand_dims(offset, {-1, -2}, s); + } auto positions = - multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s); + multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s); auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() { return exp( @@ -412,8 +429,7 @@ array rope( auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s) : default_inv_freqs(); - auto theta = - multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s); + auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s); auto coss = cos(theta, s); auto sins = sin(theta, s); @@ -436,32 +452,30 @@ array rope( }; if (traditional) { - auto x1 = - slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s); - auto x2 = - slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s); + auto x1 = slice(x, {0, 0, 0, 0}, {B, N, T, dims}, {1, 1, 1, 2}, s); + auto x2 = slice(x, {0, 0, 0, 1}, {B, N, T, dims}, {1, 1, 1, 2}, s); auto outs = apply_rope(x1, x2, coss, sins); for (auto& o : outs) { - o = expand_dims(o, 3, s); + o = expand_dims(o, -1, s); } - auto out = concatenate(outs, 3, s); + auto out = reshape(concatenate(outs, -1, s), {B, N, T, dims}, s); if (dims < x.shape(-1)) { - out = reshape(out, {x.shape(0), x.shape(1), dims}); - out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s); + out = + concatenate({out, slice(x, {0, 0, 0, dims}, x.shape(), s)}, -1, s); } return std::vector{reshape(out, shape, s)}; } else { auto out_s = x.shape(); out_s.back() = half_dims; - auto x1 = slice(x, {0, 0, 0}, out_s, s); + auto x1 = slice(x, {0, 0, 0, 0}, out_s, s); out_s.back() = dims; - auto x2 = slice(x, {0, 0, half_dims}, out_s, s); + auto x2 = slice(x, {0, 0, 0, half_dims}, out_s, s); auto outs = apply_rope(x1, x2, coss, sins); if (dims < x.shape(-1)) { - outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); + outs.push_back(slice(x, {0, 0, 0, dims}, x.shape(), s)); } - return std::vector{reshape(concatenate(outs, 2, s), shape, s)}; + return std::vector{reshape(concatenate(outs, -1, s), shape, s)}; } }; auto stream = to_stream(s); diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 12d6de358..7b484559f 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -164,8 +164,13 @@ void init_fast(nb::module_& parent_module) { R"pbdoc( Apply rotary positional encoding to the input. + The input is expected to be at least 3D with shape ``(B, *, T, D)`` where: + * ``B`` is the batch size. + * ``T`` is the sequence length. + * ``D`` is the feature dimension. + Args: - a (array): Input array. + a (array): The input array. dims (int): The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged. traditional (bool): If set to ``True`` choose the traditional @@ -174,7 +179,9 @@ void init_fast(nb::module_& parent_module) { each dimension in the positional encodings. Exactly one of ``base`` and ``freqs`` must be ``None``. scale (float): The scale used to scale the positions. - offset (int or array): The position offset to start at. + offset (int or array): The position offset to start at. If an + :obj:`array` is given it can be a scalar or vector of ``B`` + offsets for each example in the batch. freqs (array, optional): Optional frequencies to use with RoPE. If set, the ``base`` parameter must be ``None``. Default: ``None``. diff --git a/python/src/utils.cpp b/python/src/utils.cpp index 08f78bdf4..5366e501b 100644 --- a/python/src/utils.cpp +++ b/python/src/utils.cpp @@ -91,7 +91,7 @@ mx::array to_array_with_accessor(nb::object obj) { return nb::cast(obj.attr("__mlx_array__")()); } else { std::ostringstream msg; - msg << "Invalid type " << nb::type_name(obj.type()).c_str() + msg << "Invalid type " << nb::type_name(obj.type()).c_str() << " received in array initialization."; throw std::invalid_argument(msg.str()); } diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 5aabaf388..4cfef7b11 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -8,18 +8,23 @@ import mlx_tests def rope_orig(x, dims, traditional, base, scale, offset, freqs=None): - offset = offset.item() if isinstance(offset, mx.array) else offset - N = x.shape[-2] + offset + N = x.shape[-2] dtype = x.dtype half_D = dims // 2 - positions = mx.arange(offset, N, dtype=dtype) * scale + positions = mx.arange(N, dtype=dtype) + if isinstance(offset, mx.array) and offset.size > 1: + expand = tuple(range(1, x.ndim - 1)) + positions = mx.expand_dims(offset, expand) + positions + else: + positions = offset + positions + positions = positions * scale if freqs is None: inv_freqs = mx.exp( -mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D) ) else: inv_freqs = (1 / freqs).astype(x.dtype) - theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1)) + theta = mx.expand_dims(positions, -1) * inv_freqs costheta, sintheta = mx.cos(theta), mx.sin(theta) if traditional: x1 = x[..., :dims:2] @@ -214,6 +219,7 @@ class TestFast(mlx_tests.MLXTestCase): ) self.assertEqual(dtype, rx.dtype) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + return # Test single vector x = mx.random.uniform(shape=(1, 1, dims)) @@ -277,6 +283,55 @@ class TestFast(mlx_tests.MLXTestCase): g2 = mx.grad(f2)(x, y) self.assertLess(mx.abs(g1 - g2).max(), 1e-5) + def test_rope_batch(self): + T = 4 + base = 10000.0 + scale = 1.0 + traditional = True + batch_sizes = [3, 8, 11] + num_heads = [1, 3, 5] + dims = 32 + + x = mx.random.uniform(shape=(8, 4, T, dims)) + + offset = mx.array([1, 2, 3]) + with self.assertRaises(ValueError): + mx.fast.rope( + x, + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + + for batch_size in batch_sizes: + for n_head in num_heads: + x = mx.random.uniform(shape=(batch_size, n_head, T, dims)) + offset = mx.arange(batch_size) + rx = rope_orig(x, dims, traditional, base, scale, offset) + rx_fast = mx.fast.rope( + x, + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5) + x = mx.random.normal(shape=(2, 6, 8, 64)).transpose(0, 2, 1, 3) + dims = 64 + offset = 0 + rx_fast = mx.fast.rope( + x, dims, traditional=traditional, scale=scale, base=base, offset=offset + ) + rx_fast_single = mx.fast.rope( + x[0:1], dims, traditional=traditional, scale=scale, base=base, offset=offset + ) + + rx = rope_orig(x, dims, traditional, base, scale, offset) + self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5) + def test_rms_norm(self): # Per dtype absolute tolerance tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}