mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Gather qmm batched kernel and refactoring of quantized (#2078)
This commit is contained in:
committed by
GitHub
parent
99eefd2ec0
commit
5de6d94a90
@@ -3,6 +3,10 @@
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
@@ -1686,26 +1690,26 @@ template <
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void bs_qmv_fast(
|
||||
[[kernel]] void gather_qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& in_vec_size [[buffer(7)]],
|
||||
const constant int& out_vec_size [[buffer(8)]],
|
||||
const constant int& x_batch_ndims [[buffer(9)]],
|
||||
const constant int* x_shape [[buffer(10)]],
|
||||
const constant int64_t* x_strides [[buffer(11)]],
|
||||
const constant int& w_batch_ndims [[buffer(12)]],
|
||||
const constant int* w_shape [[buffer(13)]],
|
||||
const constant int64_t* w_strides [[buffer(14)]],
|
||||
const constant int64_t* s_strides [[buffer(15)]],
|
||||
const constant int64_t* b_strides [[buffer(16)]],
|
||||
const constant int& batch_ndims [[buffer(17)]],
|
||||
const constant int* batch_shape [[buffer(18)]],
|
||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -1748,26 +1752,26 @@ template <typename T, int group_size, int bits>
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void bs_qmv(
|
||||
[[kernel]] void gather_qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& in_vec_size [[buffer(7)]],
|
||||
const constant int& out_vec_size [[buffer(8)]],
|
||||
const constant int& x_batch_ndims [[buffer(9)]],
|
||||
const constant int* x_shape [[buffer(10)]],
|
||||
const constant int64_t* x_strides [[buffer(11)]],
|
||||
const constant int& w_batch_ndims [[buffer(12)]],
|
||||
const constant int* w_shape [[buffer(13)]],
|
||||
const constant int64_t* w_strides [[buffer(14)]],
|
||||
const constant int64_t* s_strides [[buffer(15)]],
|
||||
const constant int64_t* b_strides [[buffer(16)]],
|
||||
const constant int& batch_ndims [[buffer(17)]],
|
||||
const constant int* batch_shape [[buffer(18)]],
|
||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -1810,26 +1814,26 @@ template <typename T, int group_size, int bits>
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void bs_qvm(
|
||||
[[kernel]] void gather_qvm(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
const constant int64_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& in_vec_size [[buffer(7)]],
|
||||
const constant int& out_vec_size [[buffer(8)]],
|
||||
const constant int& x_batch_ndims [[buffer(9)]],
|
||||
const constant int* x_shape [[buffer(10)]],
|
||||
const constant int64_t* x_strides [[buffer(11)]],
|
||||
const constant int& w_batch_ndims [[buffer(12)]],
|
||||
const constant int* w_shape [[buffer(13)]],
|
||||
const constant int64_t* w_strides [[buffer(14)]],
|
||||
const constant int64_t* s_strides [[buffer(15)]],
|
||||
const constant int64_t* b_strides [[buffer(16)]],
|
||||
const constant int& batch_ndims [[buffer(17)]],
|
||||
const constant int* batch_shape [[buffer(18)]],
|
||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -1879,27 +1883,27 @@ template <
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void bs_qmm_t(
|
||||
[[kernel]] void gather_qmm_t(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& K [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant int64_t* w_strides [[buffer(13)]],
|
||||
const constant int64_t* s_strides [[buffer(14)]],
|
||||
const constant int64_t* b_strides [[buffer(15)]],
|
||||
const constant int& batch_ndims [[buffer(16)]],
|
||||
const constant int* batch_shape [[buffer(17)]],
|
||||
const device uint32_t* lhs_indices [[buffer(18)]],
|
||||
const device uint32_t* rhs_indices [[buffer(19)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& K [[buffer(7)]],
|
||||
const constant int& N [[buffer(8)]],
|
||||
const constant int& M [[buffer(9)]],
|
||||
const constant int& x_batch_ndims [[buffer(10)]],
|
||||
const constant int* x_shape [[buffer(11)]],
|
||||
const constant int64_t* x_strides [[buffer(12)]],
|
||||
const constant int& w_batch_ndims [[buffer(13)]],
|
||||
const constant int* w_shape [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(15)]],
|
||||
const constant int64_t* s_strides [[buffer(16)]],
|
||||
const constant int64_t* b_strides [[buffer(17)]],
|
||||
const constant int& batch_ndims [[buffer(18)]],
|
||||
const constant int* batch_shape [[buffer(19)]],
|
||||
const constant int64_t* lhs_strides [[buffer(20)]],
|
||||
const constant int64_t* rhs_strides [[buffer(21)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -1946,27 +1950,27 @@ template <
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void bs_qmm_n(
|
||||
[[kernel]] void gather_qmm_n(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& K [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant int64_t* w_strides [[buffer(13)]],
|
||||
const constant int64_t* s_strides [[buffer(14)]],
|
||||
const constant int64_t* b_strides [[buffer(15)]],
|
||||
const constant int& batch_ndims [[buffer(16)]],
|
||||
const constant int* batch_shape [[buffer(17)]],
|
||||
const device uint32_t* lhs_indices [[buffer(18)]],
|
||||
const device uint32_t* rhs_indices [[buffer(19)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& K [[buffer(7)]],
|
||||
const constant int& N [[buffer(8)]],
|
||||
const constant int& M [[buffer(9)]],
|
||||
const constant int& x_batch_ndims [[buffer(10)]],
|
||||
const constant int* x_shape [[buffer(11)]],
|
||||
const constant int64_t* x_strides [[buffer(12)]],
|
||||
const constant int& w_batch_ndims [[buffer(13)]],
|
||||
const constant int* w_shape [[buffer(14)]],
|
||||
const constant int64_t* w_strides [[buffer(15)]],
|
||||
const constant int64_t* s_strides [[buffer(16)]],
|
||||
const constant int64_t* b_strides [[buffer(17)]],
|
||||
const constant int& batch_ndims [[buffer(18)]],
|
||||
const constant int* batch_shape [[buffer(19)]],
|
||||
const constant int64_t* lhs_strides [[buffer(20)]],
|
||||
const constant int64_t* rhs_strides [[buffer(21)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -2007,6 +2011,289 @@ template <
|
||||
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
||||
METAL_FUNC void gemm_loop_aligned(
|
||||
threadgroup T* As,
|
||||
threadgroup T* Bs,
|
||||
thread mma_t& mma_op,
|
||||
thread loader_a_t& loader_a,
|
||||
thread loader_b_t& loader_b,
|
||||
const int k_iterations) {
|
||||
for (int k = 0; k < k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup memory
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
bool rows_aligned,
|
||||
bool cols_aligned,
|
||||
bool transpose,
|
||||
typename T,
|
||||
typename mma_t,
|
||||
typename loader_a_t,
|
||||
typename loader_b_t>
|
||||
METAL_FUNC void gemm_loop_unaligned(
|
||||
threadgroup T* As,
|
||||
threadgroup T* Bs,
|
||||
thread mma_t& mma_op,
|
||||
thread loader_a_t& loader_a,
|
||||
thread loader_b_t& loader_b,
|
||||
const int k_iterations,
|
||||
const short tgp_bm,
|
||||
const short tgp_bn,
|
||||
const short tgp_bk) {
|
||||
for (int k = 0; k < k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load elements into threadgroup memory
|
||||
if (rows_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(short2(tgp_bk, tgp_bm));
|
||||
}
|
||||
if (cols_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(
|
||||
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
||||
METAL_FUNC void gemm_loop_finalize(
|
||||
threadgroup T* As,
|
||||
threadgroup T* Bs,
|
||||
thread mma_t& mma_op,
|
||||
thread loader_a_t& loader_a,
|
||||
thread loader_b_t& loader_b,
|
||||
const short2 tile_a,
|
||||
const short2 tile_b) {
|
||||
loader_a.load_safe(tile_a);
|
||||
loader_b.load_safe(tile_b);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int group_size,
|
||||
int bits,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose>
|
||||
[[kernel]] void gather_qmm_rhs(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
const device uint32_t* indices [[buffer(4)]],
|
||||
device T* y [[buffer(5)]],
|
||||
const constant int& M [[buffer(6)]],
|
||||
const constant int& N [[buffer(7)]],
|
||||
const constant int& K [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
|
||||
using mma_t = mlx::steel::BlockMMA<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
false,
|
||||
transpose,
|
||||
BK_padded,
|
||||
transpose ? BK_padded : BN_padded>;
|
||||
using loader_x_t =
|
||||
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
||||
using loader_w_t = QuantizedBlockLoader<
|
||||
T,
|
||||
transpose ? BN : BK,
|
||||
transpose ? BK : BN,
|
||||
transpose ? BK_padded : BN_padded,
|
||||
transpose,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
bits>;
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
|
||||
|
||||
// Compute the block
|
||||
const int K_w = K * bytes_per_pack / pack_factor;
|
||||
const int K_g = K / group_size;
|
||||
const int N_w = N * bytes_per_pack / pack_factor;
|
||||
const int N_g = N / group_size;
|
||||
const int K_it = K / BK;
|
||||
const size_t stride_w = transpose ? N * K_w : K * N_w;
|
||||
const size_t stride_s = transpose ? N * K_g : K * N_g;
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
const size_t y_row_long = size_t(y_row);
|
||||
const size_t y_col_long = size_t(y_col);
|
||||
|
||||
// Prepare threadgroup bounds
|
||||
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
|
||||
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
|
||||
|
||||
// Calculate the final tiles in the case that K is not aligned
|
||||
const int k_remain = K - K_it * BK;
|
||||
const short2 tile_x = short2(k_remain, tgp_bm);
|
||||
const short2 tile_w =
|
||||
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
|
||||
// Move x and output to the correct block
|
||||
auto wl = (const device uint8_t*)w;
|
||||
x += y_row_long * K;
|
||||
y += y_row_long * N + y_col_long;
|
||||
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
|
||||
scales += transpose ? y_col_long * K_g : y_col / group_size;
|
||||
biases += transpose ? y_col_long * K_g : y_col / group_size;
|
||||
|
||||
// Do as many matmuls as necessary
|
||||
uint32_t index;
|
||||
short offset;
|
||||
uint32_t index_next = indices[y_row];
|
||||
short offset_next = 0;
|
||||
int n = 0;
|
||||
while (n < tgp_bm) {
|
||||
n++;
|
||||
offset = offset_next;
|
||||
index = index_next;
|
||||
offset_next = tgp_bm;
|
||||
for (; n < tgp_bm; n++) {
|
||||
if (indices[y_row + n] != index) {
|
||||
offset_next = n;
|
||||
index_next = indices[y_row + n];
|
||||
break;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
|
||||
thread loader_w_t loader_w(
|
||||
wl + index * stride_w,
|
||||
scales + index * stride_s,
|
||||
biases + index * stride_s,
|
||||
transpose ? K : N,
|
||||
Ws,
|
||||
simd_group_id,
|
||||
simd_lane_id);
|
||||
|
||||
// Matrices are all aligned check nothing
|
||||
if (align_M && align_N) {
|
||||
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
|
||||
if (!align_K) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
if (offset_next - offset == BM) {
|
||||
mma_op.store_result(y, N);
|
||||
} else {
|
||||
mma_op.store_result_slice(
|
||||
y, N, short2(0, offset), short2(BN, offset_next));
|
||||
}
|
||||
} else {
|
||||
// Tile aligned so check outside of the hot loop
|
||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
|
||||
if (!align_K) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
gemm_loop_finalize(
|
||||
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
if (offset_next - offset == BM) {
|
||||
mma_op.store_result(y, N);
|
||||
} else {
|
||||
mma_op.store_result_slice(
|
||||
y, N, short2(0, offset), short2(BN, offset_next));
|
||||
}
|
||||
}
|
||||
|
||||
// Tile partially aligned check rows
|
||||
else if (align_N || tgp_bn == BN) {
|
||||
gemm_loop_unaligned<false, true, transpose>(
|
||||
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
||||
if (!align_K) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
gemm_loop_finalize(
|
||||
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||
}
|
||||
mma_op.store_result_slice(
|
||||
y, N, short2(0, offset), short2(BN, offset_next));
|
||||
}
|
||||
|
||||
// Tile partially aligned check cols
|
||||
else if (align_M || tgp_bm == BM) {
|
||||
gemm_loop_unaligned<true, false, transpose>(
|
||||
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
||||
if (!align_K) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
gemm_loop_finalize(
|
||||
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||
}
|
||||
mma_op.store_result_slice(
|
||||
y, N, short2(0, offset), short2(tgp_bn, offset_next));
|
||||
}
|
||||
|
||||
// Nothing aligned so check both rows and cols
|
||||
else {
|
||||
gemm_loop_unaligned<false, false, transpose>(
|
||||
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
||||
if (!align_K) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
gemm_loop_finalize(
|
||||
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||
}
|
||||
mma_op.store_result_slice(
|
||||
y, N, short2(0, offset), short2(tgp_bn, offset_next));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void affine_quantize(
|
||||
const device T* w [[buffer(0)]],
|
||||
|
||||
Reference in New Issue
Block a user