From 9a3842a2d9e6889056a291759b7caff200b688af Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 6 Nov 2024 17:10:33 -0800 Subject: [PATCH] fix (#1566) --- mlx/backend/metal/custom_kernel.cpp | 2 +- mlx/backend/metal/fft.cpp | 2 +- mlx/backend/metal/hadamard.cpp | 2 +- mlx/backend/metal/scaled_dot_product_attention.cpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index e2627c87b5..06d7bf58ca 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -75,7 +75,7 @@ void CustomKernel::eval_gpu( MTL::Size group_dims = MTL::Size(tx, ty, tz); const auto [gx, gy, gz] = grid_; MTL::Size grid_dims = MTL::Size(gx, gy, gz); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); } diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 43ded53780..12668eca74 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -738,7 +738,7 @@ void fft_op( auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft); auto grid_dims = MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } d.add_temporaries(std::move(copies), s.index); diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index dc2268f7d3..83a17a7a31 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -144,7 +144,7 @@ void Hadamard::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size(1, threads_per, 1); MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); }; if (m > 1) { diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 54ec91a4cd..19af2a8508 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -139,7 +139,7 @@ void sdpa_full_self_attention_metal( MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out); MTL::Size group_dims = MTL::Size(32, wm, wn); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } void sdpa_vector(