add trellis quant mode

This commit is contained in:
Alex Barron 2025-03-18 18:52:22 -07:00
parent e9e268336b
commit d7acf59fd0
16 changed files with 852 additions and 108 deletions

View File

@ -684,6 +684,115 @@ METAL_FUNC void qmv_fast_impl(
} }
} }
template <uint32_t a = 89226354, uint32_t b = 64248484, uint32_t m = 996162400>
float inst3(uint16_t xi) {
uint32_t x = xi;
x = a * x + b;
x = (x & 0b10001111111111111000111111111111) ^ m;
auto xf = reinterpret_cast<thread float16_t*>(&x);
return xf[0] + xf[1];
}
template <typename T, int bits>
METAL_FUNC void qmv_trellis_impl(
const device uint32_t* w,
const device T* scales,
const device T* biases,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int packs_per_thread = 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int reads_per = 16 / bits;
constexpr int local_w_size =
results_per_simdgroup * values_per_thread / reads_per;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
thread U x_thread[values_per_thread];
thread uint16_t w_thread[local_w_size];
thread U result[results_per_simdgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
x += tid.x * in_vec_size + simd_lid * values_per_thread;
y += tid.x * out_vec_size + out_row;
T scale = scales[0];
for (int k = 0; k < in_vec_size; k += block_size) {
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_thread; i++) {
x_thread[i] = x[i];
}
#pragma clang loop unroll(full)
for (int row = 0; row < results_per_simdgroup; row++) {
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_thread / reads_per; i++) {
auto wl = (const device uint16_t*)(ws + row * in_vec_size_w);
w_thread[row * values_per_thread / reads_per + i] = wl[i];
}
}
#pragma clang loop unroll(full)
for (int row = 0; row < results_per_simdgroup; row++) {
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_thread / reads_per; i++) {
int index = row * values_per_thread / reads_per + i;
uint16_t w0 = w_thread[index];
uint16_t w1 = w_thread[(index + 1) % local_w_size];
uint16_t wx = w0 ^ w1;
uint16_t wx1 = wx ^ 1;
uint16_t wf = w0 ^ (1 << bits);
if (bits == 2) {
result[row] += x_thread[8 * i] * inst3(w0);
result[row] += x_thread[8 * i + 1] * inst3(wf ^ (wx1 & 0x3));
result[row] += x_thread[8 * i + 2] * inst3(w0 ^ (wx & 0xf));
result[row] += x_thread[8 * i + 3] * inst3(w0 ^ (wx1 & 0x3f));
result[row] += x_thread[8 * i + 4] * inst3(w0 ^ (wx & 0xff));
result[row] += x_thread[8 * i + 5] * inst3(w0 ^ (wx1 & 0x3ff));
result[row] += x_thread[8 * i + 6] * inst3(w0 ^ (wx & 0xfff));
result[row] += x_thread[8 * i + 7] * inst3(w0 ^ (wx1 & 0x3fff));
} else if (bits == 4) {
result[row] += x_thread[4 * i] * inst3(w0);
result[row] += x_thread[4 * i + 1] * inst3(wf ^ (wx1 & 0xf));
result[row] += x_thread[4 * i + 2] * inst3(w0 ^ (wx & 0xff));
result[row] += x_thread[4 * i + 3] * inst3(w0 ^ (wx1 & 0xfff));
}
}
}
ws += block_size * bytes_per_pack / pack_factor;
x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(scale * result[row]);
}
}
}
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
METAL_FUNC void qmv_impl( METAL_FUNC void qmv_impl(
const device uint32_t* w, const device uint32_t* w,
@ -1302,7 +1411,13 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride; y += tid.z * output_stride;
} }
template <typename T, int group_size, int bits, int D, bool batched> template <
typename T,
int group_size,
int bits,
int D,
bool batched,
bool trellis = false>
[[kernel]] void qmv_quad( [[kernel]] void qmv_quad(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
@ -1354,7 +1469,12 @@ template <typename T, int group_size, int bits, int D, bool batched>
quad_lid); quad_lid);
} }
template <typename T, int group_size, int bits, bool batched> template <
typename T,
int group_size,
int bits,
bool batched,
bool trellis = false>
[[kernel]] void qmv_fast( [[kernel]] void qmv_fast(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
@ -1393,20 +1513,39 @@ template <typename T, int group_size, int bits, bool batched>
b_strides, b_strides,
tid); tid);
} }
qmv_fast_impl<T, group_size, bits>( if (trellis) {
w, qmv_trellis_impl<T, bits>(
scales, w,
biases, scales,
x, biases,
y, x,
in_vec_size, y,
out_vec_size, in_vec_size,
tid, out_vec_size,
simd_gid, tid,
simd_lid); simd_gid,
simd_lid);
} else {
qmv_fast_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
} }
template <typename T, const int group_size, const int bits, bool batched> template <
typename T,
const int group_size,
const int bits,
bool batched,
bool trellis = false>
[[kernel]] void qmv( [[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
@ -1458,7 +1597,12 @@ template <typename T, const int group_size, const int bits, bool batched>
simd_lid); simd_lid);
} }
template <typename T, const int group_size, const int bits, bool batched> template <
typename T,
const int group_size,
const int bits,
bool batched,
bool trellis = false>
[[kernel]] void qvm( [[kernel]] void qvm(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
@ -1572,6 +1716,7 @@ template <
const int bits, const int bits,
const bool aligned_N, const bool aligned_N,
const bool batched, const bool batched,
bool trellis = false,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
@ -1630,6 +1775,7 @@ template <
const int group_size, const int group_size,
const int bits, const int bits,
const bool batched, const bool batched,
bool trellis = false,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
@ -1685,7 +1831,7 @@ template <
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits, bool trellis = false>
[[kernel]] void bs_qmv_fast( [[kernel]] void bs_qmv_fast(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
@ -1734,20 +1880,34 @@ template <typename T, int group_size, int bits>
s_strides, s_strides,
b_strides, b_strides,
tid); tid);
qmv_fast_impl<T, group_size, bits>( if (trellis) {
w, qmv_trellis_impl<T, bits>(
scales, w,
biases, scales,
x, biases,
y, x,
in_vec_size, y,
out_vec_size, in_vec_size,
tid, out_vec_size,
simd_gid, tid,
simd_lid); simd_gid,
simd_lid);
} else {
qmv_fast_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits, bool trellis = false>
[[kernel]] void bs_qmv( [[kernel]] void bs_qmv(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
@ -1809,7 +1969,7 @@ template <typename T, int group_size, int bits>
simd_lid); simd_lid);
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits, bool trellis = false>
[[kernel]] void bs_qvm( [[kernel]] void bs_qvm(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
@ -1876,6 +2036,7 @@ template <
const int group_size, const int group_size,
const int bits, const int bits,
const bool aligned_N, const bool aligned_N,
bool trellis = false,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
@ -1943,6 +2104,7 @@ template <
typename T, typename T,
const int group_size, const int group_size,
const int bits, const int bits,
bool trellis = false,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
@ -2157,3 +2319,211 @@ template <typename T, const int group_size, const int bits>
} }
} }
} }
template <
typename T,
const bool use_overlap,
const int bits = 2,
const int timesteps = 128>
[[kernel]] void trellis_viterbi(
const device T* w [[buffer(0)]],
device float16_t* score [[buffer(1)]],
device uint8_t* pointers [[buffer(2)]],
const device uint16_t* overlap [[buffer(3)]],
uint3 tid [[thread_position_in_grid]]) {
constexpr uint16_t L = 16;
constexpr uint L2 = 1 << L;
uint16_t idx = tid.y * 16;
threadgroup float16_t swap_V[16384];
thread float16_t min_V[16] = {0};
for (uint16_t t = 0; t < timesteps; t++) {
uint16_t tt = t % 8 == 0 ? L / bits : t % 8;
uint16_t shift = ((tt - 1) % (L / bits)) * bits;
uint16_t flip = (t == 0 || (t > 1 && t % 8 == 1)) ? (1 << bits) + 1 : t % 2;
uint16_t s000 = 1 << (shift - 6);
uint16_t s0 = 1 << (shift - 2);
uint16_t s1 = 1 << (shift);
uint16_t s2 = 1 << (shift + 2);
uint16_t s4 = 1 << (shift + 4);
if (t > 1) {
uint16_t i = 0;
uint16_t loff = 1 << (metal::clamp((shift + 14) % 16, 2, 12));
uint16_t hoff = shift > 4 ? 4 : shift == 4 ? 16 : 1;
uint16_t ind = idx;
if (shift == 0) {
ind >>= 2;
} else if (shift == 14) {
ind = (ind & 0xfff) + (ind >> 12);
} else if (shift == 2) {
} else if (shift == 4) {
ind = ((ind >> 4) & 0x3) + (ind & ~0x3f);
} else if (shift == 6) {
ind = ((ind / s0) % 4) * s1 + ((ind / s1) % 4) + (ind / s2) * s2;
} else {
ind = ((ind / 16) % s000) * 16 + ((ind / s0) % 4) * s1 +
((ind / s1) % 4) + (ind / s2) * s2;
}
for (uint16_t high = 0; high < 4; high++) {
uint16_t sub_ind = ind;
for (uint16_t low = 0; low < 4; low++) {
swap_V[sub_ind] = min_V[i];
i++;
sub_ind += loff;
}
ind += hoff;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint16_t i = 0; i < 16; i++) {
min_V[i] = swap_V[idx + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
uint16_t rolled_t = use_overlap ? t : (t + 64) % 128;
T w_t = w[tid.x * timesteps + rolled_t];
for (uint16_t i = 0; i < 4; i++) {
thread float16_t min_val[4] = {INFINITY, INFINITY, INFINITY, INFINITY};
thread uint16_t min_idx[4] = {0};
uint16_t ii = idx * 4 + i * 16;
uint16_t big_idx = ii;
if (shift > 0 && shift < 14) {
big_idx = ((ii / s2) % 4) + (ii / s4 * s4);
if (shift > 2) {
big_idx += ((ii / 16) % s0) * 4;
}
} else if (shift == 14 && t > 0) {
big_idx >>= 2;
}
uint16_t loff = t == 0 ? 4 : s1;
uint16_t hoff = (t == 0 || shift == 14) ? 1 : s2;
for (uint16_t high = 0; high < 4; high++) {
uint16_t sub_ind = big_idx;
for (uint16_t low = 0; low < 4; low++) {
float mse = inst3(sub_ind ^ flip) - w_t;
mse *= mse;
float16_t new_val = min_V[i * 4 + high] + mse;
if (new_val < min_val[low]) {
min_val[low] = new_val;
min_idx[low] = high;
}
sub_ind += loff;
}
big_idx += hoff;
}
for (uint16_t j = 0; j < 4; j++) {
min_V[i * 4 + j] = min_val[j];
pointers[tid.x * L2 / 4 * timesteps + t * L2 / 4 + idx + i * 4 + j] =
min_idx[j];
}
}
if (t == 0 && use_overlap) {
uint16_t over = overlap[tid.x * 128 + 64];
over = over & ((1 << 14) - 1);
for (uint16_t i = 0; i < 16; i++) {
uint16_t rs = (over >> 2) ^ 1;
uint16_t ls = (idx + i) & ((1 << 12) - 1);
min_V[i] = rs == ls ? min_V[i] : INFINITY;
}
}
}
if (use_overlap) {
uint16_t over = overlap[tid.x * 128 + 64];
over = over & ((1 << 14) - 1);
uint16_t node =
(over % 4) * 4096 + ((over / 4) % 1024) * 4 + (over / 4096) % 4;
for (uint16_t i = 0; i < 16; i++) {
min_V[i] = (idx + i) == node ? min_V[i] : INFINITY;
}
}
for (uint16_t i = 0; i < 16; i++) {
score[tid.x * L2 / 4 + idx + i] = min_V[i];
}
}
uint16_t remove_bits(uint16_t i, uint16_t shift) {
uint16_t lower = i & ((1 << shift) - 1);
uint16_t upper = i & ~((1 << (shift + 2)) - 1);
return lower + (upper >> 2);
}
uint16_t swap_bits(uint16_t i, uint16_t shift) {
uint16_t diff = ((i >> shift) ^ i) & 0x3;
i = i ^ diff;
i ^= diff << shift;
return i;
}
template <const bool use_overlap, const int bits = 2, const int timesteps = 128>
[[kernel]] void trellis_backtrack(
const device uint32_t* start [[buffer(0)]],
const device uint8_t* pointers [[buffer(1)]],
device uint16_t* out [[buffer(2)]],
const device uint16_t* overlap [[buffer(3)]],
uint3 tid [[thread_position_in_grid]]) {
constexpr uint16_t L = 16;
uint16_t node = start[tid.x];
uint16_t dir =
pointers[tid.x * timesteps * 16384 + (timesteps - 1) * 16384 + node];
node = (node % 4) * 4096 + ((node / 4) % 1024) * 4 + (node / 4096) % 4;
node ^= 1;
node += dir * 16384;
out[tid.x * timesteps + timesteps - 1] = node;
for (int t = timesteps - 2; t >= 0; t--) {
uint16_t shift = (t % (L / bits)) * bits;
uint16_t mask = ((1 << L) - 1) ^ (((1 << bits) - 1) << shift);
uint16_t flip = t % (L / bits) == 0 ? 1 << bits : 1;
uint16_t i = (node & mask) ^ flip;
if (shift > 0) {
i = remove_bits(i, shift);
}
if (t == 0) {
i >>= 2;
}
if (t % 2 == 1 || t == 0) {
i ^= 1;
}
shift = shift == 0 ? L : shift;
if (t > 0) {
i = swap_bits(i, shift - 2);
}
shift = shift == L ? 0 : shift;
uint16_t last_p = pointers[tid.x * timesteps * 16384 + t * 16384 + i];
if ((t % 8 == 1 && t > 1) || t == 0) {
last_p ^= 1;
}
node = ((node & mask) ^ flip) | (last_p << shift);
if (t == 0 && use_overlap) {
uint16_t over = overlap[tid.x * 128 + 64];
over = over & ((1 << 14) - 1);
node = (node & 0xfffc) + (over & 0x3);
}
out[tid.x * timesteps + t] = node;
}
}

View File

@ -16,6 +16,8 @@
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include <iostream>
namespace mlx::core { namespace mlx::core {
template <typename T> template <typename T>
@ -158,33 +160,25 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) { void arg_reduce_dispatch(
assert(inputs.size() == 1); const array& in,
auto& in = inputs[0]; array& out,
out.set_data(allocator::malloc(out.nbytes())); int axis,
auto& s = stream(); std::string op_name,
const Stream& s) {
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
std::string op_name;
switch (reduce_type_) {
case ArgReduce::ArgMin:
op_name = "argmin_";
break;
case ArgReduce::ArgMax:
op_name = "argmax_";
break;
}
// Prepare the shapes, strides and axis arguments. // Prepare the shapes, strides and axis arguments.
auto in_strides = in.strides(); auto in_strides = in.strides();
auto shape = in.shape(); auto shape = in.shape();
auto out_strides = out.strides(); auto out_strides = out.strides();
auto axis_stride = in_strides[axis_]; auto axis_stride = in_strides[axis];
size_t axis_size = shape[axis_]; size_t axis_size = shape[axis];
if (out_strides.size() == in_strides.size()) { if (out_strides.size() == in_strides.size()) {
out_strides.erase(out_strides.begin() + axis_); out_strides.erase(out_strides.begin() + axis);
} }
in_strides.erase(in_strides.begin() + axis_); in_strides.erase(in_strides.begin() + axis);
shape.erase(shape.begin() + axis_); shape.erase(shape.begin() + axis);
size_t ndim = shape.size(); size_t ndim = shape.size();
// ArgReduce // ArgReduce
@ -192,7 +186,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
int n_reads = 4; int n_reads = 4;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
{ {
auto kernel = d.get_kernel(op_name + type_to_name(in)); auto kernel = d.get_kernel(op_name + "_" + type_to_name(in));
NS::UInteger thread_group_size = std::min( NS::UInteger thread_group_size = std::min(
(axis_size + n_reads - 1) / n_reads, (axis_size + n_reads - 1) / n_reads,
kernel->maxTotalThreadsPerThreadgroup()); kernel->maxTotalThreadsPerThreadgroup());
@ -226,6 +220,23 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::string op_name;
switch (reduce_type_) {
case ArgReduce::ArgMin:
op_name = "argmin";
break;
case ArgReduce::ArgMax:
op_name = "argmax";
break;
}
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
arg_reduce_dispatch(in, out, axis_, op_name, s);
}
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) { void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
CopyType ctype = CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General; inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;

View File

@ -7,11 +7,14 @@
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/reduce.h" #include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include <iostream>
namespace mlx::core { namespace mlx::core {
void launch_qmm( void launch_qmm(
@ -31,6 +34,7 @@ void launch_qmm(
bool gather, bool gather,
bool aligned, bool aligned,
bool quad, bool quad,
const std::string& mode,
const Stream& s) { const Stream& s) {
auto& x_pre = inputs[0]; auto& x_pre = inputs[0];
auto& w_pre = inputs[1]; auto& w_pre = inputs[1];
@ -54,8 +58,12 @@ void launch_qmm(
}; };
auto x = ensure_row_contiguous_last_dims(x_pre); auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre); auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre); auto scales = scales_pre;
auto biases = ensure_row_contiguous_last_dims(biases_pre); auto biases = biases_pre;
if (mode == "affine") {
scales = ensure_row_contiguous_last_dims(scales_pre);
biases = ensure_row_contiguous_last_dims(biases_pre);
}
int x_batch_ndims = x.ndim() - 2; int x_batch_ndims = x.ndim() - 2;
auto& x_shape = x.shape(); auto& x_shape = x.shape();
@ -68,6 +76,8 @@ void launch_qmm(
std::string aligned_n = (O % 32) == 0 ? "true" : "false"; std::string aligned_n = (O % 32) == 0 ? "true" : "false";
bool is_trellis = (mode == "trellis");
std::ostringstream kname; std::ostringstream kname;
auto type_string = get_type_string(x.dtype()); auto type_string = get_type_string(x.dtype());
kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits; kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits;
@ -80,24 +90,47 @@ void launch_qmm(
if (!gather) { if (!gather) {
kname << "_batch_" << batched; kname << "_batch_" << batched;
} }
if (mode == "trellis") {
kname << "_mode_" << is_trellis;
}
// Encode and dispatch kernel // Encode and dispatch kernel
std::string template_def; std::string template_def;
if (quad) { if (quad) {
template_def = get_template_definition( template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, D, batched); kname.str(),
name,
type_string,
group_size,
bits,
D,
batched,
is_trellis);
} else if (aligned && !gather) { } else if (aligned && !gather) {
template_def = get_template_definition( template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, aligned_n, batched); kname.str(),
name,
type_string,
group_size,
bits,
aligned_n,
batched,
is_trellis);
} else if (!gather && !aligned) { } else if (!gather && !aligned) {
template_def = get_template_definition( template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, batched); kname.str(), name, type_string, group_size, bits, batched, is_trellis);
} else if (aligned && gather) { } else if (aligned && gather) {
template_def = get_template_definition( template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, aligned_n); kname.str(),
name,
type_string,
group_size,
bits,
aligned_n,
is_trellis);
} else { } else {
template_def = get_template_definition( template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits); kname.str(), name, type_string, group_size, bits, is_trellis);
} }
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto kernel = get_quantized_kernel(d, kname.str(), template_def); auto kernel = get_quantized_kernel(d, kname.str(), template_def);
@ -276,6 +309,7 @@ void qmm_op(
int group_size, int group_size,
int bits, int bits,
bool gather, bool gather,
const std::string& mode,
const Stream& s) { const Stream& s) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
@ -354,7 +388,7 @@ void qmm_op(
group_dims = MTL::Size(simdgroup_size, 1, 1); group_dims = MTL::Size(simdgroup_size, 1, 1);
grid_dims = MTL::Size(B, (O + bo - 1) / bo, N); grid_dims = MTL::Size(B, (O + bo - 1) / bo, N);
quad = true; quad = true;
} else if (B < qmv_batch_limit && O % 8 == 0 && D % 512 == 0 && D >= 512) { } else if (B < 10000 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
name += "qmv_fast"; name += "qmv_fast";
int bo = 8; int bo = 8;
int bd = 32; int bd = 32;
@ -420,19 +454,34 @@ void qmm_op(
gather, gather,
aligned, aligned,
quad, quad,
mode,
s); s);
} }
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) { void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4); assert(inputs.size() == 4);
qmm_op( qmm_op(
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream()); inputs,
out,
transpose_,
group_size_,
bits_,
/*gather=*/false,
mode_,
stream());
} }
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) { void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6); assert(inputs.size() == 6);
qmm_op( qmm_op(
inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream()); inputs,
out,
transpose_,
group_size_,
bits_,
/*gather=*/true,
mode_,
stream());
} }
void fast::AffineQuantize::eval_gpu( void fast::AffineQuantize::eval_gpu(
@ -516,4 +565,123 @@ void fast::AffineQuantize::eval_gpu(
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
void viterbi(
array& w,
array& scores,
array& pointers,
array& start,
array& overlap,
bool use_overlap,
const Stream& s) {
int B = scores.shape(0);
auto& d = metal::device(s.device);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_output_array(scores, 1);
compute_encoder.set_output_array(pointers, 2);
if (use_overlap) {
compute_encoder.set_input_array(overlap, 3);
}
std::ostringstream kname;
auto type_string = get_type_string(w.dtype());
kname << "trellis_viterbi_" << type_string << "_overlap_" << use_overlap;
auto template_def = get_template_definition(
kname.str(), "trellis_viterbi", type_string, use_overlap);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder.set_compute_pipeline_state(kernel);
auto group_dims = MTL::Size(1, 1024, 1);
auto grid_dims = MTL::Size(B, 1024, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
arg_reduce_dispatch(scores, start, 1, "argmin", s);
}
void viterbi_backtrack(
array& start,
array& pointers,
array& out,
array& overlap,
bool use_overlap,
const Stream& s) {
int B = start.shape(0);
auto& d = metal::device(s.device);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(start, 0);
compute_encoder.set_input_array(pointers, 1);
compute_encoder.set_output_array(out, 2);
if (use_overlap) {
compute_encoder.set_input_array(overlap, 3);
}
std::ostringstream kname;
kname << "trellis_backtrack" << "_overlap_" << use_overlap;
auto template_def =
get_template_definition(kname.str(), "trellis_backtrack", use_overlap);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder.set_compute_pipeline_state(kernel);
auto group_dims = MTL::Size(256, 1, 1);
auto grid_dims = MTL::Size(B, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void fast::TrellisQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
std::vector<array> copies;
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
auto w = ensure_row_contiguous(w_pre);
int B = w.shape(0);
int T = w.shape(1);
constexpr int num_states = 1 << 14;
array scores({B, num_states}, float16, nullptr, {});
scores.set_data(allocator::malloc_or_wait(scores.nbytes()));
copies.push_back(scores);
array pointers({B, T, num_states}, uint8, nullptr, {});
pointers.set_data(allocator::malloc_or_wait(pointers.nbytes()));
copies.push_back(pointers);
array start({B}, uint32, nullptr, {});
start.set_data(allocator::malloc_or_wait(start.nbytes()));
copies.push_back(start);
array rolled({B, T}, uint16, nullptr, {});
rolled.set_data(allocator::malloc_or_wait(rolled.nbytes()));
copies.push_back(rolled);
viterbi(w, scores, pointers, start, out, false, s);
viterbi_backtrack(start, pointers, rolled, out, false, s);
viterbi(w, scores, pointers, start, rolled, true, s);
viterbi_backtrack(start, pointers, out, rolled, true, s);
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -38,4 +38,11 @@ void strided_reduce_general_dispatch(
metal::Device& d, metal::Device& d,
const Stream& s); const Stream& s);
void arg_reduce_dispatch(
const array& in,
array& out,
int axis,
std::string op_name,
const Stream& s);
} // namespace mlx::core } // namespace mlx::core

View File

@ -11,6 +11,8 @@
#include "mlx/transforms.h" #include "mlx/transforms.h"
#include "mlx/transforms_impl.h" #include "mlx/transforms_impl.h"
#include <iostream>
namespace mlx::core::fast { namespace mlx::core::fast {
std::vector<array> Custom::vjp( std::vector<array> Custom::vjp(
@ -832,7 +834,7 @@ array pack_and_quantize(
return packed_w; return packed_w;
} }
std::tuple<array, array, array> std::vector<array>
affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
auto s = to_stream(s_); auto s = to_stream(s_);
@ -1028,6 +1030,54 @@ array affine_dequantize(
return fallback({w, scales, biases})[0]; return fallback({w, scales, biases})[0];
} }
std::vector<array>
trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
if (bits != 2) {
throw std::invalid_argument(
"Only 2 bit Trellis quants are currently supported.");
}
int Tx = 4;
int Ty = 32;
int batch_size = 256;
auto s = to_stream(s_);
int L = 16;
int M = w_.shape(-2);
int T = Tx * Ty;
auto scale = std(astype(w_, float32, s), s);
auto w = divide(w_, scale, s);
w = astype(w, float16, s);
w = reshape(w, {M / Tx, Tx, -1, Ty}, s);
w = transpose(w, {0, 2, 1, 3}, s);
w = reshape(w, {-1, T}, s);
auto fallback = [bits, s](const std::vector<array>& inputs) mutable
-> std::vector<array> { return {inputs[0]}; };
auto q = zeros({w.shape(0), w.shape(1) * bits / L}, uint16, s);
for (int i = 0; i < w.shape(0); i += batch_size) {
auto w_batch = slice(w, {i, 0}, {i + batch_size, w.shape(-1)}, s);
auto q_batch = array(
w_batch.shape(),
uint16,
std::make_shared<TrellisQuantize>(s, fallback, bits, true),
{w_batch});
q_batch = slice(q_batch, {0, 0}, q_batch.shape(), {1, L / bits}, s);
q = slice_update(q, q_batch, {i, 0}, {i + batch_size, q.shape(-1)}, s);
eval(q);
}
q = reshape(q, {M / Tx, -1, Tx, Ty * bits / L}, s);
q = transpose(q, {0, 2, 1, 3}, s);
q = reshape(q, {M, -1}, s);
q = view(q, uint32, s);
return {q, scale, scale};
}
bool AffineQuantize::is_equivalent(const Primitive& other) const { bool AffineQuantize::is_equivalent(const Primitive& other) const {
const AffineQuantize& p_other = static_cast<const AffineQuantize&>(other); const AffineQuantize& p_other = static_cast<const AffineQuantize&>(other);
return ( return (

View File

@ -52,7 +52,7 @@ array scaled_dot_product_attention(
const std::vector<array>& mask_arrs = {}, const std::vector<array>& mask_arrs = {},
StreamOrDevice s = {}); StreamOrDevice s = {});
std::tuple<array, array, array> affine_quantize( std::vector<array> affine_quantize(
const array& w, const array& w,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
@ -66,6 +66,9 @@ array affine_dequantize(
int bits = 4, int bits = 4,
StreamOrDevice s = {}); StreamOrDevice s = {});
std::vector<array>
trellis_quantize(const array& w, int bits = 4, StreamOrDevice s = {});
typedef std::variant<int, bool, Dtype> TemplateArg; typedef std::variant<int, bool, Dtype> TemplateArg;
typedef std::function<std::vector<array>( typedef std::function<std::vector<array>(

View File

@ -269,6 +269,38 @@ class AffineQuantize : public Custom {
bool dequantize_; bool dequantize_;
}; };
class TrellisQuantize : public Custom {
public:
explicit TrellisQuantize(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
int bits,
bool dequantize)
: Custom(stream, fallback), bits_(bits), dequantize_(dequantize) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
};
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(TrellisQuantize);
// bool is_equivalent(const Primitive& other) const override;
// std::vector<Shape> output_shapes(const std::vector<array>& inputs)
// override;
auto state() const {
return std::make_tuple(nullptr, bits_, dequantize_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int bits_;
bool dequantize_;
};
struct CustomKernelShapeInfo { struct CustomKernelShapeInfo {
bool shape = false; bool shape = false;
bool strides = false; bool strides = false;

View File

@ -17,6 +17,8 @@
#include "mlx/transforms_impl.h" #include "mlx/transforms_impl.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include <iostream>
namespace mlx::core { namespace mlx::core {
namespace { namespace {
@ -79,7 +81,8 @@ std::pair<int, int> extract_quantized_matmul_dims(
const array& biases, const array& biases,
bool transpose, bool transpose,
int group_size, int group_size,
int bits) { int bits,
const std::string& mode) {
if (w.dtype() != uint32) { if (w.dtype() != uint32) {
std::ostringstream msg; std::ostringstream msg;
msg << "[" << tag << "] The weight matrix should be uint32 " msg << "[" << tag << "] The weight matrix should be uint32 "
@ -87,12 +90,23 @@ std::pair<int, int> extract_quantized_matmul_dims(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (scales.shape() != biases.shape()) { if (mode == "affine") {
std::ostringstream msg; if (scales.shape() != biases.shape()) {
msg << "[" << tag << "] Scales and biases should have the same shape. " std::ostringstream msg;
<< "Received scales with shape " << scales.shape() msg << "[" << tag << "] Scales and biases should have the same shape. "
<< " and biases with " << biases.shape(); << "Received scales with shape " << scales.shape()
throw std::invalid_argument(msg.str()); << " and biases with " << biases.shape();
throw std::invalid_argument(msg.str());
}
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
std::ostringstream msg;
msg << "[" << tag << "] The shapes of the weight and scales are "
<< "incompatible based on bits and group_size. w.shape() == "
<< w.shape() << " and scales.shape() == " << scales.shape()
<< " with group_size=" << group_size << " and bits=" << bits;
throw std::invalid_argument(msg.str());
}
} }
if (!std::equal( if (!std::equal(
@ -105,15 +119,6 @@ std::pair<int, int> extract_quantized_matmul_dims(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
std::ostringstream msg;
msg << "[" << tag << "] The shapes of the weight and scales are "
<< "incompatible based on bits and group_size. w.shape() == "
<< w.shape() << " and scales.shape() == " << scales.shape()
<< " with group_size=" << group_size << " and bits=" << bits;
throw std::invalid_argument(msg.str());
}
int x_inner_dims = x.shape(-1); int x_inner_dims = x.shape(-1);
// Calculate the expanded w's dims // Calculate the expanded w's dims
@ -717,6 +722,9 @@ array slice(
<< "array with dimension " << a.ndim() << "."; << "array with dimension " << a.ndim() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
// std::cout << "start " << start << std::endl;
// std::cout << "stop " << stop << std::endl;
// std::cout << "strides " << strides << std::endl;
auto [has_neg_strides, out_shape] = auto [has_neg_strides, out_shape] =
normalize_slice(a.shape(), start, stop, strides); normalize_slice(a.shape(), start, stop, strides);
@ -3969,10 +3977,19 @@ array quantized_matmul(
bool transpose /* = true */, bool transpose /* = true */,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
const std::string& mode /* = "affine" */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
// Check and extract the quantized matrix shape against x // Check and extract the quantized matrix shape against x
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits); "quantized_matmul",
x,
w,
scales,
biases,
transpose,
group_size,
bits,
mode);
auto dtype = result_type(x, scales, biases); auto dtype = result_type(x, scales, biases);
if (!issubdtype(dtype, floating)) { if (!issubdtype(dtype, floating)) {
@ -3996,16 +4013,26 @@ array quantized_matmul(
std::move(out_shape), std::move(out_shape),
dtype, dtype,
std::make_shared<QuantizedMatmul>( std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose), to_stream(s), group_size, bits, transpose, mode),
std::move(inputs)); std::move(inputs));
} }
std::tuple<array, array, array> quantize( std::vector<array> quantize(
const array& w, const array& w,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
const std::string& mode, /* = affine */
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return fast::affine_quantize(w, group_size, bits, s); if (mode == "affine") {
return fast::affine_quantize(w, group_size, bits, s);
} else if (mode == "trellis") {
return fast::trellis_quantize(w, bits, s);
} else {
std::ostringstream msg;
msg << "[quantize] Unsupported quantization mode " << mode << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
} }
array dequantize( array dequantize(
@ -4028,14 +4055,15 @@ array gather_qmm(
bool transpose /* = true */, bool transpose /* = true */,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
const std::string& mode /* = "affine" */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (!lhs_indices_ && !rhs_indices_) { if (!lhs_indices_ && !rhs_indices_) {
return quantized_matmul( return quantized_matmul(
x, w, scales, biases, transpose, group_size, bits, s); x, w, scales, biases, transpose, group_size, bits, mode, s);
} }
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"gather_qmm", x, w, scales, biases, transpose, group_size, bits); "gather_qmm", x, w, scales, biases, transpose, group_size, bits, mode);
// Extract indices and broadcast them // Extract indices and broadcast them
array lhs_indices = indices_or_default(lhs_indices_, x, s); array lhs_indices = indices_or_default(lhs_indices_, x, s);
@ -4067,7 +4095,8 @@ array gather_qmm(
return array( return array(
std::move(out_shape), std::move(out_shape),
out_type, out_type,
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose), std::make_shared<GatherQMM>(
to_stream(s), group_size, bits, transpose, mode),
{astype(x, out_type, s), {astype(x, out_type, s),
w, w,
astype(scales, out_type, s), astype(scales, out_type, s),

View File

@ -1323,13 +1323,15 @@ array quantized_matmul(
bool transpose = true, bool transpose = true,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Quantize a matrix along its last axis */ /** Quantize a matrix along its last axis */
std::tuple<array, array, array> quantize( std::vector<array> quantize(
const array& w, const array& w,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Dequantize a matrix produced by quantize() */ /** Dequantize a matrix produced by quantize() */
@ -1352,6 +1354,7 @@ array gather_qmm(
bool transpose = true, bool transpose = true,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Returns a contraction of a and b over multiple dimensions. */ /** Returns a contraction of a and b over multiple dimensions. */

View File

@ -3012,6 +3012,7 @@ std::vector<array> QuantizedMatmul::vjp(
!transpose_, !transpose_,
group_size_, group_size_,
bits_, bits_,
mode_,
stream())); stream()));
} }
@ -3040,6 +3041,7 @@ std::vector<array> QuantizedMatmul::jvp(
transpose_, transpose_,
group_size_, group_size_,
bits_, bits_,
mode_,
stream())}; stream())};
} }
@ -3098,6 +3100,7 @@ std::vector<array> GatherQMM::vjp(
!transpose_, !transpose_,
group_size_, group_size_,
bits_, bits_,
mode_,
stream()), stream()),
-3, -3,
stream()), stream()),

View File

@ -1552,11 +1552,13 @@ class QuantizedMatmul : public UnaryPrimitive {
Stream stream, Stream stream,
int group_size, int group_size,
int bits, int bits,
bool transpose) bool transpose,
const std::string mode)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
group_size_(group_size), group_size_(group_size),
bits_(bits), bits_(bits),
transpose_(transpose) {} transpose_(transpose),
mode_(mode) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1567,22 +1569,29 @@ class QuantizedMatmul : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override; std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const { auto state() const {
return std::make_tuple(group_size_, bits_, transpose_); return std::make_tuple(group_size_, bits_, transpose_, mode_);
} }
private: private:
int group_size_; int group_size_;
int bits_; int bits_;
bool transpose_; bool transpose_;
const std::string mode_;
}; };
class GatherQMM : public UnaryPrimitive { class GatherQMM : public UnaryPrimitive {
public: public:
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose) explicit GatherQMM(
Stream stream,
int group_size,
int bits,
bool transpose,
const std::string& mode)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
group_size_(group_size), group_size_(group_size),
bits_(bits), bits_(bits),
transpose_(transpose) {} transpose_(transpose),
mode_(mode) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1592,13 +1601,14 @@ class GatherQMM : public UnaryPrimitive {
DEFINE_PRINT(GatherQMM) DEFINE_PRINT(GatherQMM)
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
auto state() const { auto state() const {
return std::make_tuple(group_size_, bits_, transpose_); return std::make_tuple(group_size_, bits_, transpose_, mode_);
} }
private: private:
int group_size_; int group_size_;
int bits_; int bits_;
bool transpose_; bool transpose_;
const std::string mode_;
}; };
class RandomBits : public UnaryPrimitive { class RandomBits : public UnaryPrimitive {

View File

@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import math import math
from typing import Literal
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
@ -39,6 +40,12 @@ class Embedding(Module):
""" """
return x @ self.weight.T return x @ self.weight.T
def to_quantized(self, group_size: int = 64, bits: int = 4): def to_quantized(
self,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
fake: bool = False,
):
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer.""" """Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
return QuantizedEmbedding.from_embedding(self, group_size, bits) return QuantizedEmbedding.from_embedding(self, group_size, bits)

View File

@ -1,11 +1,12 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import math import math
from typing import Any from typing import Any, Literal
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.viterbi import quantize as trellis_quantize
class Identity(Module): class Identity(Module):
@ -70,9 +71,15 @@ class Linear(Module):
x = x @ self["weight"].T x = x @ self["weight"].T
return x return x
def to_quantized(self, group_size: int = 64, bits: int = 4): def to_quantized(
self,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
fake: bool = False,
):
"""Return a :obj:`QuantizedLinear` layer that approximates this layer.""" """Return a :obj:`QuantizedLinear` layer that approximates this layer."""
return QuantizedLinear.from_linear(self, group_size, bits) return QuantizedLinear.from_linear(self, group_size, bits, mode=mode, fake=fake)
class Bilinear(Module): class Bilinear(Module):

View File

@ -1,10 +1,11 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import math import math
from typing import Callable, Optional, Union from typing import Callable, Literal, Optional, Union
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.viterbi import quantize as trellis_quantize
from mlx.utils import tree_map_with_path from mlx.utils import tree_map_with_path
@ -12,7 +13,9 @@ def quantize(
model: Module, model: Module,
group_size: int = 64, group_size: int = 64,
bits: int = 4, bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None, class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
fake: bool = False,
): ):
"""Quantize the sub-modules of a module according to a predicate. """Quantize the sub-modules of a module according to a predicate.
@ -21,7 +24,7 @@ def quantize(
will be quantized. Note also, the module is updated in-place. will be quantized. Note also, the module is updated in-place.
Args: Args:
model (mlx.nn.Module): The model whose leaf modules may be quantized. model (mlx.nn.Module):, mode: Literal["affine", "trellis"] = "affine" The model whose leaf modules may be quantized.
group_size (int): The quantization group size (see group_size (int): The quantization group size (see
:func:`mlx.core.quantize`). Default: ``64``. :func:`mlx.core.quantize`). Default: ``64``.
bits (int): The number of bits per parameter (see bits (int): The number of bits per parameter (see
@ -36,12 +39,15 @@ def quantize(
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized")) class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
def _maybe_quantize(path, m): def _maybe_quantize(path, m):
print(path)
if bool_or_params := class_predicate(path, m): if bool_or_params := class_predicate(path, m):
if hasattr(m, "to_quantized"): if hasattr(m, "to_quantized"):
if isinstance(bool_or_params, bool): if isinstance(bool_or_params, bool):
return m.to_quantized(group_size=group_size, bits=bits) return m.to_quantized(
group_size=group_size, bits=bits, mode=mode, fake=fake
)
elif isinstance(bool_or_params, dict): elif isinstance(bool_or_params, dict):
return m.to_quantized(**bool_or_params) return m.to_quantized(**bool_or_params, fake=fake)
else: else:
raise ValueError( raise ValueError(
"``class_predicate`` must return a bool" "``class_predicate`` must return a bool"
@ -131,7 +137,11 @@ class QuantizedEmbedding(Module):
@classmethod @classmethod
def from_embedding( def from_embedding(
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4 cls,
embedding_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
): ):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape embedding_dims, dims = embedding_layer.weight.shape
@ -170,12 +180,14 @@ class QuantizedLinear(Module):
bias: bool = True, bias: bool = True,
group_size: int = 64, group_size: int = 64,
bits: int = 4, bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
): ):
super().__init__() super().__init__()
# Quantization config # Quantization config
self.group_size = group_size self.group_size = group_size
self.bits = bits self.bits = bits
self.mode = mode
# Initialize the quantized weight # Initialize the quantized weight
scale = math.sqrt(1 / input_dims) scale = math.sqrt(1 / input_dims)
@ -216,19 +228,40 @@ class QuantizedLinear(Module):
transpose=True, transpose=True,
group_size=self.group_size, group_size=self.group_size,
bits=self.bits, bits=self.bits,
mode=self.mode,
) )
if "bias" in self: if "bias" in self:
x = x + self["bias"] x = x + self["bias"]
return x return x
@classmethod @classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): def from_linear(
cls,
linear_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
fake: bool = False,
):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits) ql = cls(input_dims, output_dims, False, group_size, bits, mode)
ql.weight, ql.scales, ql.biases = mx.quantize( if mode == "trellis":
linear_layer.weight, group_size, bits if fake:
) ql.weight = mx.zeros(
(output_dims, input_dims // 32 * bits), dtype=mx.uint32
)
ql.scales = mx.array(0.0)
ql.biases = mx.array(0.0)
else:
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, bits=bits, mode="trellis"
)
else:
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits, mode="affine"
)
if "bias" in linear_layer: if "bias" in linear_layer:
ql.bias = linear_layer.bias ql.bias = linear_layer.bias

View File

@ -4116,10 +4116,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true, "transpose"_a = true,
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), "def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Perform the matrix multiplication with the quantized matrix ``w``. The Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``group_size`` of quantization uses one floating point scale and bias per ``group_size`` of
@ -4138,6 +4139,8 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``. shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``. ``w``. Default: ``4``.
mode (str, optional): The mode to use for quantization.
Default: ``affine``.
Returns: Returns:
array: The result of the multiplication of ``x`` with ``w``. array: The result of the multiplication of ``x`` with ``w``.
@ -4149,9 +4152,10 @@ void init_ops(nb::module_& m) {
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
nb::kw_only(), nb::kw_only(),
"mode"_a = "affine",
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), "def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, mode: Literal['affine', 'trellis'], stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
R"pbdoc( R"pbdoc(
Quantize the matrix ``w`` using ``bits`` bits per element. Quantize the matrix ``w`` using ``bits`` bits per element.
@ -4193,6 +4197,7 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``. scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element of bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. Default: ``4``. ``w`` in the returned quantized matrix. Default: ``4``.
mode (str): The quantization mode to use. Default: ``affine``.
Returns: Returns:
tuple: A tuple containing tuple: A tuple containing
@ -4249,10 +4254,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true, "transpose"_a = true,
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather. Perform quantized matrix multiplication with matrix-level gather.
@ -4278,6 +4284,8 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``. shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``. ``w``. Default: ``4``.
mode (str, optional): The mode to use for quantization.
Default: ``affine``.
Returns: Returns:
array: The result of the multiplication of ``x`` with ``w`` array: The result of the multiplication of ``x`` with ``w``

View File

@ -10,6 +10,9 @@ import mlx_tests
class TestQuantized(mlx_tests.MLXTestCase): class TestQuantized(mlx_tests.MLXTestCase):
def test_quantize_dequantize(self): def test_quantize_dequantize(self):
w = mx.random.normal(shape=(128, 512)) w = mx.random.normal(shape=(128, 512))
w_q, scales, biases = mx.quantize(w, bits=2, mode="trellis")
print(w_q, scales, biases)
for gs in [32, 64, 128]: for gs in [32, 64, 128]:
for b in [2, 3, 6, 4, 8]: for b in [2, 3, 6, 4, 8]:
with self.subTest(gs=gs, b=b): with self.subTest(gs=gs, b=b):