diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index fff752fe5..4129563af 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -54,8 +54,8 @@ void Device::make_current() { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0)); - CHECK_CUDA_ERROR(cudaStreamBeginCaptureToGraph( - enc.stream(), graph, NULL, NULL, 0, cudaStreamCaptureModeGlobal)); + CHECK_CUDA_ERROR( + cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); } CommandEncoder::CaptureContext::~CaptureContext() {