Support for quantized matmul with w and w^T (#349)

* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
This commit is contained in:
Angelos Katharopoulos
2024-01-03 14:22:36 -08:00
committed by GitHub
parent d7ac050f4b
commit e7f5059fe4
12 changed files with 718 additions and 193 deletions

View File

@@ -1,13 +1,62 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/backend/metal/copy.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, int bits, int group_size>
void _qmm(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Ng = N / group_size;
const int Nw = N / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) {
uint32_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
}
}
}
}
result += N;
}
}
template <typename T, int bits, int group_size>
void _qmm_t(
T* result,
@@ -55,7 +104,7 @@ void _qmm_t(
}
template <typename T>
void _qmm_t_dispatch_typed(
void _qmm_dispatch_typed(
T* result,
const T* x,
const uint32_t* w,
@@ -65,30 +114,55 @@ void _qmm_t_dispatch_typed(
int N,
int K,
int group_size,
int bits) {
int bits,
bool transposed_w) {
switch (bits) {
case 2: {
switch (group_size) {
case 64:
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
if (transposed_w) {
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
if (transposed_w) {
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 4: {
switch (group_size) {
case 64:
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
if (transposed_w) {
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 8: {
switch (group_size) {
case 64:
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
if (transposed_w) {
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
}
@@ -100,21 +174,22 @@ void _qmm_t_dispatch_typed(
throw std::invalid_argument(msg.str());
}
void _qmm_t_dispatch(
void _qmm_dispatch(
array out,
const array& x,
const array& w,
const array& scales,
const array& biases,
int bits,
int group_size) {
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.size() / K;
int N = w.shape(1);
int N = out.shape(-1);
switch (x.dtype()) {
case float32:
_qmm_t_dispatch_typed<float>(
_qmm_dispatch_typed<float>(
out.data<float>(),
x.data<float>(),
w.data<uint32_t>(),
@@ -124,10 +199,11 @@ void _qmm_t_dispatch(
N,
K,
bits,
group_size);
group_size,
transposed_w);
break;
case float16:
_qmm_t_dispatch_typed<float16_t>(
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>(),
x.data<float16_t>(),
w.data<uint32_t>(),
@@ -137,10 +213,11 @@ void _qmm_t_dispatch(
N,
K,
bits,
group_size);
group_size,
transposed_w);
break;
case bfloat16:
_qmm_t_dispatch_typed<bfloat16_t>(
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>(),
x.data<bfloat16_t>(),
w.data<uint32_t>(),
@@ -150,7 +227,8 @@ void _qmm_t_dispatch(
N,
K,
bits,
group_size);
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
@@ -163,22 +241,28 @@ void _qmm_t_dispatch(
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x = inputs[0];
auto& w = inputs[1];
auto& scales = inputs[2];
auto& biases = inputs[3];
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
if (w.strides()[0] != 1) {
throw std::runtime_error("The quantized weight should be transposed");
}
auto ensure_row_contiguous = [](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
if (!x.flags().row_contiguous || !scales.flags().row_contiguous ||
!biases.flags().row_contiguous) {
throw std::runtime_error("x, scales and biases should be row contiguous.");
}
auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_t_dispatch(out, x, w, scales, biases, group_size_, bits_);
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
}
} // namespace mlx::core