Add mode parameter for quantization (#2499)

* add mode parameter for quantization

* mxfp4 quantize/dequantize + start of optional biases

* mxfp4 works

* speedup

* cpu mxfp4

* fix

* fix test tol

* fix

* refactor

* add quant mode enum
This commit is contained in:
Awni Hannun
2025-08-28 06:45:26 -07:00
committed by GitHub
parent 7ef8a6f2d5
commit 70560b6bd5
28 changed files with 3635 additions and 757 deletions

View File

@@ -108,7 +108,8 @@ if(NOT MLX_METAL_JIT)
reduction/reduce_all.h
reduction/reduce_col.h
reduction/reduce_row.h)
build_kernel(quantized quantized.h ${STEEL_HEADERS})
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
build_kernel(fp4_quantized fp4_quantized.h quantized_utils.h ${STEEL_HEADERS})
build_kernel(scan scan.h)
build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,127 @@
// Copyright © 2025 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized_utils.h"
#include "mlx/backend/metal/kernels/fp4_quantized.h"
#define instantiate_quantized(name, type) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4", \
name, \
type, \
32, \
uint8_t)
#define instantiate_quantized_batched(name, type, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_batch_" #batched, \
name, \
type, \
32, \
batched, \
uint8_t)
#define instantiate_quantized_aligned(name, type, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned, \
name, \
type, \
32, \
aligned, \
uint8_t)
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
name, \
type, \
32, \
aligned, \
batched, \
uint8_t)
#define instantiate_quantized_quad(name, type, D, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
name, \
type, \
32, \
D, \
batched, \
uint8_t)
#define instantiate_quantized_split_k(name, type, split_k) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_spk_" #split_k, \
name, \
type, \
32, \
split_k, \
uint8_t)
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
32, \
uint8_t, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type) \
instantiate_quantized_batched(name, type, 1) \
instantiate_quantized_batched(name, type, 0)
#define instantiate_quantized_all_batched(type) \
instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \
instantiate_quantized_batched_wrap(mxfp4_qmv, type) \
instantiate_quantized_batched_wrap(mxfp4_qvm, type) \
instantiate_quantized_batched_wrap(mxfp4_qmm_n, type)
#define instantiate_quantized_all_single(type) \
instantiate_quantized(mxfp4_gather_qmv_fast, type) \
instantiate_quantized(mxfp4_gather_qmv, type) \
instantiate_quantized(mxfp4_gather_qvm, type) \
instantiate_quantized(mxfp4_gather_qmm_n, type)
#define instantiate_quantized_all_aligned(type) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0)
#define instantiate_quantized_all_quad(type) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0)
#define instantiate_quantized_all_splitk(type) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32)
#define instantiate_quantized_all_rhs(type) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_types(type) \
instantiate_quantized_all_batched(type) \
instantiate_quantized_all_quad(type) \
instantiate_quantized_all_splitk(type) \
instantiate_quantized_all_single(type) \
instantiate_quantized_all_aligned(type) \
instantiate_quantized_all_rhs(type)
instantiate_quantized_types(float)
instantiate_quantized_types(bfloat16_t)
instantiate_quantized_types(float16_t)
// clang-format on

View File

