mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 07:58:41 +08:00
tmp
This commit is contained in:
parent
055c1ca929
commit
3938aaaf24
@ -378,6 +378,7 @@ void CustomKernel::eval_gpu(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
|
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
|
||||||
|
encoder.add_completed_handler([args = std::move(args)]() {});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
@ -251,201 +251,9 @@ void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
|||||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||||
}
|
}
|
||||||
|
|
||||||
void debugCuGraphAddKernelNode(
|
|
||||||
CUgraphNode* node,
|
|
||||||
cudaGraph_t graph,
|
|
||||||
const CUDA_KERNEL_NODE_PARAMS* params) {
|
|
||||||
std::cout << "=== Debugging cuGraphAddKernelNode ===" << std::endl;
|
|
||||||
|
|
||||||
// Check graph
|
|
||||||
if (graph == nullptr) {
|
|
||||||
std::cout << "ERROR: graph is NULL" << std::endl;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check params structure
|
|
||||||
if (params == nullptr) {
|
|
||||||
std::cout << "ERROR: params is NULL" << std::endl;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check kernel function
|
|
||||||
if (params->func == nullptr) {
|
|
||||||
std::cout << "ERROR: kernel function (CUfunction) is NULL" << std::endl;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate kernel function and get attributes
|
|
||||||
int maxThreadsPerBlock;
|
|
||||||
CUresult funcErr = cuFuncGetAttribute(
|
|
||||||
&maxThreadsPerBlock,
|
|
||||||
CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK,
|
|
||||||
params->func);
|
|
||||||
if (funcErr != CUDA_SUCCESS) {
|
|
||||||
const char* errStr;
|
|
||||||
cuGetErrorString(funcErr, &errStr);
|
|
||||||
std::cout << "ERROR: Invalid kernel function - " << errStr << std::endl;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get more function attributes
|
|
||||||
int sharedSize, constSize, localSize, numRegs;
|
|
||||||
cuFuncGetAttribute(
|
|
||||||
&sharedSize, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, params->func);
|
|
||||||
cuFuncGetAttribute(
|
|
||||||
&constSize, CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES, params->func);
|
|
||||||
cuFuncGetAttribute(
|
|
||||||
&localSize, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, params->func);
|
|
||||||
cuFuncGetAttribute(&numRegs, CU_FUNC_ATTRIBUTE_NUM_REGS, params->func);
|
|
||||||
|
|
||||||
std::cout << "Kernel function attributes:" << std::endl;
|
|
||||||
std::cout << " Max threads per block: " << maxThreadsPerBlock << std::endl;
|
|
||||||
std::cout << " Shared memory: " << sharedSize << " bytes" << std::endl;
|
|
||||||
std::cout << " Const memory: " << constSize << " bytes" << std::endl;
|
|
||||||
std::cout << " Local memory: " << localSize << " bytes" << std::endl;
|
|
||||||
std::cout << " Num regs: " << numRegs << std::endl;
|
|
||||||
|
|
||||||
// Check dimensions
|
|
||||||
std::cout << "\nGrid dimensions: (" << params->gridDimX << ", "
|
|
||||||
<< params->gridDimY << ", " << params->gridDimZ << ")" << std::endl;
|
|
||||||
|
|
||||||
std::cout << "Block dimensions: (" << params->blockDimX << ", "
|
|
||||||
<< params->blockDimY << ", " << params->blockDimZ << ")"
|
|
||||||
<< std::endl;
|
|
||||||
|
|
||||||
if (params->gridDimX * params->gridDimY * params->gridDimZ == 0) {
|
|
||||||
std::cout << "ERROR: Grid dimension contains zero!" << std::endl;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params->blockDimX * params->blockDimY * params->blockDimZ == 0) {
|
|
||||||
std::cout << "ERROR: Block dimension contains zero!" << std::endl;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get current device and check limits
|
|
||||||
CUdevice device;
|
|
||||||
cuCtxGetDevice(&device);
|
|
||||||
|
|
||||||
int maxGridX, maxGridY, maxGridZ;
|
|
||||||
int maxBlockX, maxBlockY, maxBlockZ;
|
|
||||||
int maxThreadsPerBlockDevice;
|
|
||||||
int maxSharedMemPerBlock;
|
|
||||||
|
|
||||||
cuDeviceGetAttribute(&maxGridX, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, device);
|
|
||||||
cuDeviceGetAttribute(&maxGridY, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, device);
|
|
||||||
cuDeviceGetAttribute(&maxGridZ, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, device);
|
|
||||||
cuDeviceGetAttribute(&maxBlockX, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, device);
|
|
||||||
cuDeviceGetAttribute(&maxBlockY, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, device);
|
|
||||||
cuDeviceGetAttribute(&maxBlockZ, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z, device);
|
|
||||||
cuDeviceGetAttribute(
|
|
||||||
&maxThreadsPerBlockDevice,
|
|
||||||
CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK,
|
|
||||||
device);
|
|
||||||
cuDeviceGetAttribute(
|
|
||||||
&maxSharedMemPerBlock,
|
|
||||||
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK,
|
|
||||||
device);
|
|
||||||
|
|
||||||
std::cout << "\nDevice limits check:" << std::endl;
|
|
||||||
std::cout << " Max grid size: (" << maxGridX << ", " << maxGridY << ", "
|
|
||||||
<< maxGridZ << ")" << std::endl;
|
|
||||||
|
|
||||||
std::cout << " Max block size: (" << maxBlockX << ", " << maxBlockY << ", "
|
|
||||||
<< maxBlockZ << ")" << std::endl;
|
|
||||||
|
|
||||||
std::cout << " Max threads per block: " << maxThreadsPerBlockDevice
|
|
||||||
<< std::endl;
|
|
||||||
|
|
||||||
// Check if dimensions exceed limits
|
|
||||||
if (params->gridDimX > (unsigned)maxGridX ||
|
|
||||||
params->gridDimY > (unsigned)maxGridY ||
|
|
||||||
params->gridDimZ > (unsigned)maxGridZ) {
|
|
||||||
std::cout << "ERROR: Grid dimensions exceed device limits!" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params->blockDimX > (unsigned)maxBlockX ||
|
|
||||||
params->blockDimY > (unsigned)maxBlockY ||
|
|
||||||
params->blockDimZ > (unsigned)maxBlockZ) {
|
|
||||||
std::cout << "ERROR: Block dimensions exceed device limits!" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned totalThreadsPerBlock =
|
|
||||||
params->blockDimX * params->blockDimY * params->blockDimZ;
|
|
||||||
if (totalThreadsPerBlock > (unsigned)maxThreadsPerBlockDevice) {
|
|
||||||
std::cout << "ERROR: Total threads per block (" << totalThreadsPerBlock
|
|
||||||
<< ") exceeds limit (" << maxThreadsPerBlockDevice << ")"
|
|
||||||
<< std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check shared memory
|
|
||||||
std::cout << "\nShared memory requested: " << params->sharedMemBytes
|
|
||||||
<< " bytes" << std::endl;
|
|
||||||
std::cout << "Max shared memory per block: " << maxSharedMemPerBlock
|
|
||||||
<< " bytes" << std::endl;
|
|
||||||
|
|
||||||
if (params->sharedMemBytes > (unsigned)maxSharedMemPerBlock) {
|
|
||||||
std::cout << "ERROR: Requested shared memory exceeds limit!" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check kernel parameters
|
|
||||||
std::cout << "\nKernel parameters:" << std::endl;
|
|
||||||
std::cout << " kernelParams pointer: " << std::hex << params->kernelParams
|
|
||||||
<< std::dec << std::endl;
|
|
||||||
std::cout << " extra pointer: " << std::hex << params->extra << std::dec
|
|
||||||
<< std::endl;
|
|
||||||
|
|
||||||
if (params->kernelParams == nullptr && params->extra == nullptr) {
|
|
||||||
std::cout
|
|
||||||
<< "WARNING: Both kernelParams and extra are NULL (no arguments to kernel)"
|
|
||||||
<< std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If using kernelParams, try to print the array
|
|
||||||
if (params->kernelParams != nullptr) {
|
|
||||||
std::cout << " Kernel parameter pointers:" << std::endl;
|
|
||||||
for (int i = 0; i < 10; i++) { // Check first 10 slots (adjust as needed)
|
|
||||||
if (params->kernelParams[i] == nullptr) {
|
|
||||||
std::cout << " [" << i << "]: NULL (end of params)" << std::endl;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
std::cout << " [" << i << "]: " << std::hex << params->kernelParams[i]
|
|
||||||
<< std::dec << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to add the node
|
|
||||||
std::cout << "\nAttempting to add kernel node..." << std::endl;
|
|
||||||
CUresult err = cuGraphAddKernelNode(node, graph, NULL, 0, params);
|
|
||||||
|
|
||||||
if (err != CUDA_SUCCESS) {
|
|
||||||
const char* errStr;
|
|
||||||
cuGetErrorString(err, &errStr);
|
|
||||||
std::cout << "ERROR: " << errStr << " (code: " << err << ")" << std::endl;
|
|
||||||
|
|
||||||
// Additional hints based on error code
|
|
||||||
if (err == CUDA_ERROR_INVALID_VALUE) {
|
|
||||||
std::cout << "\nHints for 'invalid argument':" << std::endl;
|
|
||||||
std::cout
|
|
||||||
<< " - Check if CUDA_KERNEL_NODE_PARAMS struct is properly initialized"
|
|
||||||
<< std::endl;
|
|
||||||
std::cout << " - Verify CUfunction handle is valid" << std::endl;
|
|
||||||
std::cout << " - Ensure grid/block dimensions are non-zero" << std::endl;
|
|
||||||
std::cout << " - Check kernelParams array is properly set up"
|
|
||||||
<< std::endl;
|
|
||||||
std::cout << " - Verify context is current" << std::endl;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
std::cout << "SUCCESS: Kernel node added to graph!" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << "=== End Debug ===" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||||
CUgraphNode node;
|
CUgraphNode node;
|
||||||
// CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||||
debugCuGraphAddKernelNode(&node, graph_, ¶ms);
|
|
||||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ struct KernelArgs {
|
|||||||
std::vector<void*> args_;
|
std::vector<void*> args_;
|
||||||
|
|
||||||
// The cuLaunchKernel API requires passing pointers to arguments so store
|
// The cuLaunchKernel API requires passing pointers to arguments so store
|
||||||
// temporary values untill kernel is launched.
|
// temporary values until the kernel is launched.
|
||||||
using Arg = std::variant<
|
using Arg = std::variant<
|
||||||
std::monostate,
|
std::monostate,
|
||||||
CUdeviceptr,
|
CUdeviceptr,
|
||||||
|
Loading…
Reference in New Issue
Block a user