Add 3bit packed quants

This commit is contained in:
Angelos Katharopoulos 2024-12-17 10:08:47 -08:00
parent 14420949d2
commit d75a509234
4 changed files with 112 additions and 29 deletions

View File

@ -2295,6 +2295,68 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
} }
} }
template <typename T, int group_size, int bits, int results_per_simdgroup>
METAL_FUNC void affine_packed_byte_qmv_fast_impl(
const device uint8_t* w,
const device vec<T, 2 * results_per_simdgroup>* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int packs_per_thread = 2;
constexpr int num_simdgroups = 2;
constexpr int pack_factor = (bits == 3) ? 8 : 4;
;
constexpr int bytes_per_pack = 3;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
typedef float U;
thread U x_thread[values_per_thread];
vec<U, results_per_simdgroup> result = 0;
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int scales_row = tid.x * num_simdgroups + simd_gid;
const int out_row = scales_row * results_per_simdgroup;
w += out_row * in_vec_size_w + simd_lid * (packs_per_thread * bytes_per_pack);
scales += scales_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
for (int k = 0; k < in_vec_size; k += block_size) {
// Load the input vector
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
// Load the scales and biases
vec<T, 2 * results_per_simdgroup> sb = scales[0];
// Load the weights and perform the partial dot product
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] += qdot<U, values_per_thread, bits>(
w + row * in_vec_size_w, x_thread, sb[row], sb[2 + row], sum);
}
w += block_size * bytes_per_pack / pack_factor;
scales += block_size / group_size;
x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(result[row]);
}
}
}
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
[[kernel]] void affine_packed_qmv_fast( [[kernel]] void affine_packed_qmv_fast(
const device vec<uint32_t, 4>* w [[buffer(0)]], const device vec<uint32_t, 4>* w [[buffer(0)]],
@ -2306,8 +2368,21 @@ template <typename T, int group_size, int bits>
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
affine_packed_qmv_fast_impl<T, group_size, bits>( if (bits & (bits - 1)) {
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); affine_packed_byte_qmv_fast_impl<T, group_size, bits, 2>(
(const device uint8_t*)w,
scales,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
} else {
affine_packed_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
} }
template < template <
@ -2617,6 +2692,9 @@ template <
s_strides, s_strides,
tid); tid);
} }
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>( if (bits & (bits - 1)) {
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } else {
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
} }

View File

