diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index b45b4bd96..1652207e3 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -648,14 +648,14 @@ METAL_FUNC void qmv_fast_impl( // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.y * in_vec_size + simd_lid * values_per_thread; - y += tid.y * out_vec_size + out_row; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; for (int k = 0; k < in_vec_size; k += block_size) { U sum = load_vector(x, x_thread); @@ -716,7 +716,7 @@ METAL_FUNC void qmv_impl( // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); @@ -731,8 +731,8 @@ METAL_FUNC void qmv_impl( out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.y * in_vec_size + simd_lid * values_per_thread; - y += tid.y * out_vec_size + out_row; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; int k = 0; for (; k < in_vec_size - block_size; k += block_size) { @@ -788,8 +788,8 @@ METAL_FUNC void qmv_impl( simd_lid * packs_per_thread * bytes_per_pack; scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.y * in_vec_size + simd_lid * values_per_thread; - y += tid.y * out_vec_size + used_out_row; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + used_out_row; int k = 0; for (; k < in_vec_size - block_size; k += block_size) { @@ -876,12 +876,12 @@ METAL_FUNC void qvm_impl( // Adjust positions const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; const int out_vec_size_g = out_vec_size / group_size; - int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid); + int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid); ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; scales += out_col / group_size + simd_lid * out_vec_size_g; biases += out_col / group_size + simd_lid * out_vec_size_g; - x += tid.y * in_vec_size + simd_lid; - y += tid.y * out_vec_size + out_col; + x += tid.x * in_vec_size + simd_lid; + y += tid.x * out_vec_size + out_col; if (out_col >= out_vec_size) { return; diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 4454476c9..d6fddd058 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -167,7 +167,7 @@ void qvm_split_k( int bo = 64; int bd = 32; MTL::Size group_dims = MTL::Size(bd, 2, 1); - MTL::Size grid_dims = MTL::Size(O / bo, B, N); + MTL::Size grid_dims = MTL::Size(B, O / bo, N); auto& x_pre = inputs[0]; auto& w_pre = inputs[1]; @@ -313,13 +313,13 @@ void qmm_op( int bo = 8; int bd = 32; group_dims = MTL::Size(bd, 2, 1); - grid_dims = MTL::Size(O / bo, B, N); + grid_dims = MTL::Size(B, O / bo, N); } else if (B < 6) { name += "qmv"; int bo = 8; int bd = 32; group_dims = MTL::Size(bd, 2, 1); - grid_dims = MTL::Size((O + bo - 1) / bo, B, N); + grid_dims = MTL::Size(B, (O + bo - 1) / bo, N); } else { int wn = 2; int wm = 2; @@ -339,7 +339,7 @@ void qmm_op( int bo = 64; int bd = 32; group_dims = MTL::Size(bd, 2, 1); - grid_dims = MTL::Size(O / bo, B, N); + grid_dims = MTL::Size(B, O / bo, N); } else { name += "qmm_n"; int wn = 2;