Fix random

This commit is contained in:
Angelos Katharopoulos 2025-06-14 23:53:03 -07:00
parent bfe105990b
commit 229e3a29a6
2 changed files with 36 additions and 30 deletions

View File

@ -4,6 +4,7 @@
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <cassert> #include <cassert>
@ -12,6 +13,8 @@ namespace mlx::core {
namespace cu { namespace cu {
namespace cg = cooperative_groups;
__constant__ constexpr uint32_t rotations[2][4] = { __constant__ constexpr uint32_t rotations[2][4] = {
{13, 15, 26, 6}, {13, 15, 26, 6},
{17, 29, 16, 24}}; {17, 29, 16, 24}};
@ -47,27 +50,28 @@ __global__ void rbitsc(
dim3 grid_dims, dim3 grid_dims,
bool odd, bool odd,
uint32_t bytes_per_key) { uint32_t bytes_per_key) {
uint2 index{ auto grid = cg::this_grid();
blockIdx.x * blockDim.x + threadIdx.x, uint thread_index = grid.thread_rank();
blockIdx.y * blockDim.y + threadIdx.y}; uint index_x = thread_index % grid_dims.x;
if (index.x >= grid_dims.x || index.y >= grid_dims.y) { uint index_y = thread_index / grid_dims.x;
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
return; return;
} }
auto kidx = 2 * index.x; auto kidx = 2 * index_x;
auto key = uint2{keys[kidx], keys[kidx + 1]}; auto key = uint2{keys[kidx], keys[kidx + 1]};
auto half_size = grid_dims.y - odd; auto half_size = grid_dims.y - odd;
out += index.x * bytes_per_key; out += index_x * bytes_per_key;
bool drop_last = odd && (index.y == half_size); bool drop_last = odd && (index_y == half_size);
auto bits = threefry2x32_hash( auto bits = threefry2x32_hash(
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y}); key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
size_t idx = size_t(index.y) << 2; size_t idx = size_t(index_y) << 2;
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[0][i]; out[idx + i] = bits.bytes[0][i];
} }
if (!drop_last) { if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2; idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4); int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) { for (int i = 0; i < edge_bytes; ++i) {
out[idx + i] = bits.bytes[1][i]; out[idx + i] = bits.bytes[1][i];
@ -89,30 +93,31 @@ __global__ void rbits(
int32_t ndim, int32_t ndim,
const __grid_constant__ Shape key_shape, const __grid_constant__ Shape key_shape,
const __grid_constant__ Strides key_strides) { const __grid_constant__ Strides key_strides) {
uint2 index{ auto grid = cg::this_grid();
blockIdx.x * blockDim.x + threadIdx.x, uint thread_index = grid.thread_rank();
blockIdx.y * blockDim.y + threadIdx.y}; uint index_x = thread_index % grid_dims.x;
if (index.x >= grid_dims.x || index.y >= grid_dims.y) { uint index_y = thread_index / grid_dims.x;
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
return; return;
} }
auto kidx = 2 * index.x; auto kidx = 2 * index_x;
auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim); auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);
auto k2_elem = auto k2_elem =
elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim); elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);
auto key = uint2{keys[k1_elem], keys[k2_elem]}; auto key = uint2{keys[k1_elem], keys[k2_elem]};
auto half_size = grid_dims.y - odd; auto half_size = grid_dims.y - odd;
out += size_t(index.x) * bytes_per_key; out += size_t(index_x) * bytes_per_key;
bool drop_last = odd && (index.y == half_size); bool drop_last = odd && (index_y == half_size);
auto bits = threefry2x32_hash( auto bits = threefry2x32_hash(
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y}); key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
size_t idx = size_t(index.y) << 2; size_t idx = size_t(index_y) << 2;
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[0][i]; out[idx + i] = bits.bytes[0][i];
} }
if (!drop_last) { if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2; idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4); int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) { for (int i = 0; i < edge_bytes; ++i) {
out[idx + i] = bits.bytes[1][i]; out[idx + i] = bits.bytes[1][i];
@ -153,19 +158,22 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dim3 grid_dims{num_keys, half_size + odd}; dim3 grid_dims{num_keys, half_size + odd};
dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1); int64_t total = grid_dims.x * grid_dims.y;
dim3 num_blocks{ int32_t threads_y = 1;
cuda::ceil_div(grid_dims.x, block_dims.x), while ((total / threads_y) >= (1U << 31)) {
cuda::ceil_div(grid_dims.y, block_dims.y)}; threads_y *= 2;
}
int32_t threads_x = cuda::ceil_div(total, threads_y);
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
if (keys.flags().row_contiguous) { if (keys.flags().row_contiguous) {
cu::rbitsc<<<num_blocks, block_dims, 0, stream>>>( cu::rbitsc<<<grid, block, 0, stream>>>(
keys.data<uint32_t>(), keys.data<uint32_t>(),
out.data<uint8_t>(), out.data<uint8_t>(),
grid_dims, grid_dims,
odd, odd,
bytes_per_key); bytes_per_key);
} else { } else {
cu::rbits<<<num_blocks, block_dims, 0, stream>>>( cu::rbits<<<grid, block, 0, stream>>>(
keys.data<uint32_t>(), keys.data<uint32_t>(),
out.data<uint8_t>(), out.data<uint8_t>(),
grid_dims, grid_dims,

View File

@ -12,8 +12,6 @@ namespace mlx::core {
namespace cu { namespace cu {
namespace cg = cooperative_groups;
template <typename T, bool traditional, bool forward> template <typename T, bool traditional, bool forward>
__device__ void rope_single_impl( __device__ void rope_single_impl(
const T* in, const T* in,