diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 08df53a8e..457ecb7f7 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -209,4 +209,14 @@ Dims get_2d_grid_dims_common( static_cast(grid_x), static_cast(grid_y), 1); } +std::pair get_grid_and_block_common(int dim0, int dim1, int dim2) { + auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2); + auto gx = (dim0 + bx - 1) / bx; + auto gy = (dim1 + by - 1) / by; + auto gz = (dim2 + bz - 1) / bz; + + return std::make_pair( + std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz)); +} + } // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 40bc3efe4..114878846 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -95,6 +95,9 @@ Dims get_2d_grid_dims_common( const Strides& strides, size_t divisor); +// Get both the block and a grid of blocks that covers dim0, dim1 and dim2. +std::pair get_grid_and_block_common(int dim0, int dim1, int dim2); + struct ContiguousIterator { inline void step() { int dims = shape_.size(); diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 7cc74353a..d96bb8812 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -32,6 +32,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu + ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu diff --git a/mlx/backend/cuda/kernel_utils.cu b/mlx/backend/cuda/kernel_utils.cu index 575af7cf6..7b87aa5b0 100644 --- a/mlx/backend/cuda/kernel_utils.cu +++ b/mlx/backend/cuda/kernel_utils.cu @@ -23,4 +23,11 @@ dim3 get_2d_grid_dims( return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } +std::pair get_grid_and_block(int dim0, int dim1, int dim2) { + auto [grid, block] = get_grid_and_block_common(dim0, dim1, dim2); + auto [gx, gy, gz] = grid; + auto [bx, by, bz] = block; + return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz)); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 7e957bbbd..84392a1ec 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -121,6 +121,7 @@ dim3 get_2d_grid_dims( const Shape& shape, const Strides& strides, size_t divisor); +std::pair get_grid_and_block(int dim0, int dim1, int dim2); // Return a block size that achieves maximum potential occupancy for kernel. template diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 0c4d3a8aa..c2362bea2 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -94,7 +94,6 @@ NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace fast { -NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu index d2b1b7dd5..0cb550d56 100644 --- a/mlx/backend/cuda/random.cu +++ b/mlx/backend/cuda/random.cu @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/primitives.h" +#include #include #include @@ -12,6 +13,8 @@ namespace mlx::core { namespace cu { +namespace cg = cooperative_groups; + __constant__ constexpr uint32_t rotations[2][4] = { {13, 15, 26, 6}, {17, 29, 16, 24}}; @@ -47,27 +50,28 @@ __global__ void rbitsc( dim3 grid_dims, bool odd, uint32_t bytes_per_key) { - uint2 index{ - blockIdx.x * blockDim.x + threadIdx.x, - blockIdx.y * blockDim.y + threadIdx.y}; - if (index.x >= grid_dims.x || index.y >= grid_dims.y) { + auto grid = cg::this_grid(); + uint thread_index = grid.thread_rank(); + uint index_x = thread_index % grid_dims.x; + uint index_y = thread_index / grid_dims.x; + if (index_x >= grid_dims.x || index_y >= grid_dims.y) { return; } - auto kidx = 2 * index.x; + auto kidx = 2 * index_x; auto key = uint2{keys[kidx], keys[kidx + 1]}; auto half_size = grid_dims.y - odd; - out += index.x * bytes_per_key; - bool drop_last = odd && (index.y == half_size); + out += index_x * bytes_per_key; + bool drop_last = odd && (index_y == half_size); auto bits = threefry2x32_hash( - key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y}); - size_t idx = size_t(index.y) << 2; + key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y}); + size_t idx = size_t(index_y) << 2; for (int i = 0; i < 4; ++i) { out[idx + i] = bits.bytes[0][i]; } if (!drop_last) { - idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2; - if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2; + if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) { int edge_bytes = (bytes_per_key % 4); for (int i = 0; i < edge_bytes; ++i) { out[idx + i] = bits.bytes[1][i]; @@ -89,30 +93,31 @@ __global__ void rbits( int32_t ndim, const __grid_constant__ Shape key_shape, const __grid_constant__ Strides key_strides) { - uint2 index{ - blockIdx.x * blockDim.x + threadIdx.x, - blockIdx.y * blockDim.y + threadIdx.y}; - if (index.x >= grid_dims.x || index.y >= grid_dims.y) { + auto grid = cg::this_grid(); + uint thread_index = grid.thread_rank(); + uint index_x = thread_index % grid_dims.x; + uint index_y = thread_index / grid_dims.x; + if (index_x >= grid_dims.x || index_y >= grid_dims.y) { return; } - auto kidx = 2 * index.x; + auto kidx = 2 * index_x; auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim); auto k2_elem = elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim); auto key = uint2{keys[k1_elem], keys[k2_elem]}; auto half_size = grid_dims.y - odd; - out += size_t(index.x) * bytes_per_key; - bool drop_last = odd && (index.y == half_size); + out += size_t(index_x) * bytes_per_key; + bool drop_last = odd && (index_y == half_size); auto bits = threefry2x32_hash( - key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y}); - size_t idx = size_t(index.y) << 2; + key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y}); + size_t idx = size_t(index_y) << 2; for (int i = 0; i < 4; ++i) { out[idx + i] = bits.bytes[0][i]; } if (!drop_last) { - idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2; - if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2; + if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) { int edge_bytes = (bytes_per_key % 4); for (int i = 0; i < edge_bytes; ++i) { out[idx + i] = bits.bytes[1][i]; @@ -153,19 +158,22 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dim3 grid_dims{num_keys, half_size + odd}; - dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1); - dim3 num_blocks{ - cuda::ceil_div(grid_dims.x, block_dims.x), - cuda::ceil_div(grid_dims.y, block_dims.y)}; + int64_t total = grid_dims.x * grid_dims.y; + int32_t threads_y = 1; + while ((total / threads_y) >= (1U << 31)) { + threads_y *= 2; + } + int32_t threads_x = cuda::ceil_div(total, threads_y); + auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); if (keys.flags().row_contiguous) { - cu::rbitsc<<>>( + cu::rbitsc<<>>( keys.data(), out.data(), grid_dims, odd, bytes_per_key); } else { - cu::rbits<<>>( + cu::rbits<<>>( keys.data(), out.data(), grid_dims, diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu new file mode 100644 index 000000000..1d8307811 --- /dev/null +++ b/mlx/backend/cuda/rope.cu @@ -0,0 +1,385 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include + +namespace mlx::core { + +namespace cu { + +template +__device__ void rope_single_impl( + const T* in, + T* out, + int32_t offset, + float inv_freq, + float scale, + int64_t stride, + uint2 pos, + uint2 dims) { + float L = scale * static_cast(offset); + + // Compute costheta, sintheta + float theta = L * inv_freq; + float costheta = cos(theta); + float sintheta = sin(theta); + + // Compute the input and output indices + uint index_1, index_2; + if (traditional) { + index_1 = 2 * pos.x + pos.y * stride; + index_2 = index_1 + 1; + } else { + index_1 = pos.x + pos.y * stride; + index_2 = index_1 + dims.x; + } + + // Read and write the output + float x1 = static_cast(in[index_1]); + float x2 = static_cast(in[index_2]); + float rx1; + float rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[index_1] = static_cast(rx1); + out[index_2] = static_cast(rx2); +} + +template +__global__ void rope_single( + const T* in, + T* out, + const int32_t* offset, + float scale, + float base, + int64_t stride, + uint2 dims) { + uint2 pos = make_uint2( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y); + if (pos.x >= dims.x || pos.y >= dims.y) { + return; + } + + float d = static_cast(pos.x) / static_cast(dims.x); + float inv_freq = exp2(-d * base); + rope_single_impl( + in, out, *offset, inv_freq, scale, stride, pos, dims); +} + +template +__global__ void rope_single_freqs( + const T* in, + T* out, + const int32_t* offset, + const float* freqs, + float scale, + int64_t stride, + uint2 dims, + int64_t freq_stride) { + uint2 pos = make_uint2( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y); + if (pos.x >= dims.x || pos.y >= dims.y) { + return; + } + + float inv_freq = 1.0 / freqs[freq_stride * pos.x]; + rope_single_impl( + in, out, *offset, inv_freq, scale, stride, pos, dims); +} + +template +__device__ void rope_impl( + const T* in, + T* out, + int offset, + float inv_freq, + float scale, + const cuda::std::array strides, + const cuda::std::array out_strides, + int64_t n_batch, + uint3 pos, + uint3 dims) { + float L = scale * static_cast(pos.y + offset); + + // Compute costheta, sintheta + float theta = L * inv_freq; + float costheta = cos(theta); + float sintheta = sin(theta); + + // Compute the input and output indices + size_t in_index_1, in_index_2; + size_t out_index_1, out_index_2; + if (traditional) { + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + + N * pos.z * out_strides[0]; + out_index_2 = out_index_1 + 1; + in_index_1 = + 2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_2 = in_index_1 + strides[2]; + } else { + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + + N * pos.z * out_strides[0]; + out_index_2 = out_index_1 + dims.x * out_strides[2]; + in_index_1 = + pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_2 = in_index_1 + dims.x * strides[2]; + } + for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { + // Read and write the output + float x1 = static_cast(in[in_index_1]); + float x2 = static_cast(in[in_index_2]); + float rx1; + float rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[out_index_1] = static_cast(rx1); + out[out_index_2] = static_cast(rx2); + in_index_1 += strides[0]; + in_index_2 += strides[0]; + out_index_1 += out_strides[0]; + out_index_2 += out_strides[0]; + } +} + +template +__global__ void rope( + const T* in, + T* out, + const int32_t* offset, + float scale, + float base, + const __grid_constant__ cuda::std::array strides, + const __grid_constant__ cuda::std::array out_strides, + int64_t n_batch, + uint3 dims) { + uint3 pos = make_uint3( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.z * blockDim.z + threadIdx.z); + if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) { + return; + } + + float d = static_cast(pos.x) / static_cast(dims.x); + float inv_freq = exp2(-d * base); + rope_impl( + in, + out, + *offset, + inv_freq, + scale, + strides, + out_strides, + n_batch, + pos, + dims); +} + +template +__global__ void rope_freqs( + const T* in, + T* out, + const int32_t* offset, + const float* freqs, + float scale, + float base, + const __grid_constant__ cuda::std::array strides, + const __grid_constant__ cuda::std::array out_strides, + int64_t n_batch, + uint3 dims, + int64_t freq_stride) { + uint3 pos = make_uint3( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.z * blockDim.z + threadIdx.z); + if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) { + return; + } + + float inv_freq = 1.0 / freqs[freq_stride * pos.x]; + rope_impl( + in, + out, + *offset, + inv_freq, + scale, + strides, + out_strides, + n_batch, + pos, + dims); +} + +} // namespace cu + +namespace fast { + +bool RoPE::use_fallback(Stream s) { + return s.device == Device::cpu; +} + +void RoPE::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("RoPE::eval_gpu"); + + auto& s = stream(); + auto& in = inputs[0]; + auto& offset = inputs[1]; + auto& out = outputs[0]; + + if (in.ndim() < 3) { + throw std::runtime_error("[RoPE] Input must have at least 3 dimensions"); + } + + cuda::std::array strides; + cuda::std::array out_strides; + bool donated = false; + int ndim = in.ndim(); + int dispatch_ndim = in.ndim(); + while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { + dispatch_ndim--; + } + size_t mat_size = in.shape(-2) * in.shape(-1); + + // We apply rope to less that the whole vector so copy to output and then + // apply in-place. + if (dims_ < in.shape(-1)) { + donated = true; + auto ctype = + (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; + copy_gpu(in, out, ctype, s); + strides[0] = mat_size; + strides[1] = out.strides()[ndim - 2]; + strides[2] = out.strides()[ndim - 1]; + } + + // Either copy or apply in-place + else if (in.flags().row_contiguous) { + if (in.is_donatable()) { + donated = true; + out.copy_shared_buffer(in); + } else { + out.set_data(allocator::malloc(out.nbytes())); + } + strides[0] = mat_size; + strides[1] = in.strides()[ndim - 2]; + strides[2] = in.strides()[ndim - 1]; + } else if (dispatch_ndim == 3) { + // Handle non-contiguous 3D inputs + out.set_data(allocator::malloc(out.nbytes())); + strides[0] = in.strides()[ndim - 3]; + strides[1] = in.strides()[ndim - 2]; + strides[2] = in.strides()[ndim - 1]; + } else { + // Copy non-contiguous > 3D inputs into the output and treat + // input as donated + donated = true; + copy_gpu(in, out, CopyType::General, s); + strides[0] = mat_size; + strides[1] = out.strides()[ndim - 2]; + strides[2] = out.strides()[ndim - 1]; + } + out_strides[0] = mat_size; + out_strides[1] = out.strides()[ndim - 2]; + out_strides[2] = out.strides()[ndim - 1]; + + // Some flags to help us dispatch below + bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); + bool with_freqs = inputs.size() == 3; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(donated ? out : in); + encoder.set_input_array(offset); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, { + using DataType = cuda_type_t; + MLX_SWITCH_BOOL(traditional_, TRADITIONAL, { + MLX_SWITCH_BOOL(forward_, FORWARD, { + if (single && !with_freqs) { + auto kernel = cu::rope_single; + uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + kernel<<>>( + (donated ? out : in).data(), + out.data(), + offset.data(), + scale_, + std::log2(base_), + mat_size, + dims); + } else if (single) { + auto kernel = cu::rope_single_freqs; + uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + kernel<<>>( + (donated ? out : in).data(), + out.data(), + offset.data(), + inputs[2].data(), + scale_, + mat_size, + dims, + inputs[2].strides(0)); + } else if (with_freqs) { + auto kernel = cu::rope_freqs; + uint3 dims = + make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); + dims.z = (dims.z + 3) / 4; + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + kernel<<>>( + (donated ? out : in).data(), + out.data(), + offset.data(), + inputs[2].data(), + scale_, + std::log2(base_), + strides, + out_strides, + in.size() / mat_size, + dims, + inputs[2].strides(0)); + } else { + auto kernel = cu::rope; + uint3 dims = + make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); + dims.z = (dims.z + 3) / 4; + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + kernel<<>>( + (donated ? out : in).data(), + out.data(), + offset.data(), + scale_, + std::log2(base_), + strides, + out_strides, + in.size() / mat_size, + dims); + } + }); + }); + }); + }); +} + +} // namespace fast + +} // namespace mlx::core