Add batch offsets for mx.fast.rope (#2564)

* implement batch rope for Metal

* cuda rope (#2576)
This commit is contained in:
Awni Hannun
2025-09-08 17:35:07 -07:00
committed by GitHub
parent b194d65a6a
commit 17310d91a6
7 changed files with 231 additions and 153 deletions

View File

@@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
__device__ void rope_impl( __device__ void rope_impl(
const T* in, const T* in,
T* out, T* out,
int offset, const int* offset,
float inv_freq, float inv_freq,
float scale, float scale,
const cuda::std::array<int64_t, 3> strides, const cuda::std::array<int64_t, 3> strides,
const cuda::std::array<int64_t, 3> out_strides, const cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch, int64_t offset_stride,
int n_head,
uint3 pos, uint3 pos,
uint3 dims) { uint3 dims) {
float L = scale * static_cast<float>(pos.y + offset); auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
// Compute costheta, sintheta // Compute costheta, sintheta
float theta = L * inv_freq; float theta = L * inv_freq;
@@ -123,20 +129,19 @@ __device__ void rope_impl(
size_t out_index_1, out_index_2; size_t out_index_1, out_index_2;
if (traditional) { if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0]; mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1; out_index_2 = out_index_1 + 1;
in_index_1 = in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2]; in_index_2 = in_index_1 + strides[2];
} else { } else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0]; mat_idx * out_strides[0];
out_index_2 = out_index_1 + dims.x * out_strides[2]; out_index_2 = out_index_1 + dims.x * out_strides[2];
in_index_1 = in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + dims.x * strides[2]; in_index_2 = in_index_1 + dims.x * strides[2];
} }
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output // Read and write the output
float x1 = static_cast<float>(in[in_index_1]); float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]); float x2 = static_cast<float>(in[in_index_2]);
@@ -167,7 +172,8 @@ __global__ void rope(
float base, float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides, const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides, const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch, int64_t offset_stride,
int n_head,
uint3 dims) { uint3 dims) {
uint3 pos = make_uint3( uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x, blockIdx.x * blockDim.x + threadIdx.x,
@@ -182,12 +188,13 @@ __global__ void rope(
rope_impl<T, traditional, forward>( rope_impl<T, traditional, forward>(
in, in,
out, out,
*offset, offset,
inv_freq, inv_freq,
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
dims); dims);
} }
@@ -202,7 +209,8 @@ __global__ void rope_freqs(
float base, float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides, const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides, const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch, int64_t offset_stride,
int n_head,
uint3 dims, uint3 dims,
int64_t freq_stride) { int64_t freq_stride) {
uint3 pos = make_uint3( uint3 pos = make_uint3(
@@ -217,12 +225,13 @@ __global__ void rope_freqs(
rope_impl<T, traditional, forward>( rope_impl<T, traditional, forward>(
in, in,
out, out,
*offset, offset,
inv_freq, inv_freq,
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
dims); dims);
} }
@@ -245,23 +254,28 @@ void RoPE::eval_gpu(
auto& offset = inputs[1]; auto& offset = inputs[1];
auto& out = outputs[0]; auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
cuda::std::array<int64_t, 3> strides; cuda::std::array<int64_t, 3> strides;
cuda::std::array<int64_t, 3> out_strides; cuda::std::array<int64_t, 3> out_strides;
bool donated = false; bool donated = false;
int ndim = in.ndim(); int ndim = in.ndim();
int dispatch_ndim = in.ndim();
int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--; dispatch_ndim--;
} }
size_t mat_size = in.shape(-2) * in.shape(-1);
int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}
// We apply rope to less that the whole vector so copy to output and then // We apply rope to less that the whole vector so copy to output and then
// apply in-place. // apply in-place.
if (dims_ < in.shape(-1)) { if (dims_ < D) {
donated = true; donated = true;
auto ctype = auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -302,7 +316,7 @@ void RoPE::eval_gpu(
out_strides[2] = out.strides()[ndim - 1]; out_strides[2] = out.strides()[ndim - 1];
// Some flags to help us dispatch below // Some flags to help us dispatch below
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3; bool with_freqs = inputs.size() == 3;
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
@@ -319,7 +333,7 @@ void RoPE::eval_gpu(
if (single && !with_freqs) { if (single && !with_freqs) {
auto kernel = auto kernel =
cu::rope_single<DataType, traditional.value, forward.value>; cu::rope_single<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
@@ -336,7 +350,7 @@ void RoPE::eval_gpu(
} else if (single) { } else if (single) {
auto kernel = auto kernel =
cu::rope_single_freqs<DataType, traditional.value, forward.value>; cu::rope_single_freqs<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
@@ -354,10 +368,14 @@ void RoPE::eval_gpu(
} else if (with_freqs) { } else if (with_freqs) {
auto kernel = auto kernel =
cu::rope_freqs<DataType, traditional.value, forward.value>; cu::rope_freqs<DataType, traditional.value, forward.value>;
uint3 dims = int n_per_thread = 4;
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
dims.z = (dims.z + 3) / 4; uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
grid, grid,
@@ -371,15 +389,20 @@ void RoPE::eval_gpu(
std::log2(base_), std::log2(base_),
strides, strides,
out_strides, out_strides,
in.size() / mat_size, offset_stride,
N,
dims, dims,
inputs[2].strides(0)); inputs[2].strides(0));
} else { } else {
auto kernel = cu::rope<DataType, traditional.value, forward.value>; auto kernel = cu::rope<DataType, traditional.value, forward.value>;
uint3 dims = int n_per_thread = 4;
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
dims.z = (dims.z + 3) / 4; uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
grid, grid,
@@ -392,7 +415,8 @@ void RoPE::eval_gpu(
std::log2(base_), std::log2(base_),
strides, strides,
out_strides, out_strides,
in.size() / mat_size, offset_stride,
N,
dims); dims);
} }
}); });

View File

@@ -10,7 +10,7 @@ void rope_single_impl(
constant const int& offset, constant const int& offset,
const float inv_freq, const float inv_freq,
constant const float& scale, constant const float& scale,
constant const size_t& stride, constant const int64_t& stride,
uint2 pos, uint2 pos,
uint2 grid) { uint2 grid) {
float L = scale * static_cast<float>(offset); float L = scale * static_cast<float>(offset);
@@ -52,7 +52,7 @@ template <typename T, bool traditional, bool forward>
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, constant const int& offset,
constant const float& scale, constant const float& scale,
constant const size_t& stride, constant const int64_t& stride,
constant const float& base [[buffer(10)]], constant const float& base [[buffer(10)]],
uint2 pos [[thread_position_in_grid]], uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) { uint2 grid [[threads_per_grid]]) {
@@ -68,9 +68,9 @@ template <typename T, bool traditional, bool forward>
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, constant const int& offset,
constant const float& scale, constant const float& scale,
constant const size_t& stride, constant const int64_t& stride,
const device float* freqs [[buffer(10)]], const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]], constant const int64_t& freq_stride [[buffer(11)]],
uint2 pos [[thread_position_in_grid]], uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) { uint2 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
@@ -82,15 +82,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
void rope_impl( void rope_impl(
const device T* in, const device T* in,
device T* out, device T* out,
constant const int& offset, const device int* offset,
const float inv_freq, const float inv_freq,
constant const float& scale, constant const float& scale,
constant const size_t strides[3], constant const int64_t strides[3],
constant const size_t out_strides[3], constant const int64_t out_strides[3],
constant const size_t& n_batch, constant const int64_t& offset_stride,
constant const int& n_head,
uint3 pos, uint3 pos,
uint3 grid) { uint3 grid) {
float L = scale * static_cast<float>(pos.y + offset); auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
// Compute costheta, sintheta // Compute costheta, sintheta
float theta = L * inv_freq; float theta = L * inv_freq;
@@ -102,20 +108,19 @@ void rope_impl(
size_t out_index_1, out_index_2; size_t out_index_1, out_index_2;
if (traditional) { if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0]; mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1; out_index_2 = out_index_1 + 1;
in_index_1 = in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2]; in_index_2 = in_index_1 + strides[2];
} else { } else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0]; mat_idx * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2]; out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 = in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2]; in_index_2 = in_index_1 + grid.x * strides[2];
} }
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output // Read and write the output
float x1 = static_cast<float>(in[in_index_1]); float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]); float x2 = static_cast<float>(in[in_index_2]);
@@ -141,11 +146,12 @@ template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope( [[kernel]] void rope(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, const device int* offset,
constant const float& scale, constant const float& scale,
constant const size_t strides[3], constant const int64_t strides[3],
constant const size_t out_strides[3], constant const int64_t out_strides[3],
constant const size_t& n_batch, constant const int64_t& offset_stride,
constant const int& n_head,
constant const float& base [[buffer(10)]], constant const float& base [[buffer(10)]],
uint3 pos [[thread_position_in_grid]], uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) { uint3 grid [[threads_per_grid]]) {
@@ -159,7 +165,8 @@ template <typename T, bool traditional, bool forward, int N = 4>
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
grid); grid);
} }
@@ -168,13 +175,14 @@ template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope_freqs( [[kernel]] void rope_freqs(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, const device int* offset,
constant const float& scale, constant const float& scale,
constant const size_t strides[3], constant const int64_t strides[3],
constant const size_t out_strides[3], constant const int64_t out_strides[3],
constant const size_t& n_batch, constant const int64_t& offset_stride,
constant const int& n_head,
const device float* freqs [[buffer(10)]], const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]], constant const int64_t& freq_stride [[buffer(11)]],
uint3 pos [[thread_position_in_grid]], uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) { uint3 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
@@ -186,61 +194,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
grid); grid);
} }
// clang-format off // clang-format off
#define instantiate_rope_g(name, type, traditional, forward) \ #define instantiate_rope_g(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] [[kernel]] void \ instantiate_kernel("rope_" #name, rope, type, traditional, forward) \
rope<type, traditional, forward>( \ instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward)
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
constant const float& base [[buffer(10)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); \
template [[host_name("rope_freqs_" #name)]] \
[[kernel]] void rope_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_rope_s(name, type, traditional, forward) \ #define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \ instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \
rope_single<type, traditional, forward>( \ instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward)
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
constant const float& base [[buffer(10)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]); \
template [[host_name("rope_single_freqs_" #name)]] \
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]);
#define instantiate_rope(name, type, traditional, forward) \ #define instantiate_rope(name, type, traditional, forward) \
instantiate_rope_s(name, type, traditional, forward) \ instantiate_rope_s(name, type, traditional, forward) \

View File

@@ -18,23 +18,29 @@ void RoPE::eval_gpu(
auto& in = inputs[0]; auto& in = inputs[0];
auto& out = outputs[0]; auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
size_t strides[3]; int64_t strides[3];
size_t out_strides[3]; int64_t out_strides[3];
bool donated = false; bool donated = false;
int ndim = in.ndim(); int ndim = in.ndim();
int dispatch_ndim = in.ndim(); int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--; dispatch_ndim--;
} }
size_t mat_size = in.shape(-2) * in.shape(-1);
if (dims_ < in.shape(-1)) { int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}
if (dims_ < D) {
donated = true; donated = true;
auto ctype = auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -71,8 +77,8 @@ void RoPE::eval_gpu(
out_strides[1] = out.strides()[ndim - 2]; out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1]; out_strides[2] = out.strides()[ndim - 1];
// Special case for inference (single time step and contiguous) // Special case for inference (single batch, single time step, and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3; bool with_freqs = inputs.size() == 3;
std::ostringstream kname; std::ostringstream kname;
@@ -86,24 +92,29 @@ void RoPE::eval_gpu(
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(donated ? out : in, 0); compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder.set_input_array(inputs[1], 2); compute_encoder.set_input_array(inputs[1], 2);
compute_encoder.set_bytes(scale_, 3); compute_encoder.set_bytes(scale_, 3);
size_t n_batch = in.size() / mat_size;
MTL::Size group_dims; MTL::Size group_dims;
MTL::Size grid_dims; MTL::Size grid_dims;
if (single) { if (single) {
compute_encoder.set_bytes(out_strides, 1, 4); compute_encoder.set_bytes(out_strides, 1, 4);
uint32_t dim0 = dims_ / 2; uint32_t dim0 = dims_ / 2;
group_dims = get_block_dims(dim0, n_batch, 1); group_dims = get_block_dims(dim0, N, 1);
grid_dims = MTL::Size(dim0, n_batch, 1); grid_dims = MTL::Size(dim0, N, 1);
} else { } else {
compute_encoder.set_bytes(strides, 3, 4); compute_encoder.set_bytes(strides, 3, 4);
compute_encoder.set_bytes(out_strides, 3, 5); compute_encoder.set_bytes(out_strides, 3, 5);
compute_encoder.set_bytes(n_batch, 6); int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
compute_encoder.set_bytes(offset_stride, 6);
compute_encoder.set_bytes(N, 7);
uint32_t dim0 = dims_ / 2; uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2); uint32_t dim1 = T;
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread; uint32_t dim2 = B * ((N + n_per_thread - 1) / n_per_thread);
group_dims = get_block_dims(dim0, dim1, dim2); group_dims = get_block_dims(dim0, dim1, dim2);
grid_dims = MTL::Size(dim0, dim1, dim2); grid_dims = MTL::Size(dim0, dim1, dim2);
} }

View File

@@ -366,10 +366,16 @@ array rope(
msg << "[rope] Input must be a floating type but got " << x.dtype() << "."; msg << "[rope] Input must be a floating type but got " << x.dtype() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (offset.size() != 1) { if (offset.ndim() > 1) {
std::ostringstream msg; std::ostringstream msg;
msg << "[rope] offset must be a scalar but has shape " << offset.shape() msg << "[rope] offset must have at most one dimension but has shape "
<< "."; << offset.shape() << ".";
throw std::invalid_argument(msg.str());
}
if (offset.size() != 1 && offset.size() != x.shape(0)) {
std::ostringstream msg;
msg << "[rope] offset must be a scalar or vector with " << x.shape(0)
<< " elements but has shape " << offset.shape() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (!issubdtype(offset.dtype(), integer)) { if (!issubdtype(offset.dtype(), integer)) {
@@ -379,7 +385,7 @@ array rope(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (offset.dtype().size() != 4) { if (offset.dtype().size() != 4) {
inputs[1] = astype(offset, uint32, s); inputs[1] = astype(offset, int32, s);
} }
if (inputs.size() == 3 && if (inputs.size() == 3 &&
(inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) { (inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {
@@ -391,15 +397,26 @@ array rope(
auto fallback = [dims, traditional, base, scale, forward, s]( auto fallback = [dims, traditional, base, scale, forward, s](
std::vector<array> inputs) { std::vector<array> inputs) {
auto& shape = inputs[0].shape(); auto x = inputs[0];
int ndim = shape.size(); auto shape = x.shape();
auto x = flatten(inputs[0], 0, ndim - 3, s); if (x.ndim() == 3) {
x = expand_dims(x, 1, s);
} else if (x.ndim() > 4) {
x = flatten(x, 1, 1 + (x.ndim() - 4), s);
}
auto B = x.shape(0);
auto N = x.shape(1);
auto T = x.shape(2);
auto t = x.dtype(); auto t = x.dtype();
// Compute sines and cosines // Compute sines and cosines
auto half_dims = dims / 2; auto half_dims = dims / 2;
auto& offset = inputs[1]; auto offset = inputs[1];
if (offset.size() > 1) {
offset = expand_dims(offset, {-1, -2}, s);
}
auto positions = auto positions =
multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s); multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s);
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() { auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
return exp( return exp(
@@ -412,8 +429,7 @@ array rope(
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s) auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
: default_inv_freqs(); : default_inv_freqs();
auto theta = auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
auto coss = cos(theta, s); auto coss = cos(theta, s);
auto sins = sin(theta, s); auto sins = sin(theta, s);
@@ -436,32 +452,30 @@ array rope(
}; };
if (traditional) { if (traditional) {
auto x1 = auto x1 = slice(x, {0, 0, 0, 0}, {B, N, T, dims}, {1, 1, 1, 2}, s);
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s); auto x2 = slice(x, {0, 0, 0, 1}, {B, N, T, dims}, {1, 1, 1, 2}, s);
auto x2 =
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
auto outs = apply_rope(x1, x2, coss, sins); auto outs = apply_rope(x1, x2, coss, sins);
for (auto& o : outs) { for (auto& o : outs) {
o = expand_dims(o, 3, s); o = expand_dims(o, -1, s);
} }
auto out = concatenate(outs, 3, s); auto out = reshape(concatenate(outs, -1, s), {B, N, T, dims}, s);
if (dims < x.shape(-1)) { if (dims < x.shape(-1)) {
out = reshape(out, {x.shape(0), x.shape(1), dims}); out =
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s); concatenate({out, slice(x, {0, 0, 0, dims}, x.shape(), s)}, -1, s);
} }
return std::vector<array>{reshape(out, shape, s)}; return std::vector<array>{reshape(out, shape, s)};
} else { } else {
auto out_s = x.shape(); auto out_s = x.shape();
out_s.back() = half_dims; out_s.back() = half_dims;
auto x1 = slice(x, {0, 0, 0}, out_s, s); auto x1 = slice(x, {0, 0, 0, 0}, out_s, s);
out_s.back() = dims; out_s.back() = dims;
auto x2 = slice(x, {0, 0, half_dims}, out_s, s); auto x2 = slice(x, {0, 0, 0, half_dims}, out_s, s);
auto outs = apply_rope(x1, x2, coss, sins); auto outs = apply_rope(x1, x2, coss, sins);
if (dims < x.shape(-1)) { if (dims < x.shape(-1)) {
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); outs.push_back(slice(x, {0, 0, 0, dims}, x.shape(), s));
} }
return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)}; return std::vector<array>{reshape(concatenate(outs, -1, s), shape, s)};
} }
}; };
auto stream = to_stream(s); auto stream = to_stream(s);

View File

@@ -164,8 +164,13 @@ void init_fast(nb::module_& parent_module) {
R"pbdoc( R"pbdoc(
Apply rotary positional encoding to the input. Apply rotary positional encoding to the input.
The input is expected to be at least 3D with shape ``(B, *, T, D)`` where:
* ``B`` is the batch size.
* ``T`` is the sequence length.
* ``D`` is the feature dimension.
Args: Args:
a (array): Input array. a (array): The input array.
dims (int): The feature dimensions to be rotated. If the input feature dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged. is larger than dims then the rest is left unchanged.
traditional (bool): If set to ``True`` choose the traditional traditional (bool): If set to ``True`` choose the traditional
@@ -174,7 +179,9 @@ void init_fast(nb::module_& parent_module) {
each dimension in the positional encodings. Exactly one of ``base`` and each dimension in the positional encodings. Exactly one of ``base`` and
``freqs`` must be ``None``. ``freqs`` must be ``None``.
scale (float): The scale used to scale the positions. scale (float): The scale used to scale the positions.
offset (int or array): The position offset to start at. offset (int or array): The position offset to start at. If an
:obj:`array` is given it can be a scalar or vector of ``B``
offsets for each example in the batch.
freqs (array, optional): Optional frequencies to use with RoPE. freqs (array, optional): Optional frequencies to use with RoPE.
If set, the ``base`` parameter must be ``None``. Default: ``None``. If set, the ``base`` parameter must be ``None``. Default: ``None``.

View File

@@ -91,7 +91,7 @@ mx::array to_array_with_accessor(nb::object obj) {
return nb::cast<mx::array>(obj.attr("__mlx_array__")()); return nb::cast<mx::array>(obj.attr("__mlx_array__")());
} else { } else {
std::ostringstream msg; std::ostringstream msg;
msg << "Invalid type " << nb::type_name(obj.type()).c_str() msg << "Invalid type " << nb::type_name(obj.type()).c_str()
<< " received in array initialization."; << " received in array initialization.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }

View File

@@ -8,18 +8,23 @@ import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None): def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
offset = offset.item() if isinstance(offset, mx.array) else offset N = x.shape[-2]
N = x.shape[-2] + offset
dtype = x.dtype dtype = x.dtype
half_D = dims // 2 half_D = dims // 2
positions = mx.arange(offset, N, dtype=dtype) * scale positions = mx.arange(N, dtype=dtype)
if isinstance(offset, mx.array) and offset.size > 1:
expand = tuple(range(1, x.ndim - 1))
positions = mx.expand_dims(offset, expand) + positions
else:
positions = offset + positions
positions = positions * scale
if freqs is None: if freqs is None:
inv_freqs = mx.exp( inv_freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D) -mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
) )
else: else:
inv_freqs = (1 / freqs).astype(x.dtype) inv_freqs = (1 / freqs).astype(x.dtype)
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1)) theta = mx.expand_dims(positions, -1) * inv_freqs
costheta, sintheta = mx.cos(theta), mx.sin(theta) costheta, sintheta = mx.cos(theta), mx.sin(theta)
if traditional: if traditional:
x1 = x[..., :dims:2] x1 = x[..., :dims:2]
@@ -214,6 +219,7 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertEqual(dtype, rx.dtype) self.assertEqual(dtype, rx.dtype)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
return
# Test single vector # Test single vector
x = mx.random.uniform(shape=(1, 1, dims)) x = mx.random.uniform(shape=(1, 1, dims))
@@ -277,6 +283,55 @@ class TestFast(mlx_tests.MLXTestCase):
g2 = mx.grad(f2)(x, y) g2 = mx.grad(f2)(x, y)
self.assertLess(mx.abs(g1 - g2).max(), 1e-5) self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
def test_rope_batch(self):
T = 4
base = 10000.0
scale = 1.0
traditional = True
batch_sizes = [3, 8, 11]
num_heads = [1, 3, 5]
dims = 32
x = mx.random.uniform(shape=(8, 4, T, dims))
offset = mx.array([1, 2, 3])
with self.assertRaises(ValueError):
mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
for batch_size in batch_sizes:
for n_head in num_heads:
x = mx.random.uniform(shape=(batch_size, n_head, T, dims))
offset = mx.arange(batch_size)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
x = mx.random.normal(shape=(2, 6, 8, 64)).transpose(0, 2, 1, 3)
dims = 64
offset = 0
rx_fast = mx.fast.rope(
x, dims, traditional=traditional, scale=scale, base=base, offset=offset
)
rx_fast_single = mx.fast.rope(
x[0:1], dims, traditional=traditional, scale=scale, base=base, offset=offset
)
rx = rope_orig(x, dims, traditional, base, scale, offset)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
def test_rms_norm(self): def test_rms_norm(self):
# Per dtype absolute tolerance # Per dtype absolute tolerance
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2} tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}