CUDA backend: random (#2261)

This commit is contained in:
Cheng 2025-06-11 00:59:56 +09:00 committed by GitHub
parent bae9a6b404
commit 7c4eb5d03e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 182 additions and 1 deletions

View File

@ -15,6 +15,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu

View File

@ -91,7 +91,6 @@ NO_GPU_MULTI(LUF)
NO_GPU(Partition)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(RandomBits)
NO_GPU(Reduce)
NO_GPU(Scan)
NO_GPU(Scatter)

181
mlx/backend/cuda/random.cu Normal file
View File

@ -0,0 +1,181 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <cassert>
namespace mlx::core {
namespace cu {
__constant__ constexpr uint32_t rotations[2][4] = {
{13, 15, 26, 6},
{17, 29, 16, 24}};
union rbits {
uint2 val;
uint8_t bytes[2][4];
};
__device__ rbits threefry2x32_hash(uint2 key, uint2 count) {
uint32_t ks[] = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
rbits v;
v.val.x = count.x + ks[0];
v.val.y = count.y + ks[1];
for (int i = 0; i < 5; ++i) {
for (auto r : rotations[i % 2]) {
v.val.x += v.val.y;
v.val.y = (v.val.y << r) | (v.val.y >> (32 - r));
v.val.y ^= v.val.x;
}
v.val.x += ks[(i + 1) % 3];
v.val.y += ks[(i + 2) % 3] + i + 1;
}
return v;
}
__global__ void rbitsc(
const uint32_t* keys,
uint8_t* out,
dim3 grid_dims,
bool odd,
uint32_t bytes_per_key) {
uint2 index{
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y};
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
return;
}
auto kidx = 2 * index.x;
auto key = uint2{keys[kidx], keys[kidx + 1]};
auto half_size = grid_dims.y - odd;
out += index.x * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
auto bits = threefry2x32_hash(
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
size_t idx = size_t(index.y) << 2;
for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[0][i];
}
if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[idx + i] = bits.bytes[1][i];
}
} else {
for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[1][i];
}
}
}
}
__global__ void rbits(
const uint32_t* keys,
uint8_t* out,
dim3 grid_dims,
bool odd,
uint32_t bytes_per_key,
int32_t ndim,
const __grid_constant__ Shape key_shape,
const __grid_constant__ Strides key_strides) {
uint2 index{
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y};
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
return;
}
auto kidx = 2 * index.x;
auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);
auto k2_elem =
elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);
auto key = uint2{keys[k1_elem], keys[k2_elem]};
auto half_size = grid_dims.y - odd;
out += size_t(index.x) * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
auto bits = threefry2x32_hash(
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
size_t idx = size_t(index.y) << 2;
for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[0][i];
}
if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[idx + i] = bits.bytes[1][i];
}
} else {
for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[1][i];
}
}
}
}
} // namespace cu
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("RandomBits::eval_gpu");
assert(inputs.size() == 1);
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
uint32_t num_keys = keys.size() / 2;
uint32_t elems_per_key = out.size() / num_keys;
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4;
uint32_t half_size = out_per_key / 2;
bool odd = out_per_key % 2;
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(keys);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
dim3 grid_dims{num_keys, half_size + odd};
dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1);
dim3 num_blocks{
cuda::ceil_div(grid_dims.x, block_dims.x),
cuda::ceil_div(grid_dims.y, block_dims.y)};
if (keys.flags().row_contiguous) {
cu::rbitsc<<<num_blocks, block_dims, 0, stream>>>(
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,
odd,
bytes_per_key);
} else {
cu::rbits<<<num_blocks, block_dims, 0, stream>>>(
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,
odd,
bytes_per_key,
keys.ndim(),
const_param(keys.shape()),
const_param(keys.strides()));
}
});
}
} // namespace mlx::core