diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu index d2b1b7dd5..0cb550d56 100644 --- a/mlx/backend/cuda/random.cu +++ b/mlx/backend/cuda/random.cu @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/primitives.h" +#include #include #include @@ -12,6 +13,8 @@ namespace mlx::core { namespace cu { +namespace cg = cooperative_groups; + __constant__ constexpr uint32_t rotations[2][4] = { {13, 15, 26, 6}, {17, 29, 16, 24}}; @@ -47,27 +50,28 @@ __global__ void rbitsc( dim3 grid_dims, bool odd, uint32_t bytes_per_key) { - uint2 index{ - blockIdx.x * blockDim.x + threadIdx.x, - blockIdx.y * blockDim.y + threadIdx.y}; - if (index.x >= grid_dims.x || index.y >= grid_dims.y) { + auto grid = cg::this_grid(); + uint thread_index = grid.thread_rank(); + uint index_x = thread_index % grid_dims.x; + uint index_y = thread_index / grid_dims.x; + if (index_x >= grid_dims.x || index_y >= grid_dims.y) { return; } - auto kidx = 2 * index.x; + auto kidx = 2 * index_x; auto key = uint2{keys[kidx], keys[kidx + 1]}; auto half_size = grid_dims.y - odd; - out += index.x * bytes_per_key; - bool drop_last = odd && (index.y == half_size); + out += index_x * bytes_per_key; + bool drop_last = odd && (index_y == half_size); auto bits = threefry2x32_hash( - key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y}); - size_t idx = size_t(index.y) << 2; + key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y}); + size_t idx = size_t(index_y) << 2; for (int i = 0; i < 4; ++i) { out[idx + i] = bits.bytes[0][i]; } if (!drop_last) { - idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2; - if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2; + if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) { int edge_bytes = (bytes_per_key % 4); for (int i = 0; i < edge_bytes; ++i) { out[idx + i] = bits.bytes[1][i]; @@ -89,30 +93,31 @@ __global__ void rbits( int32_t ndim, const __grid_constant__ Shape key_shape, const __grid_constant__ Strides key_strides) { - uint2 index{ - blockIdx.x * blockDim.x + threadIdx.x, - blockIdx.y * blockDim.y + threadIdx.y}; - if (index.x >= grid_dims.x || index.y >= grid_dims.y) { + auto grid = cg::this_grid(); + uint thread_index = grid.thread_rank(); + uint index_x = thread_index % grid_dims.x; + uint index_y = thread_index / grid_dims.x; + if (index_x >= grid_dims.x || index_y >= grid_dims.y) { return; } - auto kidx = 2 * index.x; + auto kidx = 2 * index_x; auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim); auto k2_elem = elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim); auto key = uint2{keys[k1_elem], keys[k2_elem]}; auto half_size = grid_dims.y - odd; - out += size_t(index.x) * bytes_per_key; - bool drop_last = odd && (index.y == half_size); + out += size_t(index_x) * bytes_per_key; + bool drop_last = odd && (index_y == half_size); auto bits = threefry2x32_hash( - key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y}); - size_t idx = size_t(index.y) << 2; + key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y}); + size_t idx = size_t(index_y) << 2; for (int i = 0; i < 4; ++i) { out[idx + i] = bits.bytes[0][i]; } if (!drop_last) { - idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2; - if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2; + if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) { int edge_bytes = (bytes_per_key % 4); for (int i = 0; i < edge_bytes; ++i) { out[idx + i] = bits.bytes[1][i]; @@ -153,19 +158,22 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dim3 grid_dims{num_keys, half_size + odd}; - dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1); - dim3 num_blocks{ - cuda::ceil_div(grid_dims.x, block_dims.x), - cuda::ceil_div(grid_dims.y, block_dims.y)}; + int64_t total = grid_dims.x * grid_dims.y; + int32_t threads_y = 1; + while ((total / threads_y) >= (1U << 31)) { + threads_y *= 2; + } + int32_t threads_x = cuda::ceil_div(total, threads_y); + auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); if (keys.flags().row_contiguous) { - cu::rbitsc<<>>( + cu::rbitsc<<>>( keys.data(), out.data(), grid_dims, odd, bytes_per_key); } else { - cu::rbits<<>>( + cu::rbits<<>>( keys.data(), out.data(), grid_dims, diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index 730b5b789..1d8307811 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -12,8 +12,6 @@ namespace mlx::core { namespace cu { -namespace cg = cooperative_groups; - template __device__ void rope_single_impl( const T* in,