diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index e2627c87b..06d7bf58c 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -75,7 +75,7 @@ void CustomKernel::eval_gpu( MTL::Size group_dims = MTL::Size(tx, ty, tz); const auto [gx, gy, gz] = grid_; MTL::Size grid_dims = MTL::Size(gx, gy, gz); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); } diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 43ded5378..12668eca7 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -738,7 +738,7 @@ void fft_op( auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft); auto grid_dims = MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } d.add_temporaries(std::move(copies), s.index); diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index dc2268f7d..83a17a7a3 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -144,7 +144,7 @@ void Hadamard::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size(1, threads_per, 1); MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); }; if (m > 1) { diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 54ec91a4c..19af2a850 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -139,7 +139,7 @@ void sdpa_full_self_attention_metal( MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out); MTL::Size group_dims = MTL::Size(32, wm, wn); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } void sdpa_vector(