mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Working 64-bit scans (#1506)
This commit is contained in:
committed by
GitHub
parent
32972a5924
commit
c9b41d460f
@@ -14,19 +14,27 @@ namespace mlx::core {
|
||||
void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Ensure contiguity
|
||||
std::vector<array> copies;
|
||||
auto in = inputs[0];
|
||||
if (!in.flags().row_contiguous) {
|
||||
if (in.flags().contiguous && in.strides()[axis_] != 0) {
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
in = arr_copy;
|
||||
out.move_shared_buffer(in);
|
||||
}
|
||||
|
||||
bool contiguous = in.strides()[axis_] == 1;
|
||||
@@ -61,7 +69,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (contiguous) {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
@@ -70,7 +79,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||
constexpr int simd_size = 32;
|
||||
int elements_per_simd = n_reads * simd_size;
|
||||
int thread_groups = in.size() / size;
|
||||
int thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (size <= n_reads * 1024) {
|
||||
thread_group_size =
|
||||
@@ -82,28 +90,41 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
thread_group_size = std::min(
|
||||
thread_group_size,
|
||||
static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));
|
||||
MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
auto tmp_grid_dims =
|
||||
get_2d_grid_dims(in.shape(), in.strides(), /** divisor= */ size);
|
||||
MTL::Size grid_dims(
|
||||
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
|
||||
MTL::Size group_dims(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
size_t stride = in.strides()[axis_];
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
size_t stride_blocks = (stride + bn - 1) / bn;
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&stride_blocks, sizeof(size_t), 4);
|
||||
|
||||
// Compute the thread grid
|
||||
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||
int tile_x = 32;
|
||||
int tile_y = 32;
|
||||
int elements_per_tile_x = tile_x * n_reads;
|
||||
int grid_y = in.size() / size / stride;
|
||||
int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x;
|
||||
MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1);
|
||||
MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1);
|
||||
int n_simdgroups = bn / n_reads;
|
||||
int thread_group_size = n_simdgroups * 32;
|
||||
auto tmp_grid_dims = get_2d_grid_dims(
|
||||
in.shape(), in.strides(), /** divisor= */ size * stride);
|
||||
if (tmp_grid_dims.width * stride_blocks <= UINT_MAX) {
|
||||
tmp_grid_dims.width *= stride_blocks;
|
||||
} else {
|
||||
tmp_grid_dims.height *= stride_blocks;
|
||||
}
|
||||
MTL::Size grid_dims(
|
||||
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
|
||||
MTL::Size group_dims(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user