@@ -1434,7 +1434,7 @@ METAL_FUNC void adjust_matrix_offsets(
}
template <typename T, int group_size, int bits, int D, bool batched>
[[kernel]] void qmv_quad(
[[kernel]] void affine_qmv_quad(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1486,7 +1486,7 @@ template <typename T, int group_size, int bits, int D, bool batched>
}
template <typename T, int group_size, int bits, bool batched>
[[kernel]] void qmv_fast(
[[kernel]] void affine_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1538,7 +1538,7 @@ template <typename T, int group_size, int bits, bool batched>
}
template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qmv(
[[kernel]] void affine_qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1590,7 +1590,7 @@ template <typename T, const int group_size, const int bits, bool batched>
}
template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qvm(
[[kernel]] void affine_qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1642,7 +1642,7 @@ template <typename T, const int group_size, const int bits, bool batched>
}
template <typename T, const int group_size, const int bits, int split_k = 32>
[[kernel]] void qvm_split_k(
[[kernel]] void affine_qvm_split_k(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1706,7 +1706,7 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void qmm_t(
[[kernel]] void affine_qmm_t(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1764,7 +1764,7 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void qmm_n(
[[kernel]] void affine_qmm_n(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1817,7 +1817,7 @@ template <
}
template <typename T, int group_size, int bits>
[[kernel]] void gather_qmv_fast(
[[kernel]] void affine_gather_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1879,7 +1879,7 @@ template <typename T, int group_size, int bits>
}
template <typename T, int group_size, int bits>
[[kernel]] void gather_qmv(
[[kernel]] void affine_gather_qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1941,7 +1941,7 @@ template <typename T, int group_size, int bits>
}
template <typename T, int group_size, int bits>
[[kernel]] void gather_qvm(
[[kernel]] void affine_gather_qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -2010,7 +2010,7 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void gather_qmm_t(
[[kernel]] void affine_gather_qmm_t(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -2077,7 +2077,7 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void gather_qmm_n(
[[kernel]] void affine_gather_qmm_n(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -2138,92 +2138,6 @@ template <
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_aligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <
bool rows_aligned,
bool cols_aligned,
bool transpose,
typename T,
typename mma_t,
typename loader_a_t,
typename loader_b_t>
METAL_FUNC void gemm_loop_unaligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations,
const short tgp_bm,
const short tgp_bn,
const short tgp_bk) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
if (rows_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(short2(tgp_bk, tgp_bm));
}
if (cols_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_finalize(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const short2 tile_a,
const short2 tile_b) {
loader_a.load_safe(tile_a);
loader_b.load_safe(tile_b);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
template <
typename T,
int group_size,
@@ -2234,7 +2148,7 @@ template <
int WM,
int WN,
bool transpose>
[[kernel]] void gather_qmm_rhs(
[[kernel]] void affine_gather_qmm_rhs(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],

View File

@@ -3,6 +3,7 @@
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized_utils.h"
#include "mlx/backend/metal/kernels/quantized.h"
#define instantiate_quantized(name, type, group_size, bits) \
@@ -79,40 +80,40 @@
instantiate_quantized_batched(name, type, group_size, bits, 0)
#define instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)
instantiate_quantized_batched_wrap(affine_qmv_fast, type, group_size, bits) \
instantiate_quantized_batched_wrap(affine_qmv, type, group_size, bits) \
instantiate_quantized_batched_wrap(affine_qvm, type, group_size, bits) \
instantiate_quantized_batched_wrap(affine_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
instantiate_quantized(gather_qmv, type, group_size, bits) \
instantiate_quantized(gather_qvm, type, group_size, bits) \
instantiate_quantized(gather_qmm_n, type, group_size, bits)
instantiate_quantized(affine_gather_qmv_fast, type, group_size, bits) \
instantiate_quantized(affine_gather_qmv, type, group_size, bits) \
instantiate_quantized(affine_gather_qvm, type, group_size, bits) \
instantiate_quantized(affine_gather_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0)
instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 1) \
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 0)
#define instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 1) \
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 0) \
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 0)
#define instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \

View File

@@ -0,0 +1,90 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_simdgroup>
#include <metal_stdlib>
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_aligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <
bool rows_aligned,
bool cols_aligned,
bool transpose,
typename T,
typename mma_t,
typename loader_a_t,
typename loader_b_t>
METAL_FUNC void gemm_loop_unaligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations,
const short tgp_bm,
const short tgp_bn,
const short tgp_bk) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
if (rows_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(short2(tgp_bk, tgp_bm));
}
if (cols_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_finalize(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const short2 tile_a,
const short2 tile_b) {
loader_a.load_safe(tile_a);
loader_b.load_safe(tile_b);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}

View File

@@ -1,7 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h"
@@ -99,7 +97,7 @@ inline int add_strides_and_shapes(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
int offset) {
if (skip) {
return 0;
@@ -109,16 +107,18 @@ inline int add_strides_and_shapes(
int x_batch_ndims = x.ndim() - 2;
int w_batch_ndims = w.ndim() - 2;
compute_encoder.set_bytes(x_batch_ndims, offset);
compute_encoder.set_vector_bytes(x.shape(), offset + 1);
compute_encoder.set_vector_bytes(x.strides(), offset + 2);
compute_encoder.set_bytes(w_batch_ndims, offset + 3);
compute_encoder.set_vector_bytes(w.shape(), offset + 4);
compute_encoder.set_vector_bytes(w.strides(), offset + 5);
compute_encoder.set_vector_bytes(scales.strides(), offset + 6);
compute_encoder.set_vector_bytes(biases.strides(), offset + 7);
compute_encoder.set_bytes(x_batch_ndims, offset++);
compute_encoder.set_vector_bytes(x.shape(), offset++);
compute_encoder.set_vector_bytes(x.strides(), offset++);
compute_encoder.set_bytes(w_batch_ndims, offset++);
compute_encoder.set_vector_bytes(w.shape(), offset++);
compute_encoder.set_vector_bytes(w.strides(), offset++);
compute_encoder.set_vector_bytes(scales.strides(), offset++);
if (biases) {
compute_encoder.set_vector_bytes(biases->strides(), offset++);
}
return 8;
return offset;
}
inline int add_gather_strides_and_shapes(
@@ -130,12 +130,12 @@ inline int add_gather_strides_and_shapes(
lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()});
int ndims = shape.size();
compute_encoder.set_bytes(ndims, offset);
compute_encoder.set_vector_bytes(shape, offset + 1);
compute_encoder.set_vector_bytes(strides[0], offset + 2);
compute_encoder.set_vector_bytes(strides[1], offset + 3);
compute_encoder.set_bytes(ndims, offset++);
compute_encoder.set_vector_bytes(shape, offset++);
compute_encoder.set_vector_bytes(strides[0], offset++);
compute_encoder.set_vector_bytes(strides[1], offset++);
return 4;
return offset;
}
} // namespace
@@ -144,7 +144,7 @@ void qmv_quad(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int group_size,
int bits,
@@ -152,7 +152,8 @@ void qmv_quad(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
constexpr int quads_per_simd = 8;
@@ -165,9 +166,10 @@ void qmv_quad(
std::string kname;
kname.reserve(64);
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
"qmv_quad_",
mode + "_qmv_quad_",
type_string,
"_gs_",
group_size,
@@ -177,20 +179,23 @@ void qmv_quad(
K,
B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition(
kname, "qmv_quad", type_string, group_size, bits, K, B > 1);
kname, mode + "_qmv_quad", type_string, group_size, bits, K, B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(K, 5);
compute_encoder.set_bytes(N, 6);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -199,7 +204,7 @@ void qmv(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int group_size,
int bits,
@@ -207,7 +212,8 @@ void qmv(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int bn = 8;
@@ -219,9 +225,10 @@ void qmv(
kname.reserve(64);
std::string type_string = get_type_string(x.dtype());
bool fast = N % bn == 0 && K % 512 == 0;
concatenate(
kname,
fast ? "qmv_fast_" : "qmv_",
mode + (fast ? "_qmv_fast_" : "_qmv_"),
type_string,
"_gs_",
group_size,
@@ -229,20 +236,28 @@ void qmv(
bits,
B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition(
kname, fast ? "qmv_fast" : "qmv", type_string, group_size, bits, B > 1);
kname,
mode + (fast ? "_qmv_fast" : "_qmv"),
type_string,
group_size,
bits,
B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(K, 5);
compute_encoder.set_bytes(N, 6);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -251,7 +266,7 @@ void qvm_split_k(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int group_size,
int bits,
@@ -259,7 +274,8 @@ void qvm_split_k(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int split_k = K > 8192 ? 32 : 8;
int split_D = (K + split_k - 1) / split_k;
int B = out.size() / M / N;
@@ -283,7 +299,6 @@ void qvm_split_k(
auto w_shape = w.shape();
auto w_strides = w.strides();
auto s_strides = scales.strides();
auto b_strides = biases.strides();
// Add split_k dim with reshapes
x_shape.insert(x_shape.end() - 2, split_k);
@@ -297,7 +312,6 @@ void qvm_split_k(
w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1));
w_batch_ndims += 1;
s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1));
b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1));
int final_block_size = K - (split_k - 1) * split_D;
@@ -315,7 +329,7 @@ void qvm_split_k(
kname.reserve(64);
concatenate(
kname,
"qvm_split_k_",
mode + "_qvm_split_k_",
type_string,
"_gs_",
group_size,
@@ -324,30 +338,37 @@ void qvm_split_k(
"_spk_",
split_k);
auto template_def = get_template_definition(
kname, "qvm_split_k", type_string, group_size, bits, split_k);
kname, mode + "_qvm_split_k", type_string, group_size, bits, split_k);
// Encode and dispatch kernel
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(intermediate, 4);
compute_encoder.set_bytes(split_D, 5);
compute_encoder.set_bytes(N, 6);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(intermediate, c++);
compute_encoder.set_bytes(split_D, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(x_batch_ndims, 7);
compute_encoder.set_vector_bytes(x_shape, 8);
compute_encoder.set_vector_bytes(x_strides, 9);
compute_encoder.set_bytes(w_batch_ndims, 10);
compute_encoder.set_vector_bytes(w_shape, 11);
compute_encoder.set_vector_bytes(w_strides, 12);
compute_encoder.set_vector_bytes(s_strides, 13);
compute_encoder.set_vector_bytes(b_strides, 14);
compute_encoder.set_bytes(final_block_size, 15);
compute_encoder.set_bytes(x_batch_ndims, c++);
compute_encoder.set_vector_bytes(x_shape, c++);
compute_encoder.set_vector_bytes(x_strides, c++);
compute_encoder.set_bytes(w_batch_ndims, c++);
compute_encoder.set_vector_bytes(w_shape, c++);
compute_encoder.set_vector_bytes(w_strides, c++);
compute_encoder.set_vector_bytes(s_strides, c++);
if (biases) {
auto b_strides = biases->strides();
b_strides.insert(b_strides.end() - 2, split_D * biases->shape(-1));
compute_encoder.set_vector_bytes(b_strides, c++);
}
compute_encoder.set_bytes(final_block_size, c++);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -364,7 +385,7 @@ void qvm(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int group_size,
int bits,
@@ -372,7 +393,8 @@ void qvm(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int bn = 64;
@@ -385,7 +407,7 @@ void qvm(
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
"qvm_",
mode + "_qvm_",
type_string,
"_gs_",
group_size,
@@ -393,20 +415,23 @@ void qvm(
bits,
B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition(
kname, "qvm", type_string, group_size, bits, B > 1);
kname, mode + "_qvm", type_string, group_size, bits, B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(K, 5);
compute_encoder.set_bytes(N, 6);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -415,7 +440,7 @@ void qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
bool transpose,
int group_size,
@@ -424,7 +449,8 @@ void qmm(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int wm = 2;
@@ -441,7 +467,7 @@ void qmm(
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
transpose ? "qmm_t_" : "qmm_n_",
mode + (transpose ? "_qmm_t_" : "_qmm_n_"),
type_string,
"_gs_",
group_size,
@@ -452,25 +478,34 @@ void qmm(
std::string template_def;
if (transpose) {
template_def = get_template_definition(
kname, "qmm_t", type_string, group_size, bits, aligned, batched);
kname,
mode + "_qmm_t",
type_string,
group_size,
bits,
aligned,
batched);
} else {
template_def = get_template_definition(
kname, "qmm_n", type_string, group_size, bits, batched);
kname, mode + "_qmm_n", type_string, group_size, bits, batched);
}
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(K, 5);
compute_encoder.set_bytes(N, 6);
compute_encoder.set_bytes(M, 7);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 8);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(M, c++);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -479,7 +514,7 @@ void gather_qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
@@ -490,7 +525,8 @@ void gather_qmm(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int wm = 2;
@@ -507,7 +543,7 @@ void gather_qmm(
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
transpose ? "gather_qmm_t_" : "gather_qmm_n_",
mode + (transpose ? "_gather_qmm_t_" : "_gather_qmm_n_"),
type_string,
"_gs_",
group_size,
@@ -517,30 +553,31 @@ void gather_qmm(
std::string template_def;
if (transpose) {
template_def = get_template_definition(
kname, "gather_qmm_t", type_string, group_size, bits, aligned);
kname, mode + "_gather_qmm_t", type_string, group_size, bits, aligned);
} else {
template_def = get_template_definition(
kname, "gather_qmm_n", type_string, group_size, bits);
kname, mode + "_gather_qmm_n", type_string, group_size, bits);
}
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder.set_bytes(K, 7);
compute_encoder.set_bytes(N, 8);
compute_encoder.set_bytes(M, 9);
int n =
add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 10);
add_gather_strides_and_shapes(
compute_encoder, lhs_indices, rhs_indices, 10 + n);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_input_array(lhs_indices, c++);
compute_encoder.set_input_array(rhs_indices, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(M, c++);
c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c);
add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -549,7 +586,7 @@ void gather_qmv(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
@@ -559,7 +596,8 @@ void gather_qmv(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int bn = 8;
@@ -573,7 +611,7 @@ void gather_qmv(
bool fast = N % bn == 0 && K % 512 == 0;
concatenate(
kname,
fast ? "gather_qmv_fast_" : "gather_qmv_",
mode + (fast ? "_gather_qmv_fast_" : "_gather_qmv_"),
type_string,
"_gs_",
group_size,
@@ -581,7 +619,7 @@ void gather_qmv(
bits);
auto template_def = get_template_definition(
kname,
fast ? "gather_qmv_fast" : "gather_qmv",
mode + (fast ? "_gather_qmv_fast" : "_gather_qmv"),
type_string,
group_size,
bits);
@@ -590,19 +628,20 @@ void gather_qmv(
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder.set_bytes(K, 7);
compute_encoder.set_bytes(N, 8);
int n =
add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9);
add_gather_strides_and_shapes(
compute_encoder, lhs_indices, rhs_indices, 9 + n);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_input_array(lhs_indices, c++);
compute_encoder.set_input_array(rhs_indices, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c);
add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -611,7 +650,7 @@ void gather_qvm(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
@@ -621,7 +660,8 @@ void gather_qvm(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int bn = 64;
@@ -633,27 +673,34 @@ void gather_qvm(
kname.reserve(64);
std::string type_string = get_type_string(x.dtype());
concatenate(
kname, "gather_qvm_", type_string, "_gs_", group_size, "_b_", bits);
kname,
mode + "_gather_qvm_",
type_string,
"_gs_",
group_size,
"_b_",
bits);
auto template_def = get_template_definition(
kname, "gather_qvm", type_string, group_size, bits);
kname, mode + "_gather_qvm", type_string, group_size, bits);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder.set_bytes(K, 7);
compute_encoder.set_bytes(N, 8);
int n =
add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9);
add_gather_strides_and_shapes(
compute_encoder, lhs_indices, rhs_indices, 9 + n);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_input_array(lhs_indices, c++);
compute_encoder.set_input_array(rhs_indices, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c++);
add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -662,7 +709,7 @@ void gather_qmm_rhs(
const array& x_,
const array& w_,
const array& scales_,
const array& biases_,
const std::optional<array>& biases_,
const array& indices_,
array& out,
bool transpose,
@@ -672,7 +719,8 @@ void gather_qmm_rhs(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string mode) {
// Start by normalizing the indices
array indices = ensure_row_contiguous(indices_, d, s);
@@ -697,7 +745,6 @@ void gather_qmm_rhs(
array x = broadcast_with_indices(x_);
array w = ensure_row_contiguous(w_, d, s);
array scales = ensure_row_contiguous(scales_, d, s);
array biases = ensure_row_contiguous(biases_, d, s);
// TODO: Tune the block sizes
int bm = 16, bn = 32, bk = 32;
@@ -713,7 +760,7 @@ void gather_qmm_rhs(
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
transpose ? "gather_qmm_rhs_nt_" : "gather_qmm_rhs_nn_",
mode + (transpose ? "_gather_qmm_rhs_nt_" : "_gather_qmm_rhs_nn_"),
type_string,
"_gs_",
group_size,
@@ -770,15 +817,19 @@ void gather_qmm_rhs(
MTL::Size group_dims(32, wn, wm);
MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_input_array(indices, 4);
compute_encoder.set_output_array(out, 5);
compute_encoder.set_bytes(M, 6);
compute_encoder.set_bytes(N, 7);
compute_encoder.set_bytes(K, 8);
int c = 0;
compute_encoder.set_input_array(x, c++);
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases_) {
array biases = ensure_row_contiguous(*biases_, d, s);
compute_encoder.set_input_array(biases, c++);
}
compute_encoder.set_input_array(indices, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(M, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -794,7 +845,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
array x = ensure_row_contiguous_matrix(inputs[0], d, s);
array w = ensure_row_contiguous_matrix(inputs[1], d, s);
array scales = ensure_row_contiguous_matrix(inputs[2], d, s);
array biases = ensure_row_contiguous_matrix(inputs[3], d, s);
std::optional<array> biases = std::nullopt;
if (inputs.size() == 4) {
biases = ensure_row_contiguous_matrix(inputs[3], d, s);
}
// Extract the matmul shapes
bool non_batched = w.ndim() == 2 && x.flags().row_contiguous;
@@ -803,7 +857,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int N = out.shape(-1);
int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4;
auto mode = quantization_mode_to_string(mode_);
// It is a matrix matrix product.
if (M >= vector_limit) {
qmm(x,
@@ -818,30 +872,33 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode);
return;
}
// It is a qmv with a small inner dimension so route to qmv_quad kernel
if (transpose_ && (K == 128 || K == 64) && is_power_of_2(bits_)) {
qmv_quad(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
qmv_quad(
x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode);
return;
}
// Run of the mill qmv
if (transpose_) {
qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode);
return;
}
// Run of the mill qvm
if (K < 1024) {
qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode);
return;
}
// Qvm with large dimension so route to a split K kernel for more parallelism
qvm_split_k(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
qvm_split_k(
x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode);
return;
}
@@ -854,9 +911,12 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
array x = ensure_row_contiguous_matrix(inputs[0], d, s);
array w = ensure_row_contiguous_matrix(inputs[1], d, s);
array scales = ensure_row_contiguous_matrix(inputs[2], d, s);
array biases = ensure_row_contiguous_matrix(inputs[3], d, s);
const array& lhs_indices = inputs[4];
const array& rhs_indices = inputs[5];
std::optional<array> biases = std::nullopt;
if (inputs.size() == 6) {
biases = ensure_row_contiguous_matrix(inputs[3], d, s);
}
const array& lhs_indices = inputs[inputs.size() - 2];
const array& rhs_indices = inputs[inputs.size() - 1];
int K = x.shape(-1);
int M = x.shape(-2);
@@ -864,6 +924,7 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int B = out.size() / M / N;
int E = w.size() / w.shape(-1) / w.shape(-2);
int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4;
auto mode = quantization_mode_to_string(mode_);
// We are walking x in order and w is also in order so we can batch up the
// matmuls and reuse reading x and w.
@@ -884,7 +945,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode);
return;
}
@@ -905,7 +967,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode);
return;
}
@@ -924,7 +987,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode);
return;
}
@@ -942,10 +1006,11 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode);
}
void fast::AffineQuantize::eval_gpu(
void fast::Quantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];