mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Partly fix random for large sizes (#2798)
This commit is contained in:
@@ -139,10 +139,10 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// keys has shape (N1, ..., NK, 2)
|
// keys has shape (N1, ..., NK, 2)
|
||||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||||
auto& keys = inputs[0];
|
auto& keys = inputs[0];
|
||||||
uint32_t num_keys = keys.size() / 2;
|
size_t num_keys = keys.size() / 2;
|
||||||
|
|
||||||
uint32_t elems_per_key = out.size() / num_keys;
|
size_t elems_per_key = out.size() / num_keys;
|
||||||
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
|
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
@@ -150,19 +150,25 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||||
uint32_t half_size = out_per_key / 2;
|
size_t half_size = out_per_key / 2;
|
||||||
|
|
||||||
bool odd = out_per_key % 2;
|
bool odd = out_per_key % 2;
|
||||||
|
if ((half_size + odd) >= UINT32_MAX || num_keys >= UINT32_MAX) {
|
||||||
|
throw std::runtime_error("[RandomBits::eval_gpu] Large size unsupported");
|
||||||
|
}
|
||||||
|
|
||||||
encoder.set_input_array(keys);
|
encoder.set_input_array(keys);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
dim3 grid_dims{num_keys, half_size + odd};
|
int64_t total = num_keys * (half_size + odd);
|
||||||
int64_t total = grid_dims.x * grid_dims.y;
|
uint32_t threads_y = 1;
|
||||||
int32_t threads_y = 1;
|
while ((total / threads_y) >= UINT_MAX) {
|
||||||
while ((total / threads_y) >= (1U << 31)) {
|
|
||||||
threads_y *= 2;
|
threads_y *= 2;
|
||||||
}
|
}
|
||||||
int32_t threads_x = cuda::ceil_div(total, threads_y);
|
uint32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||||
|
|
||||||
|
dim3 grid_dims{
|
||||||
|
static_cast<uint32_t>(num_keys), static_cast<uint32_t>(half_size + odd)};
|
||||||
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
||||||
auto& stream = encoder.stream();
|
auto& stream = encoder.stream();
|
||||||
if (keys.flags().row_contiguous) {
|
if (keys.flags().row_contiguous) {
|
||||||
|
|||||||
Reference in New Issue
Block a user