diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index d85e02e51..31e8a0b67 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -378,6 +378,7 @@ void CustomKernel::eval_gpu( } }); encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args()); + encoder.add_completed_handler([args = std::move(args)]() {}); } } // namespace mlx::core::fast diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 08c6bb7ae..4cf153774 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -251,201 +251,9 @@ void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) { insert_graph_dependencies(GraphNode{node, 'K'}); } -void debugCuGraphAddKernelNode( - CUgraphNode* node, - cudaGraph_t graph, - const CUDA_KERNEL_NODE_PARAMS* params) { - std::cout << "=== Debugging cuGraphAddKernelNode ===" << std::endl; - - // Check graph - if (graph == nullptr) { - std::cout << "ERROR: graph is NULL" << std::endl; - return; - } - - // Check params structure - if (params == nullptr) { - std::cout << "ERROR: params is NULL" << std::endl; - return; - } - - // Check kernel function - if (params->func == nullptr) { - std::cout << "ERROR: kernel function (CUfunction) is NULL" << std::endl; - return; - } - - // Validate kernel function and get attributes - int maxThreadsPerBlock; - CUresult funcErr = cuFuncGetAttribute( - &maxThreadsPerBlock, - CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, - params->func); - if (funcErr != CUDA_SUCCESS) { - const char* errStr; - cuGetErrorString(funcErr, &errStr); - std::cout << "ERROR: Invalid kernel function - " << errStr << std::endl; - return; - } - - // Get more function attributes - int sharedSize, constSize, localSize, numRegs; - cuFuncGetAttribute( - &sharedSize, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, params->func); - cuFuncGetAttribute( - &constSize, CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES, params->func); - cuFuncGetAttribute( - &localSize, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, params->func); - cuFuncGetAttribute(&numRegs, CU_FUNC_ATTRIBUTE_NUM_REGS, params->func); - - std::cout << "Kernel function attributes:" << std::endl; - std::cout << " Max threads per block: " << maxThreadsPerBlock << std::endl; - std::cout << " Shared memory: " << sharedSize << " bytes" << std::endl; - std::cout << " Const memory: " << constSize << " bytes" << std::endl; - std::cout << " Local memory: " << localSize << " bytes" << std::endl; - std::cout << " Num regs: " << numRegs << std::endl; - - // Check dimensions - std::cout << "\nGrid dimensions: (" << params->gridDimX << ", " - << params->gridDimY << ", " << params->gridDimZ << ")" << std::endl; - - std::cout << "Block dimensions: (" << params->blockDimX << ", " - << params->blockDimY << ", " << params->blockDimZ << ")" - << std::endl; - - if (params->gridDimX * params->gridDimY * params->gridDimZ == 0) { - std::cout << "ERROR: Grid dimension contains zero!" << std::endl; - return; - } - - if (params->blockDimX * params->blockDimY * params->blockDimZ == 0) { - std::cout << "ERROR: Block dimension contains zero!" << std::endl; - return; - } - - // Get current device and check limits - CUdevice device; - cuCtxGetDevice(&device); - - int maxGridX, maxGridY, maxGridZ; - int maxBlockX, maxBlockY, maxBlockZ; - int maxThreadsPerBlockDevice; - int maxSharedMemPerBlock; - - cuDeviceGetAttribute(&maxGridX, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, device); - cuDeviceGetAttribute(&maxGridY, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, device); - cuDeviceGetAttribute(&maxGridZ, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, device); - cuDeviceGetAttribute(&maxBlockX, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, device); - cuDeviceGetAttribute(&maxBlockY, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, device); - cuDeviceGetAttribute(&maxBlockZ, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z, device); - cuDeviceGetAttribute( - &maxThreadsPerBlockDevice, - CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK, - device); - cuDeviceGetAttribute( - &maxSharedMemPerBlock, - CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK, - device); - - std::cout << "\nDevice limits check:" << std::endl; - std::cout << " Max grid size: (" << maxGridX << ", " << maxGridY << ", " - << maxGridZ << ")" << std::endl; - - std::cout << " Max block size: (" << maxBlockX << ", " << maxBlockY << ", " - << maxBlockZ << ")" << std::endl; - - std::cout << " Max threads per block: " << maxThreadsPerBlockDevice - << std::endl; - - // Check if dimensions exceed limits - if (params->gridDimX > (unsigned)maxGridX || - params->gridDimY > (unsigned)maxGridY || - params->gridDimZ > (unsigned)maxGridZ) { - std::cout << "ERROR: Grid dimensions exceed device limits!" << std::endl; - } - - if (params->blockDimX > (unsigned)maxBlockX || - params->blockDimY > (unsigned)maxBlockY || - params->blockDimZ > (unsigned)maxBlockZ) { - std::cout << "ERROR: Block dimensions exceed device limits!" << std::endl; - } - - unsigned totalThreadsPerBlock = - params->blockDimX * params->blockDimY * params->blockDimZ; - if (totalThreadsPerBlock > (unsigned)maxThreadsPerBlockDevice) { - std::cout << "ERROR: Total threads per block (" << totalThreadsPerBlock - << ") exceeds limit (" << maxThreadsPerBlockDevice << ")" - << std::endl; - } - - // Check shared memory - std::cout << "\nShared memory requested: " << params->sharedMemBytes - << " bytes" << std::endl; - std::cout << "Max shared memory per block: " << maxSharedMemPerBlock - << " bytes" << std::endl; - - if (params->sharedMemBytes > (unsigned)maxSharedMemPerBlock) { - std::cout << "ERROR: Requested shared memory exceeds limit!" << std::endl; - } - - // Check kernel parameters - std::cout << "\nKernel parameters:" << std::endl; - std::cout << " kernelParams pointer: " << std::hex << params->kernelParams - << std::dec << std::endl; - std::cout << " extra pointer: " << std::hex << params->extra << std::dec - << std::endl; - - if (params->kernelParams == nullptr && params->extra == nullptr) { - std::cout - << "WARNING: Both kernelParams and extra are NULL (no arguments to kernel)" - << std::endl; - } - - // If using kernelParams, try to print the array - if (params->kernelParams != nullptr) { - std::cout << " Kernel parameter pointers:" << std::endl; - for (int i = 0; i < 10; i++) { // Check first 10 slots (adjust as needed) - if (params->kernelParams[i] == nullptr) { - std::cout << " [" << i << "]: NULL (end of params)" << std::endl; - break; - } - std::cout << " [" << i << "]: " << std::hex << params->kernelParams[i] - << std::dec << std::endl; - } - } - - // Try to add the node - std::cout << "\nAttempting to add kernel node..." << std::endl; - CUresult err = cuGraphAddKernelNode(node, graph, NULL, 0, params); - - if (err != CUDA_SUCCESS) { - const char* errStr; - cuGetErrorString(err, &errStr); - std::cout << "ERROR: " << errStr << " (code: " << err << ")" << std::endl; - - // Additional hints based on error code - if (err == CUDA_ERROR_INVALID_VALUE) { - std::cout << "\nHints for 'invalid argument':" << std::endl; - std::cout - << " - Check if CUDA_KERNEL_NODE_PARAMS struct is properly initialized" - << std::endl; - std::cout << " - Verify CUfunction handle is valid" << std::endl; - std::cout << " - Ensure grid/block dimensions are non-zero" << std::endl; - std::cout << " - Check kernelParams array is properly set up" - << std::endl; - std::cout << " - Verify context is current" << std::endl; - } - } else { - std::cout << "SUCCESS: Kernel node added to graph!" << std::endl; - } - - std::cout << "=== End Debug ===" << std::endl; -} - void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) { CUgraphNode node; - // CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms)); - debugCuGraphAddKernelNode(&node, graph_, ¶ms); + CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms)); insert_graph_dependencies(GraphNode{node, 'K'}); } diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index 3aa7ffa8a..1933bfccb 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -64,7 +64,7 @@ struct KernelArgs { std::vector args_; // The cuLaunchKernel API requires passing pointers to arguments so store - // temporary values untill kernel is launched. + // temporary values until the kernel is launched. using Arg = std::variant< std::monostate, CUdeviceptr,