mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Gather qmm batched kernel and refactoring of quantized (#2078)
This commit is contained in:
parent
99eefd2ec0
commit
5de6d94a90
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
84
benchmarks/python/gather_qmm_bench.py
Normal file
84
benchmarks/python/gather_qmm_bench.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
N = 1024
|
||||||
|
D = 1024
|
||||||
|
M = 1024
|
||||||
|
E = 32
|
||||||
|
I = 4
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gather_mm_simulate(x, w, indices):
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
for i in range(2):
|
||||||
|
y = mx.concatenate(
|
||||||
|
[
|
||||||
|
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
||||||
|
for i, j in enumerate(idx.tolist())
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
x = y[:, None]
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def time_gather_qmm():
|
||||||
|
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||||
|
w1 = mx.quantize(w1)
|
||||||
|
w2 = mx.quantize(w2)
|
||||||
|
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||||
|
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||||
|
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||||
|
|
||||||
|
def gather_mm(x, w1, w2, indices, sort):
|
||||||
|
idx = indices
|
||||||
|
inv_order = None
|
||||||
|
if sort:
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||||
|
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||||
|
if sort:
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||||
|
|
||||||
|
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||||
|
w1 = mx.quantize(w1)
|
||||||
|
w2 = mx.quantize(w2)
|
||||||
|
mx.eval(x, w1, w2)
|
||||||
|
|
||||||
|
def equivalent_matmul(x, w1, w2):
|
||||||
|
x = mx.quantized_matmul(x, *w1, transpose=True)
|
||||||
|
x = mx.quantized_matmul(x, *w2, transpose=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(equivalent_matmul, x, w1, w2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_gather_qmm()
|
@ -752,4 +752,43 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
|||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array& x,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn,
|
||||||
|
bool transpose) {
|
||||||
|
const auto& lib_name = kernel_name;
|
||||||
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
|
std::string kernel_source;
|
||||||
|
concatenate(
|
||||||
|
kernel_source,
|
||||||
|
metal::utils(),
|
||||||
|
metal::gemm(),
|
||||||
|
metal::quantized(),
|
||||||
|
get_template_definition(
|
||||||
|
lib_name,
|
||||||
|
"gather_qmm_rhs",
|
||||||
|
get_type_string(x.dtype()),
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn,
|
||||||
|
transpose));
|
||||||
|
return kernel_source;
|
||||||
|
});
|
||||||
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -224,6 +224,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::string& template_def);
|
const std::string& template_def);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array& x,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn,
|
||||||
|
bool transpose);
|
||||||
|
|
||||||
// Create a GPU kernel template definition for JIT compilation
|
// Create a GPU kernel template definition for JIT compilation
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
std::string
|
std::string
|
||||||
|
@ -3,6 +3,10 @@
|
|||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
constant bool align_M [[function_constant(200)]];
|
||||||
|
constant bool align_N [[function_constant(201)]];
|
||||||
|
constant bool align_K [[function_constant(202)]];
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
#define MLX_MTL_CONST static constant constexpr const
|
#define MLX_MTL_CONST static constant constexpr const
|
||||||
@ -1686,26 +1690,26 @@ template <
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <typename T, int group_size, int bits>
|
||||||
[[kernel]] void bs_qmv_fast(
|
[[kernel]] void gather_qmv_fast(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
const device T* biases [[buffer(2)]],
|
const device T* biases [[buffer(2)]],
|
||||||
const device T* x [[buffer(3)]],
|
const device T* x [[buffer(3)]],
|
||||||
device T* y [[buffer(4)]],
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
const constant int& in_vec_size [[buffer(5)]],
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
const constant int& out_vec_size [[buffer(6)]],
|
device T* y [[buffer(6)]],
|
||||||
const constant int& x_batch_ndims [[buffer(7)]],
|
const constant int& in_vec_size [[buffer(7)]],
|
||||||
const constant int* x_shape [[buffer(8)]],
|
const constant int& out_vec_size [[buffer(8)]],
|
||||||
const constant int64_t* x_strides [[buffer(9)]],
|
const constant int& x_batch_ndims [[buffer(9)]],
|
||||||
const constant int& w_batch_ndims [[buffer(10)]],
|
const constant int* x_shape [[buffer(10)]],
|
||||||
const constant int* w_shape [[buffer(11)]],
|
const constant int64_t* x_strides [[buffer(11)]],
|
||||||
const constant int64_t* w_strides [[buffer(12)]],
|
const constant int& w_batch_ndims [[buffer(12)]],
|
||||||
const constant int64_t* s_strides [[buffer(13)]],
|
const constant int* w_shape [[buffer(13)]],
|
||||||
const constant int64_t* b_strides [[buffer(14)]],
|
const constant int64_t* w_strides [[buffer(14)]],
|
||||||
const constant int& batch_ndims [[buffer(15)]],
|
const constant int64_t* s_strides [[buffer(15)]],
|
||||||
const constant int* batch_shape [[buffer(16)]],
|
const constant int64_t* b_strides [[buffer(16)]],
|
||||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
const constant int& batch_ndims [[buffer(17)]],
|
||||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
const constant int* batch_shape [[buffer(18)]],
|
||||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
@ -1748,26 +1752,26 @@ template <typename T, int group_size, int bits>
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <typename T, int group_size, int bits>
|
||||||
[[kernel]] void bs_qmv(
|
[[kernel]] void gather_qmv(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
const device T* biases [[buffer(2)]],
|
const device T* biases [[buffer(2)]],
|
||||||
const device T* x [[buffer(3)]],
|
const device T* x [[buffer(3)]],
|
||||||
device T* y [[buffer(4)]],
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
const constant int& in_vec_size [[buffer(5)]],
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
const constant int& out_vec_size [[buffer(6)]],
|
device T* y [[buffer(6)]],
|
||||||
const constant int& x_batch_ndims [[buffer(7)]],
|
const constant int& in_vec_size [[buffer(7)]],
|
||||||
const constant int* x_shape [[buffer(8)]],
|
const constant int& out_vec_size [[buffer(8)]],
|
||||||
const constant int64_t* x_strides [[buffer(9)]],
|
const constant int& x_batch_ndims [[buffer(9)]],
|
||||||
const constant int& w_batch_ndims [[buffer(10)]],
|
const constant int* x_shape [[buffer(10)]],
|
||||||
const constant int* w_shape [[buffer(11)]],
|
const constant int64_t* x_strides [[buffer(11)]],
|
||||||
const constant int64_t* w_strides [[buffer(12)]],
|
const constant int& w_batch_ndims [[buffer(12)]],
|
||||||
const constant int64_t* s_strides [[buffer(13)]],
|
const constant int* w_shape [[buffer(13)]],
|
||||||
const constant int64_t* b_strides [[buffer(14)]],
|
const constant int64_t* w_strides [[buffer(14)]],
|
||||||
const constant int& batch_ndims [[buffer(15)]],
|
const constant int64_t* s_strides [[buffer(15)]],
|
||||||
const constant int* batch_shape [[buffer(16)]],
|
const constant int64_t* b_strides [[buffer(16)]],
|
||||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
const constant int& batch_ndims [[buffer(17)]],
|
||||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
const constant int* batch_shape [[buffer(18)]],
|
||||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
@ -1810,26 +1814,26 @@ template <typename T, int group_size, int bits>
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <typename T, int group_size, int bits>
|
||||||
[[kernel]] void bs_qvm(
|
[[kernel]] void gather_qvm(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
const device T* biases [[buffer(2)]],
|
const device T* biases [[buffer(2)]],
|
||||||
const device T* x [[buffer(3)]],
|
const device T* x [[buffer(3)]],
|
||||||
device T* y [[buffer(4)]],
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
const constant int& in_vec_size [[buffer(5)]],
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
const constant int& out_vec_size [[buffer(6)]],
|
device T* y [[buffer(6)]],
|
||||||
const constant int& x_batch_ndims [[buffer(7)]],
|
const constant int& in_vec_size [[buffer(7)]],
|
||||||
const constant int* x_shape [[buffer(8)]],
|
const constant int& out_vec_size [[buffer(8)]],
|
||||||
const constant int64_t* x_strides [[buffer(9)]],
|
const constant int& x_batch_ndims [[buffer(9)]],
|
||||||
const constant int& w_batch_ndims [[buffer(10)]],
|
const constant int* x_shape [[buffer(10)]],
|
||||||
const constant int* w_shape [[buffer(11)]],
|
const constant int64_t* x_strides [[buffer(11)]],
|
||||||
const constant int64_t* w_strides [[buffer(12)]],
|
const constant int& w_batch_ndims [[buffer(12)]],
|
||||||
const constant int64_t* s_strides [[buffer(13)]],
|
const constant int* w_shape [[buffer(13)]],
|
||||||
const constant int64_t* b_strides [[buffer(14)]],
|
const constant int64_t* w_strides [[buffer(14)]],
|
||||||
const constant int& batch_ndims [[buffer(15)]],
|
const constant int64_t* s_strides [[buffer(15)]],
|
||||||
const constant int* batch_shape [[buffer(16)]],
|
const constant int64_t* b_strides [[buffer(16)]],
|
||||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
const constant int& batch_ndims [[buffer(17)]],
|
||||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
const constant int* batch_shape [[buffer(18)]],
|
||||||
const constant int64_t* lhs_strides [[buffer(19)]],
|
const constant int64_t* lhs_strides [[buffer(19)]],
|
||||||
const constant int64_t* rhs_strides [[buffer(20)]],
|
const constant int64_t* rhs_strides [[buffer(20)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
@ -1879,27 +1883,27 @@ template <
|
|||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
[[kernel]] void bs_qmm_t(
|
[[kernel]] void gather_qmm_t(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
const device T* biases [[buffer(2)]],
|
const device T* biases [[buffer(2)]],
|
||||||
const device T* x [[buffer(3)]],
|
const device T* x [[buffer(3)]],
|
||||||
device T* y [[buffer(4)]],
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
const constant int& K [[buffer(5)]],
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
const constant int& N [[buffer(6)]],
|
device T* y [[buffer(6)]],
|
||||||
const constant int& M [[buffer(7)]],
|
const constant int& K [[buffer(7)]],
|
||||||
const constant int& x_batch_ndims [[buffer(8)]],
|
const constant int& N [[buffer(8)]],
|
||||||
const constant int* x_shape [[buffer(9)]],
|
const constant int& M [[buffer(9)]],
|
||||||
const constant int64_t* x_strides [[buffer(10)]],
|
const constant int& x_batch_ndims [[buffer(10)]],
|
||||||
const constant int& w_batch_ndims [[buffer(11)]],
|
const constant int* x_shape [[buffer(11)]],
|
||||||
const constant int* w_shape [[buffer(12)]],
|
const constant int64_t* x_strides [[buffer(12)]],
|
||||||
const constant int64_t* w_strides [[buffer(13)]],
|
const constant int& w_batch_ndims [[buffer(13)]],
|
||||||
const constant int64_t* s_strides [[buffer(14)]],
|
const constant int* w_shape [[buffer(14)]],
|
||||||
const constant int64_t* b_strides [[buffer(15)]],
|
const constant int64_t* w_strides [[buffer(15)]],
|
||||||
const constant int& batch_ndims [[buffer(16)]],
|
const constant int64_t* s_strides [[buffer(16)]],
|
||||||
const constant int* batch_shape [[buffer(17)]],
|
const constant int64_t* b_strides [[buffer(17)]],
|
||||||
const device uint32_t* lhs_indices [[buffer(18)]],
|
const constant int& batch_ndims [[buffer(18)]],
|
||||||
const device uint32_t* rhs_indices [[buffer(19)]],
|
const constant int* batch_shape [[buffer(19)]],
|
||||||
const constant int64_t* lhs_strides [[buffer(20)]],
|
const constant int64_t* lhs_strides [[buffer(20)]],
|
||||||
const constant int64_t* rhs_strides [[buffer(21)]],
|
const constant int64_t* rhs_strides [[buffer(21)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
@ -1946,27 +1950,27 @@ template <
|
|||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
[[kernel]] void bs_qmm_n(
|
[[kernel]] void gather_qmm_n(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
const device T* biases [[buffer(2)]],
|
const device T* biases [[buffer(2)]],
|
||||||
const device T* x [[buffer(3)]],
|
const device T* x [[buffer(3)]],
|
||||||
device T* y [[buffer(4)]],
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
const constant int& K [[buffer(5)]],
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
const constant int& N [[buffer(6)]],
|
device T* y [[buffer(6)]],
|
||||||
const constant int& M [[buffer(7)]],
|
const constant int& K [[buffer(7)]],
|
||||||
const constant int& x_batch_ndims [[buffer(8)]],
|
const constant int& N [[buffer(8)]],
|
||||||
const constant int* x_shape [[buffer(9)]],
|
const constant int& M [[buffer(9)]],
|
||||||
const constant int64_t* x_strides [[buffer(10)]],
|
const constant int& x_batch_ndims [[buffer(10)]],
|
||||||
const constant int& w_batch_ndims [[buffer(11)]],
|
const constant int* x_shape [[buffer(11)]],
|
||||||
const constant int* w_shape [[buffer(12)]],
|
const constant int64_t* x_strides [[buffer(12)]],
|
||||||
const constant int64_t* w_strides [[buffer(13)]],
|
const constant int& w_batch_ndims [[buffer(13)]],
|
||||||
const constant int64_t* s_strides [[buffer(14)]],
|
const constant int* w_shape [[buffer(14)]],
|
||||||
const constant int64_t* b_strides [[buffer(15)]],
|
const constant int64_t* w_strides [[buffer(15)]],
|
||||||
const constant int& batch_ndims [[buffer(16)]],
|
const constant int64_t* s_strides [[buffer(16)]],
|
||||||
const constant int* batch_shape [[buffer(17)]],
|
const constant int64_t* b_strides [[buffer(17)]],
|
||||||
const device uint32_t* lhs_indices [[buffer(18)]],
|
const constant int& batch_ndims [[buffer(18)]],
|
||||||
const device uint32_t* rhs_indices [[buffer(19)]],
|
const constant int* batch_shape [[buffer(19)]],
|
||||||
const constant int64_t* lhs_strides [[buffer(20)]],
|
const constant int64_t* lhs_strides [[buffer(20)]],
|
||||||
const constant int64_t* rhs_strides [[buffer(21)]],
|
const constant int64_t* rhs_strides [[buffer(21)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
@ -2007,6 +2011,289 @@ template <
|
|||||||
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
||||||
|
METAL_FUNC void gemm_loop_aligned(
|
||||||
|
threadgroup T* As,
|
||||||
|
threadgroup T* Bs,
|
||||||
|
thread mma_t& mma_op,
|
||||||
|
thread loader_a_t& loader_a,
|
||||||
|
thread loader_b_t& loader_b,
|
||||||
|
const int k_iterations) {
|
||||||
|
for (int k = 0; k < k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load elements into threadgroup memory
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
bool rows_aligned,
|
||||||
|
bool cols_aligned,
|
||||||
|
bool transpose,
|
||||||
|
typename T,
|
||||||
|
typename mma_t,
|
||||||
|
typename loader_a_t,
|
||||||
|
typename loader_b_t>
|
||||||
|
METAL_FUNC void gemm_loop_unaligned(
|
||||||
|
threadgroup T* As,
|
||||||
|
threadgroup T* Bs,
|
||||||
|
thread mma_t& mma_op,
|
||||||
|
thread loader_a_t& loader_a,
|
||||||
|
thread loader_b_t& loader_b,
|
||||||
|
const int k_iterations,
|
||||||
|
const short tgp_bm,
|
||||||
|
const short tgp_bn,
|
||||||
|
const short tgp_bk) {
|
||||||
|
for (int k = 0; k < k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load elements into threadgroup memory
|
||||||
|
if (rows_aligned) {
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
} else {
|
||||||
|
loader_a.load_safe(short2(tgp_bk, tgp_bm));
|
||||||
|
}
|
||||||
|
if (cols_aligned) {
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
} else {
|
||||||
|
loader_b.load_safe(
|
||||||
|
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
||||||
|
METAL_FUNC void gemm_loop_finalize(
|
||||||
|
threadgroup T* As,
|
||||||
|
threadgroup T* Bs,
|
||||||
|
thread mma_t& mma_op,
|
||||||
|
thread loader_a_t& loader_a,
|
||||||
|
thread loader_b_t& loader_b,
|
||||||
|
const short2 tile_a,
|
||||||
|
const short2 tile_b) {
|
||||||
|
loader_a.load_safe(tile_a);
|
||||||
|
loader_b.load_safe(tile_b);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose>
|
||||||
|
[[kernel]] void gather_qmm_rhs(
|
||||||
|
const device T* x [[buffer(0)]],
|
||||||
|
const device uint32_t* w [[buffer(1)]],
|
||||||
|
const device T* scales [[buffer(2)]],
|
||||||
|
const device T* biases [[buffer(3)]],
|
||||||
|
const device uint32_t* indices [[buffer(4)]],
|
||||||
|
device T* y [[buffer(5)]],
|
||||||
|
const constant int& M [[buffer(6)]],
|
||||||
|
const constant int& N [[buffer(7)]],
|
||||||
|
const constant int& K [[buffer(8)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||||
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||||
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||||
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||||
|
|
||||||
|
using mma_t = mlx::steel::BlockMMA<
|
||||||
|
T,
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
false,
|
||||||
|
transpose,
|
||||||
|
BK_padded,
|
||||||
|
transpose ? BK_padded : BN_padded>;
|
||||||
|
using loader_x_t =
|
||||||
|
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
||||||
|
using loader_w_t = QuantizedBlockLoader<
|
||||||
|
T,
|
||||||
|
transpose ? BN : BK,
|
||||||
|
transpose ? BK : BN,
|
||||||
|
transpose ? BK_padded : BN_padded,
|
||||||
|
transpose,
|
||||||
|
WM * WN * SIMD_SIZE,
|
||||||
|
group_size,
|
||||||
|
bits>;
|
||||||
|
|
||||||
|
threadgroup T Xs[BM * BK_padded];
|
||||||
|
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
|
||||||
|
|
||||||
|
// Compute the block
|
||||||
|
const int K_w = K * bytes_per_pack / pack_factor;
|
||||||
|
const int K_g = K / group_size;
|
||||||
|
const int N_w = N * bytes_per_pack / pack_factor;
|
||||||
|
const int N_g = N / group_size;
|
||||||
|
const int K_it = K / BK;
|
||||||
|
const size_t stride_w = transpose ? N * K_w : K * N_w;
|
||||||
|
const size_t stride_s = transpose ? N * K_g : K * N_g;
|
||||||
|
const int y_row = tid.y * BM;
|
||||||
|
const int y_col = tid.x * BN;
|
||||||
|
const size_t y_row_long = size_t(y_row);
|
||||||
|
const size_t y_col_long = size_t(y_col);
|
||||||
|
|
||||||
|
// Prepare threadgroup bounds
|
||||||
|
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
|
||||||
|
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
|
||||||
|
|
||||||
|
// Calculate the final tiles in the case that K is not aligned
|
||||||
|
const int k_remain = K - K_it * BK;
|
||||||
|
const short2 tile_x = short2(k_remain, tgp_bm);
|
||||||
|
const short2 tile_w =
|
||||||
|
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||||
|
|
||||||
|
// Move x and output to the correct block
|
||||||
|
auto wl = (const device uint8_t*)w;
|
||||||
|
x += y_row_long * K;
|
||||||
|
y += y_row_long * N + y_col_long;
|
||||||
|
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
|
||||||
|
scales += transpose ? y_col_long * K_g : y_col / group_size;
|
||||||
|
biases += transpose ? y_col_long * K_g : y_col / group_size;
|
||||||
|
|
||||||
|
// Do as many matmuls as necessary
|
||||||
|
uint32_t index;
|
||||||
|
short offset;
|
||||||
|
uint32_t index_next = indices[y_row];
|
||||||
|
short offset_next = 0;
|
||||||
|
int n = 0;
|
||||||
|
while (n < tgp_bm) {
|
||||||
|
n++;
|
||||||
|
offset = offset_next;
|
||||||
|
index = index_next;
|
||||||
|
offset_next = tgp_bm;
|
||||||
|
for (; n < tgp_bm; n++) {
|
||||||
|
if (indices[y_row + n] != index) {
|
||||||
|
offset_next = n;
|
||||||
|
index_next = indices[y_row + n];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
|
||||||
|
thread loader_w_t loader_w(
|
||||||
|
wl + index * stride_w,
|
||||||
|
scales + index * stride_s,
|
||||||
|
biases + index * stride_s,
|
||||||
|
transpose ? K : N,
|
||||||
|
Ws,
|
||||||
|
simd_group_id,
|
||||||
|
simd_lane_id);
|
||||||
|
|
||||||
|
// Matrices are all aligned check nothing
|
||||||
|
if (align_M && align_N) {
|
||||||
|
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
|
||||||
|
if (!align_K) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
if (offset_next - offset == BM) {
|
||||||
|
mma_op.store_result(y, N);
|
||||||
|
} else {
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
y, N, short2(0, offset), short2(BN, offset_next));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Tile aligned so check outside of the hot loop
|
||||||
|
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||||
|
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
|
||||||
|
if (!align_K) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
gemm_loop_finalize(
|
||||||
|
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
if (offset_next - offset == BM) {
|
||||||
|
mma_op.store_result(y, N);
|
||||||
|
} else {
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
y, N, short2(0, offset), short2(BN, offset_next));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tile partially aligned check rows
|
||||||
|
else if (align_N || tgp_bn == BN) {
|
||||||
|
gemm_loop_unaligned<false, true, transpose>(
|
||||||
|
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
||||||
|
if (!align_K) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
gemm_loop_finalize(
|
||||||
|
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||||
|
}
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
y, N, short2(0, offset), short2(BN, offset_next));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tile partially aligned check cols
|
||||||
|
else if (align_M || tgp_bm == BM) {
|
||||||
|
gemm_loop_unaligned<true, false, transpose>(
|
||||||
|
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
||||||
|
if (!align_K) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
gemm_loop_finalize(
|
||||||
|
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||||
|
}
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
y, N, short2(0, offset), short2(tgp_bn, offset_next));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nothing aligned so check both rows and cols
|
||||||
|
else {
|
||||||
|
gemm_loop_unaligned<false, false, transpose>(
|
||||||
|
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
||||||
|
if (!align_K) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
gemm_loop_finalize(
|
||||||
|
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
||||||
|
}
|
||||||
|
mma_op.store_result_slice(
|
||||||
|
y, N, short2(0, offset), short2(tgp_bn, offset_next));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits>
|
template <typename T, const int group_size, const int bits>
|
||||||
[[kernel]] void affine_quantize(
|
[[kernel]] void affine_quantize(
|
||||||
const device T* w [[buffer(0)]],
|
const device T* w [[buffer(0)]],
|
||||||
|
@ -60,6 +60,20 @@
|
|||||||
bits, \
|
bits, \
|
||||||
split_k)
|
split_k)
|
||||||
|
|
||||||
|
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
|
||||||
|
func, \
|
||||||
|
type, \
|
||||||
|
group_size, \
|
||||||
|
bits, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
transpose)
|
||||||
|
|
||||||
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
||||||
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
||||||
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
||||||
@ -73,14 +87,14 @@
|
|||||||
#define instantiate_quantized_all_single(type, group_size, bits) \
|
#define instantiate_quantized_all_single(type, group_size, bits) \
|
||||||
instantiate_quantized(affine_quantize, type, group_size, bits) \
|
instantiate_quantized(affine_quantize, type, group_size, bits) \
|
||||||
instantiate_quantized(affine_dequantize, type, group_size, bits) \
|
instantiate_quantized(affine_dequantize, type, group_size, bits) \
|
||||||
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
|
instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
|
||||||
instantiate_quantized(bs_qmv, type, group_size, bits) \
|
instantiate_quantized(gather_qmv, type, group_size, bits) \
|
||||||
instantiate_quantized(bs_qvm, type, group_size, bits) \
|
instantiate_quantized(gather_qvm, type, group_size, bits) \
|
||||||
instantiate_quantized(bs_qmm_n, type, group_size, bits)
|
instantiate_quantized(gather_qmm_n, type, group_size, bits)
|
||||||
|
|
||||||
#define instantiate_quantized_all_aligned(type, group_size, bits) \
|
#define instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||||
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \
|
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
|
||||||
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \
|
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
|
||||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
|
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
|
||||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
|
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
|
||||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
|
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
|
||||||
@ -96,12 +110,17 @@
|
|||||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
||||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
||||||
|
|
||||||
|
#define instantiate_quantized_all_rhs(type, group_size, bits) \
|
||||||
|
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
|
||||||
|
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
|
||||||
|
|
||||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||||
instantiate_quantized_all_single(type, group_size, bits) \
|
instantiate_quantized_all_single(type, group_size, bits) \
|
||||||
instantiate_quantized_all_batched(type, group_size, bits) \
|
instantiate_quantized_all_batched(type, group_size, bits) \
|
||||||
instantiate_quantized_all_aligned(type, group_size, bits) \
|
instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||||
instantiate_quantized_all_quad(type, group_size, bits) \
|
instantiate_quantized_all_quad(type, group_size, bits) \
|
||||||
instantiate_quantized_all_splitk(type, group_size, bits)
|
instantiate_quantized_all_splitk(type, group_size, bits) \
|
||||||
|
instantiate_quantized_all_rhs(type, group_size, bits)
|
||||||
|
|
||||||
#define instantiate_quantized_types(group_size, bits) \
|
#define instantiate_quantized_types(group_size, bits) \
|
||||||
instantiate_quantized_funcs(float, group_size, bits) \
|
instantiate_quantized_funcs(float, group_size, bits) \
|
||||||
|
@ -1908,8 +1908,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
// Extract shapes strides from inputs and copy in case of non-contiguous
|
// Extract shapes from inputs.
|
||||||
// vectors.
|
|
||||||
int M = a.shape(-2);
|
int M = a.shape(-2);
|
||||||
int N = b.shape(-1);
|
int N = b.shape(-1);
|
||||||
int K = a.shape(-1);
|
int K = a.shape(-1);
|
||||||
|
@ -269,4 +269,21 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
|||||||
return d.get_kernel(kernel_name);
|
return d.get_kernel(kernel_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array&,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
bool) {
|
||||||
|
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
File diff suppressed because it is too large
Load Diff
15
mlx/ops.cpp
15
mlx/ops.cpp
@ -4028,6 +4028,7 @@ array gather_qmm(
|
|||||||
bool transpose /* = true */,
|
bool transpose /* = true */,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
|
bool sorted_indices /* = false */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (!lhs_indices_ && !rhs_indices_) {
|
if (!lhs_indices_ && !rhs_indices_) {
|
||||||
return quantized_matmul(
|
return quantized_matmul(
|
||||||
@ -4067,13 +4068,19 @@ array gather_qmm(
|
|||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
|
std::make_shared<GatherQMM>(
|
||||||
|
to_stream(s),
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
transpose,
|
||||||
|
sorted_indices && !rhs_indices_,
|
||||||
|
sorted_indices && !lhs_indices_),
|
||||||
{astype(x, out_type, s),
|
{astype(x, out_type, s),
|
||||||
w,
|
std::move(w),
|
||||||
astype(scales, out_type, s),
|
astype(scales, out_type, s),
|
||||||
astype(biases, out_type, s),
|
astype(biases, out_type, s),
|
||||||
lhs_indices,
|
std::move(lhs_indices),
|
||||||
rhs_indices});
|
std::move(rhs_indices)});
|
||||||
}
|
}
|
||||||
|
|
||||||
array tensordot(
|
array tensordot(
|
||||||
|
@ -1352,6 +1352,7 @@ array gather_qmm(
|
|||||||
bool transpose = true,
|
bool transpose = true,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
|
bool sorted_indices = false,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Returns a contraction of a and b over multiple dimensions. */
|
/** Returns a contraction of a and b over multiple dimensions. */
|
||||||
|
@ -3080,6 +3080,8 @@ std::vector<array> GatherQMM::vjp(
|
|||||||
auto& lhs_indices = primals[4];
|
auto& lhs_indices = primals[4];
|
||||||
auto& rhs_indices = primals[5];
|
auto& rhs_indices = primals[5];
|
||||||
|
|
||||||
|
bool sorted = left_sorted_ || right_sorted_;
|
||||||
|
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
// gradient wrt to x
|
// gradient wrt to x
|
||||||
if (arg == 0) {
|
if (arg == 0) {
|
||||||
@ -3098,6 +3100,7 @@ std::vector<array> GatherQMM::vjp(
|
|||||||
!transpose_,
|
!transpose_,
|
||||||
group_size_,
|
group_size_,
|
||||||
bits_,
|
bits_,
|
||||||
|
sorted,
|
||||||
stream()),
|
stream()),
|
||||||
-3,
|
-3,
|
||||||
stream()),
|
stream()),
|
||||||
|
@ -1591,11 +1591,19 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
|
|
||||||
class GatherQMM : public UnaryPrimitive {
|
class GatherQMM : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
|
explicit GatherQMM(
|
||||||
|
Stream stream,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
bool transpose,
|
||||||
|
bool left_sorted = false,
|
||||||
|
bool right_sorted = false)
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
group_size_(group_size),
|
group_size_(group_size),
|
||||||
bits_(bits),
|
bits_(bits),
|
||||||
transpose_(transpose) {}
|
transpose_(transpose),
|
||||||
|
left_sorted_(left_sorted),
|
||||||
|
right_sorted_(right_sorted) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1605,13 +1613,16 @@ class GatherQMM : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(GatherQMM)
|
DEFINE_PRINT(GatherQMM)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
auto state() const {
|
auto state() const {
|
||||||
return std::make_tuple(group_size_, bits_, transpose_);
|
return std::make_tuple(
|
||||||
|
group_size_, bits_, transpose_, left_sorted_, right_sorted_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int group_size_;
|
int group_size_;
|
||||||
int bits_;
|
int bits_;
|
||||||
bool transpose_;
|
bool transpose_;
|
||||||
|
bool left_sorted_;
|
||||||
|
bool right_sorted_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class RandomBits : public UnaryPrimitive {
|
class RandomBits : public UnaryPrimitive {
|
||||||
|
@ -4250,9 +4250,10 @@ void init_ops(nb::module_& m) {
|
|||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
|
"sorted_indices"_a = false,
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Perform quantized matrix multiplication with matrix-level gather.
|
Perform quantized matrix multiplication with matrix-level gather.
|
||||||
|
|
||||||
@ -4265,23 +4266,25 @@ void init_ops(nb::module_& m) {
|
|||||||
as ``w`` since they represent the same quantized matrix.
|
as ``w`` since they represent the same quantized matrix.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (array): Input array
|
x (array): Input array
|
||||||
w (array): Quantized matrix packed in unsigned integers
|
w (array): Quantized matrix packed in unsigned integers
|
||||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||||
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
|
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
|
||||||
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
|
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
|
||||||
transpose (bool, optional): Defines whether to multiply with the
|
transpose (bool, optional): Defines whether to multiply with the
|
||||||
transposed ``w`` or not, namely whether we are performing
|
transposed ``w`` or not, namely whether we are performing
|
||||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||||
group_size (int, optional): The size of the group in ``w`` that
|
group_size (int, optional): The size of the group in ``w`` that
|
||||||
shares a scale and bias. Default: ``64``.
|
shares a scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. Default: ``4``.
|
``w``. Default: ``4``.
|
||||||
|
sorted_indices (bool, optional): May allow a faster implementation
|
||||||
|
if the passed indices are sorted. Default: ``False``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The result of the multiplication of ``x`` with ``w``
|
array: The result of the multiplication of ``x`` with ``w``
|
||||||
after gathering using ``lhs_indices`` and ``rhs_indices``.
|
after gathering using ``lhs_indices`` and ``rhs_indices``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"tensordot",
|
"tensordot",
|
||||||
@ -4311,16 +4314,16 @@ void init_ops(nb::module_& m) {
|
|||||||
Compute the tensor dot product along the specified axes.
|
Compute the tensor dot product along the specified axes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array
|
a (array): Input array
|
||||||
b (array): Input array
|
b (array): Input array
|
||||||
axes (int or list(list(int)), optional): The number of dimensions to
|
axes (int or list(list(int)), optional): The number of dimensions to
|
||||||
sum over. If an integer is provided, then sum over the last
|
sum over. If an integer is provided, then sum over the last
|
||||||
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
|
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
|
||||||
``b``. If a list of lists is provided, then sum over the
|
``b``. If a list of lists is provided, then sum over the
|
||||||
corresponding dimensions of ``a`` and ``b``. Default: 2.
|
corresponding dimensions of ``a`` and ``b``. Default: 2.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The tensor dot product.
|
array: The tensor dot product.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"inner",
|
"inner",
|
||||||
|
@ -174,12 +174,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
tests = product(
|
tests = product(
|
||||||
[128, 64, 32], # group_size
|
[128, 64, 32], # group_size
|
||||||
[2, 3, 4, 6, 8], # bits
|
[2, 3, 4, 6, 8], # bits
|
||||||
[128, 256], # M
|
[32, 128, 256], # M
|
||||||
[128, 256, 67], # N
|
[128, 256, 67], # N
|
||||||
[0, 1, 3, 8], # B
|
[0, 1, 3, 8], # B
|
||||||
)
|
)
|
||||||
for group_size, bits, M, N, B in tests:
|
for group_size, bits, M, N, B in tests:
|
||||||
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
|
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
|
||||||
|
if M < group_size:
|
||||||
|
continue
|
||||||
x_shape = (1, N) if B == 0 else (B, 1, N)
|
x_shape = (1, N) if B == 0 else (B, 1, N)
|
||||||
w_shape = (N, M) if B == 0 else (B, N, M)
|
w_shape = (N, M) if B == 0 else (B, N, M)
|
||||||
x = mx.random.normal(shape=x_shape, key=k1)
|
x = mx.random.normal(shape=x_shape, key=k1)
|
||||||
@ -448,6 +450,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for kwargs in inputs:
|
for kwargs in inputs:
|
||||||
|
test_shape(1, 32, 128, **kwargs)
|
||||||
test_shape(32, 32, 256, **kwargs)
|
test_shape(32, 32, 256, **kwargs)
|
||||||
test_shape(1, 32, 256, **kwargs)
|
test_shape(1, 32, 256, **kwargs)
|
||||||
test_shape(32, 256, 32, transpose=False, **kwargs)
|
test_shape(32, 256, 32, transpose=False, **kwargs)
|
||||||
@ -486,6 +489,66 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
|
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
|
||||||
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
|
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
|
||||||
|
|
||||||
|
def test_gather_qmm_sorted(self):
|
||||||
|
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||||
|
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||||
|
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
||||||
|
if transpose:
|
||||||
|
w_hat = w_hat.swapaxes(-1, -2)
|
||||||
|
return w_hat, qw, s, b
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
parameters = [
|
||||||
|
# L, K, D, E, I, transpose
|
||||||
|
(128, 1024, 1024, 32, 4, True),
|
||||||
|
(128, 1024, 544, 32, 4, True),
|
||||||
|
(433, 1024, 1024, 32, 4, True),
|
||||||
|
(433, 1024, 555, 32, 4, True),
|
||||||
|
(433, 2048, 1024, 32, 4, True),
|
||||||
|
(128, 1024, 1024, 32, 4, False),
|
||||||
|
(128, 1024, 544, 32, 4, False),
|
||||||
|
(433, 1024, 1024, 32, 4, False),
|
||||||
|
(433, 1024, 544, 32, 4, False),
|
||||||
|
(433, 1024, 555, 32, 4, False),
|
||||||
|
(433, 2048, 1024, 32, 4, False),
|
||||||
|
]
|
||||||
|
for L, K, D, E, I, transpose in parameters:
|
||||||
|
K, D = (K, D) if transpose else (D, K)
|
||||||
|
ishape = (L, I)
|
||||||
|
xshape = (L, 1, 1, K)
|
||||||
|
wshape = (E, D, K) if transpose else (E, K, D)
|
||||||
|
|
||||||
|
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
|
||||||
|
x = mx.random.normal(xshape) / K**0.5
|
||||||
|
w = mx.random.normal(wshape) / K**0.5
|
||||||
|
w, *wq = quantize(w, transpose=transpose)
|
||||||
|
|
||||||
|
y1 = mx.gather_mm(x, w, rhs_indices=indices)
|
||||||
|
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
|
||||||
|
xs, idx, inv_order = gather_sort(x, indices)
|
||||||
|
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
|
||||||
|
y4 = mx.gather_qmm(
|
||||||
|
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
|
||||||
|
)
|
||||||
|
y3 = scatter_unsort(y3, inv_order, indices.shape)
|
||||||
|
y4 = scatter_unsort(y4, inv_order, indices.shape)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
|
||||||
|
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
||||||
|
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user