mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 14:59:22 +08:00
91 lines
2.3 KiB
C++
91 lines
2.3 KiB
C++
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
#include <metal_simdgroup>
|
|
#include <metal_stdlib>
|
|
|
|
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
|
METAL_FUNC void gemm_loop_aligned(
|
|
threadgroup T* As,
|
|
threadgroup T* Bs,
|
|
thread mma_t& mma_op,
|
|
thread loader_a_t& loader_a,
|
|
thread loader_b_t& loader_b,
|
|
const int k_iterations) {
|
|
for (int k = 0; k < k_iterations; k++) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Load elements into threadgroup memory
|
|
loader_a.load_unsafe();
|
|
loader_b.load_unsafe();
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Multiply and accumulate threadgroup elements
|
|
mma_op.mma(As, Bs);
|
|
|
|
// Prepare for next iteration
|
|
loader_a.next();
|
|
loader_b.next();
|
|
}
|
|
}
|
|
|
|
template <
|
|
bool rows_aligned,
|
|
bool cols_aligned,
|
|
bool transpose,
|
|
typename T,
|
|
typename mma_t,
|
|
typename loader_a_t,
|
|
typename loader_b_t>
|
|
METAL_FUNC void gemm_loop_unaligned(
|
|
threadgroup T* As,
|
|
threadgroup T* Bs,
|
|
thread mma_t& mma_op,
|
|
thread loader_a_t& loader_a,
|
|
thread loader_b_t& loader_b,
|
|
const int k_iterations,
|
|
const short tgp_bm,
|
|
const short tgp_bn,
|
|
const short tgp_bk) {
|
|
for (int k = 0; k < k_iterations; k++) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Load elements into threadgroup memory
|
|
if (rows_aligned) {
|
|
loader_a.load_unsafe();
|
|
} else {
|
|
loader_a.load_safe(short2(tgp_bk, tgp_bm));
|
|
}
|
|
if (cols_aligned) {
|
|
loader_b.load_unsafe();
|
|
} else {
|
|
loader_b.load_safe(
|
|
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
|
|
}
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Multiply and accumulate threadgroup elements
|
|
mma_op.mma(As, Bs);
|
|
|
|
// Prepare for next iteration
|
|
loader_a.next();
|
|
loader_b.next();
|
|
}
|
|
}
|
|
|
|
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
|
METAL_FUNC void gemm_loop_finalize(
|
|
threadgroup T* As,
|
|
threadgroup T* Bs,
|
|
thread mma_t& mma_op,
|
|
thread loader_a_t& loader_a,
|
|
thread loader_b_t& loader_b,
|
|
const short2 tile_a,
|
|
const short2 tile_b) {
|
|
loader_a.load_safe(tile_a);
|
|
loader_b.load_safe(tile_b);
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
mma_op.mma(As, Bs);
|
|
}
|