Working custom kernels jointly

This commit is contained in:
Angelos Katharopoulos 2025-08-12 14:30:29 -07:00
parent 0b309e8edc
commit 3b94e37270
8 changed files with 425 additions and 15 deletions

View File

@ -234,11 +234,47 @@ MetalKernelFunction metal_kernel(
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value),
init_value,
std::vector<ScalarArg>{},
false,
0),
std::move(inputs));
};
}
std::vector<array> precompiled_custom_kernel(
const std::string& name,
const std::string& compiled_source,
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
const std::vector<ScalarArg>& scalars,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
int shared_memory,
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice s) {
std::vector<CustomKernelShapeInfo> shape_infos(
inputs.size(), CustomKernelShapeInfo{false, false, false});
return array::make_arrays(
output_shapes,
output_dtypes,
std::make_shared<CustomKernel>(
to_stream(s),
name,
compiled_source,
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value,
scalars,
true,
shared_memory),
inputs);
}
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@ -274,10 +310,19 @@ void CustomKernel::eval_gpu(
}
// Compile the custom kernel
std::string kernel_name = "mlx::core::cu::" + name_;
cu::JitModule& mod = cu::get_jit_module(s.device, name_, [&]() {
return std::make_pair(source_, std::vector<std::string>{kernel_name});
});
std::string kernel_name =
(is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
auto get_module = [&]() -> cu::JitModule& {
if (is_precompiled_) {
return cu::get_jit_module(
s.device, name_, source_, std::vector<std::string>{kernel_name});
} else {
return cu::get_jit_module(s.device, name_, [&]() {
return std::make_pair(source_, std::vector<std::string>{kernel_name});
});
}
};
cu::JitModule& mod = get_module();
// Make the arguments
cu::KernelArgs args;
@ -298,6 +343,15 @@ void CustomKernel::eval_gpu(
for (auto& out : outputs) {
args.append(out);
}
for (auto& s : scalar_arguments_) {
if (std::holds_alternative<bool>(s)) {
args.append(std::get<bool>(s));
} else if (std::holds_alternative<int>(s)) {
args.append(std::get<int>(s));
} else if (std::holds_alternative<float>(s)) {
args.append(std::get<float>(s));
}
}
// Make the grid
const auto [tx, ty, tz] = threadgroup_;
@ -313,8 +367,17 @@ void CustomKernel::eval_gpu(
for (const auto& out : outputs) {
encoder.set_output_array(out);
}
for (const auto& t : copies) {
encoder.add_temporary(t);
}
auto kernel = mod.get_kernel(kernel_name);
encoder.add_kernel_node(kernel, grid, block, 0, args.args());
if (shared_memory_ > 0 && shared_memory_ > 48000) {
cuFuncSetAttribute(
kernel,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_memory_);
}
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
}
} // namespace mlx::core::fast

View File

@ -10,6 +10,8 @@
#include <future>
#include <unordered_set>
#include <iostream>
namespace mlx::core::cu {
namespace {
@ -249,9 +251,201 @@ void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
insert_graph_dependencies(GraphNode{node, 'K'});
}
void debugCuGraphAddKernelNode(
CUgraphNode* node,
cudaGraph_t graph,
const CUDA_KERNEL_NODE_PARAMS* params) {
std::cout << "=== Debugging cuGraphAddKernelNode ===" << std::endl;
// Check graph
if (graph == nullptr) {
std::cout << "ERROR: graph is NULL" << std::endl;
return;
}
// Check params structure
if (params == nullptr) {
std::cout << "ERROR: params is NULL" << std::endl;
return;
}
// Check kernel function
if (params->func == nullptr) {
std::cout << "ERROR: kernel function (CUfunction) is NULL" << std::endl;
return;
}
// Validate kernel function and get attributes
int maxThreadsPerBlock;
CUresult funcErr = cuFuncGetAttribute(
&maxThreadsPerBlock,
CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK,
params->func);
if (funcErr != CUDA_SUCCESS) {
const char* errStr;
cuGetErrorString(funcErr, &errStr);
std::cout << "ERROR: Invalid kernel function - " << errStr << std::endl;
return;
}
// Get more function attributes
int sharedSize, constSize, localSize, numRegs;
cuFuncGetAttribute(
&sharedSize, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, params->func);
cuFuncGetAttribute(
&constSize, CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES, params->func);
cuFuncGetAttribute(
&localSize, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, params->func);
cuFuncGetAttribute(&numRegs, CU_FUNC_ATTRIBUTE_NUM_REGS, params->func);
std::cout << "Kernel function attributes:" << std::endl;
std::cout << " Max threads per block: " << maxThreadsPerBlock << std::endl;
std::cout << " Shared memory: " << sharedSize << " bytes" << std::endl;
std::cout << " Const memory: " << constSize << " bytes" << std::endl;
std::cout << " Local memory: " << localSize << " bytes" << std::endl;
std::cout << " Num regs: " << numRegs << std::endl;
// Check dimensions
std::cout << "\nGrid dimensions: (" << params->gridDimX << ", "
<< params->gridDimY << ", " << params->gridDimZ << ")" << std::endl;
std::cout << "Block dimensions: (" << params->blockDimX << ", "
<< params->blockDimY << ", " << params->blockDimZ << ")"
<< std::endl;
if (params->gridDimX * params->gridDimY * params->gridDimZ == 0) {
std::cout << "ERROR: Grid dimension contains zero!" << std::endl;
return;
}
if (params->blockDimX * params->blockDimY * params->blockDimZ == 0) {
std::cout << "ERROR: Block dimension contains zero!" << std::endl;
return;
}
// Get current device and check limits
CUdevice device;
cuCtxGetDevice(&device);
int maxGridX, maxGridY, maxGridZ;
int maxBlockX, maxBlockY, maxBlockZ;
int maxThreadsPerBlockDevice;
int maxSharedMemPerBlock;
cuDeviceGetAttribute(&maxGridX, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, device);
cuDeviceGetAttribute(&maxGridY, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, device);
cuDeviceGetAttribute(&maxGridZ, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, device);
cuDeviceGetAttribute(&maxBlockX, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, device);
cuDeviceGetAttribute(&maxBlockY, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, device);
cuDeviceGetAttribute(&maxBlockZ, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z, device);
cuDeviceGetAttribute(
&maxThreadsPerBlockDevice,
CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK,
device);
cuDeviceGetAttribute(
&maxSharedMemPerBlock,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK,
device);
std::cout << "\nDevice limits check:" << std::endl;
std::cout << " Max grid size: (" << maxGridX << ", " << maxGridY << ", "
<< maxGridZ << ")" << std::endl;
std::cout << " Max block size: (" << maxBlockX << ", " << maxBlockY << ", "
<< maxBlockZ << ")" << std::endl;
std::cout << " Max threads per block: " << maxThreadsPerBlockDevice
<< std::endl;
// Check if dimensions exceed limits
if (params->gridDimX > (unsigned)maxGridX ||
params->gridDimY > (unsigned)maxGridY ||
params->gridDimZ > (unsigned)maxGridZ) {
std::cout << "ERROR: Grid dimensions exceed device limits!" << std::endl;
}
if (params->blockDimX > (unsigned)maxBlockX ||
params->blockDimY > (unsigned)maxBlockY ||
params->blockDimZ > (unsigned)maxBlockZ) {
std::cout << "ERROR: Block dimensions exceed device limits!" << std::endl;
}
unsigned totalThreadsPerBlock =
params->blockDimX * params->blockDimY * params->blockDimZ;
if (totalThreadsPerBlock > (unsigned)maxThreadsPerBlockDevice) {
std::cout << "ERROR: Total threads per block (" << totalThreadsPerBlock
<< ") exceeds limit (" << maxThreadsPerBlockDevice << ")"
<< std::endl;
}
// Check shared memory
std::cout << "\nShared memory requested: " << params->sharedMemBytes
<< " bytes" << std::endl;
std::cout << "Max shared memory per block: " << maxSharedMemPerBlock
<< " bytes" << std::endl;
if (params->sharedMemBytes > (unsigned)maxSharedMemPerBlock) {
std::cout << "ERROR: Requested shared memory exceeds limit!" << std::endl;
}
// Check kernel parameters
std::cout << "\nKernel parameters:" << std::endl;
std::cout << " kernelParams pointer: " << std::hex << params->kernelParams
<< std::dec << std::endl;
std::cout << " extra pointer: " << std::hex << params->extra << std::dec
<< std::endl;
if (params->kernelParams == nullptr && params->extra == nullptr) {
std::cout
<< "WARNING: Both kernelParams and extra are NULL (no arguments to kernel)"
<< std::endl;
}
// If using kernelParams, try to print the array
if (params->kernelParams != nullptr) {
std::cout << " Kernel parameter pointers:" << std::endl;
for (int i = 0; i < 10; i++) { // Check first 10 slots (adjust as needed)
if (params->kernelParams[i] == nullptr) {
std::cout << " [" << i << "]: NULL (end of params)" << std::endl;
break;
}
std::cout << " [" << i << "]: " << std::hex << params->kernelParams[i]
<< std::dec << std::endl;
}
}
// Try to add the node
std::cout << "\nAttempting to add kernel node..." << std::endl;
CUresult err = cuGraphAddKernelNode(node, graph, NULL, 0, params);
if (err != CUDA_SUCCESS) {
const char* errStr;
cuGetErrorString(err, &errStr);
std::cout << "ERROR: " << errStr << " (code: " << err << ")" << std::endl;
// Additional hints based on error code
if (err == CUDA_ERROR_INVALID_VALUE) {
std::cout << "\nHints for 'invalid argument':" << std::endl;
std::cout
<< " - Check if CUDA_KERNEL_NODE_PARAMS struct is properly initialized"
<< std::endl;
std::cout << " - Verify CUfunction handle is valid" << std::endl;
std::cout << " - Ensure grid/block dimensions are non-zero" << std::endl;
std::cout << " - Check kernelParams array is properly set up"
<< std::endl;
std::cout << " - Verify context is current" << std::endl;
}
} else {
std::cout << "SUCCESS: Kernel node added to graph!" << std::endl;
}
std::cout << "=== End Debug ===" << std::endl;
}
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
CUgraphNode node;
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, &params));
// CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, &params));
debugCuGraphAddKernelNode(&node, graph_, &params);
insert_graph_dependencies(GraphNode{node, 'K'});
}

