mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 07:03:10 +08:00
refactor
This commit is contained in:
parent
27e31ab249
commit
a9716cd34c
@ -108,8 +108,8 @@ if(NOT MLX_METAL_JIT)
|
|||||||
reduction/reduce_all.h
|
reduction/reduce_all.h
|
||||||
reduction/reduce_col.h
|
reduction/reduce_col.h
|
||||||
reduction/reduce_row.h)
|
reduction/reduce_row.h)
|
||||||
build_kernel(quantized quantized.h ${STEEL_HEADERS})
|
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
|
||||||
build_kernel(fp4_quantized fp4_quantized.h ${STEEL_HEADERS})
|
build_kernel(fp4_quantized fp4_quantized.h quantized_utils.h ${STEEL_HEADERS})
|
||||||
build_kernel(scan scan.h)
|
build_kernel(scan scan.h)
|
||||||
build_kernel(softmax softmax.h)
|
build_kernel(softmax softmax.h)
|
||||||
build_kernel(logsumexp logsumexp.h)
|
build_kernel(logsumexp logsumexp.h)
|
||||||
|
@ -1595,92 +1595,6 @@ template <
|
|||||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
|
||||||
METAL_FUNC void gemm_loop_aligned(
|
|
||||||
threadgroup T* As,
|
|
||||||
threadgroup T* Bs,
|
|
||||||
thread mma_t& mma_op,
|
|
||||||
thread loader_a_t& loader_a,
|
|
||||||
thread loader_b_t& loader_b,
|
|
||||||
const int k_iterations) {
|
|
||||||
for (int k = 0; k < k_iterations; k++) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Load elements into threadgroup memory
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
bool rows_aligned,
|
|
||||||
bool cols_aligned,
|
|
||||||
bool transpose,
|
|
||||||
typename T,
|
|
||||||
typename mma_t,
|
|
||||||
typename loader_a_t,
|
|
||||||
typename loader_b_t>
|
|
||||||
METAL_FUNC void gemm_loop_unaligned(
|
|
||||||
threadgroup T* As,
|
|
||||||
threadgroup T* Bs,
|
|
||||||
thread mma_t& mma_op,
|
|
||||||
thread loader_a_t& loader_a,
|
|
||||||
thread loader_b_t& loader_b,
|
|
||||||
const int k_iterations,
|
|
||||||
const short tgp_bm,
|
|
||||||
const short tgp_bn,
|
|
||||||
const short tgp_bk) {
|
|
||||||
for (int k = 0; k < k_iterations; k++) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Load elements into threadgroup memory
|
|
||||||
if (rows_aligned) {
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
} else {
|
|
||||||
loader_a.load_safe(short2(tgp_bk, tgp_bm));
|
|
||||||
}
|
|
||||||
if (cols_aligned) {
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
} else {
|
|
||||||
loader_b.load_safe(
|
|
||||||
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
|
||||||
METAL_FUNC void gemm_loop_finalize(
|
|
||||||
threadgroup T* As,
|
|
||||||
threadgroup T* Bs,
|
|
||||||
thread mma_t& mma_op,
|
|
||||||
thread loader_a_t& loader_a,
|
|
||||||
thread loader_b_t& loader_b,
|
|
||||||
const short2 tile_a,
|
|
||||||
const short2 tile_b) {
|
|
||||||
loader_a.load_safe(tile_a);
|
|
||||||
loader_b.load_safe(tile_b);
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
int group_size,
|
int group_size,
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
// clang-format off
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
#include "mlx/backend/metal/kernels/quantized_utils.h"
|
||||||
#include "mlx/backend/metal/kernels/fp4_quantized.h"
|
#include "mlx/backend/metal/kernels/fp4_quantized.h"
|
||||||
|
|
||||||
#define instantiate_quantized(name, type) \
|
#define instantiate_quantized(name, type) \
|
||||||
|
@ -2138,92 +2138,6 @@ template <
|
|||||||
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
|
||||||
METAL_FUNC void gemm_loop_aligned(
|
|
||||||
threadgroup T* As,
|
|
||||||
threadgroup T* Bs,
|
|
||||||
thread mma_t& mma_op,
|
|
||||||
thread loader_a_t& loader_a,
|
|
||||||
thread loader_b_t& loader_b,
|
|
||||||
const int k_iterations) {
|
|
||||||
for (int k = 0; k < k_iterations; k++) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Load elements into threadgroup memory
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
bool rows_aligned,
|
|
||||||
bool cols_aligned,
|
|
||||||
bool transpose,
|
|
||||||
typename T,
|
|
||||||
typename mma_t,
|
|
||||||
typename loader_a_t,
|
|
||||||
typename loader_b_t>
|
|
||||||
METAL_FUNC void gemm_loop_unaligned(
|
|
||||||
threadgroup T* As,
|
|
||||||
threadgroup T* Bs,
|
|
||||||
thread mma_t& mma_op,
|
|
||||||
thread loader_a_t& loader_a,
|
|
||||||
thread loader_b_t& loader_b,
|
|
||||||
const int k_iterations,
|
|
||||||
const short tgp_bm,
|
|
||||||
const short tgp_bn,
|
|
||||||
const short tgp_bk) {
|
|
||||||
for (int k = 0; k < k_iterations; k++) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Load elements into threadgroup memory
|
|
||||||
if (rows_aligned) {
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
} else {
|
|
||||||
loader_a.load_safe(short2(tgp_bk, tgp_bm));
|
|
||||||
}
|
|
||||||
if (cols_aligned) {
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
} else {
|
|
||||||
loader_b.load_safe(
|
|
||||||
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
|
||||||
METAL_FUNC void gemm_loop_finalize(
|
|
||||||
threadgroup T* As,
|
|
||||||
threadgroup T* Bs,
|
|
||||||
thread mma_t& mma_op,
|
|
||||||
thread loader_a_t& loader_a,
|
|
||||||
thread loader_b_t& loader_b,
|
|
||||||
const short2 tile_a,
|
|
||||||
const short2 tile_b) {
|
|
||||||
loader_a.load_safe(tile_a);
|
|
||||||
loader_b.load_safe(tile_b);
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
int group_size,
|
int group_size,
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
// clang-format off
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
#include "mlx/backend/metal/kernels/quantized_utils.h"
|
||||||
#include "mlx/backend/metal/kernels/quantized.h"
|
#include "mlx/backend/metal/kernels/quantized.h"
|
||||||
|
|
||||||
#define instantiate_quantized(name, type, group_size, bits) \
|
#define instantiate_quantized(name, type, group_size, bits) \
|
||||||
|
90
mlx/backend/metal/kernels/quantized_utils.h
Normal file
90
mlx/backend/metal/kernels/quantized_utils.h
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <metal_simdgroup>
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
||||||
|
METAL_FUNC void gemm_loop_aligned(
|
||||||
|
threadgroup T* As,
|
||||||
|
threadgroup T* Bs,
|
||||||
|
thread mma_t& mma_op,
|
||||||
|
thread loader_a_t& loader_a,
|
||||||
|
thread loader_b_t& loader_b,
|
||||||
|
const int k_iterations) {
|
||||||
|
for (int k = 0; k < k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load elements into threadgroup memory
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
bool rows_aligned,
|
||||||
|
bool cols_aligned,
|
||||||
|
bool transpose,
|
||||||
|
typename T,
|
||||||
|
typename mma_t,
|
||||||
|
typename loader_a_t,
|
||||||
|
typename loader_b_t>
|
||||||
|
METAL_FUNC void gemm_loop_unaligned(
|
||||||
|
threadgroup T* As,
|
||||||
|
threadgroup T* Bs,
|
||||||
|
thread mma_t& mma_op,
|
||||||
|
thread loader_a_t& loader_a,
|
||||||
|
thread loader_b_t& loader_b,
|
||||||
|
const int k_iterations,
|
||||||
|
const short tgp_bm,
|
||||||
|
const short tgp_bn,
|
||||||
|
const short tgp_bk) {
|
||||||
|
for (int k = 0; k < k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load elements into threadgroup memory
|
||||||
|
if (rows_aligned) {
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
} else {
|
||||||
|
loader_a.load_safe(short2(tgp_bk, tgp_bm));
|
||||||
|
}
|
||||||
|
if (cols_aligned) {
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
} else {
|
||||||
|
loader_b.load_safe(
|
||||||
|
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
||||||
|
METAL_FUNC void gemm_loop_finalize(
|
||||||
|
threadgroup T* As,
|
||||||
|
threadgroup T* Bs,
|
||||||
|
thread mma_t& mma_op,
|
||||||
|
thread loader_a_t& loader_a,
|
||||||
|
thread loader_b_t& loader_b,
|
||||||
|
const short2 tile_a,
|
||||||
|
const short2 tile_b) {
|
||||||
|
loader_a.load_safe(tile_a);
|
||||||
|
loader_b.load_safe(tile_b);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user