fix for max block dim (#2631)

This commit is contained in:
Awni Hannun
2025-09-29 08:59:25 -07:00
committed by GitHub
parent e76a8dd5c5
commit dc371ae7a5
7 changed files with 67 additions and 21 deletions

View File

@@ -332,9 +332,9 @@ void Compiled::eval_gpu(
encoder.set_output_array(out); encoder.set_output_array(out);
} }
auto kernel = mod.get_kernel(kernel_name); auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name);
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(outputs[0], large, work_per_thread); get_launch_args(outputs[0], large, work_per_thread, max_block_dims);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
} }

View File

@@ -297,7 +297,8 @@ void load_module(
const std::string& ptx, const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels, const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_, CUmodule& module_,
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) { std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>>&
kernels) {
// Load module. // Load module.
char jit_log[4089] = {}; char jit_log[4089] = {};
CUjit_option options[] = { CUjit_option options[] = {
@@ -314,7 +315,7 @@ void load_module(
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel; CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels[name] = std::make_pair(kernel, false); kernels[name] = std::make_tuple(kernel, false, 0);
} }
} }
@@ -358,7 +359,7 @@ JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_)); CHECK_CUDA_ERROR(cuModuleUnload(module_));
} }
CUfunction JitModule::get_kernel( std::pair<CUfunction, uint> JitModule::get_kernel_and_dims(
const std::string& kernel_name, const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) { std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name); auto it = kernels_.find(kernel_name);
@@ -369,14 +370,22 @@ CUfunction JitModule::get_kernel(
// If it is the first time we run this kernel then configure it. Do it only // If it is the first time we run this kernel then configure it. Do it only
// once! // once!
if (!it->second.second) { auto kernel = std::get<0>(it->second);
if (!std::get<1>(it->second)) {
if (configure_kernel) { if (configure_kernel) {
configure_kernel(it->second.first); configure_kernel(kernel);
} }
it->second.second = true; std::get<1>(it->second) = true;
std::get<2>(it->second) = max_occupancy_block_dim(kernel);
} }
return it->second.first; return {kernel, std::get<2>(it->second)};
}
CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first;
} }
std::unordered_map<std::string, JitModule>& get_jit_module_cache() { std::unordered_map<std::string, JitModule>& get_jit_module_cache() {

View File

@@ -99,10 +99,13 @@ class JitModule {
CUfunction get_kernel( CUfunction get_kernel(
const std::string& kernel_name, const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr); std::function<void(CUfunction)> configure_kernel = nullptr);
std::pair<CUfunction, uint> get_kernel_and_dims(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
private: private:
CUmodule module_{nullptr}; CUmodule module_{nullptr};
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_; std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>> kernels_;
}; };
std::unordered_map<std::string, JitModule>& get_jit_module_cache(); std::unordered_map<std::string, JitModule>& get_jit_module_cache();

View File

@@ -35,12 +35,10 @@ std::tuple<dim3, uint> get_launch_args(
const Shape& shape, const Shape& shape,
const Strides& strides, const Strides& strides,
bool large, bool large,
int work_per_thread) { int work_per_thread /* = 1 */,
uint max_block_dim /* = 1024 */) {
size_t nthreads = cuda::ceil_div(size, work_per_thread); size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = 1024; uint block_dim = max_block_dim < nthreads ? max_block_dim : nthreads;
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks; dim3 num_blocks;
if (large) { if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread); num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);

View File

@@ -120,19 +120,28 @@ dim3 get_2d_grid_dims(
size_t divisor); size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2); std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Get the num_blocks and block_dims that maximize occupancy for |kernel|, // Get the num_blocks and block_dims assuming each thread handles
// assuming each thread handles |work_per_thread| elements of |arr|. // |work_per_thread| elements of |arr|.
std::tuple<dim3, uint> get_launch_args( std::tuple<dim3, uint> get_launch_args(
size_t size, size_t size,
const Shape& shape, const Shape& shape,
const Strides& strides, const Strides& strides,
bool large, bool large,
int work_per_thread = 1); int work_per_thread = 1,
uint max_block_dim = 1024);
inline std::tuple<dim3, uint> inline std::tuple<dim3, uint> get_launch_args(
get_launch_args(const array& arr, bool large, int work_per_thread = 1) { const array& arr,
bool large,
int work_per_thread = 1,
uint max_block_dim = 1024) {
return get_launch_args( return get_launch_args(
arr.size(), arr.shape(), arr.strides(), large, work_per_thread); arr.size(),
arr.shape(),
arr.strides(),
large,
work_per_thread,
max_block_dim);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -12,6 +12,7 @@ namespace mlx::core {
namespace cu { namespace cu {
class Device; class Device;
} }
struct Dtype; struct Dtype;
@@ -86,4 +87,17 @@ class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
explicit CudaStream(cu::Device& device); explicit CudaStream(cu::Device& device);
}; };
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
if constexpr (std::is_same_v<T, CUfunction>) {
CHECK_CUDA_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
} else {
CHECK_CUDA_ERROR(
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
}
return block_dim;
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -828,6 +828,19 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(arrs) out = fun(arrs)
self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0]))) self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0])))
inputs = [mx.arange(16384).astype(mx.float16) for _ in range(8)]
def fun(inputs):
a = inputs[0] + inputs[1]
b = inputs[2] + inputs[3]
c = inputs[4] + inputs[5]
d = inputs[6] + inputs[7]
return a * b * c * d
out = mx.compile(fun)(inputs)
expected = fun(inputs)
self.assertTrue(mx.allclose(out, expected))
def test_compile_many_outputs(self): def test_compile_many_outputs(self):
@mx.compile @mx.compile