Faster contiguous gather for indices in the first axis (#2552)

* faster contiguous gather for indices in the first axis

* work per thread > 1

* angelos suggestion for scales / biases
This commit is contained in:
Awni Hannun
2025-08-28 21:26:30 -07:00
committed by GitHub
parent 827003d568
commit 111f1e71af
10 changed files with 97 additions and 33 deletions

View File

@@ -98,11 +98,10 @@ class QuantizedEmbedding(Module):
# Initialize the quantized weight
scale = math.sqrt(1 / dims)
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
if mode == "affine":
self.scales, self.biases = scales_biases
else:
(self.scales,) = scales_biases
self.weight, self.scales, *biases = mx.quantize(
weight, group_size, bits, mode=mode
)
self.biases = biases[0] if biases else None
self.num_embeddings = num_embeddings
self.dims = dims
@@ -155,16 +154,13 @@ class QuantizedEmbedding(Module):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape
ql = cls(embedding_dims, dims, group_size, bits, mode=mode)
ql.weight, *scales_biases = mx.quantize(
ql.weight, ql.scales, *biases = mx.quantize(
embedding_layer.weight,
group_size,
bits,
mode=mode,
)
if mode == "affine":
ql.scales, ql.biases = scales_biases
else:
(ql.scales,) = scales_biases
ql.biases = biases[0] if biases else None
return ql
@@ -214,11 +210,10 @@ class QuantizedLinear(Module):
high=scale,
shape=(output_dims, input_dims),
)
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
if mode == "affine":
self.scales, self.biases = scales_biases
else:
(self.scales,) = scales_biases
self.weight, self.scales, *biases = mx.quantize(
weight, group_size, bits, mode=mode
)
self.biases = biases[0] if biases else None
# And bias if needed
if bias:
@@ -261,16 +256,13 @@ class QuantizedLinear(Module):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits, mode=mode)
ql.weight, *scales_biases = mx.quantize(
ql.weight, ql.scales, *biases = mx.quantize(
linear_layer.weight,
group_size,
bits,
mode=mode,
)
if mode == "affine":
ql.scales, ql.biases = scales_biases
else:
(ql.scales,) = scales_biases
ql.biases = biases[0] if biases else None
if "bias" in linear_layer:
ql.bias = linear_layer.bias