[CUDA] Partly fix random for large sizes (#2798)

This commit is contained in:
Awni Hannun
2025-11-20 07:27:50 -08:00
committed by GitHub
parent d8e9ded928
commit f9e1a14135

View File

@@ -139,10 +139,10 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
// keys has shape (N1, ..., NK, 2) // keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...) // out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0]; auto& keys = inputs[0];
uint32_t num_keys = keys.size() / 2; size_t num_keys = keys.size() / 2;
uint32_t elems_per_key = out.size() / num_keys; size_t elems_per_key = out.size() / num_keys;
uint32_t bytes_per_key = out.itemsize() * elems_per_key; size_t bytes_per_key = out.itemsize() * elems_per_key;
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder)); out.set_data(cu::malloc_async(out.nbytes(), encoder));
@@ -150,19 +150,25 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4; size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
uint32_t half_size = out_per_key / 2; size_t half_size = out_per_key / 2;
bool odd = out_per_key % 2; bool odd = out_per_key % 2;
if ((half_size + odd) >= UINT32_MAX || num_keys >= UINT32_MAX) {
throw std::runtime_error("[RandomBits::eval_gpu] Large size unsupported");
}
encoder.set_input_array(keys); encoder.set_input_array(keys);
encoder.set_output_array(out); encoder.set_output_array(out);
dim3 grid_dims{num_keys, half_size + odd}; int64_t total = num_keys * (half_size + odd);
int64_t total = grid_dims.x * grid_dims.y; uint32_t threads_y = 1;
int32_t threads_y = 1; while ((total / threads_y) >= UINT_MAX) {
while ((total / threads_y) >= (1U << 31)) {
threads_y *= 2; threads_y *= 2;
} }
int32_t threads_x = cuda::ceil_div(total, threads_y); uint32_t threads_x = cuda::ceil_div(total, threads_y);
dim3 grid_dims{
static_cast<uint32_t>(num_keys), static_cast<uint32_t>(half_size + odd)};
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
auto& stream = encoder.stream(); auto& stream = encoder.stream();
if (keys.flags().row_contiguous) { if (keys.flags().row_contiguous) {