Gather qmm batched kernel and refactoring of quantized (#2078)

This commit is contained in:
Angelos Katharopoulos 2025-04-17 13:53:11 -07:00 committed by GitHub
parent 99eefd2ec0
commit 5de6d94a90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1479 additions and 449 deletions

View File

@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn

View File

@ -0,0 +1,84 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate(
[
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
for i, j in enumerate(idx.tolist())
],
axis=0,
)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_qmm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = mx.quantized_matmul(x, *w1, transpose=True)
x = mx.quantized_matmul(x, *w2, transpose=True)
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_qmm()

View File

@ -752,4 +752,43 @@ MTL::ComputePipelineState* get_quantized_kernel(
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& x,
int group_size,
int bits,
int bm,
int bn,
int bk,
int wm,
int wn,
bool transpose) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::quantized(),
get_template_definition(
lib_name,
"gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
} // namespace mlx::core

View File

@ -224,6 +224,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
const std::string& kernel_name,
const std::string& template_def);
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& x,
int group_size,
int bits,
int bm,
int bn,
int bk,
int wm,
int wn,
bool transpose);
// Create a GPU kernel template definition for JIT compilation
template <typename... Args>
std::string

View File

@ -3,6 +3,10 @@
#include <metal_simdgroup>
#include <metal_stdlib>
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
@ -1686,26 +1690,26 @@ template <
}
template <typename T, int group_size, int bits>
[[kernel]] void bs_qmv_fast(
[[kernel]] void gather_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& x_batch_ndims [[buffer(9)]],
const constant int* x_shape [[buffer(10)]],
const constant int64_t* x_strides [[buffer(11)]],
const constant int& w_batch_ndims [[buffer(12)]],
const constant int* w_shape [[buffer(13)]],
const constant int64_t* w_strides [[buffer(14)]],
const constant int64_t* s_strides [[buffer(15)]],
const constant int64_t* b_strides [[buffer(16)]],
const constant int& batch_ndims [[buffer(17)]],
const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
@ -1748,26 +1752,26 @@ template <typename T, int group_size, int bits>
}
template <typename T, int group_size, int bits>
[[kernel]] void bs_qmv(
[[kernel]] void gather_qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& x_batch_ndims [[buffer(9)]],
const constant int* x_shape [[buffer(10)]],
const constant int64_t* x_strides [[buffer(11)]],
const constant int& w_batch_ndims [[buffer(12)]],
const constant int* w_shape [[buffer(13)]],
const constant int64_t* w_strides [[buffer(14)]],
const constant int64_t* s_strides [[buffer(15)]],
const constant int64_t* b_strides [[buffer(16)]],
const constant int& batch_ndims [[buffer(17)]],
const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
@ -1810,26 +1814,26 @@ template <typename T, int group_size, int bits>
}
template <typename T, int group_size, int bits>
[[kernel]] void bs_qvm(
[[kernel]] void gather_qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
const constant int64_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& x_batch_ndims [[buffer(9)]],
const constant int* x_shape [[buffer(10)]],
const constant int64_t* x_strides [[buffer(11)]],
const constant int& w_batch_ndims [[buffer(12)]],
const constant int* w_shape [[buffer(13)]],
const constant int64_t* w_strides [[buffer(14)]],
const constant int64_t* s_strides [[buffer(15)]],
const constant int64_t* b_strides [[buffer(16)]],
const constant int& batch_ndims [[buffer(17)]],
const constant int* batch_shape [[buffer(18)]],
const constant int64_t* lhs_strides [[buffer(19)]],
const constant int64_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
@ -1879,27 +1883,27 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void bs_qmm_t(
[[kernel]] void gather_qmm_t(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& K [[buffer(7)]],
const constant int& N [[buffer(8)]],
const constant int& M [[buffer(9)]],
const constant int& x_batch_ndims [[buffer(10)]],
const constant int* x_shape [[buffer(11)]],
const constant int64_t* x_strides [[buffer(12)]],
const constant int& w_batch_ndims [[buffer(13)]],
const constant int* w_shape [[buffer(14)]],
const constant int64_t* w_strides [[buffer(15)]],
const constant int64_t* s_strides [[buffer(16)]],
const constant int64_t* b_strides [[buffer(17)]],
const constant int& batch_ndims [[buffer(18)]],
const constant int* batch_shape [[buffer(19)]],
const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
@ -1946,27 +1950,27 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void bs_qmm_n(
[[kernel]] void gather_qmm_n(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& K [[buffer(7)]],
const constant int& N [[buffer(8)]],
const constant int& M [[buffer(9)]],
const constant int& x_batch_ndims [[buffer(10)]],
const constant int* x_shape [[buffer(11)]],
const constant int64_t* x_strides [[buffer(12)]],
const constant int& w_batch_ndims [[buffer(13)]],
const constant int* w_shape [[buffer(14)]],
const constant int64_t* w_strides [[buffer(15)]],
const constant int64_t* s_strides [[buffer(16)]],
const constant int64_t* b_strides [[buffer(17)]],
const constant int& batch_ndims [[buffer(18)]],
const constant int* batch_shape [[buffer(19)]],
const constant int64_t* lhs_strides [[buffer(20)]],
const constant int64_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
@ -2007,6 +2011,289 @@ template <
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_aligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <
bool rows_aligned,
bool cols_aligned,
bool transpose,
typename T,
typename mma_t,
typename loader_a_t,
typename loader_b_t>
METAL_FUNC void gemm_loop_unaligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations,
const short tgp_bm,
const short tgp_bn,
const short tgp_bk) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
if (rows_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(short2(tgp_bk, tgp_bm));
}
if (cols_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_finalize(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const short2 tile_a,
const short2 tile_b) {
loader_a.load_safe(tile_a);
loader_b.load_safe(tile_b);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
template <
typename T,
int group_size,
int bits,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose>
[[kernel]] void gather_qmm_rhs(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* indices [[buffer(4)]],
device T* y [[buffer(5)]],
const constant int& M [[buffer(6)]],
const constant int& N [[buffer(7)]],
const constant int& K [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
using mma_t = mlx::steel::BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
false,
transpose,
BK_padded,
transpose ? BK_padded : BN_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<
T,
transpose ? BN : BK,
transpose ? BK : BN,
transpose ? BK_padded : BN_padded,
transpose,
WM * WN * SIMD_SIZE,
group_size,
bits>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
// Compute the block
const int K_w = K * bytes_per_pack / pack_factor;
const int K_g = K / group_size;
const int N_w = N * bytes_per_pack / pack_factor;
const int N_g = N / group_size;
const int K_it = K / BK;
const size_t stride_w = transpose ? N * K_w : K * N_w;
const size_t stride_s = transpose ? N * K_g : K * N_g;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
const size_t y_row_long = size_t(y_row);
const size_t y_col_long = size_t(y_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
// Calculate the final tiles in the case that K is not aligned
const int k_remain = K - K_it * BK;
const short2 tile_x = short2(k_remain, tgp_bm);
const short2 tile_w =
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
// Move x and output to the correct block
auto wl = (const device uint8_t*)w;
x += y_row_long * K;
y += y_row_long * N + y_col_long;
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
scales += transpose ? y_col_long * K_g : y_col / group_size;
biases += transpose ? y_col_long * K_g : y_col / group_size;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = indices[y_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (indices[y_row + n] != index) {
offset_next = n;
index_next = indices[y_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
thread loader_w_t loader_w(
wl + index * stride_w,
scales + index * stride_s,
biases + index * stride_s,
transpose ? K : N,
Ws,
simd_group_id,
simd_lane_id);
// Matrices are all aligned check nothing
if (align_M && align_N) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
} else {
// Tile aligned so check outside of the hot loop
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_loop_unaligned<false, true, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_loop_unaligned<true, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_loop_unaligned<false, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize(
const device T* w [[buffer(0)]],

View File

@ -60,6 +60,20 @@
bits, \
split_k)
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
group_size, \
bits, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0)
@ -73,14 +87,14 @@
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \
instantiate_quantized(bs_qvm, type, group_size, bits) \
instantiate_quantized(bs_qmm_n, type, group_size, bits)
instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
instantiate_quantized(gather_qmv, type, group_size, bits) \
instantiate_quantized(gather_qvm, type, group_size, bits) \
instantiate_quantized(gather_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
@ -96,12 +110,17 @@
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits)
instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_all_rhs(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \

View File

@ -1908,8 +1908,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
// Extract shapes strides from inputs and copy in case of non-contiguous
// vectors.
// Extract shapes from inputs.
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);

View File

@ -269,4 +269,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
int,
int,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
} // namespace mlx::core

File diff suppressed because it is too large Load Diff

View File

@ -4028,6 +4028,7 @@ array gather_qmm(
bool transpose /* = true */,
int group_size /* = 64 */,
int bits /* = 4 */,
bool sorted_indices /* = false */,
StreamOrDevice s /* = {} */) {
if (!lhs_indices_ && !rhs_indices_) {
return quantized_matmul(
@ -4067,13 +4068,19 @@ array gather_qmm(
return array(
std::move(out_shape),
out_type,
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
std::make_shared<GatherQMM>(
to_stream(s),
group_size,
bits,
transpose,
sorted_indices && !rhs_indices_,
sorted_indices && !lhs_indices_),
{astype(x, out_type, s),
w,
std::move(w),
astype(scales, out_type, s),
astype(biases, out_type, s),
lhs_indices,
rhs_indices});
std::move(lhs_indices),
std::move(rhs_indices)});
}
array tensordot(

View File

@ -1352,6 +1352,7 @@ array gather_qmm(
bool transpose = true,
int group_size = 64,
int bits = 4,
bool sorted_indices = false,
StreamOrDevice s = {});
/** Returns a contraction of a and b over multiple dimensions. */

View File

@ -3080,6 +3080,8 @@ std::vector<array> GatherQMM::vjp(
auto& lhs_indices = primals[4];
auto& rhs_indices = primals[5];
bool sorted = left_sorted_ || right_sorted_;
for (auto arg : argnums) {
// gradient wrt to x
if (arg == 0) {
@ -3098,6 +3100,7 @@ std::vector<array> GatherQMM::vjp(
!transpose_,
group_size_,
bits_,
sorted,
stream()),
-3,
stream()),

View File

@ -1591,11 +1591,19 @@ class QuantizedMatmul : public UnaryPrimitive {
class GatherQMM : public UnaryPrimitive {
public:
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
explicit GatherQMM(
Stream stream,
int group_size,
int bits,
bool transpose,
bool left_sorted = false,
bool right_sorted = false)
: UnaryPrimitive(stream),
group_size_(group_size),
bits_(bits),
transpose_(transpose) {}
transpose_(transpose),
left_sorted_(left_sorted),
right_sorted_(right_sorted) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1605,13 +1613,16 @@ class GatherQMM : public UnaryPrimitive {
DEFINE_PRINT(GatherQMM)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(group_size_, bits_, transpose_);
return std::make_tuple(
group_size_, bits_, transpose_, left_sorted_, right_sorted_);
}
private:
int group_size_;
int bits_;
bool transpose_;
bool left_sorted_;
bool right_sorted_;
};
class RandomBits : public UnaryPrimitive {

View File

@ -4250,9 +4250,10 @@ void init_ops(nb::module_& m) {
"group_size"_a = 64,
"bits"_a = 4,
nb::kw_only(),
"sorted_indices"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather.
@ -4265,23 +4266,25 @@ void init_ops(nb::module_& m) {
as ``w`` since they represent the same quantized matrix.
Args:
x (array): Input array
w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. Default: ``True``.
group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
x (array): Input array
w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. Default: ``True``.
group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
sorted_indices (bool, optional): May allow a faster implementation
if the passed indices are sorted. Default: ``False``.
Returns:
array: The result of the multiplication of ``x`` with ``w``
after gathering using ``lhs_indices`` and ``rhs_indices``.
array: The result of the multiplication of ``x`` with ``w``
after gathering using ``lhs_indices`` and ``rhs_indices``.
)pbdoc");
m.def(
"tensordot",
@ -4311,16 +4314,16 @@ void init_ops(nb::module_& m) {
Compute the tensor dot product along the specified axes.
Args:
a (array): Input array
b (array): Input array
axes (int or list(list(int)), optional): The number of dimensions to
sum over. If an integer is provided, then sum over the last
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
``b``. If a list of lists is provided, then sum over the
corresponding dimensions of ``a`` and ``b``. Default: 2.
a (array): Input array
b (array): Input array
axes (int or list(list(int)), optional): The number of dimensions to
sum over. If an integer is provided, then sum over the last
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
``b``. If a list of lists is provided, then sum over the
corresponding dimensions of ``a`` and ``b``. Default: 2.
Returns:
array: The tensor dot product.
array: The tensor dot product.
)pbdoc");
m.def(
"inner",

View File

@ -174,12 +174,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
tests = product(
[128, 64, 32], # group_size
[2, 3, 4, 6, 8], # bits
[128, 256], # M
[32, 128, 256], # M
[128, 256, 67], # N
[0, 1, 3, 8], # B
)
for group_size, bits, M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
if M < group_size:
continue
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1)
@ -448,6 +450,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
)
for kwargs in inputs:
test_shape(1, 32, 128, **kwargs)
test_shape(32, 32, 256, **kwargs)
test_shape(1, 32, 256, **kwargs)
test_shape(32, 256, 32, transpose=False, **kwargs)
@ -486,6 +489,66 @@ class TestQuantized(mlx_tests.MLXTestCase):
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
def test_gather_qmm_sorted(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
parameters = [
# L, K, D, E, I, transpose
(128, 1024, 1024, 32, 4, True),
(128, 1024, 544, 32, 4, True),
(433, 1024, 1024, 32, 4, True),
(433, 1024, 555, 32, 4, True),
(433, 2048, 1024, 32, 4, True),
(128, 1024, 1024, 32, 4, False),
(128, 1024, 544, 32, 4, False),
(433, 1024, 1024, 32, 4, False),
(433, 1024, 544, 32, 4, False),
(433, 1024, 555, 32, 4, False),
(433, 2048, 1024, 32, 4, False),
]
for L, K, D, E, I, transpose in parameters:
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)
wshape = (E, D, K) if transpose else (E, K, D)
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
x = mx.random.normal(xshape) / K**0.5
w = mx.random.normal(wshape) / K**0.5
w, *wq = quantize(w, transpose=transpose)
y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y4 = mx.gather_qmm(
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
)
y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape)
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
if __name__ == "__main__":
unittest.main()