diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index c3164af0c..70faa1d24 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -108,8 +108,8 @@ if(NOT MLX_METAL_JIT) reduction/reduce_all.h reduction/reduce_col.h reduction/reduce_row.h) - build_kernel(quantized quantized.h ${STEEL_HEADERS}) - build_kernel(fp4_quantized fp4_quantized.h ${STEEL_HEADERS}) + build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS}) + build_kernel(fp4_quantized fp4_quantized.h quantized_utils.h ${STEEL_HEADERS}) build_kernel(scan scan.h) build_kernel(softmax softmax.h) build_kernel(logsumexp logsumexp.h) diff --git a/mlx/backend/metal/kernels/fp4_quantized.h b/mlx/backend/metal/kernels/fp4_quantized.h index 40c5fa187..b5a8918e4 100644 --- a/mlx/backend/metal/kernels/fp4_quantized.h +++ b/mlx/backend/metal/kernels/fp4_quantized.h @@ -1595,92 +1595,6 @@ template < w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } -template -METAL_FUNC void gemm_loop_aligned( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const int k_iterations) { - for (int k = 0; k < k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup memory - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } -} - -template < - bool rows_aligned, - bool cols_aligned, - bool transpose, - typename T, - typename mma_t, - typename loader_a_t, - typename loader_b_t> -METAL_FUNC void gemm_loop_unaligned( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const int k_iterations, - const short tgp_bm, - const short tgp_bn, - const short tgp_bk) { - for (int k = 0; k < k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup memory - if (rows_aligned) { - loader_a.load_unsafe(); - } else { - loader_a.load_safe(short2(tgp_bk, tgp_bm)); - } - if (cols_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe( - transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } -} - -template -METAL_FUNC void gemm_loop_finalize( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const short2 tile_a, - const short2 tile_b) { - loader_a.load_safe(tile_a); - loader_b.load_safe(tile_b); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); -} - template < typename T, int group_size, diff --git a/mlx/backend/metal/kernels/fp4_quantized.metal b/mlx/backend/metal/kernels/fp4_quantized.metal index 9a817d20d..c982b4bf4 100644 --- a/mlx/backend/metal/kernels/fp4_quantized.metal +++ b/mlx/backend/metal/kernels/fp4_quantized.metal @@ -3,6 +3,7 @@ // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/quantized_utils.h" #include "mlx/backend/metal/kernels/fp4_quantized.h" #define instantiate_quantized(name, type) \ diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 09a73f887..bf639814b 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2138,92 +2138,6 @@ template < w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } -template -METAL_FUNC void gemm_loop_aligned( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const int k_iterations) { - for (int k = 0; k < k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup memory - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } -} - -template < - bool rows_aligned, - bool cols_aligned, - bool transpose, - typename T, - typename mma_t, - typename loader_a_t, - typename loader_b_t> -METAL_FUNC void gemm_loop_unaligned( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const int k_iterations, - const short tgp_bm, - const short tgp_bn, - const short tgp_bk) { - for (int k = 0; k < k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup memory - if (rows_aligned) { - loader_a.load_unsafe(); - } else { - loader_a.load_safe(short2(tgp_bk, tgp_bm)); - } - if (cols_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe( - transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } -} - -template -METAL_FUNC void gemm_loop_finalize( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const short2 tile_a, - const short2 tile_b) { - loader_a.load_safe(tile_a); - loader_b.load_safe(tile_b); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); -} - template < typename T, int group_size, diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 56e8e4ca5..f734b9bce 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -3,6 +3,7 @@ // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/quantized_utils.h" #include "mlx/backend/metal/kernels/quantized.h" #define instantiate_quantized(name, type, group_size, bits) \ diff --git a/mlx/backend/metal/kernels/quantized_utils.h b/mlx/backend/metal/kernels/quantized_utils.h new file mode 100644 index 000000000..38253f8fe --- /dev/null +++ b/mlx/backend/metal/kernels/quantized_utils.h @@ -0,0 +1,90 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +template +METAL_FUNC void gemm_loop_aligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template < + bool rows_aligned, + bool cols_aligned, + bool transpose, + typename T, + typename mma_t, + typename loader_a_t, + typename loader_b_t> +METAL_FUNC void gemm_loop_unaligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations, + const short tgp_bm, + const short tgp_bn, + const short tgp_bk) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + if (rows_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(short2(tgp_bk, tgp_bm)); + } + if (cols_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe( + transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template +METAL_FUNC void gemm_loop_finalize( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const short2 tile_a, + const short2 tile_b) { + loader_a.load_safe(tile_a); + loader_b.load_safe(tile_b); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); +}