From 3b94e372704fc42bdcd4e69daec0d114d772e1ff Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 12 Aug 2025 14:30:29 -0700 Subject: [PATCH] Working custom kernels jointly --- mlx/backend/cuda/custom_kernel.cpp | 75 ++++++++++- mlx/backend/cuda/device.cpp | 196 +++++++++++++++++++++++++++- mlx/backend/cuda/jit_module.cpp | 39 ++++++ mlx/backend/cuda/jit_module.h | 13 ++ mlx/backend/metal/custom_kernel.cpp | 5 +- mlx/fast.h | 22 +++- mlx/fast_primitives.h | 18 ++- python/src/fast.cpp | 72 ++++++++++ 8 files changed, 425 insertions(+), 15 deletions(-) diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index 3aeebbe9a..560429b4a 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -234,11 +234,47 @@ MetalKernelFunction metal_kernel( threadgroup, shape_infos, ensure_row_contiguous, - init_value), + init_value, + std::vector{}, + false, + 0), std::move(inputs)); }; } +std::vector precompiled_custom_kernel( + const std::string& name, + const std::string& compiled_source, + const std::vector& inputs, + const std::vector& output_shapes, + const std::vector& output_dtypes, + const std::vector& scalars, + std::tuple grid, + std::tuple threadgroup, + int shared_memory, + std::optional init_value, + bool ensure_row_contiguous, + StreamOrDevice s) { + std::vector shape_infos( + inputs.size(), CustomKernelShapeInfo{false, false, false}); + return array::make_arrays( + output_shapes, + output_dtypes, + std::make_shared( + to_stream(s), + name, + compiled_source, + grid, + threadgroup, + shape_infos, + ensure_row_contiguous, + init_value, + scalars, + true, + shared_memory), + inputs); +} + void CustomKernel::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -274,10 +310,19 @@ void CustomKernel::eval_gpu( } // Compile the custom kernel - std::string kernel_name = "mlx::core::cu::" + name_; - cu::JitModule& mod = cu::get_jit_module(s.device, name_, [&]() { - return std::make_pair(source_, std::vector{kernel_name}); - }); + std::string kernel_name = + (is_precompiled_) ? name_ : "mlx::core::cu::" + name_; + auto get_module = [&]() -> cu::JitModule& { + if (is_precompiled_) { + return cu::get_jit_module( + s.device, name_, source_, std::vector{kernel_name}); + } else { + return cu::get_jit_module(s.device, name_, [&]() { + return std::make_pair(source_, std::vector{kernel_name}); + }); + } + }; + cu::JitModule& mod = get_module(); // Make the arguments cu::KernelArgs args; @@ -298,6 +343,15 @@ void CustomKernel::eval_gpu( for (auto& out : outputs) { args.append(out); } + for (auto& s : scalar_arguments_) { + if (std::holds_alternative(s)) { + args.append(std::get(s)); + } else if (std::holds_alternative(s)) { + args.append(std::get(s)); + } else if (std::holds_alternative(s)) { + args.append(std::get(s)); + } + } // Make the grid const auto [tx, ty, tz] = threadgroup_; @@ -313,8 +367,17 @@ void CustomKernel::eval_gpu( for (const auto& out : outputs) { encoder.set_output_array(out); } + for (const auto& t : copies) { + encoder.add_temporary(t); + } auto kernel = mod.get_kernel(kernel_name); - encoder.add_kernel_node(kernel, grid, block, 0, args.args()); + if (shared_memory_ > 0 && shared_memory_ > 48000) { + cuFuncSetAttribute( + kernel, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_memory_); + } + encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args()); } } // namespace mlx::core::fast diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 371ae020c..08c6bb7ae 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -10,6 +10,8 @@ #include #include +#include + namespace mlx::core::cu { namespace { @@ -249,9 +251,201 @@ void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) { insert_graph_dependencies(GraphNode{node, 'K'}); } +void debugCuGraphAddKernelNode( + CUgraphNode* node, + cudaGraph_t graph, + const CUDA_KERNEL_NODE_PARAMS* params) { + std::cout << "=== Debugging cuGraphAddKernelNode ===" << std::endl; + + // Check graph + if (graph == nullptr) { + std::cout << "ERROR: graph is NULL" << std::endl; + return; + } + + // Check params structure + if (params == nullptr) { + std::cout << "ERROR: params is NULL" << std::endl; + return; + } + + // Check kernel function + if (params->func == nullptr) { + std::cout << "ERROR: kernel function (CUfunction) is NULL" << std::endl; + return; + } + + // Validate kernel function and get attributes + int maxThreadsPerBlock; + CUresult funcErr = cuFuncGetAttribute( + &maxThreadsPerBlock, + CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, + params->func); + if (funcErr != CUDA_SUCCESS) { + const char* errStr; + cuGetErrorString(funcErr, &errStr); + std::cout << "ERROR: Invalid kernel function - " << errStr << std::endl; + return; + } + + // Get more function attributes + int sharedSize, constSize, localSize, numRegs; + cuFuncGetAttribute( + &sharedSize, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, params->func); + cuFuncGetAttribute( + &constSize, CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES, params->func); + cuFuncGetAttribute( + &localSize, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, params->func); + cuFuncGetAttribute(&numRegs, CU_FUNC_ATTRIBUTE_NUM_REGS, params->func); + + std::cout << "Kernel function attributes:" << std::endl; + std::cout << " Max threads per block: " << maxThreadsPerBlock << std::endl; + std::cout << " Shared memory: " << sharedSize << " bytes" << std::endl; + std::cout << " Const memory: " << constSize << " bytes" << std::endl; + std::cout << " Local memory: " << localSize << " bytes" << std::endl; + std::cout << " Num regs: " << numRegs << std::endl; + + // Check dimensions + std::cout << "\nGrid dimensions: (" << params->gridDimX << ", " + << params->gridDimY << ", " << params->gridDimZ << ")" << std::endl; + + std::cout << "Block dimensions: (" << params->blockDimX << ", " + << params->blockDimY << ", " << params->blockDimZ << ")" + << std::endl; + + if (params->gridDimX * params->gridDimY * params->gridDimZ == 0) { + std::cout << "ERROR: Grid dimension contains zero!" << std::endl; + return; + } + + if (params->blockDimX * params->blockDimY * params->blockDimZ == 0) { + std::cout << "ERROR: Block dimension contains zero!" << std::endl; + return; + } + + // Get current device and check limits + CUdevice device; + cuCtxGetDevice(&device); + + int maxGridX, maxGridY, maxGridZ; + int maxBlockX, maxBlockY, maxBlockZ; + int maxThreadsPerBlockDevice; + int maxSharedMemPerBlock; + + cuDeviceGetAttribute(&maxGridX, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, device); + cuDeviceGetAttribute(&maxGridY, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, device); + cuDeviceGetAttribute(&maxGridZ, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, device); + cuDeviceGetAttribute(&maxBlockX, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, device); + cuDeviceGetAttribute(&maxBlockY, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, device); + cuDeviceGetAttribute(&maxBlockZ, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z, device); + cuDeviceGetAttribute( + &maxThreadsPerBlockDevice, + CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK, + device); + cuDeviceGetAttribute( + &maxSharedMemPerBlock, + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK, + device); + + std::cout << "\nDevice limits check:" << std::endl; + std::cout << " Max grid size: (" << maxGridX << ", " << maxGridY << ", " + << maxGridZ << ")" << std::endl; + + std::cout << " Max block size: (" << maxBlockX << ", " << maxBlockY << ", " + << maxBlockZ << ")" << std::endl; + + std::cout << " Max threads per block: " << maxThreadsPerBlockDevice + << std::endl; + + // Check if dimensions exceed limits + if (params->gridDimX > (unsigned)maxGridX || + params->gridDimY > (unsigned)maxGridY || + params->gridDimZ > (unsigned)maxGridZ) { + std::cout << "ERROR: Grid dimensions exceed device limits!" << std::endl; + } + + if (params->blockDimX > (unsigned)maxBlockX || + params->blockDimY > (unsigned)maxBlockY || + params->blockDimZ > (unsigned)maxBlockZ) { + std::cout << "ERROR: Block dimensions exceed device limits!" << std::endl; + } + + unsigned totalThreadsPerBlock = + params->blockDimX * params->blockDimY * params->blockDimZ; + if (totalThreadsPerBlock > (unsigned)maxThreadsPerBlockDevice) { + std::cout << "ERROR: Total threads per block (" << totalThreadsPerBlock + << ") exceeds limit (" << maxThreadsPerBlockDevice << ")" + << std::endl; + } + + // Check shared memory + std::cout << "\nShared memory requested: " << params->sharedMemBytes + << " bytes" << std::endl; + std::cout << "Max shared memory per block: " << maxSharedMemPerBlock + << " bytes" << std::endl; + + if (params->sharedMemBytes > (unsigned)maxSharedMemPerBlock) { + std::cout << "ERROR: Requested shared memory exceeds limit!" << std::endl; + } + + // Check kernel parameters + std::cout << "\nKernel parameters:" << std::endl; + std::cout << " kernelParams pointer: " << std::hex << params->kernelParams + << std::dec << std::endl; + std::cout << " extra pointer: " << std::hex << params->extra << std::dec + << std::endl; + + if (params->kernelParams == nullptr && params->extra == nullptr) { + std::cout + << "WARNING: Both kernelParams and extra are NULL (no arguments to kernel)" + << std::endl; + } + + // If using kernelParams, try to print the array + if (params->kernelParams != nullptr) { + std::cout << " Kernel parameter pointers:" << std::endl; + for (int i = 0; i < 10; i++) { // Check first 10 slots (adjust as needed) + if (params->kernelParams[i] == nullptr) { + std::cout << " [" << i << "]: NULL (end of params)" << std::endl; + break; + } + std::cout << " [" << i << "]: " << std::hex << params->kernelParams[i] + << std::dec << std::endl; + } + } + + // Try to add the node + std::cout << "\nAttempting to add kernel node..." << std::endl; + CUresult err = cuGraphAddKernelNode(node, graph, NULL, 0, params); + + if (err != CUDA_SUCCESS) { + const char* errStr; + cuGetErrorString(err, &errStr); + std::cout << "ERROR: " << errStr << " (code: " << err << ")" << std::endl; + + // Additional hints based on error code + if (err == CUDA_ERROR_INVALID_VALUE) { + std::cout << "\nHints for 'invalid argument':" << std::endl; + std::cout + << " - Check if CUDA_KERNEL_NODE_PARAMS struct is properly initialized" + << std::endl; + std::cout << " - Verify CUfunction handle is valid" << std::endl; + std::cout << " - Ensure grid/block dimensions are non-zero" << std::endl; + std::cout << " - Check kernelParams array is properly set up" + << std::endl; + std::cout << " - Verify context is current" << std::endl; + } + } else { + std::cout << "SUCCESS: Kernel node added to graph!" << std::endl; + } + + std::cout << "=== End Debug ===" << std::endl; +} + void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) { CUgraphNode node; - CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms)); + // CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms)); + debugCuGraphAddKernelNode(&node, graph_, ¶ms); insert_graph_dependencies(GraphNode{node, 'K'}); } diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 25db207e3..9805b7cf6 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -316,6 +316,31 @@ JitModule::JitModule( } } +JitModule::JitModule( + Device& device, + const std::string& module_name, + const std::string& ptx, + const std::vector& kernel_names) { + // Load module. + char jit_log[4089] = {}; + CUjit_option options[] = { + CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES}; + void* values[] = {jit_log, reinterpret_cast(std::size(jit_log) - 1)}; + CUresult jit_result = cuModuleLoadDataEx( + &module_, ptx.c_str(), std::size(options), options, values); + if (jit_result != CUDA_SUCCESS) { + throw std::runtime_error(fmt::format( + "Failed to load compiled {} kernel: {}.", module_name, jit_log)); + } + + // Load kernels. + for (const auto& name : kernel_names) { + CUfunction kernel; + CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, name.c_str())); + kernels_[name] = kernel; + } +} + JitModule::~JitModule() { CHECK_CUDA_ERROR(cuModuleUnload(module_)); } @@ -346,4 +371,18 @@ JitModule& get_jit_module( return it->second; } +JitModule& get_jit_module( + const mlx::core::Device& device, + const std::string& name, + const std::string& ptx, + const std::vector& kernel_names) { + auto& map = get_jit_module_cache(); + auto it = map.find(name); + if (it == map.end()) { + it = map.try_emplace(name, cu::device(device), name, ptx, kernel_names) + .first; + } + return it->second; +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index df3f58352..1de77bf26 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -68,9 +68,11 @@ struct KernelArgs { using Arg = std::variant< std::monostate, CUdeviceptr, + bool, int32_t, uint32_t, int64_t, + float, SmallVector, SmallVector, SmallVector>; @@ -83,6 +85,11 @@ class JitModule { Device& device, const std::string& module_name, const KernelBuilder& builder); + JitModule( + Device& device, + const std::string& module_name, + const std::string& ptx, + const std::vector& kernel_names); ~JitModule(); JitModule(const JitModule&) = delete; @@ -101,4 +108,10 @@ JitModule& get_jit_module( const std::string& name, const KernelBuilder& builder); +JitModule& get_jit_module( + const mlx::core::Device& device, + const std::string& name, + const std::string& ptx, + const std::vector& kernel_names); + } // namespace mlx::core::cu diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 161503a0e..e40d853fc 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -316,7 +316,10 @@ MetalKernelFunction metal_kernel( threadgroup, shape_infos, ensure_row_contiguous, - init_value), + init_value, + std::vector{}, + false, + 0), std::move(inputs)); }; } diff --git a/mlx/fast.h b/mlx/fast.h index 7aebe3863..05a185ecc 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -66,9 +66,10 @@ array affine_dequantize( int bits = 4, StreamOrDevice s = {}); -typedef std::variant TemplateArg; +using TemplateArg = std::variant; +using ScalarArg = std::variant; -typedef std::function( +using MetalKernelFunction = std::function( const std::vector&, const std::vector&, const std::vector&, @@ -77,8 +78,7 @@ typedef std::function( std::vector>, std::optional, bool, - StreamOrDevice)> - MetalKernelFunction; + StreamOrDevice)>; MetalKernelFunction metal_kernel( const std::string& name, @@ -89,4 +89,18 @@ MetalKernelFunction metal_kernel( bool ensure_row_contiguous = true, bool atomic_outputs = false); +std::vector precompiled_custom_kernel( + const std::string& name, + const std::string& compiled_source, + const std::vector& inputs, + const std::vector& output_shapes, + const std::vector& output_dtypes, + const std::vector& scalars, + std::tuple grid, + std::tuple threadgroup, + int shared_memory = 0, + std::optional init_value = std::nullopt, + bool ensure_row_contiguous = false, + StreamOrDevice s = {}); + } // namespace mlx::core::fast diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 52135adad..98730104d 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include +#include #include "mlx/primitives.h" @@ -283,6 +284,8 @@ struct CustomKernelShapeInfo { bool ndim = false; }; +using ScalarArg = std::variant; + class CustomKernel : public Primitive { public: CustomKernel( @@ -293,7 +296,10 @@ class CustomKernel : public Primitive { std::tuple threadgroup, std::vector shape_infos, bool ensure_row_contiguous, - std::optional init_value) + std::optional init_value, + std::vector scalar_arguments, + bool is_precompiled, + int shared_memory) : Primitive(stream), source_(std::move(source)), name_(std::move(name)), @@ -301,11 +307,14 @@ class CustomKernel : public Primitive { threadgroup_(threadgroup), shape_infos_(std::move(shape_infos)), ensure_row_contiguous_(ensure_row_contiguous), - init_value_(init_value) {} + init_value_(init_value), + scalar_arguments_(scalar_arguments), + is_precompiled_(is_precompiled), + shared_memory_(shared_memory) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { - throw std::runtime_error("Custom Metal kernels only run on GPU."); + throw std::runtime_error("Custom kernels only run on GPU."); } void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -321,6 +330,9 @@ class CustomKernel : public Primitive { std::vector shape_infos_; bool ensure_row_contiguous_; std::optional init_value_; + std::vector scalar_arguments_; + bool is_precompiled_; + int shared_memory_; }; } // namespace mlx::core::fast diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 3d0bc4147..b359f84bd 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -384,4 +384,76 @@ void init_fast(nb::module_& parent_module) { b = exp_elementwise(a) assert mx.allclose(b, mx.exp(a)) )pbdoc"); + + m.def( + "precompiled_custom_kernel", + [](const std::string& name, + const std::string& compiled_source, + const std::vector& inputs_, + const std::vector& output_shapes, + const std::vector& output_dtypes, + const std::vector& scalars_, + std::tuple grid, + std::tuple threadgroup, + int shared_memory, + std::optional init_value = std::nullopt, + bool ensure_row_contiguous = false, + mx::StreamOrDevice s = {}) { + // Collect the inputs and cast them to array + std::vector inputs; + for (const auto& value : inputs_) { + inputs.push_back(to_array(value, std::nullopt)); + } + + // Collect the scalar inputs + std::vector scalars; + scalars.reserve(scalars_.size()); + for (const auto& v : scalars_) { + if (nb::isinstance(v)) { + scalars.push_back(nb::cast(v)); + } else if (nb::isinstance(v)) { + scalars.push_back(nb::cast(v)); + } else if (nb::isinstance(v)) { + scalars.push_back(nb::cast(v)); + } else { + nb::object vtype = v.attr("__class__"); + std::string vtype_name = + nb::cast(vtype.attr("__name__")); + std::ostringstream msg; + msg << "[precompiled_custom_kernel] Invalid scalar argument type. " + << "Received " << vtype_name + << " but must be one of bool, int or float"; + throw std::invalid_argument(msg.str()); + } + } + + return mx::fast::precompiled_custom_kernel( + name, + compiled_source, + inputs, + output_shapes, + output_dtypes, + scalars, + grid, + threadgroup, + shared_memory, + init_value, + ensure_row_contiguous, + s); + }, + nb::kw_only(), + "name"_a, + "compiled_source"_a, + "inputs"_a, + "output_shapes"_a, + "output_dtypes"_a, + "scalars"_a, + "grid"_a, + "threadgroup"_a, + "shared_memory"_a = 0, + "init_value"_a = nb::none(), + "ensure_row_contiguous"_a = false, + "stream"_a = nb::none(), + R"pbdoc( + )pbdoc"); }