diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index ba58b9aec..38bc73f7f 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -56,7 +56,8 @@ Device::Device(int device) : device_(device) { // The cudnn handle is used by Convolution. CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_)); - // Ensure the jit module cache is initialized + // Initialize the jit module cache here ensures it is not + // unloaded before any evaluation is done get_jit_module_cache(); }