From 9eb5fa764c53f682a05d42ea7e08c72ff929f823 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 2 May 2025 20:40:22 -0700 Subject: [PATCH] fix --- mlx/backend/metal/kernels/ternary.h | 2 +- mlx/backend/metal/ternary.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 8debf2267..5251dc7e9 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -5,8 +5,8 @@ template ::n> device const bool* a, device const T* b, device const T* c, - constant uint& size, device T* d, + constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; for (int i = 0; i < N && (index + i) < size; ++i) { diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index e81ae1562..0b821151e 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -113,10 +113,10 @@ void ternary_op_gpu_inplace( MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims; if (large) { - compute_encoder.set_bytes(out.data_size(), 2); + compute_encoder.set_bytes(out.data_size(), 4); grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); } else { - compute_encoder.set_bytes(out.data_size(), 2); + compute_encoder.set_bytes(out.data_size(), 4); grid_dims = MTL::Size(nthreads, 1, 1); } compute_encoder.dispatch_threads(grid_dims, group_dims);