mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
add trellis quant mode
This commit is contained in:
parent
e9e268336b
commit
d7acf59fd0
@ -684,6 +684,115 @@ METAL_FUNC void qmv_fast_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t a = 89226354, uint32_t b = 64248484, uint32_t m = 996162400>
|
||||
float inst3(uint16_t xi) {
|
||||
uint32_t x = xi;
|
||||
x = a * x + b;
|
||||
x = (x & 0b10001111111111111000111111111111) ^ m;
|
||||
auto xf = reinterpret_cast<thread float16_t*>(&x);
|
||||
return xf[0] + xf[1];
|
||||
}
|
||||
|
||||
template <typename T, int bits>
|
||||
METAL_FUNC void qmv_trellis_impl(
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* biases,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int packs_per_thread = 2;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int reads_per = 16 / bits;
|
||||
constexpr int local_w_size =
|
||||
results_per_simdgroup * values_per_thread / reads_per;
|
||||
|
||||
const device uint8_t* ws = (const device uint8_t*)w;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread uint16_t w_thread[local_w_size];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||
simd_gid * results_per_simdgroup;
|
||||
|
||||
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
||||
x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.x * out_vec_size + out_row;
|
||||
|
||||
T scale = scales[0];
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < values_per_thread; i++) {
|
||||
x_thread[i] = x[i];
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < values_per_thread / reads_per; i++) {
|
||||
auto wl = (const device uint16_t*)(ws + row * in_vec_size_w);
|
||||
w_thread[row * values_per_thread / reads_per + i] = wl[i];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < values_per_thread / reads_per; i++) {
|
||||
int index = row * values_per_thread / reads_per + i;
|
||||
uint16_t w0 = w_thread[index];
|
||||
uint16_t w1 = w_thread[(index + 1) % local_w_size];
|
||||
|
||||
uint16_t wx = w0 ^ w1;
|
||||
uint16_t wx1 = wx ^ 1;
|
||||
uint16_t wf = w0 ^ (1 << bits);
|
||||
|
||||
if (bits == 2) {
|
||||
result[row] += x_thread[8 * i] * inst3(w0);
|
||||
result[row] += x_thread[8 * i + 1] * inst3(wf ^ (wx1 & 0x3));
|
||||
result[row] += x_thread[8 * i + 2] * inst3(w0 ^ (wx & 0xf));
|
||||
result[row] += x_thread[8 * i + 3] * inst3(w0 ^ (wx1 & 0x3f));
|
||||
result[row] += x_thread[8 * i + 4] * inst3(w0 ^ (wx & 0xff));
|
||||
result[row] += x_thread[8 * i + 5] * inst3(w0 ^ (wx1 & 0x3ff));
|
||||
result[row] += x_thread[8 * i + 6] * inst3(w0 ^ (wx & 0xfff));
|
||||
result[row] += x_thread[8 * i + 7] * inst3(w0 ^ (wx1 & 0x3fff));
|
||||
} else if (bits == 4) {
|
||||
result[row] += x_thread[4 * i] * inst3(w0);
|
||||
result[row] += x_thread[4 * i + 1] * inst3(wf ^ (wx1 & 0xf));
|
||||
result[row] += x_thread[4 * i + 2] * inst3(w0 ^ (wx & 0xff));
|
||||
result[row] += x_thread[4 * i + 3] * inst3(w0 ^ (wx1 & 0xfff));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ws += block_size * bytes_per_pack / pack_factor;
|
||||
x += block_size;
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
if (simd_lid == 0) {
|
||||
y[row] = static_cast<T>(scale * result[row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void qmv_impl(
|
||||
const device uint32_t* w,
|
||||
@ -1302,7 +1411,13 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, int D, bool batched>
|
||||
template <
|
||||
typename T,
|
||||
int group_size,
|
||||
int bits,
|
||||
int D,
|
||||
bool batched,
|
||||
bool trellis = false>
|
||||
[[kernel]] void qmv_quad(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@ -1354,7 +1469,12 @@ template <typename T, int group_size, int bits, int D, bool batched>
|
||||
quad_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, bool batched>
|
||||
template <
|
||||
typename T,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool batched,
|
||||
bool trellis = false>
|
||||
[[kernel]] void qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@ -1393,6 +1513,19 @@ template <typename T, int group_size, int bits, bool batched>
|
||||
b_strides,
|
||||
tid);
|
||||
}
|
||||
if (trellis) {
|
||||
qmv_trellis_impl<T, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
} else {
|
||||
qmv_fast_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
@ -1405,8 +1538,14 @@ template <typename T, int group_size, int bits, bool batched>
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits, bool batched>
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
bool batched,
|
||||
bool trellis = false>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@ -1458,7 +1597,12 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits, bool batched>
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
bool batched,
|
||||
bool trellis = false>
|
||||
[[kernel]] void qvm(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@ -1572,6 +1716,7 @@ template <
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const bool batched,
|
||||
bool trellis = false,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
@ -1630,6 +1775,7 @@ template <
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool batched,
|
||||
bool trellis = false,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
@ -1685,7 +1831,7 @@ template <
|
||||
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
template <typename T, int group_size, int bits, bool trellis = false>
|
||||
[[kernel]] void bs_qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@ -1734,6 +1880,19 @@ template <typename T, int group_size, int bits>
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
if (trellis) {
|
||||
qmv_trellis_impl<T, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
} else {
|
||||
qmv_fast_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
@ -1746,8 +1905,9 @@ template <typename T, int group_size, int bits>
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
template <typename T, int group_size, int bits, bool trellis = false>
|
||||
[[kernel]] void bs_qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@ -1809,7 +1969,7 @@ template <typename T, int group_size, int bits>
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
template <typename T, int group_size, int bits, bool trellis = false>
|
||||
[[kernel]] void bs_qvm(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@ -1876,6 +2036,7 @@ template <
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
bool trellis = false,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
@ -1943,6 +2104,7 @@ template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
bool trellis = false,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
@ -2157,3 +2319,211 @@ template <typename T, const int group_size, const int bits>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const bool use_overlap,
|
||||
const int bits = 2,
|
||||
const int timesteps = 128>
|
||||
[[kernel]] void trellis_viterbi(
|
||||
const device T* w [[buffer(0)]],
|
||||
device float16_t* score [[buffer(1)]],
|
||||
device uint8_t* pointers [[buffer(2)]],
|
||||
const device uint16_t* overlap [[buffer(3)]],
|
||||
uint3 tid [[thread_position_in_grid]]) {
|
||||
constexpr uint16_t L = 16;
|
||||
constexpr uint L2 = 1 << L;
|
||||
|
||||
uint16_t idx = tid.y * 16;
|
||||
|
||||
threadgroup float16_t swap_V[16384];
|
||||
|
||||
thread float16_t min_V[16] = {0};
|
||||
|
||||
for (uint16_t t = 0; t < timesteps; t++) {
|
||||
uint16_t tt = t % 8 == 0 ? L / bits : t % 8;
|
||||
uint16_t shift = ((tt - 1) % (L / bits)) * bits;
|
||||
uint16_t flip = (t == 0 || (t > 1 && t % 8 == 1)) ? (1 << bits) + 1 : t % 2;
|
||||
|
||||
uint16_t s000 = 1 << (shift - 6);
|
||||
uint16_t s0 = 1 << (shift - 2);
|
||||
uint16_t s1 = 1 << (shift);
|
||||
uint16_t s2 = 1 << (shift + 2);
|
||||
uint16_t s4 = 1 << (shift + 4);
|
||||
|
||||
if (t > 1) {
|
||||
uint16_t i = 0;
|
||||
uint16_t loff = 1 << (metal::clamp((shift + 14) % 16, 2, 12));
|
||||
uint16_t hoff = shift > 4 ? 4 : shift == 4 ? 16 : 1;
|
||||
uint16_t ind = idx;
|
||||
|
||||
if (shift == 0) {
|
||||
ind >>= 2;
|
||||
} else if (shift == 14) {
|
||||
ind = (ind & 0xfff) + (ind >> 12);
|
||||
} else if (shift == 2) {
|
||||
} else if (shift == 4) {
|
||||
ind = ((ind >> 4) & 0x3) + (ind & ~0x3f);
|
||||
} else if (shift == 6) {
|
||||
ind = ((ind / s0) % 4) * s1 + ((ind / s1) % 4) + (ind / s2) * s2;
|
||||
} else {
|
||||
ind = ((ind / 16) % s000) * 16 + ((ind / s0) % 4) * s1 +
|
||||
((ind / s1) % 4) + (ind / s2) * s2;
|
||||
}
|
||||
|
||||
for (uint16_t high = 0; high < 4; high++) {
|
||||
uint16_t sub_ind = ind;
|
||||
for (uint16_t low = 0; low < 4; low++) {
|
||||
swap_V[sub_ind] = min_V[i];
|
||||
i++;
|
||||
sub_ind += loff;
|
||||
}
|
||||
ind += hoff;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (uint16_t i = 0; i < 16; i++) {
|
||||
min_V[i] = swap_V[idx + i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
uint16_t rolled_t = use_overlap ? t : (t + 64) % 128;
|
||||
T w_t = w[tid.x * timesteps + rolled_t];
|
||||
|
||||
for (uint16_t i = 0; i < 4; i++) {
|
||||
thread float16_t min_val[4] = {INFINITY, INFINITY, INFINITY, INFINITY};
|
||||
thread uint16_t min_idx[4] = {0};
|
||||
|
||||
uint16_t ii = idx * 4 + i * 16;
|
||||
uint16_t big_idx = ii;
|
||||
if (shift > 0 && shift < 14) {
|
||||
big_idx = ((ii / s2) % 4) + (ii / s4 * s4);
|
||||
if (shift > 2) {
|
||||
big_idx += ((ii / 16) % s0) * 4;
|
||||
}
|
||||
} else if (shift == 14 && t > 0) {
|
||||
big_idx >>= 2;
|
||||
}
|
||||
|
||||
uint16_t loff = t == 0 ? 4 : s1;
|
||||
uint16_t hoff = (t == 0 || shift == 14) ? 1 : s2;
|
||||
|
||||
for (uint16_t high = 0; high < 4; high++) {
|
||||
uint16_t sub_ind = big_idx;
|
||||
for (uint16_t low = 0; low < 4; low++) {
|
||||
float mse = inst3(sub_ind ^ flip) - w_t;
|
||||
mse *= mse;
|
||||
|
||||
float16_t new_val = min_V[i * 4 + high] + mse;
|
||||
if (new_val < min_val[low]) {
|
||||
min_val[low] = new_val;
|
||||
min_idx[low] = high;
|
||||
}
|
||||
sub_ind += loff;
|
||||
}
|
||||
big_idx += hoff;
|
||||
}
|
||||
|
||||
for (uint16_t j = 0; j < 4; j++) {
|
||||
min_V[i * 4 + j] = min_val[j];
|
||||
pointers[tid.x * L2 / 4 * timesteps + t * L2 / 4 + idx + i * 4 + j] =
|
||||
min_idx[j];
|
||||
}
|
||||
}
|
||||
if (t == 0 && use_overlap) {
|
||||
uint16_t over = overlap[tid.x * 128 + 64];
|
||||
over = over & ((1 << 14) - 1);
|
||||
for (uint16_t i = 0; i < 16; i++) {
|
||||
uint16_t rs = (over >> 2) ^ 1;
|
||||
uint16_t ls = (idx + i) & ((1 << 12) - 1);
|
||||
min_V[i] = rs == ls ? min_V[i] : INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (use_overlap) {
|
||||
uint16_t over = overlap[tid.x * 128 + 64];
|
||||
over = over & ((1 << 14) - 1);
|
||||
uint16_t node =
|
||||
(over % 4) * 4096 + ((over / 4) % 1024) * 4 + (over / 4096) % 4;
|
||||
for (uint16_t i = 0; i < 16; i++) {
|
||||
min_V[i] = (idx + i) == node ? min_V[i] : INFINITY;
|
||||
}
|
||||
}
|
||||
for (uint16_t i = 0; i < 16; i++) {
|
||||
score[tid.x * L2 / 4 + idx + i] = min_V[i];
|
||||
}
|
||||
}
|
||||
|
||||
uint16_t remove_bits(uint16_t i, uint16_t shift) {
|
||||
uint16_t lower = i & ((1 << shift) - 1);
|
||||
uint16_t upper = i & ~((1 << (shift + 2)) - 1);
|
||||
return lower + (upper >> 2);
|
||||
}
|
||||
|
||||
uint16_t swap_bits(uint16_t i, uint16_t shift) {
|
||||
uint16_t diff = ((i >> shift) ^ i) & 0x3;
|
||||
i = i ^ diff;
|
||||
i ^= diff << shift;
|
||||
return i;
|
||||
}
|
||||
|
||||
template <const bool use_overlap, const int bits = 2, const int timesteps = 128>
|
||||
[[kernel]] void trellis_backtrack(
|
||||
const device uint32_t* start [[buffer(0)]],
|
||||
const device uint8_t* pointers [[buffer(1)]],
|
||||
device uint16_t* out [[buffer(2)]],
|
||||
const device uint16_t* overlap [[buffer(3)]],
|
||||
uint3 tid [[thread_position_in_grid]]) {
|
||||
constexpr uint16_t L = 16;
|
||||
|
||||
uint16_t node = start[tid.x];
|
||||
|
||||
uint16_t dir =
|
||||
pointers[tid.x * timesteps * 16384 + (timesteps - 1) * 16384 + node];
|
||||
|
||||
node = (node % 4) * 4096 + ((node / 4) % 1024) * 4 + (node / 4096) % 4;
|
||||
node ^= 1;
|
||||
node += dir * 16384;
|
||||
|
||||
out[tid.x * timesteps + timesteps - 1] = node;
|
||||
|
||||
for (int t = timesteps - 2; t >= 0; t--) {
|
||||
uint16_t shift = (t % (L / bits)) * bits;
|
||||
uint16_t mask = ((1 << L) - 1) ^ (((1 << bits) - 1) << shift);
|
||||
uint16_t flip = t % (L / bits) == 0 ? 1 << bits : 1;
|
||||
uint16_t i = (node & mask) ^ flip;
|
||||
|
||||
if (shift > 0) {
|
||||
i = remove_bits(i, shift);
|
||||
}
|
||||
|
||||
if (t == 0) {
|
||||
i >>= 2;
|
||||
}
|
||||
|
||||
if (t % 2 == 1 || t == 0) {
|
||||
i ^= 1;
|
||||
}
|
||||
|
||||
shift = shift == 0 ? L : shift;
|
||||
|
||||
if (t > 0) {
|
||||
i = swap_bits(i, shift - 2);
|
||||
}
|
||||
|
||||
shift = shift == L ? 0 : shift;
|
||||
|
||||
uint16_t last_p = pointers[tid.x * timesteps * 16384 + t * 16384 + i];
|
||||
if ((t % 8 == 1 && t > 1) || t == 0) {
|
||||
last_p ^= 1;
|
||||
}
|
||||
|
||||
node = ((node & mask) ^ flip) | (last_p << shift);
|
||||
if (t == 0 && use_overlap) {
|
||||
uint16_t over = overlap[tid.x * 128 + 64];
|
||||
over = over & ((1 << 14) - 1);
|
||||
node = (node & 0xfffc) + (over & 0x3);
|
||||
}
|
||||
out[tid.x * timesteps + t] = node;
|
||||
}
|
||||
}
|
@ -16,6 +16,8 @@
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
@ -158,33 +160,25 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
void arg_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
std::string op_name,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case ArgReduce::ArgMin:
|
||||
op_name = "argmin_";
|
||||
break;
|
||||
case ArgReduce::ArgMax:
|
||||
op_name = "argmax_";
|
||||
break;
|
||||
}
|
||||
|
||||
// Prepare the shapes, strides and axis arguments.
|
||||
auto in_strides = in.strides();
|
||||
auto shape = in.shape();
|
||||
auto out_strides = out.strides();
|
||||
auto axis_stride = in_strides[axis_];
|
||||
size_t axis_size = shape[axis_];
|
||||
auto axis_stride = in_strides[axis];
|
||||
size_t axis_size = shape[axis];
|
||||
if (out_strides.size() == in_strides.size()) {
|
||||
out_strides.erase(out_strides.begin() + axis_);
|
||||
out_strides.erase(out_strides.begin() + axis);
|
||||
}
|
||||
in_strides.erase(in_strides.begin() + axis_);
|
||||
shape.erase(shape.begin() + axis_);
|
||||
in_strides.erase(in_strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
size_t ndim = shape.size();
|
||||
|
||||
// ArgReduce
|
||||
@ -192,7 +186,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int n_reads = 4;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name + type_to_name(in));
|
||||
auto kernel = d.get_kernel(op_name + "_" + type_to_name(in));
|
||||
NS::UInteger thread_group_size = std::min(
|
||||
(axis_size + n_reads - 1) / n_reads,
|
||||
kernel->maxTotalThreadsPerThreadgroup());
|
||||
@ -226,6 +220,23 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case ArgReduce::ArgMin:
|
||||
op_name = "argmin";
|
||||
break;
|
||||
case ArgReduce::ArgMax:
|
||||
op_name = "argmax";
|
||||
break;
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
arg_reduce_dispatch(in, out, axis_, op_name, s);
|
||||
}
|
||||
|
||||
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
CopyType ctype =
|
||||
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
|
@ -7,11 +7,14 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void launch_qmm(
|
||||
@ -31,6 +34,7 @@ void launch_qmm(
|
||||
bool gather,
|
||||
bool aligned,
|
||||
bool quad,
|
||||
const std::string& mode,
|
||||
const Stream& s) {
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
@ -54,8 +58,12 @@ void launch_qmm(
|
||||
};
|
||||
auto x = ensure_row_contiguous_last_dims(x_pre);
|
||||
auto w = ensure_row_contiguous_last_dims(w_pre);
|
||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
auto scales = scales_pre;
|
||||
auto biases = biases_pre;
|
||||
if (mode == "affine") {
|
||||
scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
}
|
||||
|
||||
int x_batch_ndims = x.ndim() - 2;
|
||||
auto& x_shape = x.shape();
|
||||
@ -68,6 +76,8 @@ void launch_qmm(
|
||||
|
||||
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
|
||||
|
||||
bool is_trellis = (mode == "trellis");
|
||||
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits;
|
||||
@ -80,24 +90,47 @@ void launch_qmm(
|
||||
if (!gather) {
|
||||
kname << "_batch_" << batched;
|
||||
}
|
||||
if (mode == "trellis") {
|
||||
kname << "_mode_" << is_trellis;
|
||||
}
|
||||
|
||||
// Encode and dispatch kernel
|
||||
std::string template_def;
|
||||
if (quad) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, D, batched);
|
||||
kname.str(),
|
||||
name,
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
D,
|
||||
batched,
|
||||
is_trellis);
|
||||
} else if (aligned && !gather) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, aligned_n, batched);
|
||||
kname.str(),
|
||||
name,
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
aligned_n,
|
||||
batched,
|
||||
is_trellis);
|
||||
} else if (!gather && !aligned) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, batched);
|
||||
kname.str(), name, type_string, group_size, bits, batched, is_trellis);
|
||||
} else if (aligned && gather) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, aligned_n);
|
||||
kname.str(),
|
||||
name,
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
aligned_n,
|
||||
is_trellis);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits);
|
||||
kname.str(), name, type_string, group_size, bits, is_trellis);
|
||||
}
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
@ -276,6 +309,7 @@ void qmm_op(
|
||||
int group_size,
|
||||
int bits,
|
||||
bool gather,
|
||||
const std::string& mode,
|
||||
const Stream& s) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
@ -354,7 +388,7 @@ void qmm_op(
|
||||
group_dims = MTL::Size(simdgroup_size, 1, 1);
|
||||
grid_dims = MTL::Size(B, (O + bo - 1) / bo, N);
|
||||
quad = true;
|
||||
} else if (B < qmv_batch_limit && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
} else if (B < 10000 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
name += "qmv_fast";
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
@ -420,19 +454,34 @@ void qmm_op(
|
||||
gather,
|
||||
aligned,
|
||||
quad,
|
||||
mode,
|
||||
s);
|
||||
}
|
||||
|
||||
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
qmm_op(
|
||||
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream());
|
||||
inputs,
|
||||
out,
|
||||
transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
/*gather=*/false,
|
||||
mode_,
|
||||
stream());
|
||||
}
|
||||
|
||||
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
qmm_op(
|
||||
inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream());
|
||||
inputs,
|
||||
out,
|
||||
transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
/*gather=*/true,
|
||||
mode_,
|
||||
stream());
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
@ -516,4 +565,123 @@ void fast::AffineQuantize::eval_gpu(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void viterbi(
|
||||
array& w,
|
||||
array& scores,
|
||||
array& pointers,
|
||||
array& start,
|
||||
array& overlap,
|
||||
bool use_overlap,
|
||||
const Stream& s) {
|
||||
int B = scores.shape(0);
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_output_array(scores, 1);
|
||||
compute_encoder.set_output_array(pointers, 2);
|
||||
if (use_overlap) {
|
||||
compute_encoder.set_input_array(overlap, 3);
|
||||
}
|
||||
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(w.dtype());
|
||||
kname << "trellis_viterbi_" << type_string << "_overlap_" << use_overlap;
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "trellis_viterbi", type_string, use_overlap);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
auto group_dims = MTL::Size(1, 1024, 1);
|
||||
auto grid_dims = MTL::Size(B, 1024, 1);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
arg_reduce_dispatch(scores, start, 1, "argmin", s);
|
||||
}
|
||||
|
||||
void viterbi_backtrack(
|
||||
array& start,
|
||||
array& pointers,
|
||||
array& out,
|
||||
array& overlap,
|
||||
bool use_overlap,
|
||||
const Stream& s) {
|
||||
int B = start.shape(0);
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_input_array(start, 0);
|
||||
compute_encoder.set_input_array(pointers, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
if (use_overlap) {
|
||||
compute_encoder.set_input_array(overlap, 3);
|
||||
}
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "trellis_backtrack" << "_overlap_" << use_overlap;
|
||||
auto template_def =
|
||||
get_template_definition(kname.str(), "trellis_backtrack", use_overlap);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
auto group_dims = MTL::Size(256, 1, 1);
|
||||
auto grid_dims = MTL::Size(B, 1, 1);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void fast::TrellisQuantize::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& w_pre = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
std::vector<array> copies;
|
||||
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto w = ensure_row_contiguous(w_pre);
|
||||
|
||||
int B = w.shape(0);
|
||||
int T = w.shape(1);
|
||||
|
||||
constexpr int num_states = 1 << 14;
|
||||
|
||||
array scores({B, num_states}, float16, nullptr, {});
|
||||
scores.set_data(allocator::malloc_or_wait(scores.nbytes()));
|
||||
copies.push_back(scores);
|
||||
|
||||
array pointers({B, T, num_states}, uint8, nullptr, {});
|
||||
pointers.set_data(allocator::malloc_or_wait(pointers.nbytes()));
|
||||
copies.push_back(pointers);
|
||||
|
||||
array start({B}, uint32, nullptr, {});
|
||||
start.set_data(allocator::malloc_or_wait(start.nbytes()));
|
||||
copies.push_back(start);
|
||||
|
||||
array rolled({B, T}, uint16, nullptr, {});
|
||||
rolled.set_data(allocator::malloc_or_wait(rolled.nbytes()));
|
||||
copies.push_back(rolled);
|
||||
|
||||
viterbi(w, scores, pointers, start, out, false, s);
|
||||
viterbi_backtrack(start, pointers, rolled, out, false, s);
|
||||
|
||||
viterbi(w, scores, pointers, start, rolled, true, s);
|
||||
viterbi_backtrack(start, pointers, out, rolled, true, s);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -38,4 +38,11 @@ void strided_reduce_general_dispatch(
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
void arg_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
std::string op_name,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
52
mlx/fast.cpp
52
mlx/fast.cpp
@ -11,6 +11,8 @@
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
std::vector<array> Custom::vjp(
|
||||
@ -832,7 +834,7 @@ array pack_and_quantize(
|
||||
return packed_w;
|
||||
}
|
||||
|
||||
std::tuple<array, array, array>
|
||||
std::vector<array>
|
||||
affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||
auto s = to_stream(s_);
|
||||
|
||||
@ -1028,6 +1030,54 @@ array affine_dequantize(
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
std::vector<array>
|
||||
trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
|
||||
if (bits != 2) {
|
||||
throw std::invalid_argument(
|
||||
"Only 2 bit Trellis quants are currently supported.");
|
||||
}
|
||||
|
||||
int Tx = 4;
|
||||
int Ty = 32;
|
||||
int batch_size = 256;
|
||||
|
||||
auto s = to_stream(s_);
|
||||
|
||||
int L = 16;
|
||||
int M = w_.shape(-2);
|
||||
int T = Tx * Ty;
|
||||
auto scale = std(astype(w_, float32, s), s);
|
||||
auto w = divide(w_, scale, s);
|
||||
w = astype(w, float16, s);
|
||||
|
||||
w = reshape(w, {M / Tx, Tx, -1, Ty}, s);
|
||||
w = transpose(w, {0, 2, 1, 3}, s);
|
||||
w = reshape(w, {-1, T}, s);
|
||||
|
||||
auto fallback = [bits, s](const std::vector<array>& inputs) mutable
|
||||
-> std::vector<array> { return {inputs[0]}; };
|
||||
|
||||
auto q = zeros({w.shape(0), w.shape(1) * bits / L}, uint16, s);
|
||||
for (int i = 0; i < w.shape(0); i += batch_size) {
|
||||
auto w_batch = slice(w, {i, 0}, {i + batch_size, w.shape(-1)}, s);
|
||||
auto q_batch = array(
|
||||
w_batch.shape(),
|
||||
uint16,
|
||||
std::make_shared<TrellisQuantize>(s, fallback, bits, true),
|
||||
{w_batch});
|
||||
q_batch = slice(q_batch, {0, 0}, q_batch.shape(), {1, L / bits}, s);
|
||||
q = slice_update(q, q_batch, {i, 0}, {i + batch_size, q.shape(-1)}, s);
|
||||
eval(q);
|
||||
}
|
||||
|
||||
q = reshape(q, {M / Tx, -1, Tx, Ty * bits / L}, s);
|
||||
q = transpose(q, {0, 2, 1, 3}, s);
|
||||
q = reshape(q, {M, -1}, s);
|
||||
q = view(q, uint32, s);
|
||||
|
||||
return {q, scale, scale};
|
||||
}
|
||||
|
||||
bool AffineQuantize::is_equivalent(const Primitive& other) const {
|
||||
const AffineQuantize& p_other = static_cast<const AffineQuantize&>(other);
|
||||
return (
|
||||
|
@ -52,7 +52,7 @@ array scaled_dot_product_attention(
|
||||
const std::vector<array>& mask_arrs = {},
|
||||
StreamOrDevice s = {});
|
||||
|
||||
std::tuple<array, array, array> affine_quantize(
|
||||
std::vector<array> affine_quantize(
|
||||
const array& w,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
@ -66,6 +66,9 @@ array affine_dequantize(
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
std::vector<array>
|
||||
trellis_quantize(const array& w, int bits = 4, StreamOrDevice s = {});
|
||||
|
||||
typedef std::variant<int, bool, Dtype> TemplateArg;
|
||||
|
||||
typedef std::function<std::vector<array>(
|
||||
|
@ -269,6 +269,38 @@ class AffineQuantize : public Custom {
|
||||
bool dequantize_;
|
||||
};
|
||||
|
||||
class TrellisQuantize : public Custom {
|
||||
public:
|
||||
explicit TrellisQuantize(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
int bits,
|
||||
bool dequantize)
|
||||
: Custom(stream, fallback), bits_(bits), dequantize_(dequantize) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(TrellisQuantize);
|
||||
|
||||
// bool is_equivalent(const Primitive& other) const override;
|
||||
// std::vector<Shape> output_shapes(const std::vector<array>& inputs)
|
||||
// override;
|
||||
auto state() const {
|
||||
return std::make_tuple(nullptr, bits_, dequantize_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
int bits_;
|
||||
bool dequantize_;
|
||||
};
|
||||
|
||||
struct CustomKernelShapeInfo {
|
||||
bool shape = false;
|
||||
bool strides = false;
|
||||
|
63
mlx/ops.cpp
63
mlx/ops.cpp
@ -17,6 +17,8 @@
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
@ -79,7 +81,8 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
const array& biases,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits) {
|
||||
int bits,
|
||||
const std::string& mode) {
|
||||
if (w.dtype() != uint32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] The weight matrix should be uint32 "
|
||||
@ -87,6 +90,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (mode == "affine") {
|
||||
if (scales.shape() != biases.shape()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Scales and biases should have the same shape. "
|
||||
@ -95,16 +99,6 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (!std::equal(
|
||||
w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag
|
||||
<< "] Weight, scales and biases should have the same batch shape. "
|
||||
<< "Received weight with shape " << w.shape() << ", scales with "
|
||||
<< scales.shape() << " and biases with " << biases.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] The shapes of the weight and scales are "
|
||||
@ -113,6 +107,17 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
<< " with group_size=" << group_size << " and bits=" << bits;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
if (!std::equal(
|
||||
w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag
|
||||
<< "] Weight, scales and biases should have the same batch shape. "
|
||||
<< "Received weight with shape " << w.shape() << ", scales with "
|
||||
<< scales.shape() << " and biases with " << biases.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int x_inner_dims = x.shape(-1);
|
||||
|
||||
@ -717,6 +722,9 @@ array slice(
|
||||
<< "array with dimension " << a.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
// std::cout << "start " << start << std::endl;
|
||||
// std::cout << "stop " << stop << std::endl;
|
||||
// std::cout << "strides " << strides << std::endl;
|
||||
|
||||
auto [has_neg_strides, out_shape] =
|
||||
normalize_slice(a.shape(), start, stop, strides);
|
||||
@ -3969,10 +3977,19 @@ array quantized_matmul(
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Check and extract the quantized matrix shape against x
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||
"quantized_matmul",
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
transpose,
|
||||
group_size,
|
||||
bits,
|
||||
mode);
|
||||
|
||||
auto dtype = result_type(x, scales, biases);
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
@ -3996,16 +4013,26 @@ array quantized_matmul(
|
||||
std::move(out_shape),
|
||||
dtype,
|
||||
std::make_shared<QuantizedMatmul>(
|
||||
to_stream(s), group_size, bits, transpose),
|
||||
to_stream(s), group_size, bits, transpose, mode),
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
std::tuple<array, array, array> quantize(
|
||||
std::vector<array> quantize(
|
||||
const array& w,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
const std::string& mode, /* = affine */
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (mode == "affine") {
|
||||
return fast::affine_quantize(w, group_size, bits, s);
|
||||
} else if (mode == "trellis") {
|
||||
return fast::trellis_quantize(w, bits, s);
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] Unsupported quantization mode " << mode << "."
|
||||
<< std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
array dequantize(
|
||||
@ -4028,14 +4055,15 @@ array gather_qmm(
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!lhs_indices_ && !rhs_indices_) {
|
||||
return quantized_matmul(
|
||||
x, w, scales, biases, transpose, group_size, bits, s);
|
||||
x, w, scales, biases, transpose, group_size, bits, mode, s);
|
||||
}
|
||||
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits, mode);
|
||||
|
||||
// Extract indices and broadcast them
|
||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||
@ -4067,7 +4095,8 @@ array gather_qmm(
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
|
||||
std::make_shared<GatherQMM>(
|
||||
to_stream(s), group_size, bits, transpose, mode),
|
||||
{astype(x, out_type, s),
|
||||
w,
|
||||
astype(scales, out_type, s),
|
||||
|
@ -1323,13 +1323,15 @@ array quantized_matmul(
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
const std::string& mode = "affine",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantize a matrix along its last axis */
|
||||
std::tuple<array, array, array> quantize(
|
||||
std::vector<array> quantize(
|
||||
const array& w,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
const std::string& mode = "affine",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Dequantize a matrix produced by quantize() */
|
||||
@ -1352,6 +1354,7 @@ array gather_qmm(
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
const std::string& mode = "affine",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Returns a contraction of a and b over multiple dimensions. */
|
||||
|
@ -3012,6 +3012,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
mode_,
|
||||
stream()));
|
||||
}
|
||||
|
||||
@ -3040,6 +3041,7 @@ std::vector<array> QuantizedMatmul::jvp(
|
||||
transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
mode_,
|
||||
stream())};
|
||||
}
|
||||
|
||||
@ -3098,6 +3100,7 @@ std::vector<array> GatherQMM::vjp(
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
mode_,
|
||||
stream()),
|
||||
-3,
|
||||
stream()),
|
||||
|
@ -1552,11 +1552,13 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
Stream stream,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool transpose)
|
||||
bool transpose,
|
||||
const std::string mode)
|
||||
: UnaryPrimitive(stream),
|
||||
group_size_(group_size),
|
||||
bits_(bits),
|
||||
transpose_(transpose) {}
|
||||
transpose_(transpose),
|
||||
mode_(mode) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1567,22 +1569,29 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return std::make_tuple(group_size_, bits_, transpose_);
|
||||
return std::make_tuple(group_size_, bits_, transpose_, mode_);
|
||||
}
|
||||
|
||||
private:
|
||||
int group_size_;
|
||||
int bits_;
|
||||
bool transpose_;
|
||||
const std::string mode_;
|
||||
};
|
||||
|
||||
class GatherQMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
|
||||
explicit GatherQMM(
|
||||
Stream stream,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool transpose,
|
||||
const std::string& mode)
|
||||
: UnaryPrimitive(stream),
|
||||
group_size_(group_size),
|
||||
bits_(bits),
|
||||
transpose_(transpose) {}
|
||||
transpose_(transpose),
|
||||
mode_(mode) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1592,13 +1601,14 @@ class GatherQMM : public UnaryPrimitive {
|
||||
DEFINE_PRINT(GatherQMM)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(group_size_, bits_, transpose_);
|
||||
return std::make_tuple(group_size_, bits_, transpose_, mode_);
|
||||
}
|
||||
|
||||
private:
|
||||
int group_size_;
|
||||
int bits_;
|
||||
bool transpose_;
|
||||
const std::string mode_;
|
||||
};
|
||||
|
||||
class RandomBits : public UnaryPrimitive {
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
@ -39,6 +40,12 @@ class Embedding(Module):
|
||||
"""
|
||||
return x @ self.weight.T
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
def to_quantized(
|
||||
self,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: Literal["affine", "trellis"] = "affine",
|
||||
fake: bool = False,
|
||||
):
|
||||
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
|
||||
return QuantizedEmbedding.from_embedding(self, group_size, bits)
|
||||
|
@ -1,11 +1,12 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
from mlx.nn.layers.viterbi import quantize as trellis_quantize
|
||||
|
||||
|
||||
class Identity(Module):
|
||||
@ -70,9 +71,15 @@ class Linear(Module):
|
||||
x = x @ self["weight"].T
|
||||
return x
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
def to_quantized(
|
||||
self,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: Literal["affine", "trellis"] = "affine",
|
||||
fake: bool = False,
|
||||
):
|
||||
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
|
||||
return QuantizedLinear.from_linear(self, group_size, bits)
|
||||
return QuantizedLinear.from_linear(self, group_size, bits, mode=mode, fake=fake)
|
||||
|
||||
|
||||
class Bilinear(Module):
|
||||
|
@ -1,10 +1,11 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.viterbi import quantize as trellis_quantize
|
||||
from mlx.utils import tree_map_with_path
|
||||
|
||||
|
||||
@ -12,7 +13,9 @@ def quantize(
|
||||
model: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: Literal["affine", "trellis"] = "affine",
|
||||
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
|
||||
fake: bool = False,
|
||||
):
|
||||
"""Quantize the sub-modules of a module according to a predicate.
|
||||
|
||||
@ -21,7 +24,7 @@ def quantize(
|
||||
will be quantized. Note also, the module is updated in-place.
|
||||
|
||||
Args:
|
||||
model (mlx.nn.Module): The model whose leaf modules may be quantized.
|
||||
model (mlx.nn.Module):, mode: Literal["affine", "trellis"] = "affine" The model whose leaf modules may be quantized.
|
||||
group_size (int): The quantization group size (see
|
||||
:func:`mlx.core.quantize`). Default: ``64``.
|
||||
bits (int): The number of bits per parameter (see
|
||||
@ -36,12 +39,15 @@ def quantize(
|
||||
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
|
||||
|
||||
def _maybe_quantize(path, m):
|
||||
print(path)
|
||||
if bool_or_params := class_predicate(path, m):
|
||||
if hasattr(m, "to_quantized"):
|
||||
if isinstance(bool_or_params, bool):
|
||||
return m.to_quantized(group_size=group_size, bits=bits)
|
||||
return m.to_quantized(
|
||||
group_size=group_size, bits=bits, mode=mode, fake=fake
|
||||
)
|
||||
elif isinstance(bool_or_params, dict):
|
||||
return m.to_quantized(**bool_or_params)
|
||||
return m.to_quantized(**bool_or_params, fake=fake)
|
||||
else:
|
||||
raise ValueError(
|
||||
"``class_predicate`` must return a bool"
|
||||
@ -131,7 +137,11 @@ class QuantizedEmbedding(Module):
|
||||
|
||||
@classmethod
|
||||
def from_embedding(
|
||||
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
|
||||
cls,
|
||||
embedding_layer: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: Literal["affine", "trellis"] = "affine",
|
||||
):
|
||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||
embedding_dims, dims = embedding_layer.weight.shape
|
||||
@ -170,12 +180,14 @@ class QuantizedLinear(Module):
|
||||
bias: bool = True,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: Literal["affine", "trellis"] = "affine",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
self.mode = mode
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
@ -216,19 +228,40 @@ class QuantizedLinear(Module):
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
mode=self.mode,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + self["bias"]
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||
def from_linear(
|
||||
cls,
|
||||
linear_layer: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: Literal["affine", "trellis"] = "affine",
|
||||
fake: bool = False,
|
||||
):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, group_size, bits
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits, mode)
|
||||
if mode == "trellis":
|
||||
if fake:
|
||||
ql.weight = mx.zeros(
|
||||
(output_dims, input_dims // 32 * bits), dtype=mx.uint32
|
||||
)
|
||||
ql.scales = mx.array(0.0)
|
||||
ql.biases = mx.array(0.0)
|
||||
else:
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, bits=bits, mode="trellis"
|
||||
)
|
||||
else:
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, group_size, bits, mode="affine"
|
||||
)
|
||||
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
|
||||
|
@ -4116,10 +4116,11 @@ void init_ops(nb::module_& m) {
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||
quantization uses one floating point scale and bias per ``group_size`` of
|
||||
@ -4138,6 +4139,8 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
mode (str, optional): The mode to use for quantization.
|
||||
Default: ``affine``.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``.
|
||||
@ -4149,9 +4152,10 @@ void init_ops(nb::module_& m) {
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
nb::kw_only(),
|
||||
"mode"_a = "affine",
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, mode: Literal['affine', 'trellis'], stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
R"pbdoc(
|
||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||
|
||||
@ -4193,6 +4197,7 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. Default: ``4``.
|
||||
mode (str): The quantization mode to use. Default: ``affine``.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing
|
||||
@ -4249,10 +4254,11 @@ void init_ops(nb::module_& m) {
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform quantized matrix multiplication with matrix-level gather.
|
||||
|
||||
@ -4278,6 +4284,8 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
mode (str, optional): The mode to use for quantization.
|
||||
Default: ``affine``.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``
|
||||
|
@ -10,6 +10,9 @@ import mlx_tests
|
||||
class TestQuantized(mlx_tests.MLXTestCase):
|
||||
def test_quantize_dequantize(self):
|
||||
w = mx.random.normal(shape=(128, 512))
|
||||
w_q, scales, biases = mx.quantize(w, bits=2, mode="trellis")
|
||||
print(w_q, scales, biases)
|
||||
|
||||
for gs in [32, 64, 128]:
|
||||
for b in [2, 3, 6, 4, 8]:
|
||||
with self.subTest(gs=gs, b=b):
|
||||
|
Loading…
Reference in New Issue
Block a user