mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
138 lines
3.7 KiB
C++
138 lines
3.7 KiB
C++
// Copyright © 2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include "mlx/backend/metal/kernels/steel/defines.h"
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// Loading helper
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace mlx {
|
|
namespace steel {
|
|
|
|
template <
|
|
typename T,
|
|
short BROWS,
|
|
short BCOLS,
|
|
short dst_ld,
|
|
short reduction_dim,
|
|
short tgp_size,
|
|
short alignment = 1,
|
|
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
|
short TCOLS = BCOLS / n_reads,
|
|
short TROWS = tgp_size / TCOLS>
|
|
struct BlockLoader {
|
|
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
|
STEEL_CONST short vec_size = n_reads;
|
|
|
|
// Leading dimension for src
|
|
const int src_ld;
|
|
const int tile_stride;
|
|
|
|
// Thread location indices
|
|
const short thread_idx;
|
|
const short bi;
|
|
const short bj;
|
|
|
|
// threadgroup and device memory
|
|
threadgroup T* dst;
|
|
const device T* src;
|
|
|
|
struct alignas(alignment * sizeof(T)) ReadVector {
|
|
uint8_t v[sizeof(T) * vec_size];
|
|
};
|
|
|
|
/* Constructor */
|
|
METAL_FUNC BlockLoader(
|
|
const device T* src_,
|
|
const int src_ld_,
|
|
threadgroup T* dst_,
|
|
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
|
: src_ld(src_ld_),
|
|
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
|
thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
bi(thread_idx / TCOLS),
|
|
bj(vec_size * (thread_idx % TCOLS)),
|
|
dst(dst_ + bi * dst_ld + bj),
|
|
src(src_ + bi * src_ld + bj) {}
|
|
|
|
/* Apply operation to threadgroup without bound checking */
|
|
template <typename UnaryOp>
|
|
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < BROWS; i += TROWS) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < vec_size; j++) {
|
|
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
|
|
}
|
|
}
|
|
}
|
|
|
|
/* Load from device memory into threadgroup memory - without bound checking */
|
|
METAL_FUNC void load_unsafe() const {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < BROWS; i += TROWS) {
|
|
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
|
*((const device ReadVector*)(&src[i * src_ld]));
|
|
}
|
|
}
|
|
|
|
/* Load from device memory into threadgroup memory - with bound checking */
|
|
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
|
src_tile_dim = src_tile_dim - short2(bj, bi);
|
|
|
|
// Skip loading if thread has no valid reads
|
|
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < BROWS; i += TROWS) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < vec_size; j++) {
|
|
dst[i * dst_ld + j] = T(0);
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Use fast thread memory for bound checks
|
|
bool tmp_idx[vec_size];
|
|
T tmp_val[vec_size];
|
|
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < BROWS; i += TROWS) {
|
|
// Make sure tmp_idx only contains valid indices
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < vec_size; j++) {
|
|
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
|
}
|
|
|
|
// Read valid indices into tmp_val
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < vec_size; j++) {
|
|
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
|
}
|
|
|
|
// Zero out unneeded values
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < vec_size; j++) {
|
|
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
|
}
|
|
|
|
// Copy values to threadgroup memory
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < vec_size; j++) {
|
|
dst[i * dst_ld + j] = tmp_val[j];
|
|
}
|
|
}
|
|
}
|
|
|
|
/* Iteration helper */
|
|
METAL_FUNC void next() {
|
|
src += tile_stride;
|
|
}
|
|
};
|
|
|
|
} // namespace steel
|
|
} // namespace mlx
|