View File

@ -316,6 +316,31 @@ JitModule::JitModule(
}
}
JitModule::JitModule(
Device& device,
const std::string& module_name,
const std::string& ptx,
const std::vector<std::string>& kernel_names) {
// Load module.
char jit_log[4089] = {};
CUjit_option options[] = {
CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
void* values[] = {jit_log, reinterpret_cast<void*>(std::size(jit_log) - 1)};
CUresult jit_result = cuModuleLoadDataEx(
&module_, ptx.c_str(), std::size(options), options, values);
if (jit_result != CUDA_SUCCESS) {
throw std::runtime_error(fmt::format(
"Failed to load compiled {} kernel: {}.", module_name, jit_log));
}
// Load kernels.
for (const auto& name : kernel_names) {
CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, name.c_str()));
kernels_[name] = kernel;
}
}
JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_));
}
@ -346,4 +371,18 @@ JitModule& get_jit_module(
return it->second;
}
JitModule& get_jit_module(
const mlx::core::Device& device,
const std::string& name,
const std::string& ptx,
const std::vector<std::string>& kernel_names) {
auto& map = get_jit_module_cache();
auto it = map.find(name);
if (it == map.end()) {
it = map.try_emplace(name, cu::device(device), name, ptx, kernel_names)
.first;
}
return it->second;
}
} // namespace mlx::core::cu

