mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add 3bit packed quants
This commit is contained in:
parent
14420949d2
commit
d75a509234
@ -2295,6 +2295,68 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, int results_per_simdgroup>
|
||||
METAL_FUNC void affine_packed_byte_qmv_fast_impl(
|
||||
const device uint8_t* w,
|
||||
const device vec<T, 2 * results_per_simdgroup>* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int packs_per_thread = 2;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int pack_factor = (bits == 3) ? 8 : 4;
|
||||
;
|
||||
constexpr int bytes_per_pack = 3;
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
vec<U, results_per_simdgroup> result = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
const int scales_row = tid.x * num_simdgroups + simd_gid;
|
||||
const int out_row = scales_row * results_per_simdgroup;
|
||||
|
||||
w += out_row * in_vec_size_w + simd_lid * (packs_per_thread * bytes_per_pack);
|
||||
scales += scales_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
// Load the input vector
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
// Load the scales and biases
|
||||
vec<T, 2 * results_per_simdgroup> sb = scales[0];
|
||||
|
||||
// Load the weights and perform the partial dot product
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] += qdot<U, values_per_thread, bits>(
|
||||
w + row * in_vec_size_w, x_thread, sb[row], sb[2 + row], sum);
|
||||
}
|
||||
|
||||
w += block_size * bytes_per_pack / pack_factor;
|
||||
scales += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
if (simd_lid == 0) {
|
||||
y[row] = static_cast<T>(result[row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void affine_packed_qmv_fast(
|
||||
const device vec<uint32_t, 4>* w [[buffer(0)]],
|
||||
@ -2306,8 +2368,21 @@ template <typename T, int group_size, int bits>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
affine_packed_qmv_fast_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
if (bits & (bits - 1)) {
|
||||
affine_packed_byte_qmv_fast_impl<T, group_size, bits, 2>(
|
||||
(const device uint8_t*)w,
|
||||
scales,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
} else {
|
||||
affine_packed_qmv_fast_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
@ -2617,6 +2692,9 @@ template <
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
if (bits & (bits - 1)) {
|
||||
} else {
|
||||
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
}
|
||||
|
@ -405,10 +405,11 @@ void affine_packed_qmv(
|
||||
auto w = ensure_row_contiguous_last_dims(inputs[1]);
|
||||
auto scales = ensure_row_contiguous_last_dims(inputs[2]);
|
||||
|
||||
const bool pow2_bits = (bits & (bits - 1)) == 0;
|
||||
const int n_simdgroups = 2;
|
||||
const int n_outs_per_simdgroup = 4;
|
||||
const int results_per_simdgroup = (pow2_bits) ? 4 : 2;
|
||||
MTL::Size group_dims(32, n_simdgroups, 1);
|
||||
MTL::Size grid_dims(O / n_simdgroups / n_outs_per_simdgroup, B, 1);
|
||||
MTL::Size grid_dims(O / n_simdgroups / results_per_simdgroup, B, 1);
|
||||
|
||||
std::string name;
|
||||
name.reserve(64);
|
||||
|
39
mlx/ops.cpp
39
mlx/ops.cpp
@ -99,15 +99,11 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
<< "biases but biases were provided";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bits & (bits - 1)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Quantization type '" << quantization_type
|
||||
<< "' does not support " << bits << " bits.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
bool pow2_bits = (bits & (bits - 1)) == 0;
|
||||
|
||||
if (w.dtype() != uint32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] The weight matrix should be uint32 "
|
||||
@ -136,8 +132,12 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
int weight_dims = w.shape(-1) * 32 / bits;
|
||||
int scales_dims = scales.shape(-1) * group_size;
|
||||
if (quantization_type == QuantizationType::AffinePacked) {
|
||||
scales_dims /= 8;
|
||||
weight_dims /= 4;
|
||||
if (pow2_bits) {
|
||||
scales_dims /= 8;
|
||||
weight_dims /= 4;
|
||||
} else {
|
||||
scales_dims /= 4;
|
||||
}
|
||||
}
|
||||
|
||||
if (weight_dims != scales_dims) {
|
||||
@ -155,7 +155,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
|
||||
// Calculate the expanded w's dims
|
||||
int weight_dims_other = w.shape(-2);
|
||||
if (quantization_type == QuantizationType::AffinePacked) {
|
||||
if (quantization_type == QuantizationType::AffinePacked && pow2_bits) {
|
||||
weight_dims_other *= 4;
|
||||
}
|
||||
int w_inner_dims = (transpose) ? weight_dims : weight_dims_other;
|
||||
@ -3793,23 +3793,22 @@ std::tuple<array, array, std::optional<array>> quantize(
|
||||
case QuantizationType::Affine:
|
||||
return fast::affine_quantize(w, group_size, bits, s);
|
||||
case QuantizationType::AffinePacked: {
|
||||
if (bits & (bits - 1)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] Quantization type '" << quantization_type
|
||||
<< "' does not support " << bits << " bits.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
|
||||
|
||||
scales = unflatten(scales, -2, {-1, 4}, s);
|
||||
biases = unflatten(biases, -2, {-1, 4}, s);
|
||||
int pow2_bits = (bits & (bits - 1)) == 0;
|
||||
int row_packing = (pow2_bits) ? 4 : 2;
|
||||
|
||||
scales = unflatten(scales, -2, {-1, row_packing}, s);
|
||||
biases = unflatten(biases, -2, {-1, row_packing}, s);
|
||||
scales = concatenate({scales, biases}, -2, s);
|
||||
scales = moveaxis(scales, -2, -1, s);
|
||||
scales = flatten(scales, -2, -1, s);
|
||||
|
||||
wq = unflatten(wq, -2, {-1, 4}, s);
|
||||
wq = moveaxis(wq, -2, -1, s);
|
||||
wq = flatten(wq, -2, -1, s);
|
||||
if (pow2_bits) {
|
||||
wq = unflatten(wq, -2, {-1, row_packing}, s);
|
||||
wq = moveaxis(wq, -2, -1, s);
|
||||
wq = flatten(wq, -2, -1, s);
|
||||
}
|
||||
|
||||
return std::make_tuple(wq, scales, std::nullopt);
|
||||
}
|
||||
|
@ -181,12 +181,14 @@ class QuantizedLinear(Module):
|
||||
bias: bool = True,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
quantization_type: str = "affine",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
self.quantization_type = quantization_type
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
@ -195,7 +197,9 @@ class QuantizedLinear(Module):
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
||||
self.weight, self.scales, self.biases = mx.quantize(
|
||||
weight, group_size, bits, quantization_type=quantization_type
|
||||
)
|
||||
|
||||
# And bias if needed
|
||||
if bias:
|
||||
@ -223,10 +227,11 @@ class QuantizedLinear(Module):
|
||||
x,
|
||||
self["weight"],
|
||||
scales=self["scales"],
|
||||
biases=self["biases"],
|
||||
biases=self.get("biases", None),
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
quantization_type=self.quantization_type,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + self["bias"]
|
||||
@ -242,7 +247,7 @@ class QuantizedLinear(Module):
|
||||
):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits, quantization_type)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, group_size, bits, quantization_type=quantization_type
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user