mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
CUDA backend: random
This commit is contained in:
120
mlx/backend/cuda/kernels/random.cuh
Normal file
120
mlx/backend/cuda/kernels/random.cuh
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
__constant__ constexpr uint32_t rotations[2][4] = {
|
||||||
|
{13, 15, 26, 6},
|
||||||
|
{17, 29, 16, 24}};
|
||||||
|
|
||||||
|
union rbits {
|
||||||
|
uint2 val;
|
||||||
|
uint8_t bytes[2][4];
|
||||||
|
};
|
||||||
|
|
||||||
|
__device__ rbits threefry2x32_hash(uint2 key, uint2 count) {
|
||||||
|
uint32_t ks[] = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
|
||||||
|
|
||||||
|
rbits v;
|
||||||
|
v.val.x = count.x + ks[0];
|
||||||
|
v.val.y = count.y + ks[1];
|
||||||
|
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
for (auto r : rotations[i % 2]) {
|
||||||
|
v.val.x += v.val.y;
|
||||||
|
v.val.y = (v.val.y << r) | (v.val.y >> (32 - r));
|
||||||
|
v.val.y ^= v.val.x;
|
||||||
|
}
|
||||||
|
v.val.x += ks[(i + 1) % 3];
|
||||||
|
v.val.y += ks[(i + 2) % 3] + i + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void rbitsc(
|
||||||
|
const uint32_t* keys,
|
||||||
|
uint8_t* out,
|
||||||
|
const __grid_constant__ dim3 grid_dim,
|
||||||
|
const __grid_constant__ bool odd,
|
||||||
|
const __grid_constant__ uint32_t bytes_per_key) {
|
||||||
|
uint2 index{
|
||||||
|
blockIdx.x * blockDim.x + threadIdx.x,
|
||||||
|
blockIdx.y * blockDim.y + threadIdx.y};
|
||||||
|
if (index.x >= grid_dim.x || index.y >= grid_dim.y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kidx = 2 * index.x;
|
||||||
|
auto key = uint2{keys[kidx], keys[kidx + 1]};
|
||||||
|
auto half_size = grid_dim.y - odd;
|
||||||
|
out += index.x * bytes_per_key;
|
||||||
|
bool drop_last = odd && (index.y == half_size);
|
||||||
|
auto bits = threefry2x32_hash(
|
||||||
|
key, uint2{index.y, drop_last ? 0 : index.y + grid_dim.y});
|
||||||
|
size_t idx = size_t(index.y) << 2;
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
out[idx + i] = bits.bytes[0][i];
|
||||||
|
}
|
||||||
|
if (!drop_last) {
|
||||||
|
idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2;
|
||||||
|
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||||
|
int edge_bytes = (bytes_per_key % 4);
|
||||||
|
for (int i = 0; i < edge_bytes; ++i) {
|
||||||
|
out[idx + i] = bits.bytes[1][i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
out[idx + i] = bits.bytes[1][i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void rbits(
|
||||||
|
const uint32_t* keys,
|
||||||
|
uint8_t* out,
|
||||||
|
const __grid_constant__ dim3 grid_dim,
|
||||||
|
const __grid_constant__ bool odd,
|
||||||
|
const __grid_constant__ uint32_t bytes_per_key,
|
||||||
|
const __grid_constant__ int32_t ndim,
|
||||||
|
const __grid_constant__ Shape key_shape,
|
||||||
|
const __grid_constant__ Strides key_strides) {
|
||||||
|
uint2 index{
|
||||||
|
blockIdx.x * blockDim.x + threadIdx.x,
|
||||||
|
blockIdx.y * blockDim.y + threadIdx.y};
|
||||||
|
if (index.x >= grid_dim.x || index.y >= grid_dim.y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kidx = 2 * index.x;
|
||||||
|
auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);
|
||||||
|
auto k2_elem =
|
||||||
|
elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);
|
||||||
|
auto key = uint2{keys[k1_elem], keys[k2_elem]};
|
||||||
|
auto half_size = grid_dim.y - odd;
|
||||||
|
out += size_t(index.x) * bytes_per_key;
|
||||||
|
bool drop_last = odd && (index.y == half_size);
|
||||||
|
auto bits = threefry2x32_hash(
|
||||||
|
key, uint2{index.y, drop_last ? 0 : index.y + grid_dim.y});
|
||||||
|
size_t idx = size_t(index.y) << 2;
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
out[idx + i] = bits.bytes[0][i];
|
||||||
|
}
|
||||||
|
if (!drop_last) {
|
||||||
|
idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2;
|
||||||
|
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||||
|
int edge_bytes = (bytes_per_key % 4);
|
||||||
|
for (int i = 0; i < edge_bytes; ++i) {
|
||||||
|
out[idx + i] = bits.bytes[1][i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
out[idx + i] = bits.bytes[1][i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
@@ -4,6 +4,7 @@
|
|||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/arange.cuh"
|
#include "mlx/backend/cuda/kernels/arange.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/random.cuh"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
@@ -43,6 +44,59 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("RandomBits::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
// keys has shape (N1, ..., NK, 2)
|
||||||
|
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||||
|
auto& keys = inputs[0];
|
||||||
|
size_t num_keys = keys.size() / 2;
|
||||||
|
|
||||||
|
size_t elems_per_key = out.size() / num_keys;
|
||||||
|
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||||
|
size_t half_size = out_per_key / 2;
|
||||||
|
bool odd = out_per_key % 2;
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(keys);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
dim3 grid_dim{
|
||||||
|
static_cast<uint32_t>(num_keys),
|
||||||
|
static_cast<uint32_t>(half_size + odd)};
|
||||||
|
dim3 block_dim = get_block_dims(grid_dim.x, grid_dim.y, 1);
|
||||||
|
dim3 num_blocks{
|
||||||
|
cuda::ceil_div(grid_dim.x, block_dim.x),
|
||||||
|
cuda::ceil_div(grid_dim.y, block_dim.y)};
|
||||||
|
if (keys.flags().row_contiguous) {
|
||||||
|
cu::rbitsc<<<num_blocks, block_dim, 0, stream>>>(
|
||||||
|
keys.data<uint32_t>(),
|
||||||
|
out.data<uint8_t>(),
|
||||||
|
grid_dim,
|
||||||
|
odd,
|
||||||
|
bytes_per_key);
|
||||||
|
} else {
|
||||||
|
cu::rbits<<<num_blocks, block_dim, 0, stream>>>(
|
||||||
|
keys.data<uint32_t>(),
|
||||||
|
out.data<uint8_t>(),
|
||||||
|
grid_dim,
|
||||||
|
odd,
|
||||||
|
bytes_per_key,
|
||||||
|
keys.ndim(),
|
||||||
|
const_param(keys.shape()),
|
||||||
|
const_param(keys.strides()));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
void func::eval_gpu( \
|
void func::eval_gpu( \
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
@@ -76,7 +130,6 @@ NO_GPU(Matmul)
|
|||||||
NO_GPU(Partition)
|
NO_GPU(Partition)
|
||||||
NO_GPU_MULTI(QRF)
|
NO_GPU_MULTI(QRF)
|
||||||
NO_GPU(QuantizedMatmul)
|
NO_GPU(QuantizedMatmul)
|
||||||
NO_GPU(RandomBits)
|
|
||||||
NO_GPU(Scan)
|
NO_GPU(Scan)
|
||||||
NO_GPU(Scatter)
|
NO_GPU(Scatter)
|
||||||
NO_GPU(ScatterAxis)
|
NO_GPU(ScatterAxis)
|
||||||
|
|||||||
@@ -24,6 +24,37 @@ void check_cuda_error(const char* name, cudaError_t err) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: The implementation is identical to meta/utils.cpp
|
// TODO: The implementation is identical to meta/utils.cpp
|
||||||
|
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2) {
|
||||||
|
int pows[3] = {0, 0, 0};
|
||||||
|
int sum = 0;
|
||||||
|
while (true) {
|
||||||
|
int presum = sum;
|
||||||
|
// Check all the pows
|
||||||
|
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||||
|
pows[0]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||||
|
pows[1]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||||
|
pows[2]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == presum || sum == pow2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {1u << pows[0], 1u << pows[1], 1u << pows[2]};
|
||||||
|
}
|
||||||
|
|
||||||
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) {
|
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) {
|
||||||
size_t grid_x = 1;
|
size_t grid_x = 1;
|
||||||
size_t grid_y = 1;
|
size_t grid_y = 1;
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ void check_cuda_error(const char* name, cudaError_t err);
|
|||||||
// The macro version that prints the command that failed.
|
// The macro version that prints the command that failed.
|
||||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||||
|
|
||||||
|
// Compute the thread block dimensions which fit the given input dimensions.
|
||||||
|
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||||
|
|
||||||
// Computes a 2D grid where each element is < UINT_MAX.
|
// Computes a 2D grid where each element is < UINT_MAX.
|
||||||
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);
|
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user