View File

@ -68,9 +68,11 @@ struct KernelArgs {
using Arg = std::variant<
std::monostate,
CUdeviceptr,
bool,
int32_t,
uint32_t,
int64_t,
float,
SmallVector<const void*>,
SmallVector<int32_t>,
SmallVector<int64_t>>;
@ -83,6 +85,11 @@ class JitModule {
Device& device,
const std::string& module_name,
const KernelBuilder& builder);
JitModule(
Device& device,
const std::string& module_name,
const std::string& ptx,
const std::vector<std::string>& kernel_names);
~JitModule();
JitModule(const JitModule&) = delete;
@ -101,4 +108,10 @@ JitModule& get_jit_module(
const std::string& name,
const KernelBuilder& builder);
JitModule& get_jit_module(
const mlx::core::Device& device,
const std::string& name,
const std::string& ptx,
const std::vector<std::string>& kernel_names);
} // namespace mlx::core::cu

View File

@ -316,7 +316,10 @@ MetalKernelFunction metal_kernel(
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value),
init_value,
std::vector<ScalarArg>{},
false,
0),
std::move(inputs));
};
}

View File

@ -66,9 +66,10 @@ array affine_dequantize(
int bits = 4,
StreamOrDevice s = {});
typedef std::variant<int, bool, Dtype> TemplateArg;
using TemplateArg = std::variant<int, bool, Dtype>;
using ScalarArg = std::variant<bool, int, float>;
typedef std::function<std::vector<array>(
using MetalKernelFunction = std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<Shape>&,
const std::vector<Dtype>&,
@ -77,8 +78,7 @@ typedef std::function<std::vector<array>(
std::vector<std::pair<std::string, TemplateArg>>,
std::optional<float>,
bool,
StreamOrDevice)>
MetalKernelFunction;
StreamOrDevice)>;
MetalKernelFunction metal_kernel(
const std::string& name,
@ -89,4 +89,18 @@ MetalKernelFunction metal_kernel(
bool ensure_row_contiguous = true,
bool atomic_outputs = false);
std::vector<array> precompiled_custom_kernel(
const std::string& name,
const std::string& compiled_source,
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
const std::vector<ScalarArg>& scalars,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
int shared_memory = 0,
std::optional<float> init_value = std::nullopt,
bool ensure_row_contiguous = false,
StreamOrDevice s = {});
} // namespace mlx::core::fast

