mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-18 00:16:39 +08:00
Add packed_affine_qmm_t
This commit is contained in:
parent
410ccdbed5
commit
fb7be036af
@ -1,19 +1,19 @@
|
||||
import argparse
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
B = 1024
|
||||
D = 1024
|
||||
M = 4 * D
|
||||
group_size = 64
|
||||
bits = 4
|
||||
dtype = mx.float16
|
||||
loops = 100
|
||||
loops = 10
|
||||
|
||||
|
||||
def qmv_(x, wq1, wq2, q_type):
|
||||
def qmm_(x, wq1, wq2, q_type):
|
||||
for i in range(loops):
|
||||
x = mx.quantized_matmul(
|
||||
x,
|
||||
@ -32,28 +32,28 @@ def qmv_(x, wq1, wq2, q_type):
|
||||
return x
|
||||
|
||||
|
||||
def affine_qmv(x, wq1, wq2):
|
||||
return qmv_(x, wq1, wq2, "affine")
|
||||
def affine_qmm(x, wq1, wq2):
|
||||
return qmm_(x, wq1, wq2, "affine")
|
||||
|
||||
|
||||
def affine_packed_qmv(x, wq1, wq2):
|
||||
return qmv_(x, wq1, wq2, "affine-packed")
|
||||
def affine_packed_qmm(x, wq1, wq2):
|
||||
return qmm_(x, wq1, wq2, "affine-packed")
|
||||
|
||||
|
||||
def time_qmv():
|
||||
def time_qmm():
|
||||
mx.random.seed(3)
|
||||
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
||||
x = mx.random.normal(shape=(B, D)).astype(dtype)
|
||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
|
||||
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
|
||||
mx.eval(x, wq1, wq2)
|
||||
time_fn(affine_qmv, x, wq1, wq2)
|
||||
time_fn(affine_qmm, x, wq1, wq2)
|
||||
|
||||
|
||||
def time_packed_qmv():
|
||||
def time_packed_qmm():
|
||||
mx.random.seed(3)
|
||||
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
||||
x = mx.random.normal(shape=(B, D)).astype(dtype)
|
||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||
wq1 = mx.quantize(
|
||||
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||
@ -63,12 +63,12 @@ def time_packed_qmv():
|
||||
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||
)
|
||||
mx.eval(x, wq1, wq2)
|
||||
time_fn(affine_packed_qmv, x, wq1, wq2)
|
||||
time_fn(affine_packed_qmm, x, wq1, wq2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for b in [2, 4, 8]:
|
||||
bits = b
|
||||
print(f"Bits {bits}:")
|
||||
time_qmv()
|
||||
time_packed_qmv()
|
||||
time_qmm()
|
||||
time_packed_qmm()
|
@ -1248,6 +1248,41 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
const device uint32_t*& w,
|
||||
const device T*& scales,
|
||||
device T*& y,
|
||||
int output_stride,
|
||||
const constant int& x_batch_ndims,
|
||||
const constant int* x_shape,
|
||||
const constant int64_t* x_strides,
|
||||
const constant int& w_batch_ndims,
|
||||
const constant int* w_shape,
|
||||
const constant int64_t* w_strides,
|
||||
const constant int64_t* s_strides,
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
// Set the input/output matrices
|
||||
uint32_t x_idx = tid.z;
|
||||
uint32_t w_idx = tid.z;
|
||||
if (x_batch_ndims == 1) {
|
||||
x += x_idx * x_strides[0];
|
||||
} else {
|
||||
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
|
||||
}
|
||||
if (w_batch_ndims == 1) {
|
||||
w += w_idx * w_strides[0];
|
||||
scales += w_idx * s_strides[0];
|
||||
} else {
|
||||
ulong2 idx = elem_to_loc_broadcast(
|
||||
w_idx, w_shape, w_strides, s_strides, w_batch_ndims);
|
||||
w += idx.x;
|
||||
scales += idx.y;
|
||||
}
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
@ -2266,11 +2301,322 @@ template <typename T, int group_size, int bits>
|
||||
const device vec<T, 4>* scales [[buffer(1)]],
|
||||
const device T* x [[buffer(2)]],
|
||||
device T* y [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
affine_packed_qmv_fast_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short group_size,
|
||||
short bits>
|
||||
struct AffinePackedQuantizedBlockLoader {
|
||||
static_assert(
|
||||
BCOLS <= group_size,
|
||||
"The group size should be larger than the columns");
|
||||
static_assert(
|
||||
group_size % BCOLS == 0,
|
||||
"The group size should be divisible by the columns");
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
|
||||
MLX_MTL_CONST short pack_factor = 32 / bits;
|
||||
MLX_MTL_CONST short row_pack_factor = 4;
|
||||
MLX_MTL_CONST short BCOLS_PACKED = BCOLS * row_pack_factor / pack_factor;
|
||||
MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
|
||||
MLX_MTL_CONST short TOTAL_INTS = BCOLS_PACKED * BROWS_PACKED;
|
||||
MLX_MTL_CONST short n_reads =
|
||||
(TOTAL_INTS < tgp_size) ? 1 : TOTAL_INTS / tgp_size;
|
||||
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
||||
|
||||
static_assert(
|
||||
n_reads <= row_pack_factor,
|
||||
"The loader only supports per thread reads <= row_pack_factor");
|
||||
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
short group_step_cnt;
|
||||
const int group_stride;
|
||||
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
const short bii;
|
||||
const short bjj;
|
||||
|
||||
const device uint32_t* src;
|
||||
const device T* scales;
|
||||
const device T* biases;
|
||||
threadgroup T* dst;
|
||||
|
||||
AffinePackedQuantizedBlockLoader(
|
||||
const device uint32_t* src_,
|
||||
const device T* scales_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS_PACKED : BROWS_PACKED * src_ld),
|
||||
group_step_cnt(0),
|
||||
group_stride(BROWS_PACKED * 2 * src_ld / group_size),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(n_reads * thread_idx / BCOLS_PACKED),
|
||||
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
||||
bii(bi * row_pack_factor + bj % row_pack_factor),
|
||||
bjj(bj / row_pack_factor),
|
||||
src(src_ + bi * src_ld * row_pack_factor / pack_factor + bj),
|
||||
scales(
|
||||
scales_ + bi * 2 * src_ld * row_pack_factor / group_size +
|
||||
bj % row_pack_factor),
|
||||
biases(scales + row_pack_factor),
|
||||
dst(dst_ + bii * dst_ld + bjj * pack_factor) {}
|
||||
|
||||
void load_unsafe() const {
|
||||
if (bits == 2 && BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
T scale = scales[i];
|
||||
T bias = biases[i];
|
||||
dequantize<T, pack_factor, bits>(
|
||||
(const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld);
|
||||
}
|
||||
}
|
||||
|
||||
void load_safe(short2 src_tile_dim) const {
|
||||
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (reduction_dim == 1 && bii >= src_tile_dim.y) {
|
||||
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||
dst[i] = T(0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (reduction_dim == 0 && bii >= src_tile_dim.x) {
|
||||
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||
dst[i] = T(0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
T scale = scales[i];
|
||||
T bias = biases[i];
|
||||
dequantize<T, pack_factor, bits>(
|
||||
(const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld);
|
||||
}
|
||||
}
|
||||
|
||||
void next() {
|
||||
src += tile_stride;
|
||||
if (reduction_dim == 1) {
|
||||
if (group_steps > 1) {
|
||||
group_step_cnt++;
|
||||
if (group_step_cnt == group_steps) {
|
||||
group_step_cnt = 0;
|
||||
scales += 8;
|
||||
biases += 8;
|
||||
}
|
||||
} else {
|
||||
scales += 8;
|
||||
biases += 8;
|
||||
}
|
||||
} else {
|
||||
scales += group_stride;
|
||||
biases += group_stride;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void affine_packed_qmm_t_impl(
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
threadgroup T* Xs,
|
||||
threadgroup T* Ws,
|
||||
const constant int& K,
|
||||
const constant int& N,
|
||||
const constant int& M,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 2;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int row_pack_factor = 4;
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
|
||||
// Instantiate the appropriate BlockMMA and Loader
|
||||
using mma_t = mlx::steel::
|
||||
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
||||
using loader_x_t =
|
||||
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
||||
using loader_w_t = AffinePackedQuantizedBlockLoader<
|
||||
T,
|
||||
BN,
|
||||
BK,
|
||||
BK_padded,
|
||||
1,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
bits>;
|
||||
|
||||
// Set the block
|
||||
const int K_w = K * row_pack_factor / pack_factor;
|
||||
const int K_g = K * 2 * row_pack_factor / group_size;
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
const int packed_y_col = tid.x * (BN / row_pack_factor);
|
||||
|
||||
x += y_row * K;
|
||||
w += packed_y_col * K_w;
|
||||
scales += packed_y_col * K_g;
|
||||
y += y_row * N + y_col;
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
const short num_outs = min(BN, N - y_col);
|
||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||
loader_w_t loader_w(w, scales, K, Ws, simd_gid, simd_lid);
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
if (num_els < BM) {
|
||||
if (!aligned_N && num_outs < BN) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
loader_w.load_safe(short2(BK, num_outs));
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(Xs, Ws);
|
||||
loader_x.next();
|
||||
loader_w.next();
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
loader_w.load_unsafe();
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(Xs, Ws);
|
||||
loader_x.next();
|
||||
loader_w.next();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!aligned_N && num_outs < BN) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_unsafe();
|
||||
loader_w.load_safe(short2(BK, num_outs));
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(Xs, Ws);
|
||||
loader_x.next();
|
||||
loader_w.next();
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_unsafe();
|
||||
loader_w.load_unsafe();
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(Xs, Ws);
|
||||
loader_x.next();
|
||||
loader_w.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (num_els < BM || num_outs < BN) {
|
||||
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
|
||||
} else {
|
||||
mma_op.store_result(y, N);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const bool batched,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void affine_packed_qmm_t(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* x [[buffer(2)]],
|
||||
device T* y [[buffer(3)]],
|
||||
const constant int& K [[buffer(4)]],
|
||||
const constant int& N [[buffer(5)]],
|
||||
const constant int& M [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
(void)lid;
|
||||
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BN * BK_padded];
|
||||
|
||||
if (batched) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
y,
|
||||
M * N,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
@ -104,8 +104,12 @@
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
||||
|
||||
#define instantiate_quantized_all_affine_packed(type, group_size, bits) \
|
||||
instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits)
|
||||
#define instantiate_quantized_all_affine_packed(type, group_size, bits) \
|
||||
instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits) \
|
||||
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, true) \
|
||||
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, false) \
|
||||
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, true) \
|
||||
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, false)
|
||||
|
||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||
instantiate_quantized_all_single(type, group_size, bits) \
|
||||
|
@ -428,8 +428,91 @@ void affine_packed_qmv(
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(x, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder.set_bytes(D, 5);
|
||||
compute_encoder.set_bytes(O, 6);
|
||||
compute_encoder.set_bytes(D, 4);
|
||||
compute_encoder.set_bytes(O, 5);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void affine_packed_qmm_t(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
bool batched,
|
||||
int B,
|
||||
int D,
|
||||
int O,
|
||||
int group_size,
|
||||
int bits,
|
||||
const Stream& s) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
auto ensure_row_contiguous_last_dims = [&d, &s](const array& arr) {
|
||||
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
d.add_temporary(arr_copy, s.index);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
// TODO: Deal with this in routing towards qmm_n instead of qmm_t
|
||||
auto x = ensure_row_contiguous_last_dims(inputs[0]);
|
||||
auto w = ensure_row_contiguous_last_dims(inputs[1]);
|
||||
auto scales = ensure_row_contiguous_last_dims(inputs[2]);
|
||||
|
||||
int x_batch_ndims = x.ndim() - 2;
|
||||
auto& x_shape = x.shape();
|
||||
auto& x_strides = x.strides();
|
||||
int w_batch_ndims = w.ndim() - 2;
|
||||
auto& w_shape = w.shape();
|
||||
auto& w_strides = w.strides();
|
||||
auto& s_strides = scales.strides();
|
||||
|
||||
const int wn = 2;
|
||||
const int wm = 2;
|
||||
const int bm = 32;
|
||||
const int bn = 32;
|
||||
const int N = (batched) ? out.size() / B / O : 1;
|
||||
MTL::Size group_dims(32, wn, wm);
|
||||
MTL::Size grid_dims((O + bn - 1) / bn, (B + bm - 1) / bm, N);
|
||||
|
||||
std::string name;
|
||||
name.reserve(64);
|
||||
concatenate(
|
||||
name,
|
||||
"affine_packed_qmm_t_",
|
||||
get_type_string(out.dtype()),
|
||||
"_gs_",
|
||||
std::to_string(group_size),
|
||||
"_b_",
|
||||
std::to_string(bits),
|
||||
"_alN_",
|
||||
((O % 32) == 0) ? "true" : "false",
|
||||
"_batch_",
|
||||
(batched) ? "true" : "false");
|
||||
auto kernel = get_quantized_kernel(d, name, "");
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(x, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder.set_bytes(D, 4);
|
||||
compute_encoder.set_bytes(O, 5);
|
||||
compute_encoder.set_bytes(B, 6);
|
||||
if (batched) {
|
||||
compute_encoder.set_bytes(x_batch_ndims, 7);
|
||||
compute_encoder.set_vector_bytes(x_shape, 8);
|
||||
compute_encoder.set_vector_bytes(x_strides, 9);
|
||||
compute_encoder.set_bytes(w_batch_ndims, 10);
|
||||
compute_encoder.set_vector_bytes(w_shape, 11);
|
||||
compute_encoder.set_vector_bytes(w_strides, 12);
|
||||
compute_encoder.set_vector_bytes(s_strides, 13);
|
||||
}
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
@ -451,6 +534,7 @@ void affine_packed_qmm_op(
|
||||
if (B < 6) {
|
||||
affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s);
|
||||
} else {
|
||||
affine_packed_qmm_t(inputs, out, batched, B, D, O, group_size, bits, s);
|
||||
}
|
||||
} else {
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user