diff --git a/mlx/backend/cuda/kernels/random.cuh b/mlx/backend/cuda/kernels/random.cuh new file mode 100644 index 000000000..cbd3c6a54 --- /dev/null +++ b/mlx/backend/cuda/kernels/random.cuh @@ -0,0 +1,120 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/kernels/utils.cuh" + +namespace mlx::core::cu { + +__constant__ constexpr uint32_t rotations[2][4] = { + {13, 15, 26, 6}, + {17, 29, 16, 24}}; + +union rbits { + uint2 val; + uint8_t bytes[2][4]; +}; + +__device__ rbits threefry2x32_hash(uint2 key, uint2 count) { + uint32_t ks[] = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA}; + + rbits v; + v.val.x = count.x + ks[0]; + v.val.y = count.y + ks[1]; + + for (int i = 0; i < 5; ++i) { + for (auto r : rotations[i % 2]) { + v.val.x += v.val.y; + v.val.y = (v.val.y << r) | (v.val.y >> (32 - r)); + v.val.y ^= v.val.x; + } + v.val.x += ks[(i + 1) % 3]; + v.val.y += ks[(i + 2) % 3] + i + 1; + } + + return v; +} + +__global__ void rbitsc( + const uint32_t* keys, + uint8_t* out, + const __grid_constant__ dim3 grid_dim, + const __grid_constant__ bool odd, + const __grid_constant__ uint32_t bytes_per_key) { + uint2 index{ + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y}; + if (index.x >= grid_dim.x || index.y >= grid_dim.y) { + return; + } + + auto kidx = 2 * index.x; + auto key = uint2{keys[kidx], keys[kidx + 1]}; + auto half_size = grid_dim.y - odd; + out += index.x * bytes_per_key; + bool drop_last = odd && (index.y == half_size); + auto bits = threefry2x32_hash( + key, uint2{index.y, drop_last ? 0 : index.y + grid_dim.y}); + size_t idx = size_t(index.y) << 2; + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[0][i]; + } + if (!drop_last) { + idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2; + if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + int edge_bytes = (bytes_per_key % 4); + for (int i = 0; i < edge_bytes; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } else { + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } + } +} + +__global__ void rbits( + const uint32_t* keys, + uint8_t* out, + const __grid_constant__ dim3 grid_dim, + const __grid_constant__ bool odd, + const __grid_constant__ uint32_t bytes_per_key, + const __grid_constant__ int32_t ndim, + const __grid_constant__ Shape key_shape, + const __grid_constant__ Strides key_strides) { + uint2 index{ + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y}; + if (index.x >= grid_dim.x || index.y >= grid_dim.y) { + return; + } + + auto kidx = 2 * index.x; + auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim); + auto k2_elem = + elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim); + auto key = uint2{keys[k1_elem], keys[k2_elem]}; + auto half_size = grid_dim.y - odd; + out += size_t(index.x) * bytes_per_key; + bool drop_last = odd && (index.y == half_size); + auto bits = threefry2x32_hash( + key, uint2{index.y, drop_last ? 0 : index.y + grid_dim.y}); + size_t idx = size_t(index.y) << 2; + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[0][i]; + } + if (!drop_last) { + idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2; + if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + int edge_bytes = (bytes_per_key % 4); + for (int i = 0; i < edge_bytes; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } else { + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index ddbc4ef22..347cd16a1 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernels/arange.cuh" #include "mlx/backend/cuda/kernels/fp16_math.cuh" +#include "mlx/backend/cuda/kernels/random.cuh" #include "mlx/distributed/primitives.h" #include "mlx/dtype_utils.h" #include "mlx/fast_primitives.h" @@ -43,6 +44,59 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { }); } +void RandomBits::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("RandomBits::eval_gpu"); + assert(inputs.size() == 1); + + // keys has shape (N1, ..., NK, 2) + // out has shape (N1, ..., NK, M1, M2, ...) + auto& keys = inputs[0]; + size_t num_keys = keys.size() / 2; + + size_t elems_per_key = out.size() / num_keys; + size_t bytes_per_key = out.itemsize() * elems_per_key; + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + + size_t out_per_key = (bytes_per_key + 4 - 1) / 4; + size_t half_size = out_per_key / 2; + bool odd = out_per_key % 2; + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(keys); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + dim3 grid_dim{ + static_cast(num_keys), + static_cast(half_size + odd)}; + dim3 block_dim = get_block_dims(grid_dim.x, grid_dim.y, 1); + dim3 num_blocks{ + cuda::ceil_div(grid_dim.x, block_dim.x), + cuda::ceil_div(grid_dim.y, block_dim.y)}; + if (keys.flags().row_contiguous) { + cu::rbitsc<<>>( + keys.data(), + out.data(), + grid_dim, + odd, + bytes_per_key); + } else { + cu::rbits<<>>( + keys.data(), + out.data(), + grid_dim, + odd, + bytes_per_key, + keys.ndim(), + const_param(keys.shape()), + const_param(keys.strides())); + } + }); +} + #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ @@ -76,7 +130,6 @@ NO_GPU(Matmul) NO_GPU(Partition) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) -NO_GPU(RandomBits) NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(ScatterAxis) diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 093a88a90..a6ecdee20 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -24,6 +24,37 @@ void check_cuda_error(const char* name, cudaError_t err) { } // TODO: The implementation is identical to meta/utils.cpp +dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2) { + int pows[3] = {0, 0, 0}; + int sum = 0; + while (true) { + int presum = sum; + // Check all the pows + if (dim0 >= (1 << (pows[0] + 1))) { + pows[0]++; + sum++; + } + if (sum == 10) { + break; + } + if (dim1 >= (1 << (pows[1] + 1))) { + pows[1]++; + sum++; + } + if (sum == 10) { + break; + } + if (dim2 >= (1 << (pows[2] + 1))) { + pows[2]++; + sum++; + } + if (sum == presum || sum == pow2) { + break; + } + } + return {1u << pows[0], 1u << pows[1], 1u << pows[2]}; +} + dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) { size_t grid_x = 1; size_t grid_y = 1; diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 3edf61076..50dc6cff8 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -37,6 +37,9 @@ void check_cuda_error(const char* name, cudaError_t err); // The macro version that prints the command that failed. #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) +// Compute the thread block dimensions which fit the given input dimensions. +dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); + // Computes a 2D grid where each element is < UINT_MAX. dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);