View File

@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <optional>
#include <variant>
#include "mlx/primitives.h"
@ -283,6 +284,8 @@ struct CustomKernelShapeInfo {
bool ndim = false;
};
using ScalarArg = std::variant<bool, int, float>;
class CustomKernel : public Primitive {
public:
CustomKernel(
@ -293,7 +296,10 @@ class CustomKernel : public Primitive {
std::tuple<int, int, int> threadgroup,
std::vector<CustomKernelShapeInfo> shape_infos,
bool ensure_row_contiguous,
std::optional<float> init_value)
std::optional<float> init_value,
std::vector<ScalarArg> scalar_arguments,
bool is_precompiled,
int shared_memory)
: Primitive(stream),
source_(std::move(source)),
name_(std::move(name)),
@ -301,11 +307,14 @@ class CustomKernel : public Primitive {
threadgroup_(threadgroup),
shape_infos_(std::move(shape_infos)),
ensure_row_contiguous_(ensure_row_contiguous),
init_value_(init_value) {}
init_value_(init_value),
scalar_arguments_(scalar_arguments),
is_precompiled_(is_precompiled),
shared_memory_(shared_memory) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("Custom Metal kernels only run on GPU.");
throw std::runtime_error("Custom kernels only run on GPU.");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
@ -321,6 +330,9 @@ class CustomKernel : public Primitive {
std::vector<CustomKernelShapeInfo> shape_infos_;
bool ensure_row_contiguous_;
std::optional<float> init_value_;
std::vector<ScalarArg> scalar_arguments_;
bool is_precompiled_;
int shared_memory_;
};
} // namespace mlx::core::fast

View File

@ -384,4 +384,76 @@ void init_fast(nb::module_& parent_module) {
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
)pbdoc");
m.def(
"precompiled_custom_kernel",
[](const std::string& name,
const std::string& compiled_source,
const std::vector<ScalarOrArray>& inputs_,
const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes,
const std::vector<nb::object>& scalars_,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
int shared_memory,
std::optional<float> init_value = std::nullopt,
bool ensure_row_contiguous = false,
mx::StreamOrDevice s = {}) {
// Collect the inputs and cast them to array
std::vector<mx::array> inputs;
for (const auto& value : inputs_) {
inputs.push_back(to_array(value, std::nullopt));
}
// Collect the scalar inputs
std::vector<mx::fast::ScalarArg> scalars;
scalars.reserve(scalars_.size());
for (const auto& v : scalars_) {
if (nb::isinstance<bool>(v)) {
scalars.push_back(nb::cast<bool>(v));
} else if (nb::isinstance<int>(v)) {
scalars.push_back(nb::cast<int>(v));
} else if (nb::isinstance<float>(v)) {
scalars.push_back(nb::cast<float>(v));
} else {
nb::object vtype = v.attr("__class__");
std::string vtype_name =
nb::cast<std::string>(vtype.attr("__name__"));
std::ostringstream msg;
msg << "[precompiled_custom_kernel] Invalid scalar argument type. "
<< "Received " << vtype_name
<< " but must be one of bool, int or float";
throw std::invalid_argument(msg.str());
}
}
return mx::fast::precompiled_custom_kernel(
name,
compiled_source,
inputs,
output_shapes,
output_dtypes,
scalars,
grid,
threadgroup,
shared_memory,
init_value,
ensure_row_contiguous,
s);
},
nb::kw_only(),
"name"_a,
"compiled_source"_a,
"inputs"_a,
"output_shapes"_a,
"output_dtypes"_a,
"scalars"_a,
"grid"_a,
"threadgroup"_a,
"shared_memory"_a = 0,
"init_value"_a = nb::none(),
"ensure_row_contiguous"_a = false,
"stream"_a = nb::none(),
R"pbdoc(
)pbdoc");
}