diff --git a/mlx/backend/cuda/fence.cpp b/mlx/backend/cuda/fence.cpp index a4aedeb74..f11ff7e9e 100644 --- a/mlx/backend/cuda/fence.cpp +++ b/mlx/backend/cuda/fence.cpp @@ -32,9 +32,10 @@ void Fence::update(Stream s, const array& a, bool cross_device) { void* new_data; CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size)); cbuf.device = -1; + auto& encoder = cu::device(s.device).get_command_encoder(s); + encoder.commit(); CHECK_CUDA_ERROR( cudaMemcpyAsync(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault)); - auto& encoder = cu::device(s.device).get_command_encoder(s); CHECK_CUDA_ERROR(cudaFreeAsync(cbuf.data, encoder.stream())); cbuf.data = new_data; }