mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
parent
0c5eea226b
commit
c79f6a4a8c
@ -12,5 +12,4 @@ Fast
|
|||||||
layer_norm
|
layer_norm
|
||||||
rope
|
rope
|
||||||
scaled_dot_product_attention
|
scaled_dot_product_attention
|
||||||
affine_quantize
|
|
||||||
metal_kernel
|
metal_kernel
|
||||||
|
@ -6,11 +6,34 @@
|
|||||||
#include "mlx/backend/common/ops.h"
|
#include "mlx/backend/common/ops.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename T, int bits>
|
||||||
|
void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||||
|
assert(bits == 3 || bits == 6);
|
||||||
|
if (bits == 3) {
|
||||||
|
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
||||||
|
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
||||||
|
w_out[2] = static_cast<T>(((w_in[0] & 0xc0) >> 6) + ((w_in[1] & 0x1) << 2));
|
||||||
|
w_out[3] = static_cast<T>((w_in[1] & 0xe) >> 1);
|
||||||
|
w_out[4] = static_cast<T>((w_in[1] & 0x70) >> 4);
|
||||||
|
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
|
||||||
|
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
|
||||||
|
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
|
||||||
|
} else if (bits == 6) {
|
||||||
|
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
|
||||||
|
w_out[1] =
|
||||||
|
static_cast<T>(((w_in[0] >> 6) & 0x03) + ((w_in[1] & 0x0f) << 2));
|
||||||
|
w_out[2] =
|
||||||
|
static_cast<T>(((w_in[1] >> 4) & 0x0f) + ((w_in[2] & 0x03) << 4));
|
||||||
|
w_out[3] = static_cast<T>((w_in[2] >> 2) & 0x3f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, int bits, int group_size>
|
template <typename T, int bits, int group_size>
|
||||||
void _qmm(
|
void _qmm(
|
||||||
T* result,
|
T* result,
|
||||||
@ -22,13 +45,12 @@ void _qmm(
|
|||||||
int N,
|
int N,
|
||||||
int K) {
|
int K) {
|
||||||
constexpr int bitmask = (1 << bits) - 1;
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
constexpr int pack_factor = 32 / bits;
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||||
|
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||||
constexpr int packs_in_group = group_size / pack_factor;
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
const int Ng = N / group_size;
|
|
||||||
const int Nw = N / pack_factor;
|
|
||||||
|
|
||||||
for (int m = 0; m < M; m++) {
|
for (int m = 0; m < M; m++) {
|
||||||
const uint32_t* w_local = w;
|
const uint8_t* w_local = (const uint8_t*)w;
|
||||||
const T* scales_local = scales;
|
const T* scales_local = scales;
|
||||||
const T* biases_local = biases;
|
const T* biases_local = biases;
|
||||||
|
|
||||||
@ -42,17 +64,29 @@ void _qmm(
|
|||||||
T scale = *scales_local++;
|
T scale = *scales_local++;
|
||||||
T bias = *biases_local++;
|
T bias = *biases_local++;
|
||||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||||
uint32_t wi = *w_local++;
|
if (bits == 3 || bits == 6) {
|
||||||
|
T wl[pack_factor];
|
||||||
|
extract_bits<T, bits>(w_local, wl);
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int p = 0; p < pack_factor; p++) {
|
||||||
|
(*result_local++) += xi * (scale * wl[p] + bias);
|
||||||
|
}
|
||||||
|
w_local += bytes_per_pack;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
uint8_t wi = *w_local++;
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int p = 0; p < pack_factor; p++) {
|
for (int p = 0; p < pack_factor; p++) {
|
||||||
(*result_local++) +=
|
(*result_local++) +=
|
||||||
xi * (scale * static_cast<T>(wi & bitmask) + bias);
|
xi * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||||
|
if (bits != 8) {
|
||||||
wi >>= bits;
|
wi >>= bits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
result += N;
|
result += N;
|
||||||
}
|
}
|
||||||
@ -69,13 +103,12 @@ void _qmm_t(
|
|||||||
int N,
|
int N,
|
||||||
int K) {
|
int K) {
|
||||||
constexpr int bitmask = (1 << bits) - 1;
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
constexpr int pack_factor = 32 / bits;
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||||
|
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||||
constexpr int packs_in_group = group_size / pack_factor;
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
const int Kg = K / group_size;
|
|
||||||
const int Kw = K / pack_factor;
|
|
||||||
|
|
||||||
for (int m = 0; m < M; m++) {
|
for (int m = 0; m < M; m++) {
|
||||||
const uint32_t* w_local = w;
|
const uint8_t* w_local = (const uint8_t*)w;
|
||||||
const T* scales_local = scales;
|
const T* scales_local = scales;
|
||||||
const T* biases_local = biases;
|
const T* biases_local = biases;
|
||||||
|
|
||||||
@ -87,15 +120,29 @@ void _qmm_t(
|
|||||||
T bias = *biases_local++;
|
T bias = *biases_local++;
|
||||||
|
|
||||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||||
uint32_t wi = *w_local++;
|
if (bits == 3 || bits == 6) {
|
||||||
|
T wl[pack_factor];
|
||||||
|
extract_bits<T, bits>(w_local, wl);
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int p = 0; p < pack_factor; p++) {
|
for (int p = 0; p < pack_factor; p++) {
|
||||||
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
|
sum += x_local[p] * (scale * wl[p] + bias);
|
||||||
|
}
|
||||||
|
w_local += bytes_per_pack;
|
||||||
|
x_local += pack_factor;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
uint8_t wi = *w_local++;
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int p = 0; p < pack_factor; p++) {
|
||||||
|
sum +=
|
||||||
|
(*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||||
|
if (bits != 8) {
|
||||||
wi >>= bits;
|
wi >>= bits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
*result = sum;
|
*result = sum;
|
||||||
result++;
|
result++;
|
||||||
}
|
}
|
||||||
@ -104,6 +151,55 @@ void _qmm_t(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, int bits, int group_size>
|
||||||
|
void _qmm_dispatch_transpose(
|
||||||
|
T* result,
|
||||||
|
const T* x,
|
||||||
|
const uint32_t* w,
|
||||||
|
const T* scales,
|
||||||
|
const T* biases,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
bool transposed_w) {
|
||||||
|
if (transposed_w) {
|
||||||
|
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||||
|
} else {
|
||||||
|
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int bits>
|
||||||
|
void _qmm_dispatch_group(
|
||||||
|
T* result,
|
||||||
|
const T* x,
|
||||||
|
const uint32_t* w,
|
||||||
|
const T* scales,
|
||||||
|
const T* biases,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int group_size,
|
||||||
|
bool transposed_w) {
|
||||||
|
switch (group_size) {
|
||||||
|
case 32:
|
||||||
|
_qmm_dispatch_transpose<T, bits, 32>(
|
||||||
|
result, x, w, scales, biases, M, N, K, transposed_w);
|
||||||
|
break;
|
||||||
|
case 64:
|
||||||
|
_qmm_dispatch_transpose<T, bits, 64>(
|
||||||
|
result, x, w, scales, biases, M, N, K, transposed_w);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
_qmm_dispatch_transpose<T, bits, 128>(
|
||||||
|
result, x, w, scales, biases, M, N, K, transposed_w);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Quantization group size must be 32, 64 or 128.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void _qmm_dispatch_typed(
|
void _qmm_dispatch_typed(
|
||||||
T* result,
|
T* result,
|
||||||
@ -118,79 +214,29 @@ void _qmm_dispatch_typed(
|
|||||||
int bits,
|
int bits,
|
||||||
bool transposed_w) {
|
bool transposed_w) {
|
||||||
switch (bits) {
|
switch (bits) {
|
||||||
case 2: {
|
case 2:
|
||||||
switch (group_size) {
|
_qmm_dispatch_group<T, 2>(
|
||||||
case 32:
|
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||||
if (transposed_w) {
|
break;
|
||||||
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
case 3:
|
||||||
} else {
|
_qmm_dispatch_group<T, 3>(
|
||||||
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
_qmm_dispatch_group<T, 4>(
|
||||||
|
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
_qmm_dispatch_group<T, 6>(
|
||||||
|
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
_qmm_dispatch_group<T, 8>(
|
||||||
|
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8.");
|
||||||
}
|
}
|
||||||
case 64:
|
|
||||||
if (transposed_w) {
|
|
||||||
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
|
||||||
} else {
|
|
||||||
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
|
||||||
}
|
|
||||||
case 128:
|
|
||||||
if (transposed_w) {
|
|
||||||
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
|
|
||||||
} else {
|
|
||||||
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case 4: {
|
|
||||||
switch (group_size) {
|
|
||||||
case 32:
|
|
||||||
if (transposed_w) {
|
|
||||||
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
|
||||||
} else {
|
|
||||||
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
|
||||||
}
|
|
||||||
case 64:
|
|
||||||
if (transposed_w) {
|
|
||||||
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
|
||||||
} else {
|
|
||||||
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
|
||||||
}
|
|
||||||
case 128:
|
|
||||||
if (transposed_w) {
|
|
||||||
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
|
|
||||||
} else {
|
|
||||||
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case 8: {
|
|
||||||
switch (group_size) {
|
|
||||||
case 32:
|
|
||||||
if (transposed_w) {
|
|
||||||
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
|
||||||
} else {
|
|
||||||
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
|
||||||
}
|
|
||||||
case 64:
|
|
||||||
if (transposed_w) {
|
|
||||||
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
|
||||||
} else {
|
|
||||||
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
|
||||||
}
|
|
||||||
case 128:
|
|
||||||
if (transposed_w) {
|
|
||||||
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
|
|
||||||
} else {
|
|
||||||
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Quantization type not supported. Provided bits=" << bits
|
|
||||||
<< " and group_size=" << group_size
|
|
||||||
<< ". The supported options are bits in "
|
|
||||||
<< "{2, 4, 8} and group_size in {64, 128}.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void _qmm_dispatch(
|
void _qmm_dispatch(
|
||||||
@ -406,29 +452,31 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
transpose_);
|
transpose_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, typename U>
|
||||||
void quantize(
|
void quantize(
|
||||||
const array& w_,
|
const array& w_,
|
||||||
array& out_,
|
array& out_,
|
||||||
array& scales_,
|
array& scales_,
|
||||||
array& biases_,
|
array& biases_,
|
||||||
int bits,
|
int bits,
|
||||||
int group_size,
|
int group_size) {
|
||||||
bool compute_scale_bias) {
|
|
||||||
const T* w = w_.data<T>();
|
const T* w = w_.data<T>();
|
||||||
|
|
||||||
|
auto out = out_.data<U>();
|
||||||
T* scales = scales_.data<T>();
|
T* scales = scales_.data<T>();
|
||||||
T* biases = biases_.data<T>();
|
T* biases = biases_.data<T>();
|
||||||
auto out = out_.data<uint32_t>();
|
|
||||||
|
|
||||||
T n_bins = (1 << bits) - 1;
|
T n_bins = (1 << bits) - 1;
|
||||||
T eps = 1e-7;
|
T eps = 1e-7;
|
||||||
int el_per_int = 32 / bits;
|
bool power_of_2_bits = is_power_of_2(bits);
|
||||||
int int_per_group = group_size / el_per_int;
|
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||||
|
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
||||||
|
int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||||
|
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
||||||
size_t n_groups = w_.size() / group_size;
|
size_t n_groups = w_.size() / group_size;
|
||||||
|
|
||||||
for (size_t i = 0; i < n_groups; ++i) {
|
for (size_t i = 0; i < n_groups; ++i) {
|
||||||
size_t w_idx = i * group_size;
|
size_t w_idx = i * group_size;
|
||||||
if (compute_scale_bias) {
|
|
||||||
T w_min = std::numeric_limits<float>::infinity();
|
T w_min = std::numeric_limits<float>::infinity();
|
||||||
T w_max = -w_min;
|
T w_max = -w_min;
|
||||||
for (int j = 0; j < group_size; ++j) {
|
for (int j = 0; j < group_size; ++j) {
|
||||||
@ -448,9 +496,8 @@ void quantize(
|
|||||||
scales[i] = edge / q0;
|
scales[i] = edge / q0;
|
||||||
biases[i] = edge;
|
biases[i] = edge;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
size_t out_idx = i * int_per_group;
|
size_t out_idx = i * int_per_group;
|
||||||
for (int j = 0; j < int_per_group; ++j) {
|
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
||||||
uint32_t out_el = 0;
|
uint32_t out_el = 0;
|
||||||
for (int k = 0; k < el_per_int; ++k) {
|
for (int k = 0; k < el_per_int; ++k) {
|
||||||
T w_el = w[w_idx + j * el_per_int + k];
|
T w_el = w[w_idx + j * el_per_int + k];
|
||||||
@ -458,7 +505,13 @@ void quantize(
|
|||||||
w_el = std::min(std::max(w_el, T(0)), n_bins);
|
w_el = std::min(std::max(w_el, T(0)), n_bins);
|
||||||
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
||||||
}
|
}
|
||||||
|
if (power_of_2_bits) {
|
||||||
out[out_idx + j] = out_el;
|
out[out_idx + j] = out_el;
|
||||||
|
} else {
|
||||||
|
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||||
|
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||||
|
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -466,8 +519,6 @@ void quantize(
|
|||||||
void fast::AffineQuantize::eval_cpu(
|
void fast::AffineQuantize::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
bool compute_scale_bias = inputs.size() == 1;
|
|
||||||
|
|
||||||
auto ensure_row_contiguous = [](const array& arr) {
|
auto ensure_row_contiguous = [](const array& arr) {
|
||||||
if (arr.flags().row_contiguous) {
|
if (arr.flags().row_contiguous) {
|
||||||
return arr;
|
return arr;
|
||||||
@ -482,23 +533,29 @@ void fast::AffineQuantize::eval_cpu(
|
|||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto& scales =
|
auto& scales = outputs[1];
|
||||||
compute_scale_bias ? outputs[1] : const_cast<array&>(inputs[1]);
|
auto& biases = outputs[2];
|
||||||
auto& biases =
|
|
||||||
compute_scale_bias ? outputs[2] : const_cast<array&>(inputs[2]);
|
|
||||||
if (compute_scale_bias) {
|
|
||||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||||
}
|
|
||||||
if (w.dtype() == float16) {
|
if (w.dtype() == float16) {
|
||||||
quantize<float16_t>(
|
if (is_power_of_2(bits_)) {
|
||||||
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
|
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
|
||||||
|
} else {
|
||||||
|
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||||
|
}
|
||||||
} else if (w.dtype() == bfloat16) {
|
} else if (w.dtype() == bfloat16) {
|
||||||
quantize<bfloat16_t>(
|
if (is_power_of_2(bits_)) {
|
||||||
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
|
quantize<bfloat16_t, uint32_t>(
|
||||||
|
w, out, scales, biases, bits_, group_size_);
|
||||||
|
} else {
|
||||||
|
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||||
|
}
|
||||||
} else if (w.dtype() == float32) {
|
} else if (w.dtype() == float32) {
|
||||||
quantize<float>(
|
if (is_power_of_2(bits_)) {
|
||||||
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
|
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
|
||||||
|
} else {
|
||||||
|
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||||
|
@ -13,8 +13,8 @@ MLX_MTL_CONST int QUAD_SIZE = 4;
|
|||||||
template <typename T, typename U, int values_per_thread, int bits>
|
template <typename T, typename U, int values_per_thread, int bits>
|
||||||
inline U load_vector(const device T* x, thread U* x_thread) {
|
inline U load_vector(const device T* x, thread U* x_thread) {
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 4 || bits == 8,
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||||
"Template undefined for bits not in {2, 4, 8}");
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||||
|
|
||||||
U sum = 0;
|
U sum = 0;
|
||||||
|
|
||||||
@ -28,6 +28,21 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (bits == 3) {
|
||||||
|
for (int i = 0; i < values_per_thread; i += 8) {
|
||||||
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
||||||
|
x[i + 6] + x[i + 7];
|
||||||
|
x_thread[i] = x[i];
|
||||||
|
x_thread[i + 1] = x[i + 1] / 8.0f;
|
||||||
|
x_thread[i + 2] = x[i + 2] / 64.0f;
|
||||||
|
x_thread[i + 3] = x[i + 3] / 2.0f;
|
||||||
|
x_thread[i + 4] = x[i + 4] / 16.0f;
|
||||||
|
x_thread[i + 5] = x[i + 5] / 128.0f;
|
||||||
|
x_thread[i + 6] = x[i + 6] / 4.0f;
|
||||||
|
x_thread[i + 7] = x[i + 7] / 32.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
for (int i = 0; i < values_per_thread; i += 4) {
|
for (int i = 0; i < values_per_thread; i += 4) {
|
||||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||||
@ -38,6 +53,16 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (bits == 6) {
|
||||||
|
for (int i = 0; i < values_per_thread; i += 4) {
|
||||||
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||||
|
x_thread[i] = x[i];
|
||||||
|
x_thread[i + 1] = x[i + 1] / 64.0f;
|
||||||
|
x_thread[i + 2] = x[i + 2] / 16.0f;
|
||||||
|
x_thread[i + 3] = x[i + 3] / 4.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
else if (bits == 8) {
|
else if (bits == 8) {
|
||||||
for (int i = 0; i < values_per_thread; i++) {
|
for (int i = 0; i < values_per_thread; i++) {
|
||||||
sum += x[i];
|
sum += x[i];
|
||||||
@ -51,8 +76,8 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
|||||||
template <typename T, typename U, int values_per_thread, int bits>
|
template <typename T, typename U, int values_per_thread, int bits>
|
||||||
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 4 || bits == 8,
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||||
"Template undefined for bits not in {2, 4, 8}");
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||||
|
|
||||||
U sum = 0;
|
U sum = 0;
|
||||||
|
|
||||||
@ -64,8 +89,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|||||||
x_thread[i + 2] = x[i + 2] / 16.0f;
|
x_thread[i + 2] = x[i + 2] / 16.0f;
|
||||||
x_thread[i + 3] = x[i + 3] / 64.0f;
|
x_thread[i + 3] = x[i + 3] / 64.0f;
|
||||||
}
|
}
|
||||||
for (int i = N; i < values_per_thread; i++) {
|
}
|
||||||
x_thread[i] = 0;
|
|
||||||
|
else if (bits == 3) {
|
||||||
|
for (int i = 0; i < N; i += 8) {
|
||||||
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
||||||
|
x[i + 6] + x[i + 7];
|
||||||
|
|
||||||
|
x_thread[i] = x[i];
|
||||||
|
x_thread[i + 1] = x[i + 1] / 8.0f;
|
||||||
|
x_thread[i + 2] = x[i + 2] / 64.0f;
|
||||||
|
x_thread[i + 3] = x[i + 3] / 2.0f;
|
||||||
|
x_thread[i + 4] = x[i + 4] / 16.0f;
|
||||||
|
x_thread[i + 5] = x[i + 5] / 128.0f;
|
||||||
|
x_thread[i + 6] = x[i + 6] / 4.0f;
|
||||||
|
x_thread[i + 7] = x[i + 7] / 32.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,8 +115,15 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|||||||
x_thread[i + 2] = x[i + 2] / 256.0f;
|
x_thread[i + 2] = x[i + 2] / 256.0f;
|
||||||
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
||||||
}
|
}
|
||||||
for (int i = N; i < values_per_thread; i++) {
|
}
|
||||||
x_thread[i] = 0;
|
|
||||||
|
else if (bits == 6) {
|
||||||
|
for (int i = 0; i < N; i += 4) {
|
||||||
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||||
|
x_thread[i] = x[i];
|
||||||
|
x_thread[i + 1] = x[i + 1] / 64.0f;
|
||||||
|
x_thread[i + 2] = x[i + 2] / 16.0f;
|
||||||
|
x_thread[i + 3] = x[i + 3] / 4.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,10 +132,11 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|||||||
sum += x[i];
|
sum += x[i];
|
||||||
x_thread[i] = x[i];
|
x_thread[i] = x[i];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = N; i < values_per_thread; i++) {
|
for (int i = N; i < values_per_thread; i++) {
|
||||||
x_thread[i] = 0;
|
x_thread[i] = 0;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return sum;
|
return sum;
|
||||||
}
|
}
|
||||||
@ -103,8 +149,8 @@ inline U qdot(
|
|||||||
U bias,
|
U bias,
|
||||||
U sum) {
|
U sum) {
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 4 || bits == 8,
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||||
"Template undefined for bits not in {2, 4, 8}");
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||||
|
|
||||||
U accum = 0;
|
U accum = 0;
|
||||||
|
|
||||||
@ -118,6 +164,26 @@ inline U qdot(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (bits == 3) {
|
||||||
|
for (int i = 0; i < (values_per_thread / 8); i++) {
|
||||||
|
x_thread += 8 * i;
|
||||||
|
w += 3 * i;
|
||||||
|
|
||||||
|
accum += (w[0] & 0x07) * x_thread[0];
|
||||||
|
accum += (w[0] & 0x38) * x_thread[1];
|
||||||
|
accum += (w[0] & 0xc0) * x_thread[2];
|
||||||
|
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
|
||||||
|
|
||||||
|
accum += (w[1] & 0x0e) * x_thread[3];
|
||||||
|
accum += (w[1] & 0x70) * x_thread[4];
|
||||||
|
accum += (w[1] & 0x80) * x_thread[5];
|
||||||
|
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
|
||||||
|
|
||||||
|
accum += (w[2] & 0x1c) * x_thread[6];
|
||||||
|
accum += (w[2] & 0xe0) * x_thread[7];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
const device uint16_t* ws = (const device uint16_t*)w;
|
const device uint16_t* ws = (const device uint16_t*)w;
|
||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
@ -129,6 +195,23 @@ inline U qdot(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (bits == 6) {
|
||||||
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
|
x_thread += 4 * i;
|
||||||
|
w += 3 * i;
|
||||||
|
|
||||||
|
accum += (w[0] & 0x3f) * x_thread[0];
|
||||||
|
|
||||||
|
accum += (w[0] & 0xc0) * x_thread[1];
|
||||||
|
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
|
||||||
|
|
||||||
|
accum += (w[1] & 0xf0) * x_thread[2];
|
||||||
|
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
|
||||||
|
|
||||||
|
accum += (w[2] & 0xfc) * x_thread[3];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
else if (bits == 8) {
|
else if (bits == 8) {
|
||||||
for (int i = 0; i < values_per_thread; i++) {
|
for (int i = 0; i < values_per_thread; i++) {
|
||||||
accum += x_thread[i] * w[i];
|
accum += x_thread[i] * w[i];
|
||||||
@ -147,8 +230,8 @@ inline U qdot_safe(
|
|||||||
U sum,
|
U sum,
|
||||||
int N) {
|
int N) {
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 4 || bits == 8,
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||||
"Template undefined for bits not in {2, 4, 8}");
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||||
|
|
||||||
U accum = 0;
|
U accum = 0;
|
||||||
|
|
||||||
@ -162,6 +245,26 @@ inline U qdot_safe(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (bits == 3) {
|
||||||
|
for (int i = 0; i < (N / 8); i++) {
|
||||||
|
x_thread += 8 * i;
|
||||||
|
w += 3 * i;
|
||||||
|
|
||||||
|
accum += (w[0] & 0x07) * x_thread[0];
|
||||||
|
accum += (w[0] & 0x38) * x_thread[1];
|
||||||
|
accum += (w[0] & 0xc0) * x_thread[2];
|
||||||
|
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
|
||||||
|
|
||||||
|
accum += (w[1] & 0x0e) * x_thread[3];
|
||||||
|
accum += (w[1] & 0x70) * x_thread[4];
|
||||||
|
accum += (w[1] & 0x80) * x_thread[5];
|
||||||
|
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
|
||||||
|
|
||||||
|
accum += (w[2] & 0x1c) * x_thread[6];
|
||||||
|
accum += (w[2] & 0xe0) * x_thread[7];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
const device uint16_t* ws = (const device uint16_t*)w;
|
const device uint16_t* ws = (const device uint16_t*)w;
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
@ -173,6 +276,23 @@ inline U qdot_safe(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (bits == 6) {
|
||||||
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
|
x_thread += 4 * i;
|
||||||
|
w += 3 * i;
|
||||||
|
|
||||||
|
accum += (w[0] & 0x3f) * x_thread[0];
|
||||||
|
|
||||||
|
accum += (w[0] & 0xc0) * x_thread[1];
|
||||||
|
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
|
||||||
|
|
||||||
|
accum += (w[1] & 0xf0) * x_thread[2];
|
||||||
|
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
|
||||||
|
|
||||||
|
accum += (w[2] & 0xfc) * x_thread[3];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
else if (bits == 8) {
|
else if (bits == 8) {
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
accum += x_thread[i] * w[i];
|
accum += x_thread[i] * w[i];
|
||||||
@ -186,8 +306,8 @@ template <typename U, int values_per_thread, int bits>
|
|||||||
inline void
|
inline void
|
||||||
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 4 || bits == 8,
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||||
"Template undefined for bits not in {2, 4, 8}");
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
||||||
@ -199,12 +319,45 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (bits == 3) {
|
||||||
|
for (int i = 0; i < (values_per_thread / 8); i++) {
|
||||||
|
uint8_t w0 = w[3 * i];
|
||||||
|
uint8_t w1 = w[3 * i + 1];
|
||||||
|
uint8_t w2 = w[3 * i + 2];
|
||||||
|
|
||||||
|
result[8 * i] += x * ((w0 & 0x7) * scale + bias);
|
||||||
|
result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
|
||||||
|
result[8 * i + 2] +=
|
||||||
|
x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
|
||||||
|
result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
|
||||||
|
result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
|
||||||
|
result[8 * i + 5] +=
|
||||||
|
x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
|
||||||
|
result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
|
||||||
|
result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
U s[2] = {scale, scale / 16.0f};
|
U s[2] = {scale, scale / 16.0f};
|
||||||
for (int i = 0; i < (values_per_thread / 2); i++) {
|
for (int i = 0; i < (values_per_thread / 2); i++) {
|
||||||
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
|
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
|
||||||
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
|
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} else if (bits == 6) {
|
||||||
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
|
uint8_t w0 = w[3 * i];
|
||||||
|
uint8_t w1 = w[3 * i + 1];
|
||||||
|
uint8_t w2 = w[3 * i + 2];
|
||||||
|
|
||||||
|
result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
|
||||||
|
result[4 * i + 1] +=
|
||||||
|
x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
|
||||||
|
result[4 * i + 2] +=
|
||||||
|
x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
|
||||||
|
result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
else if (bits == 8) {
|
else if (bits == 8) {
|
||||||
@ -218,8 +371,8 @@ template <typename U, int N, int bits>
|
|||||||
inline void
|
inline void
|
||||||
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 4 || bits == 8,
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||||
"Template undefined for bits not in {2, 4, 8}");
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
U s[4] = {
|
U s[4] = {
|
||||||
@ -235,6 +388,22 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (bits == 3) {
|
||||||
|
for (int i = 0; i < (N / 8); i++) {
|
||||||
|
w_local += 8 * i;
|
||||||
|
w += 3 * i;
|
||||||
|
|
||||||
|
w_local[0] = (w[0] & 0x7) * scale + bias;
|
||||||
|
w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
|
||||||
|
w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
|
||||||
|
w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
|
||||||
|
w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
|
||||||
|
w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
|
||||||
|
w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
||||||
|
w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
U s[2] = {scale, scale / static_cast<U>(16.0f)};
|
U s[2] = {scale, scale / static_cast<U>(16.0f)};
|
||||||
for (int i = 0; i < (N / 2); i++) {
|
for (int i = 0; i < (N / 2); i++) {
|
||||||
@ -243,6 +412,18 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (bits == 6) {
|
||||||
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
|
w_local += 4 * i;
|
||||||
|
w += 3 * i;
|
||||||
|
|
||||||
|
w_local[0] = (w[0] & 0x3f) * scale + bias;
|
||||||
|
w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
|
||||||
|
w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
|
||||||
|
w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
else if (bits == 8) {
|
else if (bits == 8) {
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
w_local[i] = scale * w[i] + bias;
|
w_local[i] = scale * w[i] + bias;
|
||||||
@ -267,10 +448,11 @@ struct QuantizedBlockLoader {
|
|||||||
group_size % BCOLS == 0,
|
group_size % BCOLS == 0,
|
||||||
"The group size should be divisible by the columns");
|
"The group size should be divisible by the columns");
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 4 || bits == 8,
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||||
"Template undefined for bits not in {2, 4, 8}");
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||||
|
|
||||||
MLX_MTL_CONST short pack_factor = 32 / bits;
|
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||||
|
MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||||
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
||||||
MLX_MTL_CONST short n_reads =
|
MLX_MTL_CONST short n_reads =
|
||||||
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
||||||
@ -286,12 +468,12 @@ struct QuantizedBlockLoader {
|
|||||||
const short bj;
|
const short bj;
|
||||||
|
|
||||||
threadgroup T* dst;
|
threadgroup T* dst;
|
||||||
const device uint32_t* src;
|
const device uint8_t* src;
|
||||||
const device T* scales;
|
const device T* scales;
|
||||||
const device T* biases;
|
const device T* biases;
|
||||||
|
|
||||||
QuantizedBlockLoader(
|
QuantizedBlockLoader(
|
||||||
const device uint32_t* src_,
|
const device uint8_t* src_,
|
||||||
const device T* scales_,
|
const device T* scales_,
|
||||||
const device T* biases_,
|
const device T* biases_,
|
||||||
const int src_ld_,
|
const int src_ld_,
|
||||||
@ -300,14 +482,16 @@ struct QuantizedBlockLoader {
|
|||||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
: src_ld(src_ld_),
|
: src_ld(src_ld_),
|
||||||
tile_stride(
|
tile_stride(
|
||||||
reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
|
reduction_dim ? BCOLS_PACKED * bytes_per_pack
|
||||||
|
: BROWS * src_ld * bytes_per_pack / pack_factor),
|
||||||
group_step_cnt(0),
|
group_step_cnt(0),
|
||||||
group_stride(BROWS * src_ld / group_size),
|
group_stride(BROWS * src_ld / group_size),
|
||||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||||
bi(n_reads * thread_idx / BCOLS_PACKED),
|
bi(n_reads * thread_idx / BCOLS_PACKED),
|
||||||
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
||||||
dst(dst_ + bi * dst_ld + bj * pack_factor),
|
dst(dst_ + bi * dst_ld + bj * pack_factor),
|
||||||
src(src_ + bi * src_ld / pack_factor + bj),
|
src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
|
||||||
|
bj * bytes_per_pack),
|
||||||
scales(scales_ + bi * src_ld / group_size),
|
scales(scales_ + bi * src_ld / group_size),
|
||||||
biases(biases_ + bi * src_ld / group_size) {}
|
biases(biases_ + bi * src_ld / group_size) {}
|
||||||
|
|
||||||
@ -320,7 +504,7 @@ struct QuantizedBlockLoader {
|
|||||||
T bias = *biases;
|
T bias = *biases;
|
||||||
for (int i = 0; i < n_reads; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
dequantize<T, pack_factor, bits>(
|
dequantize<T, pack_factor, bits>(
|
||||||
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -347,7 +531,10 @@ struct QuantizedBlockLoader {
|
|||||||
T bias = *biases;
|
T bias = *biases;
|
||||||
for (int i = 0; i < n_reads; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
dequantize<T, pack_factor, bits>(
|
dequantize<T, pack_factor, bits>(
|
||||||
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
(device uint8_t*)(src + i * bytes_per_pack),
|
||||||
|
scale,
|
||||||
|
bias,
|
||||||
|
dst + i * pack_factor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -410,8 +597,7 @@ METAL_FUNC void qmv_quad_impl(
|
|||||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
|
||||||
for (int row = 0; row < results_per_quadgroup; row++) {
|
for (int row = 0; row < results_per_quadgroup; row++) {
|
||||||
const device uint8_t* wl =
|
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
||||||
(const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
|
||||||
const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
|
const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
|
||||||
const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
|
const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
|
||||||
|
|
||||||
@ -442,25 +628,34 @@ METAL_FUNC void qmv_fast_impl(
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
constexpr int packs_per_thread = bits > 2 ? 2 : 1;
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
|
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int results_per_simdgroup = 4;
|
constexpr int results_per_simdgroup = 4;
|
||||||
constexpr int pack_factor = 32 / bits;
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||||
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||||
|
|
||||||
|
// When bits is a power of two, read 1 uint32_t at a time
|
||||||
|
// When bits is 3 or 6, read 3 uint8_ts at a time
|
||||||
|
using W_T =
|
||||||
|
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
|
||||||
|
const device W_T* ws = (const device W_T*)w;
|
||||||
|
|
||||||
typedef float U;
|
typedef float U;
|
||||||
|
|
||||||
thread U x_thread[values_per_thread];
|
thread U x_thread[values_per_thread];
|
||||||
thread U result[results_per_simdgroup] = {0};
|
thread U result[results_per_simdgroup] = {0};
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||||
const int in_vec_size_g = in_vec_size / group_size;
|
const int in_vec_size_g = in_vec_size / group_size;
|
||||||
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
|
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
|
||||||
simd_gid * results_per_simdgroup;
|
simd_gid * results_per_simdgroup;
|
||||||
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
|
||||||
|
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
||||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||||
@ -470,8 +665,7 @@ METAL_FUNC void qmv_fast_impl(
|
|||||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
const device uint8_t* wl =
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
@ -480,7 +674,7 @@ METAL_FUNC void qmv_fast_impl(
|
|||||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
w += block_size / pack_factor;
|
ws += block_size * bytes_per_pack / pack_factor;
|
||||||
scales += block_size / group_size;
|
scales += block_size / group_size;
|
||||||
biases += block_size / group_size;
|
biases += block_size / group_size;
|
||||||
x += block_size;
|
x += block_size;
|
||||||
@ -506,21 +700,29 @@ METAL_FUNC void qmv_impl(
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int results_per_simdgroup = 4;
|
constexpr int results_per_simdgroup = 4;
|
||||||
constexpr int packs_per_thread = 1;
|
constexpr int packs_per_thread = 1;
|
||||||
constexpr int pack_factor = 32 / bits;
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||||
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||||
|
|
||||||
|
// When bits is a power of two, read 1 uint32_t at a time
|
||||||
|
// When bits is 3 or 6, read 3 uint8_ts at a time
|
||||||
|
using W_T =
|
||||||
|
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
|
||||||
|
const device W_T* ws = (const device W_T*)w;
|
||||||
|
|
||||||
typedef float U;
|
typedef float U;
|
||||||
|
|
||||||
thread U x_thread[values_per_thread];
|
thread U x_thread[values_per_thread];
|
||||||
thread U result[results_per_simdgroup] = {0};
|
thread U result[results_per_simdgroup] = {0};
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||||
const int in_vec_size_g = in_vec_size / group_size;
|
const int in_vec_size_g = in_vec_size / group_size;
|
||||||
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
|
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
|
||||||
simd_gid * results_per_simdgroup;
|
simd_gid * results_per_simdgroup;
|
||||||
@ -533,7 +735,8 @@ METAL_FUNC void qmv_impl(
|
|||||||
// In this case we need to properly guard all our reads because there isn't
|
// In this case we need to properly guard all our reads because there isn't
|
||||||
// even 1 tile in the matrix
|
// even 1 tile in the matrix
|
||||||
if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
|
if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
|
||||||
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
ws +=
|
||||||
|
out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
||||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||||
@ -544,8 +747,7 @@ METAL_FUNC void qmv_impl(
|
|||||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
|
||||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||||
const device uint8_t* wl =
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
@ -555,7 +757,7 @@ METAL_FUNC void qmv_impl(
|
|||||||
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
w += block_size / pack_factor;
|
ws += block_size * bytes_per_pack / pack_factor;
|
||||||
scales += block_size / group_size;
|
scales += block_size / group_size;
|
||||||
biases += block_size / group_size;
|
biases += block_size / group_size;
|
||||||
x += block_size;
|
x += block_size;
|
||||||
@ -569,8 +771,7 @@ METAL_FUNC void qmv_impl(
|
|||||||
x, x_thread, remaining);
|
x, x_thread, remaining);
|
||||||
|
|
||||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||||
const device uint8_t* wl =
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
@ -591,7 +792,8 @@ METAL_FUNC void qmv_impl(
|
|||||||
|
|
||||||
// In this case the last tile is moved back to redo some output values
|
// In this case the last tile is moved back to redo some output values
|
||||||
else {
|
else {
|
||||||
w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
ws += used_out_row * in_vec_size_w +
|
||||||
|
simd_lid * packs_per_thread * bytes_per_pack;
|
||||||
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||||
@ -602,8 +804,7 @@ METAL_FUNC void qmv_impl(
|
|||||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
const device uint8_t* wl =
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
@ -613,7 +814,7 @@ METAL_FUNC void qmv_impl(
|
|||||||
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
w += block_size / pack_factor;
|
ws += block_size * bytes_per_pack / pack_factor;
|
||||||
scales += block_size / group_size;
|
scales += block_size / group_size;
|
||||||
biases += block_size / group_size;
|
biases += block_size / group_size;
|
||||||
x += block_size;
|
x += block_size;
|
||||||
@ -627,8 +828,7 @@ METAL_FUNC void qmv_impl(
|
|||||||
x, x_thread, remaining);
|
x, x_thread, remaining);
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
const device uint8_t* wl =
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||||
(const device uint8_t*)(w + row * in_vec_size_w);
|
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
@ -659,14 +859,22 @@ METAL_FUNC void qvm_impl(
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int pack_factor = 32 / bits;
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||||
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||||
constexpr int tn = 32 / pack_factor;
|
constexpr int tn = 32 / pack_factor;
|
||||||
constexpr int blocksize = SIMD_SIZE;
|
constexpr int block_size = SIMD_SIZE;
|
||||||
|
|
||||||
|
// When bits is a power of two, read 1 uint32_t at a time
|
||||||
|
// When bits is 3 or 6, read 3 uint8_ts at a time
|
||||||
|
using W_T =
|
||||||
|
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
|
||||||
|
const device W_T* ws = (const device W_T*)w;
|
||||||
|
|
||||||
typedef float U;
|
typedef float U;
|
||||||
typedef struct {
|
typedef struct {
|
||||||
uint32_t wi[tn];
|
W_T wi[tn * bytes_per_pack];
|
||||||
} vec_w;
|
} vec_w;
|
||||||
|
|
||||||
thread vec_w w_local;
|
thread vec_w w_local;
|
||||||
@ -676,11 +884,10 @@ METAL_FUNC void qvm_impl(
|
|||||||
thread U x_local = 0;
|
thread U x_local = 0;
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int out_vec_size_w = out_vec_size / pack_factor;
|
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
|
||||||
const int out_vec_size_g = out_vec_size / group_size;
|
const int out_vec_size_g = out_vec_size / group_size;
|
||||||
int out_col =
|
int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid);
|
||||||
tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
|
ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
|
||||||
w += out_col / pack_factor + simd_lid * out_vec_size_w;
|
|
||||||
scales += out_col / group_size + simd_lid * out_vec_size_g;
|
scales += out_col / group_size + simd_lid * out_vec_size_g;
|
||||||
biases += out_col / group_size + simd_lid * out_vec_size_g;
|
biases += out_col / group_size + simd_lid * out_vec_size_g;
|
||||||
x += tid.y * in_vec_size + simd_lid;
|
x += tid.y * in_vec_size + simd_lid;
|
||||||
@ -690,43 +897,42 @@ METAL_FUNC void qvm_impl(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loop over in_vec in blocks of blocksize
|
// Loop over in_vec in blocks of block_size
|
||||||
int remaining = in_vec_size % blocksize;
|
int remaining = in_vec_size % block_size;
|
||||||
if (remaining == 0) {
|
if (remaining == 0) {
|
||||||
for (int i = 0; i < in_vec_size; i += blocksize) {
|
for (int i = 0; i < in_vec_size; i += block_size) {
|
||||||
x_local = *x;
|
x_local = *x;
|
||||||
scale = *scales;
|
scale = *scales;
|
||||||
bias = *biases;
|
bias = *biases;
|
||||||
w_local = *((device vec_w*)w);
|
w_local = *((device vec_w*)ws);
|
||||||
|
|
||||||
qouter<U, tn * pack_factor, bits>(
|
qouter<U, tn * pack_factor, bits>(
|
||||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||||
|
|
||||||
x += blocksize;
|
x += block_size;
|
||||||
scales += blocksize * out_vec_size_g;
|
scales += block_size * out_vec_size_g;
|
||||||
biases += blocksize * out_vec_size_g;
|
biases += block_size * out_vec_size_g;
|
||||||
w += blocksize * out_vec_size_w;
|
ws += block_size * out_vec_size_w;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = blocksize; i < in_vec_size; i += blocksize) {
|
for (int i = block_size; i < in_vec_size; i += block_size) {
|
||||||
x_local = *x;
|
x_local = *x;
|
||||||
scale = *scales;
|
scale = *scales;
|
||||||
bias = *biases;
|
bias = *biases;
|
||||||
w_local = *((device vec_w*)w);
|
w_local = *((device vec_w*)ws);
|
||||||
|
|
||||||
qouter<U, tn * pack_factor, bits>(
|
qouter<U, tn * pack_factor, bits>(
|
||||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||||
|
|
||||||
x += blocksize;
|
x += block_size;
|
||||||
scales += blocksize * out_vec_size_g;
|
scales += block_size * out_vec_size_g;
|
||||||
biases += blocksize * out_vec_size_g;
|
biases += block_size * out_vec_size_g;
|
||||||
w += blocksize * out_vec_size_w;
|
ws += block_size * out_vec_size_w;
|
||||||
}
|
}
|
||||||
if (static_cast<int>(simd_lid) < remaining) {
|
if (static_cast<int>(simd_lid) < remaining) {
|
||||||
x_local = *x;
|
x_local = *x;
|
||||||
scale = *scales;
|
scale = *scales;
|
||||||
bias = *biases;
|
bias = *biases;
|
||||||
w_local = *((device vec_w*)w);
|
w_local = *((device vec_w*)ws);
|
||||||
} else {
|
} else {
|
||||||
x_local = 0;
|
x_local = 0;
|
||||||
scale = 0;
|
scale = 0;
|
||||||
@ -781,8 +987,9 @@ METAL_FUNC void qmm_t_impl(
|
|||||||
|
|
||||||
constexpr int WM = 2;
|
constexpr int WM = 2;
|
||||||
constexpr int WN = 2;
|
constexpr int WN = 2;
|
||||||
constexpr int pack_factor = 32 / bits;
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
|
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||||
|
|
||||||
// Instantiate the appropriate BlockMMA and Loader
|
// Instantiate the appropriate BlockMMA and Loader
|
||||||
using mma_t = mlx::steel::
|
using mma_t = mlx::steel::
|
||||||
@ -800,13 +1007,15 @@ METAL_FUNC void qmm_t_impl(
|
|||||||
bits>;
|
bits>;
|
||||||
|
|
||||||
// Set the block
|
// Set the block
|
||||||
const int K_w = K / pack_factor;
|
const int K_w = K * bytes_per_pack / pack_factor;
|
||||||
const int K_g = K / group_size;
|
const int K_g = K / group_size;
|
||||||
const int y_row = tid.y * BM;
|
const int y_row = tid.y * BM;
|
||||||
const int y_col = tid.x * BN;
|
const int y_col = tid.x * BN;
|
||||||
|
|
||||||
|
auto wl = (const device uint8_t*)w;
|
||||||
|
|
||||||
x += y_row * K;
|
x += y_row * K;
|
||||||
w += y_col * K_w;
|
wl += y_col * K_w;
|
||||||
scales += y_col * K_g;
|
scales += y_col * K_g;
|
||||||
biases += y_col * K_g;
|
biases += y_col * K_g;
|
||||||
y += y_row * N + y_col;
|
y += y_row * N + y_col;
|
||||||
@ -815,7 +1024,7 @@ METAL_FUNC void qmm_t_impl(
|
|||||||
const short num_els = min(BM, M - y_row);
|
const short num_els = min(BM, M - y_row);
|
||||||
const short num_outs = min(BN, N - y_col);
|
const short num_outs = min(BN, N - y_col);
|
||||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||||
loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid);
|
loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
|
||||||
mma_t mma_op(simd_gid, simd_lid);
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
if (num_els < BM) {
|
if (num_els < BM) {
|
||||||
@ -857,6 +1066,7 @@ METAL_FUNC void qmm_t_impl(
|
|||||||
loader_x.load_unsafe();
|
loader_x.load_unsafe();
|
||||||
loader_w.load_unsafe();
|
loader_w.load_unsafe();
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
mma_op.mma(Xs, Ws);
|
mma_op.mma(Xs, Ws);
|
||||||
loader_x.next();
|
loader_x.next();
|
||||||
loader_w.next();
|
loader_w.next();
|
||||||
@ -902,9 +1112,11 @@ METAL_FUNC void qmm_n_impl(
|
|||||||
|
|
||||||
constexpr int WM = 2;
|
constexpr int WM = 2;
|
||||||
constexpr int WN = 2;
|
constexpr int WN = 2;
|
||||||
constexpr int pack_factor = 32 / bits;
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||||
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||||
|
|
||||||
// Instantiate the appropriate BlockMMA and Loader
|
// Instantiate the appropriate BlockMMA and Loader
|
||||||
using mma_t = mlx::steel::
|
using mma_t = mlx::steel::
|
||||||
@ -921,11 +1133,13 @@ METAL_FUNC void qmm_n_impl(
|
|||||||
group_size,
|
group_size,
|
||||||
bits>;
|
bits>;
|
||||||
|
|
||||||
|
auto wl = (const device uint8_t*)w;
|
||||||
|
|
||||||
// Set the block
|
// Set the block
|
||||||
const int y_row = tid.y * BM;
|
const int y_row = tid.y * BM;
|
||||||
const int y_col = tid.x * BN;
|
const int y_col = tid.x * BN;
|
||||||
x += y_row * K;
|
x += y_row * K;
|
||||||
w += y_col / pack_factor;
|
wl += y_col * bytes_per_pack / pack_factor;
|
||||||
scales += y_col / group_size;
|
scales += y_col / group_size;
|
||||||
biases += y_col / group_size;
|
biases += y_col / group_size;
|
||||||
y += y_row * N + y_col;
|
y += y_row * N + y_col;
|
||||||
@ -933,7 +1147,7 @@ METAL_FUNC void qmm_n_impl(
|
|||||||
// Make the x loader and mma operation
|
// Make the x loader and mma operation
|
||||||
const short num_els = min(BM, M - y_row);
|
const short num_els = min(BM, M - y_row);
|
||||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||||
loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid);
|
loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);
|
||||||
mma_t mma_op(simd_gid, simd_lid);
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
if (num_els < BM) {
|
if (num_els < BM) {
|
||||||
@ -1805,13 +2019,14 @@ template <typename T, const int group_size, const int bits>
|
|||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
constexpr T eps = T(1e-7);
|
constexpr T eps = T(1e-7);
|
||||||
constexpr int simd_size = 32;
|
constexpr int simd_size = 32;
|
||||||
constexpr int uint8_bits = 8;
|
|
||||||
constexpr T n_bins = (1 << bits) - 1;
|
constexpr T n_bins = (1 << bits) - 1;
|
||||||
constexpr int packs_per_int = uint8_bits / bits;
|
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||||
constexpr int values_per_reduce = group_size / simd_size;
|
constexpr int values_per_reduce = group_size / simd_size;
|
||||||
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
|
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
|
||||||
constexpr int writes_per_pack =
|
constexpr int writes_per_pack =
|
||||||
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
|
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
|
||||||
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
group_size % simd_size == 0,
|
group_size % simd_size == 0,
|
||||||
@ -1819,7 +2034,9 @@ template <typename T, const int group_size, const int bits>
|
|||||||
|
|
||||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
size_t in_index = offset * values_per_reduce;
|
size_t in_index = offset * values_per_reduce;
|
||||||
size_t out_index = offset * writes_per_pack;
|
size_t out_index = power_of_2_bits
|
||||||
|
? offset * writes_per_pack
|
||||||
|
: offset * bytes_per_pack / writes_per_reduce;
|
||||||
|
|
||||||
T w_thread[values_per_reduce];
|
T w_thread[values_per_reduce];
|
||||||
T w_min = Limits<T>::max;
|
T w_min = Limits<T>::max;
|
||||||
@ -1852,7 +2069,11 @@ template <typename T, const int group_size, const int bits>
|
|||||||
biases[gindex] = bias;
|
biases[gindex] = bias;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint8_t output = 0;
|
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
|
||||||
|
using OutT =
|
||||||
|
typename ConditionalType<power_of_2_bits, uint8_t, uint32_t>::type;
|
||||||
|
OutT output = 0;
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int i = 0; i < values_per_reduce; i++) {
|
for (int i = 0; i < values_per_reduce; i++) {
|
||||||
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
|
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
|
||||||
@ -1868,47 +2089,23 @@ template <typename T, const int group_size, const int bits>
|
|||||||
output = 0;
|
output = 0;
|
||||||
} else {
|
} else {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int j = 0; j < writes_per_reduce - 1; j++) {
|
for (int j = 1; j < writes_per_reduce; j++) {
|
||||||
uint8_t sval = simd_shuffle_down(val, j + 1);
|
uint8_t sval = simd_shuffle_down(val, j);
|
||||||
output += sval << (bits * (values_per_reduce + j + i));
|
output += sval << (bits * (j * values_per_reduce + i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (bits == 3 || bits == 6) {
|
||||||
|
if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
|
||||||
|
out[out_index] = output & 0xff;
|
||||||
|
out[out_index + 1] = (output & 0xff00) >> 8;
|
||||||
|
out[out_index + 2] = (output & 0xff0000) >> 16;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
|
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
|
||||||
out[out_index / writes_per_reduce] = output;
|
out[out_index / writes_per_reduce] = output;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits>
|
|
||||||
[[kernel]] void affine_quantize_scales_biases(
|
|
||||||
const device T* w [[buffer(0)]],
|
|
||||||
const device T* scales [[buffer(1)]],
|
|
||||||
const device T* biases [[buffer(2)]],
|
|
||||||
device uint8_t* out [[buffer(3)]],
|
|
||||||
uint2 index [[thread_position_in_grid]],
|
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
|
||||||
constexpr int uint8_bits = 8;
|
|
||||||
constexpr int packs_per_int = uint8_bits / bits;
|
|
||||||
constexpr T n_bins = (1 << bits) - 1;
|
|
||||||
|
|
||||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
|
||||||
size_t in_index = offset * packs_per_int;
|
|
||||||
size_t gindex = in_index / group_size;
|
|
||||||
|
|
||||||
T scale = scales[gindex];
|
|
||||||
T bias = biases[gindex];
|
|
||||||
|
|
||||||
uint8_t output = 0;
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (int i = 0; i < packs_per_int; i++) {
|
|
||||||
uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
|
|
||||||
if (bits == 8) {
|
|
||||||
output = val;
|
|
||||||
} else {
|
|
||||||
output += val << (bits * i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out[offset] = output;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits>
|
template <typename T, const int group_size, const int bits>
|
||||||
@ -1919,16 +2116,37 @@ template <typename T, const int group_size, const int bits>
|
|||||||
device T* out [[buffer(3)]],
|
device T* out [[buffer(3)]],
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
constexpr int uint8_bits = 8;
|
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||||
constexpr int packs_per_int = uint8_bits / bits;
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||||
|
|
||||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
size_t oindex = offset * packs_per_int;
|
size_t oindex = offset * packs_per_int;
|
||||||
size_t gindex = oindex / group_size;
|
size_t gindex = oindex / group_size;
|
||||||
T scale = scales[gindex];
|
T scale = scales[gindex];
|
||||||
T bias = biases[gindex];
|
T bias = biases[gindex];
|
||||||
uint val = w[offset];
|
|
||||||
|
|
||||||
|
out += oindex;
|
||||||
|
|
||||||
|
if (bits == 3) {
|
||||||
|
w += offset * bytes_per_pack;
|
||||||
|
out[0] = (w[0] & 0x7) * scale + bias;
|
||||||
|
out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
|
||||||
|
out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
|
||||||
|
out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
|
||||||
|
out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
|
||||||
|
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
|
||||||
|
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
||||||
|
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
||||||
|
|
||||||
|
} else if (bits == 6) {
|
||||||
|
w += offset * bytes_per_pack;
|
||||||
|
out[0] = (w[0] & 0x3f) * scale + bias;
|
||||||
|
out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
|
||||||
|
out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
|
||||||
|
out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
|
||||||
|
} else {
|
||||||
|
uint val = w[offset];
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int i = 0; i < packs_per_int; i++) {
|
for (int i = 0; i < packs_per_int; i++) {
|
||||||
uint8_t d;
|
uint8_t d;
|
||||||
@ -1939,6 +2157,7 @@ template <typename T, const int group_size, const int bits>
|
|||||||
} else if (bits == 8) {
|
} else if (bits == 8) {
|
||||||
d = val;
|
d = val;
|
||||||
}
|
}
|
||||||
out[oindex + i] = scale * d + bias;
|
out[i] = scale * d + bias;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -72,7 +72,6 @@
|
|||||||
|
|
||||||
#define instantiate_quantized_all_single(type, group_size, bits) \
|
#define instantiate_quantized_all_single(type, group_size, bits) \
|
||||||
instantiate_quantized(affine_quantize, type, group_size, bits) \
|
instantiate_quantized(affine_quantize, type, group_size, bits) \
|
||||||
instantiate_quantized(affine_quantize_scales_biases, type, group_size, bits) \
|
|
||||||
instantiate_quantized(affine_dequantize, type, group_size, bits) \
|
instantiate_quantized(affine_dequantize, type, group_size, bits) \
|
||||||
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
|
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
|
||||||
instantiate_quantized(bs_qmv, type, group_size, bits) \
|
instantiate_quantized(bs_qmv, type, group_size, bits) \
|
||||||
@ -116,7 +115,9 @@
|
|||||||
|
|
||||||
#define instantiate_quantized_all() \
|
#define instantiate_quantized_all() \
|
||||||
instantiate_quantized_groups(2) \
|
instantiate_quantized_groups(2) \
|
||||||
|
instantiate_quantized_groups(3) \
|
||||||
instantiate_quantized_groups(4) \
|
instantiate_quantized_groups(4) \
|
||||||
|
instantiate_quantized_groups(6) \
|
||||||
instantiate_quantized_groups(8)
|
instantiate_quantized_groups(8)
|
||||||
|
|
||||||
instantiate_quantized_all() // clang-format on
|
instantiate_quantized_all() // clang-format on
|
||||||
|
@ -421,3 +421,14 @@ inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
|
|||||||
return complex64_t(
|
return complex64_t(
|
||||||
simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
|
simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// std::conditional is not included with Metal
|
||||||
|
template <bool condition, typename T, typename U>
|
||||||
|
struct ConditionalType {
|
||||||
|
using type = U;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
struct ConditionalType<true, T, U> {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -298,7 +299,7 @@ void qmm_op(
|
|||||||
bool quad = false;
|
bool quad = false;
|
||||||
|
|
||||||
if (transpose) {
|
if (transpose) {
|
||||||
if (B < 6 && (D == 128 || D == 64)) {
|
if (B < 6 && (D == 128 || D == 64) && is_power_of_2(bits)) {
|
||||||
name += "qmv_quad";
|
name += "qmv_quad";
|
||||||
constexpr int quads_per_simd = 8;
|
constexpr int quads_per_simd = 8;
|
||||||
constexpr int results_per_quadgroup = 8;
|
constexpr int results_per_quadgroup = 8;
|
||||||
@ -391,8 +392,6 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
void fast::AffineQuantize::eval_gpu(
|
void fast::AffineQuantize::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
bool compute_scale_bias = inputs.size() == 1;
|
|
||||||
|
|
||||||
auto& w_pre = inputs[0];
|
auto& w_pre = inputs[0];
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
@ -415,7 +414,7 @@ void fast::AffineQuantize::eval_gpu(
|
|||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder.set_input_array(w, 0);
|
compute_encoder.set_input_array(w, 0);
|
||||||
if (!compute_scale_bias) {
|
if (dequantize_) {
|
||||||
auto& scales_pre = inputs[1];
|
auto& scales_pre = inputs[1];
|
||||||
auto& biases_pre = inputs[2];
|
auto& biases_pre = inputs[2];
|
||||||
auto scales = ensure_row_contiguous(scales_pre);
|
auto scales = ensure_row_contiguous(scales_pre);
|
||||||
@ -436,12 +435,7 @@ void fast::AffineQuantize::eval_gpu(
|
|||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
auto type_string = dequantize_ ? get_type_string(out.dtype())
|
auto type_string = dequantize_ ? get_type_string(out.dtype())
|
||||||
: get_type_string(w_pre.dtype());
|
: get_type_string(w_pre.dtype());
|
||||||
auto kernel_func = "affine_quantize_scales_biases";
|
auto kernel_func = dequantize_ ? "affine_dequantize" : "affine_quantize";
|
||||||
if (dequantize_) {
|
|
||||||
kernel_func = "affine_dequantize";
|
|
||||||
} else if (compute_scale_bias) {
|
|
||||||
kernel_func = "affine_quantize";
|
|
||||||
}
|
|
||||||
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
|
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||||
<< bits_;
|
<< bits_;
|
||||||
auto template_def = get_template_definition(
|
auto template_def = get_template_definition(
|
||||||
@ -452,10 +446,10 @@ void fast::AffineQuantize::eval_gpu(
|
|||||||
// Treat uint32 as uint8 in kernel
|
// Treat uint32 as uint8 in kernel
|
||||||
constexpr int uint8_per_uint32 = 4;
|
constexpr int uint8_per_uint32 = 4;
|
||||||
constexpr int simd_size = 32;
|
constexpr int simd_size = 32;
|
||||||
int packs_per_int = 8 / bits_;
|
int packs_per_int = bits_ == 3 ? 8 : bits_ == 6 ? 4 : 8 / bits_;
|
||||||
int per_thread = compute_scale_bias ? group_size_ / simd_size : packs_per_int;
|
int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size;
|
||||||
size_t nthreads =
|
size_t nthreads =
|
||||||
dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread;
|
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
|
||||||
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
|
104
mlx/fast.cpp
104
mlx/fast.cpp
@ -686,13 +686,11 @@ array pack_and_quantize(
|
|||||||
array& packed_w,
|
array& packed_w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
const array& biases,
|
const array& biases,
|
||||||
int group_size,
|
|
||||||
int bits,
|
int bits,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
int el_per_int = 32 / bits;
|
int el_per_int = 32 / bits;
|
||||||
array zero(0, packed_w.dtype());
|
array zero(0, packed_w.dtype());
|
||||||
array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1
|
array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1
|
||||||
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
|
||||||
packed_w = astype(
|
packed_w = astype(
|
||||||
clip(
|
clip(
|
||||||
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
||||||
@ -701,9 +699,30 @@ array pack_and_quantize(
|
|||||||
s),
|
s),
|
||||||
uint32,
|
uint32,
|
||||||
s);
|
s);
|
||||||
|
if (is_power_of_2(bits)) {
|
||||||
|
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
||||||
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
||||||
packed_w = sum(
|
packed_w = sum(
|
||||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||||
|
} else {
|
||||||
|
// This is slow but we have fast GPU/CPU versions of this function so we
|
||||||
|
// shouldn't be here often.
|
||||||
|
packed_w = expand_dims(packed_w, /* axis= */ -1, s);
|
||||||
|
packed_w = bitwise_and(
|
||||||
|
right_shift(packed_w, arange(bits, uint32, s), s),
|
||||||
|
array({1}, uint32),
|
||||||
|
s);
|
||||||
|
auto new_shape = packed_w.shape();
|
||||||
|
new_shape[new_shape.size() - 2] = -1;
|
||||||
|
new_shape.back() = 32;
|
||||||
|
packed_w = reshape(packed_w, new_shape, s);
|
||||||
|
array shifts = arange(32, uint32, s);
|
||||||
|
packed_w =
|
||||||
|
sum(left_shift(packed_w, shifts, s),
|
||||||
|
/* axis= */ -1,
|
||||||
|
/* keepdims= */ false,
|
||||||
|
s);
|
||||||
|
}
|
||||||
return packed_w;
|
return packed_w;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -718,10 +737,10 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (bits != 2 && bits != 4 && bits != 8) {
|
if (bits != 2 && bits != 3 && bits != 4 && bits != 6 && bits != 8) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantize] The requested number of bits " << bits
|
msg << "[quantize] The requested number of bits " << bits
|
||||||
<< " is not supported. The supported bits are 2, 4 and 8.";
|
<< " is not supported. The supported bits are 2, 3, 4, 6 and 8.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -740,9 +759,7 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
int el_per_int = 32 / bits;
|
auto fallback = [group_size, bits, s](
|
||||||
|
|
||||||
auto fallback = [group_size, bits, el_per_int, s](
|
|
||||||
const std::vector<array>& inputs) -> std::vector<array> {
|
const std::vector<array>& inputs) -> std::vector<array> {
|
||||||
auto& w = inputs[0];
|
auto& w = inputs[0];
|
||||||
auto wshape = w.shape();
|
auto wshape = w.shape();
|
||||||
@ -765,7 +782,7 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
|||||||
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
||||||
array biases = where(equal(q0, zero, s), zero, edge, s);
|
array biases = where(equal(q0, zero, s), zero, edge, s);
|
||||||
|
|
||||||
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
|
packed_w = pack_and_quantize(packed_w, scales, biases, bits, s);
|
||||||
return {
|
return {
|
||||||
reshape(packed_w, wshape, s),
|
reshape(packed_w, wshape, s),
|
||||||
reshape(scales, wshape, s),
|
reshape(scales, wshape, s),
|
||||||
@ -774,7 +791,7 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto wq_shape = w.shape();
|
auto wq_shape = w.shape();
|
||||||
wq_shape.back() = w.shape(-1) / el_per_int;
|
wq_shape.back() = w.shape(-1) * bits / 32;
|
||||||
auto sshape = w.shape();
|
auto sshape = w.shape();
|
||||||
sshape.back() = w.shape(-1) / group_size;
|
sshape.back() = w.shape(-1) / group_size;
|
||||||
auto outputs = array::make_arrays(
|
auto outputs = array::make_arrays(
|
||||||
@ -785,39 +802,6 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
|||||||
return {outputs[0], outputs[1], outputs[2]};
|
return {outputs[0], outputs[1], outputs[2]};
|
||||||
}
|
}
|
||||||
|
|
||||||
array affine_quantize(
|
|
||||||
const array& w,
|
|
||||||
const array& scales,
|
|
||||||
const array& biases,
|
|
||||||
int group_size,
|
|
||||||
int bits,
|
|
||||||
StreamOrDevice s_) {
|
|
||||||
auto s = to_stream(s_);
|
|
||||||
|
|
||||||
int el_per_int = 32 / bits;
|
|
||||||
auto fallback = [group_size, bits, el_per_int, s](
|
|
||||||
const std::vector<array>& inputs) -> std::vector<array> {
|
|
||||||
auto& w = inputs[0];
|
|
||||||
auto scales = expand_dims(inputs[1], -1, s);
|
|
||||||
auto biases = expand_dims(inputs[2], -1, s);
|
|
||||||
|
|
||||||
auto wshape = w.shape();
|
|
||||||
wshape.back() = -1;
|
|
||||||
|
|
||||||
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
|
||||||
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
|
|
||||||
return {reshape(packed_w, wshape, s)};
|
|
||||||
};
|
|
||||||
|
|
||||||
auto out_shape = w.shape();
|
|
||||||
out_shape.back() = w.shape(-1) / el_per_int;
|
|
||||||
return array(
|
|
||||||
std::move(out_shape),
|
|
||||||
uint32,
|
|
||||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
|
||||||
{w, scales, biases});
|
|
||||||
}
|
|
||||||
|
|
||||||
array affine_dequantize(
|
array affine_dequantize(
|
||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
@ -860,9 +844,9 @@ array affine_dequantize(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Packing into uint32
|
// Packing into uint32
|
||||||
int el_per_int = 32 / bits;
|
int out_size = w.shape(-1) * 32 / bits;
|
||||||
|
|
||||||
if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) {
|
if (out_size != scales.shape(-1) * group_size) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
||||||
<< "given the quantization parameters. Provided matrix of shape "
|
<< "given the quantization parameters. Provided matrix of shape "
|
||||||
@ -873,12 +857,12 @@ array affine_dequantize(
|
|||||||
|
|
||||||
auto s = to_stream(s_);
|
auto s = to_stream(s_);
|
||||||
|
|
||||||
auto fallback =
|
auto fallback = [&wshape, &sshape, &scales, &biases, group_size, bits, s](
|
||||||
[&wshape, &sshape, &scales, &biases, group_size, bits, el_per_int, s](
|
|
||||||
const std::vector<array>& inputs) -> std::vector<array> {
|
const std::vector<array>& inputs) -> std::vector<array> {
|
||||||
auto& w = inputs[0];
|
auto w = inputs[0];
|
||||||
auto& scales = inputs[1];
|
auto& scales = inputs[1];
|
||||||
auto& biases = inputs[2];
|
auto& biases = inputs[2];
|
||||||
|
if (is_power_of_2(bits)) {
|
||||||
std::vector<array> parts;
|
std::vector<array> parts;
|
||||||
for (int start = 0; start < 32; start += bits) {
|
for (int start = 0; start < 32; start += bits) {
|
||||||
int shift_left = 32 - (start + bits);
|
int shift_left = 32 - (start + bits);
|
||||||
@ -892,21 +876,33 @@ array affine_dequantize(
|
|||||||
-1,
|
-1,
|
||||||
s));
|
s));
|
||||||
}
|
}
|
||||||
array w_full = concatenate(parts, -1, s);
|
w = concatenate(parts, -1, s);
|
||||||
|
} else {
|
||||||
|
w = expand_dims(w, /* axis= */ -1, s);
|
||||||
|
w = bitwise_and(
|
||||||
|
right_shift(w, arange(32, uint32, s), s), array({1}, uint32), s);
|
||||||
|
auto new_shape = w.shape();
|
||||||
|
new_shape[new_shape.size() - 2] = -1;
|
||||||
|
new_shape.back() = bits;
|
||||||
|
w = reshape(w, new_shape, s);
|
||||||
|
array shifts = arange(bits, uint32, s);
|
||||||
|
w = sum(
|
||||||
|
left_shift(w, shifts, s), /* axis= */ -1, /* keepdims= */ false, s);
|
||||||
|
}
|
||||||
|
|
||||||
// Dequantize
|
// Dequantize
|
||||||
wshape.push_back(group_size);
|
wshape.push_back(group_size);
|
||||||
w_full = reshape(w_full, wshape, s);
|
w = reshape(w, wshape, s);
|
||||||
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
w = multiply(w, expand_dims(scales, -1, s), s);
|
||||||
w_full = add(w_full, expand_dims(biases, -1, s), s);
|
w = add(w, expand_dims(biases, -1, s), s);
|
||||||
w_full = reshape(w_full, sshape, s);
|
w = reshape(w, sshape, s);
|
||||||
|
|
||||||
return {w_full};
|
return {w};
|
||||||
};
|
};
|
||||||
|
|
||||||
if (s.device == Device::gpu) {
|
if (s.device == Device::gpu) {
|
||||||
auto out_shape = w.shape();
|
auto out_shape = w.shape();
|
||||||
out_shape.back() = w.shape(-1) * el_per_int;
|
out_shape.back() = out_size;
|
||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
scales.dtype(),
|
scales.dtype(),
|
||||||
|
@ -47,14 +47,6 @@ std::tuple<array, array, array> affine_quantize(
|
|||||||
int bits = 4,
|
int bits = 4,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
array affine_quantize(
|
|
||||||
const array& w,
|
|
||||||
const array& scales,
|
|
||||||
const array& biases,
|
|
||||||
int group_size = 64,
|
|
||||||
int bits = 4,
|
|
||||||
StreamOrDevice s = {});
|
|
||||||
|
|
||||||
array affine_dequantize(
|
array affine_dequantize(
|
||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
|
@ -3683,7 +3683,7 @@ std::tuple<array, array, array> quantize(
|
|||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return fast::affine_quantize(w, group_size, bits);
|
return fast::affine_quantize(w, group_size, bits, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array dequantize(
|
array dequantize(
|
||||||
|
@ -161,49 +161,6 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
array: The output array.
|
array: The output array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
m.def(
|
|
||||||
"affine_quantize",
|
|
||||||
nb::overload_cast<
|
|
||||||
const array&,
|
|
||||||
const array&,
|
|
||||||
const array&,
|
|
||||||
int,
|
|
||||||
int,
|
|
||||||
StreamOrDevice>(&fast::affine_quantize),
|
|
||||||
"w"_a,
|
|
||||||
"scales"_a,
|
|
||||||
"biases"_a,
|
|
||||||
"group_size"_a = 64,
|
|
||||||
"bits"_a = 4,
|
|
||||||
nb::kw_only(),
|
|
||||||
"stream"_a = nb::none(),
|
|
||||||
nb::sig(
|
|
||||||
"def affine_quantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
||||||
R"pbdoc(
|
|
||||||
Quantize the matrix ``w`` using the provided ``scales`` and
|
|
||||||
``biases`` and the ``group_size`` and ``bits`` configuration.
|
|
||||||
|
|
||||||
Formally, given the notation in :func:`quantize`, we compute
|
|
||||||
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
|
|
||||||
:math:`\beta` as follows
|
|
||||||
|
|
||||||
.. math::
|
|
||||||
|
|
||||||
w_i = s (\hat{w_i} + \beta)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
w (array): Matrix to be quantize
|
|
||||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
|
||||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
|
||||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
|
||||||
scale and bias. (default: ``64``)
|
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
|
||||||
``w``. (default: ``4``)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
array: The quantized version of ``w``
|
|
||||||
)pbdoc");
|
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"metal_kernel",
|
"metal_kernel",
|
||||||
[](const std::string& name,
|
[](const std::string& name,
|
||||||
|
@ -549,18 +549,6 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)(x)
|
)(x)
|
||||||
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
|
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
|
||||||
|
|
||||||
def test_affine_quantize(self):
|
|
||||||
mx.random.seed(7)
|
|
||||||
x = mx.random.uniform(shape=(4, 1024))
|
|
||||||
for bits in (2, 4, 8):
|
|
||||||
for group_size in (32, 64, 128):
|
|
||||||
with self.subTest(bits=bits, group_size=group_size):
|
|
||||||
w, scales, biases = mx.quantize(x, bits=bits, group_size=group_size)
|
|
||||||
w_p = mx.fast.affine_quantize(
|
|
||||||
x, scales, biases, bits=bits, group_size=group_size
|
|
||||||
)
|
|
||||||
self.assertTrue(mx.allclose(w, w_p))
|
|
||||||
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
def test_custom_kernel_basic(self):
|
def test_custom_kernel_basic(self):
|
||||||
mx.random.seed(7)
|
mx.random.seed(7)
|
||||||
|
@ -11,7 +11,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
def test_quantize_dequantize(self):
|
def test_quantize_dequantize(self):
|
||||||
w = mx.random.normal(shape=(128, 512))
|
w = mx.random.normal(shape=(128, 512))
|
||||||
for gs in [32, 64, 128]:
|
for gs in [32, 64, 128]:
|
||||||
for b in [2, 4, 8]:
|
for b in [2, 3, 6, 4, 8]:
|
||||||
with self.subTest(gs=gs, b=b):
|
with self.subTest(gs=gs, b=b):
|
||||||
w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b)
|
w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b)
|
||||||
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||||
@ -22,7 +22,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
# test quantize/dequantize 0s
|
# test quantize/dequantize 0s
|
||||||
a = mx.zeros((256, 512))
|
a = mx.zeros((256, 512))
|
||||||
for gs in [32, 64, 128]:
|
for gs in [32, 64, 128]:
|
||||||
for b in [2, 4, 8]:
|
for b in [2, 3, 4, 6, 8]:
|
||||||
w_q, scales, biases = mx.quantize(a, gs, b)
|
w_q, scales, biases = mx.quantize(a, gs, b)
|
||||||
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||||
self.assertTrue(mx.all(a_hat == 0))
|
self.assertTrue(mx.all(a_hat == 0))
|
||||||
@ -116,7 +116,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
k1, k2 = mx.random.split(key)
|
k1, k2 = mx.random.split(key)
|
||||||
tests = product(
|
tests = product(
|
||||||
[128, 64, 32], # group_size
|
[128, 64, 32], # group_size
|
||||||
[2, 4, 8], # bits
|
[2, 3, 4, 6, 8], # bits
|
||||||
[512, 1024, 67], # M
|
[512, 1024, 67], # M
|
||||||
[64, 128, 512, 1024], # N
|
[64, 128, 512, 1024], # N
|
||||||
[0, 1, 3, 8], # B
|
[0, 1, 3, 8], # B
|
||||||
@ -143,7 +143,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
k1, k2 = mx.random.split(key)
|
k1, k2 = mx.random.split(key)
|
||||||
tests = product(
|
tests = product(
|
||||||
[128, 64, 32], # group_size
|
[128, 64, 32], # group_size
|
||||||
[2, 4, 8], # bits
|
[2, 3, 4, 6, 8], # bits
|
||||||
[512, 1024], # M
|
[512, 1024], # M
|
||||||
[512, 1024, 67], # N
|
[512, 1024, 67], # N
|
||||||
[0, 1, 3, 8], # B
|
[0, 1, 3, 8], # B
|
||||||
|
Loading…
Reference in New Issue
Block a user