diff --git a/mlx/backend/metal/kernels/rope.metal b/mlx/backend/metal/kernels/rope.metal index 1fda49f74..fce444a1a 100644 --- a/mlx/backend/metal/kernels/rope.metal +++ b/mlx/backend/metal/kernels/rope.metal @@ -3,7 +3,12 @@ #include #include "mlx/backend/metal/kernels/utils.h" -template + +constant bool forward [[function_constant(1)]]; +constant bool traditional [[function_constant(2)]]; +constant bool hs_transpose [[function_constant(3)]]; + +template void rope_single_impl( const device T* in, device T* out, @@ -46,7 +51,7 @@ void rope_single_impl( out[index_2] = static_cast(rx2); } -template +template [[kernel]] void rope_single( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], @@ -58,11 +63,10 @@ template uint2 grid [[threads_per_grid]]) { float d = static_cast(pos.x) / static_cast(grid.x); float inv_freq = metal::exp2(-d * base); - rope_single_impl( - in, out, offset, inv_freq, scale, stride, pos, grid); + rope_single_impl(in, out, offset, inv_freq, scale, stride, pos, grid); } -template +template [[kernel]] void rope_single_freqs( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], @@ -74,11 +78,10 @@ template uint2 pos [[thread_position_in_grid]], uint2 grid [[threads_per_grid]]) { float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); - rope_single_impl( - in, out, offset, inv_freq, scale, stride, pos, grid); + rope_single_impl(in, out, offset, inv_freq, scale, stride, pos, grid); } -template +template void rope_impl( const device T* in, device T* out, @@ -102,23 +105,29 @@ void rope_impl( float theta = L * inv_freq; float costheta = metal::fast::cos(theta); float sintheta = metal::fast::sin(theta); - // Compute the input and output indices - size_t in_index_1, in_index_2; - size_t out_index_1, out_index_2; - if (traditional) { - out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + - mat_idx * out_strides[0]; - out_index_2 = out_index_1 + 1; + IdxT in_index_1; + if (hs_transpose) { + IdxT batch_stride = grid.y * IdxT(strides[1]); in_index_1 = - 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; - in_index_2 = in_index_1 + strides[2]; + batch_idx * batch_stride + pos.y * strides[1] + head_idx * strides[0]; } else { - out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + - mat_idx * out_strides[0]; - out_index_2 = out_index_1 + grid.x * out_strides[2]; - in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; - in_index_2 = in_index_1 + grid.x * strides[2]; + in_index_1 = pos.y * IdxT(strides[1]) + mat_idx * IdxT(strides[0]); + } + IdxT in_index_2; + IdxT out_index_1 = + pos.y * IdxT(out_strides[1]) + mat_idx * IdxT(out_strides[0]); + IdxT out_index_2; + if (traditional) { + out_index_1 += 2 * pos.x * IdxT(out_strides[2]); + out_index_2 = out_index_1 + 1; + in_index_1 += 2 * pos.x * IdxT(strides[2]); + in_index_2 = in_index_1 + IdxT(strides[2]); + } else { + out_index_1 += pos.x * IdxT(out_strides[2]); + out_index_2 = out_index_1 + grid.x * IdxT(out_strides[2]); + in_index_1 += pos.x * IdxT(strides[2]); + in_index_2 = in_index_1 + grid.x * IdxT(strides[2]); } for (int i = 0; i < N && head_idx + i < n_head; ++i) { // Read and write the output @@ -135,14 +144,14 @@ void rope_impl( } out[out_index_1] = static_cast(rx1); out[out_index_2] = static_cast(rx2); - in_index_1 += strides[0]; - in_index_2 += strides[0]; - out_index_1 += out_strides[0]; - out_index_2 += out_strides[0]; + in_index_1 += IdxT(strides[0]); + in_index_2 += IdxT(strides[0]); + out_index_1 += IdxT(out_strides[0]); + out_index_2 += IdxT(out_strides[0]); } } -template +template [[kernel]] void rope( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], @@ -157,7 +166,7 @@ template uint3 grid [[threads_per_grid]]) { float d = static_cast(pos.x) / static_cast(grid.x); float inv_freq = metal::exp2(-d * base); - rope_impl( + rope_impl( in, out, offset, @@ -171,7 +180,7 @@ template grid); } -template +template [[kernel]] void rope_freqs( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], @@ -186,7 +195,7 @@ template uint3 pos [[thread_position_in_grid]], uint3 grid [[threads_per_grid]]) { float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); - rope_impl( + rope_impl( in, out, offset, @@ -201,27 +210,20 @@ template } // clang-format off -#define instantiate_rope_g(name, type, traditional, forward) \ - instantiate_kernel("rope_" #name, rope, type, traditional, forward) \ - instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward) +#define instantiate_rope_g(name, type) \ + instantiate_kernel("rope_" #name, rope, type, int32_t) \ + instantiate_kernel("rope_freqs_" #name, rope_freqs, type, int32_t) \ + instantiate_kernel("rope_large_" #name, rope, type, int64_t) \ + instantiate_kernel("rope_freqs_large_" #name, rope_freqs, type, int64_t) -#define instantiate_rope_s(name, type, traditional, forward) \ - instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \ - instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward) +#define instantiate_rope_s(name, type) \ + instantiate_kernel("rope_single_" #name, rope_single, type) \ + instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type) -#define instantiate_rope(name, type, traditional, forward) \ - instantiate_rope_s(name, type, traditional, forward) \ - instantiate_rope_g(name, type, traditional, forward) +#define instantiate_rope(name, type) \ + instantiate_rope_s(name, type) \ + instantiate_rope_g(name, type) -instantiate_rope(traditional_float16, half, true, true) -instantiate_rope(traditional_bfloat16, bfloat16_t, true, true) -instantiate_rope(traditional_float32, float, true, true) -instantiate_rope(float16, half, false, true) -instantiate_rope(bfloat16, bfloat16_t, false, true) -instantiate_rope(float32, float, false, true) -instantiate_rope(vjp_traditional_float16, half, true, false) -instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false) -instantiate_rope(vjp_traditional_float32, float, true, false) -instantiate_rope(vjp_float16, half, false, false) -instantiate_rope(vjp_bfloat16, bfloat16_t, false, false) -instantiate_rope(vjp_float32, float, false, false) // clang-format on +instantiate_rope(float16, half) +instantiate_rope(bfloat16, bfloat16_t) +instantiate_rope(float32, float) // clang-format on diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index 38d822ce0..ca0a66221 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -29,6 +29,7 @@ void RoPE::eval_gpu( int T = in.shape(-2); int D = in.shape(-1); size_t mat_size = T * D; + bool large = in.data_size() > INT32_MAX || in.size() > INT32_MAX; int dispatch_ndim = ndim; while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { @@ -40,6 +41,8 @@ void RoPE::eval_gpu( N *= in.shape(i); } + bool head_seq_transpose = false; + if (dims_ < D) { donated = true; auto ctype = @@ -64,6 +67,17 @@ void RoPE::eval_gpu( strides[0] = in.strides()[ndim - 3]; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; + } else if ( + ndim == 4 && + // batch dim is regularly strided + in.strides()[0] == T * N * D && + // sequence and head dimensions are transposed + in.strides()[1] == D && in.strides()[2] == N * D) { + head_seq_transpose = true; + out.set_data(allocator::malloc(out.nbytes())); + strides[0] = in.strides()[1]; + strides[1] = in.strides()[2]; + strides[2] = in.strides()[3]; } else { // Copy non-contiguous > 3D inputs into the output and treat // input as donated @@ -77,15 +91,33 @@ void RoPE::eval_gpu( out_strides[1] = out.strides()[ndim - 2]; out_strides[2] = out.strides()[ndim - 1]; - // Special case for inference (single batch, single time step, and contiguous) - bool single = in.flags().row_contiguous && B == 1 && T == 1; + // Special case for inference (single time step, contiguous, one offset) + auto& offset = inputs[1]; + bool single = in.flags().row_contiguous && T == 1 && offset.size() == 1; bool with_freqs = inputs.size() == 3; - std::ostringstream kname; - kname << "rope_" << (single ? "single_" : "") - << ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_") - << (traditional_ ? "traditional_" : "") << type_to_name(in); - auto kernel = d.get_kernel(kname.str()); + std::string kname; + concatenate( + kname, + "rope_", + single ? "single_" : "", + (with_freqs) ? "freqs_" : "", + large ? "large_" : "", + type_to_name(in)); + std::string hash_name; + concatenate( + hash_name, + kname, + "_", + forward_ ? "" : "vjp_", + traditional_ ? "traditional_" : "", + head_seq_transpose ? "transpose" : ""); + metal::MTLFCList func_consts = { + {&forward_, MTL::DataType::DataTypeBool, 1}, + {&traditional_, MTL::DataType::DataTypeBool, 2}, + {&head_seq_transpose, MTL::DataType::DataTypeBool, 3}}; + + auto kernel = d.get_kernel(kname, hash_name, func_consts); auto& compute_encoder = d.get_command_encoder(s.index); float base = std::log2(base_); @@ -93,7 +125,7 @@ void RoPE::eval_gpu( compute_encoder.set_input_array(donated ? out : in, 0); compute_encoder.set_output_array(out, 1); - compute_encoder.set_input_array(inputs[1], 2); + compute_encoder.set_input_array(offset, 2); compute_encoder.set_bytes(scale_, 3); MTL::Size group_dims; @@ -107,8 +139,8 @@ void RoPE::eval_gpu( compute_encoder.set_bytes(strides, 3, 4); compute_encoder.set_bytes(out_strides, 3, 5); int64_t offset_stride = 0; - if (inputs[1].ndim() > 0) { - offset_stride = inputs[1].strides()[0]; + if (offset.ndim() > 0) { + offset_stride = offset.strides()[0]; } compute_encoder.set_bytes(offset_stride, 6); compute_encoder.set_bytes(N, 7);