@ -405,10 +405,11 @@ void affine_packed_qmv(
auto w = ensure_row_contiguous_last_dims(inputs[1]); auto w = ensure_row_contiguous_last_dims(inputs[1]);
auto scales = ensure_row_contiguous_last_dims(inputs[2]); auto scales = ensure_row_contiguous_last_dims(inputs[2]);
const bool pow2_bits = (bits & (bits - 1)) == 0;
const int n_simdgroups = 2; const int n_simdgroups = 2;
const int n_outs_per_simdgroup = 4; const int results_per_simdgroup = (pow2_bits) ? 4 : 2;
MTL::Size group_dims(32, n_simdgroups, 1); MTL::Size group_dims(32, n_simdgroups, 1);
MTL::Size grid_dims(O / n_simdgroups / n_outs_per_simdgroup, B, 1); MTL::Size grid_dims(O / n_simdgroups / results_per_simdgroup, B, 1);
std::string name; std::string name;
name.reserve(64); name.reserve(64);

View File

@ -99,15 +99,11 @@ std::pair<int, int> extract_quantized_matmul_dims(
<< "biases but biases were provided"; << "biases but biases were provided";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (bits & (bits - 1)) {
std::ostringstream msg;
msg << "[" << tag << "] Quantization type '" << quantization_type
<< "' does not support " << bits << " bits.";
throw std::invalid_argument(msg.str());
}
break; break;
} }
bool pow2_bits = (bits & (bits - 1)) == 0;
if (w.dtype() != uint32) { if (w.dtype() != uint32) {
std::ostringstream msg; std::ostringstream msg;
msg << "[" << tag << "] The weight matrix should be uint32 " msg << "[" << tag << "] The weight matrix should be uint32 "
@ -136,8 +132,12 @@ std::pair<int, int> extract_quantized_matmul_dims(
int weight_dims = w.shape(-1) * 32 / bits; int weight_dims = w.shape(-1) * 32 / bits;
int scales_dims = scales.shape(-1) * group_size; int scales_dims = scales.shape(-1) * group_size;
if (quantization_type == QuantizationType::AffinePacked) { if (quantization_type == QuantizationType::AffinePacked) {
scales_dims /= 8; if (pow2_bits) {
weight_dims /= 4; scales_dims /= 8;
weight_dims /= 4;
} else {
scales_dims /= 4;
}
} }
if (weight_dims != scales_dims) { if (weight_dims != scales_dims) {
@ -155,7 +155,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
// Calculate the expanded w's dims // Calculate the expanded w's dims
int weight_dims_other = w.shape(-2); int weight_dims_other = w.shape(-2);
if (quantization_type == QuantizationType::AffinePacked) { if (quantization_type == QuantizationType::AffinePacked && pow2_bits) {
weight_dims_other *= 4; weight_dims_other *= 4;
} }
int w_inner_dims = (transpose) ? weight_dims : weight_dims_other; int w_inner_dims = (transpose) ? weight_dims : weight_dims_other;
@ -3793,23 +3793,22 @@ std::tuple<array, array, std::optional<array>> quantize(
case QuantizationType::Affine: case QuantizationType::Affine:
return fast::affine_quantize(w, group_size, bits, s); return fast::affine_quantize(w, group_size, bits, s);
case QuantizationType::AffinePacked: { case QuantizationType::AffinePacked: {
if (bits & (bits - 1)) {
std::ostringstream msg;
msg << "[quantize] Quantization type '" << quantization_type
<< "' does not support " << bits << " bits.";
throw std::invalid_argument(msg.str());
}
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s); auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
scales = unflatten(scales, -2, {-1, 4}, s); int pow2_bits = (bits & (bits - 1)) == 0;
biases = unflatten(biases, -2, {-1, 4}, s); int row_packing = (pow2_bits) ? 4 : 2;
scales = unflatten(scales, -2, {-1, row_packing}, s);
biases = unflatten(biases, -2, {-1, row_packing}, s);
scales = concatenate({scales, biases}, -2, s); scales = concatenate({scales, biases}, -2, s);
scales = moveaxis(scales, -2, -1, s); scales = moveaxis(scales, -2, -1, s);
scales = flatten(scales, -2, -1, s); scales = flatten(scales, -2, -1, s);
wq = unflatten(wq, -2, {-1, 4}, s); if (pow2_bits) {
wq = moveaxis(wq, -2, -1, s); wq = unflatten(wq, -2, {-1, row_packing}, s);
wq = flatten(wq, -2, -1, s); wq = moveaxis(wq, -2, -1, s);
wq = flatten(wq, -2, -1, s);
}
return std::make_tuple(wq, scales, std::nullopt); return std::make_tuple(wq, scales, std::nullopt);
} }

View File

@ -181,12 +181,14 @@ class QuantizedLinear(Module):
bias: bool = True, bias: bool = True,
group_size: int = 64, group_size: int = 64,
bits: int = 4, bits: int = 4,
quantization_type: str = "affine",
): ):
super().__init__() super().__init__()
# Quantization config # Quantization config
self.group_size = group_size self.group_size = group_size
self.bits = bits self.bits = bits
self.quantization_type = quantization_type
# Initialize the quantized weight # Initialize the quantized weight
scale = math.sqrt(1 / input_dims) scale = math.sqrt(1 / input_dims)
@ -195,7 +197,9 @@ class QuantizedLinear(Module):
high=scale, high=scale,
shape=(output_dims, input_dims), shape=(output_dims, input_dims),
) )
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) self.weight, self.scales, self.biases = mx.quantize(
weight, group_size, bits, quantization_type=quantization_type
)
# And bias if needed # And bias if needed
if bias: if bias:
@ -223,10 +227,11 @@ class QuantizedLinear(Module):
x, x,
self["weight"], self["weight"],
scales=self["scales"], scales=self["scales"],
biases=self["biases"], biases=self.get("biases", None),
transpose=True, transpose=True,
group_size=self.group_size, group_size=self.group_size,
bits=self.bits, bits=self.bits,
quantization_type=self.quantization_type,
) )
if "bias" in self: if "bias" in self:
x = x + self["bias"] x = x + self["bias"]
@ -242,7 +247,7 @@ class QuantizedLinear(Module):
): ):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits) ql = cls(input_dims, output_dims, False, group_size, bits, quantization_type)
ql.weight, ql.scales, ql.biases = mx.quantize( ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits, quantization_type=quantization_type linear_layer.weight, group_size, bits, quantization_type=quantization_type
) )