From f9e1a141352207f4fea7ee4c212b52107c8b79d7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 20 Nov 2025 07:27:50 -0800 Subject: [PATCH] [CUDA] Partly fix random for large sizes (#2798) --- mlx/backend/cuda/random.cu | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu index 10171ff28..6b323e79e 100644 --- a/mlx/backend/cuda/random.cu +++ b/mlx/backend/cuda/random.cu @@ -139,10 +139,10 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { // keys has shape (N1, ..., NK, 2) // out has shape (N1, ..., NK, M1, M2, ...) auto& keys = inputs[0]; - uint32_t num_keys = keys.size() / 2; + size_t num_keys = keys.size() / 2; - uint32_t elems_per_key = out.size() / num_keys; - uint32_t bytes_per_key = out.itemsize() * elems_per_key; + size_t elems_per_key = out.size() / num_keys; + size_t bytes_per_key = out.itemsize() * elems_per_key; auto& s = stream(); auto& encoder = cu::get_command_encoder(s); out.set_data(cu::malloc_async(out.nbytes(), encoder)); @@ -150,19 +150,25 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { return; } - uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4; - uint32_t half_size = out_per_key / 2; + size_t out_per_key = (bytes_per_key + 4 - 1) / 4; + size_t half_size = out_per_key / 2; + bool odd = out_per_key % 2; + if ((half_size + odd) >= UINT32_MAX || num_keys >= UINT32_MAX) { + throw std::runtime_error("[RandomBits::eval_gpu] Large size unsupported"); + } encoder.set_input_array(keys); encoder.set_output_array(out); - dim3 grid_dims{num_keys, half_size + odd}; - int64_t total = grid_dims.x * grid_dims.y; - int32_t threads_y = 1; - while ((total / threads_y) >= (1U << 31)) { + int64_t total = num_keys * (half_size + odd); + uint32_t threads_y = 1; + while ((total / threads_y) >= UINT_MAX) { threads_y *= 2; } - int32_t threads_x = cuda::ceil_div(total, threads_y); + uint32_t threads_x = cuda::ceil_div(total, threads_y); + + dim3 grid_dims{ + static_cast(num_keys), static_cast(half_size + odd)}; auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); auto& stream = encoder.stream(); if (keys.flags().row_contiguous) {