mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix thread group for large arrays (#1543)
* fix thread group for large arrays * comment * one more
This commit is contained in:
parent
048fabdabd
commit
884af42da2
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
@ -110,6 +109,7 @@ void binary_op_gpu_inplace(
|
|||||||
compute_encoder.set_output_array(outputs[1], arg_idx++);
|
compute_encoder.set_output_array(outputs[1], arg_idx++);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (bopt == BinaryOpType::General) {
|
if (bopt == BinaryOpType::General) {
|
||||||
// Launch up to 3D grid of threads
|
// Launch up to 3D grid of threads
|
||||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
@ -132,7 +132,6 @@ void binary_op_gpu_inplace(
|
|||||||
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||||
}
|
}
|
||||||
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size != 1024) {
|
if (thread_group_size != 1024) {
|
||||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||||
}
|
}
|
||||||
@ -142,13 +141,12 @@ void binary_op_gpu_inplace(
|
|||||||
} else {
|
} else {
|
||||||
// Launch a 1D or 2D grid of threads
|
// Launch a 1D or 2D grid of threads
|
||||||
size_t nthreads = out.data_size();
|
size_t nthreads = out.data_size();
|
||||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
|
||||||
: MTL::Size(nthreads, 1, 1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||||
|
: MTL::Size(nthreads, 1, 1);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -421,11 +421,12 @@ void Compiled::eval_gpu(
|
|||||||
// Launch the kernel
|
// Launch the kernel
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
size_t nthreads = outputs[0].data_size();
|
size_t nthreads = outputs[0].data_size();
|
||||||
|
MTL::Size group_dims(
|
||||||
|
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||||
|
|
||||||
MTL::Size grid_dims = use_2d
|
MTL::Size grid_dims = use_2d
|
||||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||||
: MTL::Size(nthreads, 1, 1);
|
: MTL::Size(nthreads, 1, 1);
|
||||||
MTL::Size group_dims(
|
|
||||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
} else {
|
} else {
|
||||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
|
@ -120,6 +120,7 @@ void copy_gpu_inplace(
|
|||||||
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
|
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
|
||||||
compute_encoder.set_output_array(out, 1, out_offset);
|
compute_encoder.set_output_array(out, 1, out_offset);
|
||||||
|
|
||||||
|
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||||
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
||||||
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
||||||
@ -145,7 +146,6 @@ void copy_gpu_inplace(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size != 1024) {
|
if (thread_group_size != 1024) {
|
||||||
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
|
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
|
||||||
}
|
}
|
||||||
@ -155,13 +155,12 @@ void copy_gpu_inplace(
|
|||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
} else {
|
} else {
|
||||||
size_t nthreads = out.data_size();
|
size_t nthreads = out.data_size();
|
||||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
|
||||||
: MTL::Size(nthreads, 1, 1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||||
|
: MTL::Size(nthreads, 1, 1);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -205,14 +204,14 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
|||||||
compute_encoder.set_input_array(val, 0);
|
compute_encoder.set_input_array(val, 0);
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
|
||||||
|
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
size_t nthreads = out.data_size();
|
size_t nthreads = out.data_size();
|
||||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
|
||||||
: MTL::Size(nthreads, 1, 1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||||
|
: MTL::Size(nthreads, 1, 1);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,6 +72,7 @@ void ternary_op_gpu_inplace(
|
|||||||
compute_encoder.set_input_array(donate_c ? out : c, 2);
|
compute_encoder.set_input_array(donate_c ? out : c, 2);
|
||||||
compute_encoder.set_output_array(out, 3);
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
|
||||||
|
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (topt == TernaryOpType::General) {
|
if (topt == TernaryOpType::General) {
|
||||||
// Launch up to 3D grid of threads
|
// Launch up to 3D grid of threads
|
||||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
@ -93,7 +94,6 @@ void ternary_op_gpu_inplace(
|
|||||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
|
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
|
||||||
}
|
}
|
||||||
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size != 1024) {
|
if (thread_group_size != 1024) {
|
||||||
throw std::runtime_error("[Metal::ternary] Must use 1024 sized block");
|
throw std::runtime_error("[Metal::ternary] Must use 1024 sized block");
|
||||||
}
|
}
|
||||||
@ -103,13 +103,12 @@ void ternary_op_gpu_inplace(
|
|||||||
} else {
|
} else {
|
||||||
// Launch a 1D or 2D grid of threads
|
// Launch a 1D or 2D grid of threads
|
||||||
size_t nthreads = out.data_size();
|
size_t nthreads = out.data_size();
|
||||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
|
||||||
: MTL::Size(nthreads, 1, 1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||||
|
: MTL::Size(nthreads, 1, 1);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -47,9 +47,7 @@ void unary_op_gpu_inplace(
|
|||||||
kernel_name += "_" + op + type_to_name(in) + type_to_name(out);
|
kernel_name += "_" + op + type_to_name(in) + type_to_name(out);
|
||||||
auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op);
|
auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op);
|
||||||
|
|
||||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides())
|
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
: MTL::Size(nthreads, 1, 1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
compute_encoder.set_input_array(
|
compute_encoder.set_input_array(
|
||||||
@ -75,6 +73,8 @@ void unary_op_gpu_inplace(
|
|||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||||
|
: MTL::Size(nthreads, 1, 1);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -103,6 +103,9 @@ MTL::Size get_2d_grid_dims(
|
|||||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||||
throw std::runtime_error("Unable to safely factor shape.");
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
}
|
}
|
||||||
|
if (grid_y > grid_x) {
|
||||||
|
std::swap(grid_x, grid_y);
|
||||||
|
}
|
||||||
return MTL::Size(
|
return MTL::Size(
|
||||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
}
|
}
|
||||||
@ -145,6 +148,9 @@ MTL::Size get_2d_grid_dims(
|
|||||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
|
||||||
throw std::runtime_error("Unable to safely factor shape.");
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
}
|
}
|
||||||
|
if (grid_y > grid_x) {
|
||||||
|
std::swap(grid_x, grid_y);
|
||||||
|
}
|
||||||
return MTL::Size(
|
return MTL::Size(
|
||||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user