mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
fix for max block dim (#2631)
This commit is contained in:
@@ -332,9 +332,9 @@ void Compiled::eval_gpu(
|
|||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name);
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(outputs[0], large, work_per_thread);
|
get_launch_args(outputs[0], large, work_per_thread, max_block_dims);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -297,7 +297,8 @@ void load_module(
|
|||||||
const std::string& ptx,
|
const std::string& ptx,
|
||||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||||
CUmodule& module_,
|
CUmodule& module_,
|
||||||
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
|
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>>&
|
||||||
|
kernels) {
|
||||||
// Load module.
|
// Load module.
|
||||||
char jit_log[4089] = {};
|
char jit_log[4089] = {};
|
||||||
CUjit_option options[] = {
|
CUjit_option options[] = {
|
||||||
@@ -314,7 +315,7 @@ void load_module(
|
|||||||
for (const auto& [name, mangled] : ptx_kernels) {
|
for (const auto& [name, mangled] : ptx_kernels) {
|
||||||
CUfunction kernel;
|
CUfunction kernel;
|
||||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||||
kernels[name] = std::make_pair(kernel, false);
|
kernels[name] = std::make_tuple(kernel, false, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -358,7 +359,7 @@ JitModule::~JitModule() {
|
|||||||
CHECK_CUDA_ERROR(cuModuleUnload(module_));
|
CHECK_CUDA_ERROR(cuModuleUnload(module_));
|
||||||
}
|
}
|
||||||
|
|
||||||
CUfunction JitModule::get_kernel(
|
std::pair<CUfunction, uint> JitModule::get_kernel_and_dims(
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
std::function<void(CUfunction)> configure_kernel) {
|
std::function<void(CUfunction)> configure_kernel) {
|
||||||
auto it = kernels_.find(kernel_name);
|
auto it = kernels_.find(kernel_name);
|
||||||
@@ -369,14 +370,22 @@ CUfunction JitModule::get_kernel(
|
|||||||
|
|
||||||
// If it is the first time we run this kernel then configure it. Do it only
|
// If it is the first time we run this kernel then configure it. Do it only
|
||||||
// once!
|
// once!
|
||||||
if (!it->second.second) {
|
auto kernel = std::get<0>(it->second);
|
||||||
|
if (!std::get<1>(it->second)) {
|
||||||
if (configure_kernel) {
|
if (configure_kernel) {
|
||||||
configure_kernel(it->second.first);
|
configure_kernel(kernel);
|
||||||
}
|
}
|
||||||
it->second.second = true;
|
std::get<1>(it->second) = true;
|
||||||
|
std::get<2>(it->second) = max_occupancy_block_dim(kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
return it->second.first;
|
return {kernel, std::get<2>(it->second)};
|
||||||
|
}
|
||||||
|
|
||||||
|
CUfunction JitModule::get_kernel(
|
||||||
|
const std::string& kernel_name,
|
||||||
|
std::function<void(CUfunction)> configure_kernel) {
|
||||||
|
return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||||
|
@@ -99,10 +99,13 @@ class JitModule {
|
|||||||
CUfunction get_kernel(
|
CUfunction get_kernel(
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
std::function<void(CUfunction)> configure_kernel = nullptr);
|
std::function<void(CUfunction)> configure_kernel = nullptr);
|
||||||
|
std::pair<CUfunction, uint> get_kernel_and_dims(
|
||||||
|
const std::string& kernel_name,
|
||||||
|
std::function<void(CUfunction)> configure_kernel = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CUmodule module_{nullptr};
|
CUmodule module_{nullptr};
|
||||||
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
|
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>> kernels_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||||
|
@@ -35,12 +35,10 @@ std::tuple<dim3, uint> get_launch_args(
|
|||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const Strides& strides,
|
const Strides& strides,
|
||||||
bool large,
|
bool large,
|
||||||
int work_per_thread) {
|
int work_per_thread /* = 1 */,
|
||||||
|
uint max_block_dim /* = 1024 */) {
|
||||||
size_t nthreads = cuda::ceil_div(size, work_per_thread);
|
size_t nthreads = cuda::ceil_div(size, work_per_thread);
|
||||||
uint block_dim = 1024;
|
uint block_dim = max_block_dim < nthreads ? max_block_dim : nthreads;
|
||||||
if (block_dim > nthreads) {
|
|
||||||
block_dim = nthreads;
|
|
||||||
}
|
|
||||||
dim3 num_blocks;
|
dim3 num_blocks;
|
||||||
if (large) {
|
if (large) {
|
||||||
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
|
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
|
||||||
|
@@ -120,19 +120,28 @@ dim3 get_2d_grid_dims(
|
|||||||
size_t divisor);
|
size_t divisor);
|
||||||
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
|
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
|
||||||
|
|
||||||
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
|
// Get the num_blocks and block_dims assuming each thread handles
|
||||||
// assuming each thread handles |work_per_thread| elements of |arr|.
|
// |work_per_thread| elements of |arr|.
|
||||||
std::tuple<dim3, uint> get_launch_args(
|
std::tuple<dim3, uint> get_launch_args(
|
||||||
size_t size,
|
size_t size,
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const Strides& strides,
|
const Strides& strides,
|
||||||
bool large,
|
bool large,
|
||||||
int work_per_thread = 1);
|
int work_per_thread = 1,
|
||||||
|
uint max_block_dim = 1024);
|
||||||
|
|
||||||
inline std::tuple<dim3, uint>
|
inline std::tuple<dim3, uint> get_launch_args(
|
||||||
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
|
const array& arr,
|
||||||
|
bool large,
|
||||||
|
int work_per_thread = 1,
|
||||||
|
uint max_block_dim = 1024) {
|
||||||
return get_launch_args(
|
return get_launch_args(
|
||||||
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
arr.size(),
|
||||||
|
arr.shape(),
|
||||||
|
arr.strides(),
|
||||||
|
large,
|
||||||
|
work_per_thread,
|
||||||
|
max_block_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@@ -12,6 +12,7 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
class Device;
|
class Device;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Dtype;
|
struct Dtype;
|
||||||
@@ -86,4 +87,17 @@ class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
|
|||||||
explicit CudaStream(cu::Device& device);
|
explicit CudaStream(cu::Device& device);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline uint max_occupancy_block_dim(T kernel) {
|
||||||
|
int _, block_dim;
|
||||||
|
if constexpr (std::is_same_v<T, CUfunction>) {
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
|
||||||
|
} else {
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
|
||||||
|
}
|
||||||
|
return block_dim;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@@ -828,6 +828,19 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
out = fun(arrs)
|
out = fun(arrs)
|
||||||
self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0])))
|
self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0])))
|
||||||
|
|
||||||
|
inputs = [mx.arange(16384).astype(mx.float16) for _ in range(8)]
|
||||||
|
|
||||||
|
def fun(inputs):
|
||||||
|
a = inputs[0] + inputs[1]
|
||||||
|
b = inputs[2] + inputs[3]
|
||||||
|
c = inputs[4] + inputs[5]
|
||||||
|
d = inputs[6] + inputs[7]
|
||||||
|
return a * b * c * d
|
||||||
|
|
||||||
|
out = mx.compile(fun)(inputs)
|
||||||
|
expected = fun(inputs)
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
def test_compile_many_outputs(self):
|
def test_compile_many_outputs(self):
|
||||||
|
|
||||||
@mx.compile
|
@mx.compile
|
||||||
|
Reference in New Issue
Block a user