mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00
fix for max block dim (#2631)
This commit is contained in:
@@ -332,9 +332,9 @@ void Compiled::eval_gpu(
|
||||
encoder.set_output_array(out);
|
||||
}
|
||||
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name);
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(outputs[0], large, work_per_thread);
|
||||
get_launch_args(outputs[0], large, work_per_thread, max_block_dims);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||
}
|
||||
|
||||
|
@@ -297,7 +297,8 @@ void load_module(
|
||||
const std::string& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||
CUmodule& module_,
|
||||
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
|
||||
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>>&
|
||||
kernels) {
|
||||
// Load module.
|
||||
char jit_log[4089] = {};
|
||||
CUjit_option options[] = {
|
||||
@@ -314,7 +315,7 @@ void load_module(
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
CUfunction kernel;
|
||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||
kernels[name] = std::make_pair(kernel, false);
|
||||
kernels[name] = std::make_tuple(kernel, false, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -358,7 +359,7 @@ JitModule::~JitModule() {
|
||||
CHECK_CUDA_ERROR(cuModuleUnload(module_));
|
||||
}
|
||||
|
||||
CUfunction JitModule::get_kernel(
|
||||
std::pair<CUfunction, uint> JitModule::get_kernel_and_dims(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel) {
|
||||
auto it = kernels_.find(kernel_name);
|
||||
@@ -369,14 +370,22 @@ CUfunction JitModule::get_kernel(
|
||||
|
||||
// If it is the first time we run this kernel then configure it. Do it only
|
||||
// once!
|
||||
if (!it->second.second) {
|
||||
auto kernel = std::get<0>(it->second);
|
||||
if (!std::get<1>(it->second)) {
|
||||
if (configure_kernel) {
|
||||
configure_kernel(it->second.first);
|
||||
configure_kernel(kernel);
|
||||
}
|
||||
it->second.second = true;
|
||||
std::get<1>(it->second) = true;
|
||||
std::get<2>(it->second) = max_occupancy_block_dim(kernel);
|
||||
}
|
||||
|
||||
return it->second.first;
|
||||
return {kernel, std::get<2>(it->second)};
|
||||
}
|
||||
|
||||
CUfunction JitModule::get_kernel(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel) {
|
||||
return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||
|
@@ -99,10 +99,13 @@ class JitModule {
|
||||
CUfunction get_kernel(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel = nullptr);
|
||||
std::pair<CUfunction, uint> get_kernel_and_dims(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel = nullptr);
|
||||
|
||||
private:
|
||||
CUmodule module_{nullptr};
|
||||
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
|
||||
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>> kernels_;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||
|
@@ -35,12 +35,10 @@ std::tuple<dim3, uint> get_launch_args(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
bool large,
|
||||
int work_per_thread) {
|
||||
int work_per_thread /* = 1 */,
|
||||
uint max_block_dim /* = 1024 */) {
|
||||
size_t nthreads = cuda::ceil_div(size, work_per_thread);
|
||||
uint block_dim = 1024;
|
||||
if (block_dim > nthreads) {
|
||||
block_dim = nthreads;
|
||||
}
|
||||
uint block_dim = max_block_dim < nthreads ? max_block_dim : nthreads;
|
||||
dim3 num_blocks;
|
||||
if (large) {
|
||||
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
|
||||
|
@@ -120,19 +120,28 @@ dim3 get_2d_grid_dims(
|
||||
size_t divisor);
|
||||
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
|
||||
|
||||
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
|
||||
// assuming each thread handles |work_per_thread| elements of |arr|.
|
||||
// Get the num_blocks and block_dims assuming each thread handles
|
||||
// |work_per_thread| elements of |arr|.
|
||||
std::tuple<dim3, uint> get_launch_args(
|
||||
size_t size,
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
bool large,
|
||||
int work_per_thread = 1);
|
||||
int work_per_thread = 1,
|
||||
uint max_block_dim = 1024);
|
||||
|
||||
inline std::tuple<dim3, uint>
|
||||
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
|
||||
inline std::tuple<dim3, uint> get_launch_args(
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread = 1,
|
||||
uint max_block_dim = 1024) {
|
||||
return get_launch_args(
|
||||
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
||||
arr.size(),
|
||||
arr.shape(),
|
||||
arr.strides(),
|
||||
large,
|
||||
work_per_thread,
|
||||
max_block_dim);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -12,6 +12,7 @@ namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
class Device;
|
||||
|
||||
}
|
||||
|
||||
struct Dtype;
|
||||
@@ -86,4 +87,17 @@ class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
|
||||
explicit CudaStream(cu::Device& device);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline uint max_occupancy_block_dim(T kernel) {
|
||||
int _, block_dim;
|
||||
if constexpr (std::is_same_v<T, CUfunction>) {
|
||||
CHECK_CUDA_ERROR(
|
||||
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
|
||||
} else {
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
|
||||
}
|
||||
return block_dim;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -828,6 +828,19 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = fun(arrs)
|
||||
self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0])))
|
||||
|
||||
inputs = [mx.arange(16384).astype(mx.float16) for _ in range(8)]
|
||||
|
||||
def fun(inputs):
|
||||
a = inputs[0] + inputs[1]
|
||||
b = inputs[2] + inputs[3]
|
||||
c = inputs[4] + inputs[5]
|
||||
d = inputs[6] + inputs[7]
|
||||
return a * b * c * d
|
||||
|
||||
out = mx.compile(fun)(inputs)
|
||||
expected = fun(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_compile_many_outputs(self):
|
||||
|
||||
@mx.compile
|
||||
|
Reference in New Issue
Block a user