mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 03:18:12 +08:00
Faster contiguous gather for indices in the first axis (#2552)
* faster contiguous gather for indices in the first axis * work per thread > 1 * angelos suggestion for scales / biases
This commit is contained in:
@@ -98,11 +98,10 @@ class QuantizedEmbedding(Module):
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / dims)
|
||||
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
|
||||
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
|
||||
if mode == "affine":
|
||||
self.scales, self.biases = scales_biases
|
||||
else:
|
||||
(self.scales,) = scales_biases
|
||||
self.weight, self.scales, *biases = mx.quantize(
|
||||
weight, group_size, bits, mode=mode
|
||||
)
|
||||
self.biases = biases[0] if biases else None
|
||||
self.num_embeddings = num_embeddings
|
||||
self.dims = dims
|
||||
|
||||
@@ -155,16 +154,13 @@ class QuantizedEmbedding(Module):
|
||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||
embedding_dims, dims = embedding_layer.weight.shape
|
||||
ql = cls(embedding_dims, dims, group_size, bits, mode=mode)
|
||||
ql.weight, *scales_biases = mx.quantize(
|
||||
ql.weight, ql.scales, *biases = mx.quantize(
|
||||
embedding_layer.weight,
|
||||
group_size,
|
||||
bits,
|
||||
mode=mode,
|
||||
)
|
||||
if mode == "affine":
|
||||
ql.scales, ql.biases = scales_biases
|
||||
else:
|
||||
(ql.scales,) = scales_biases
|
||||
ql.biases = biases[0] if biases else None
|
||||
return ql
|
||||
|
||||
|
||||
@@ -214,11 +210,10 @@ class QuantizedLinear(Module):
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
|
||||
if mode == "affine":
|
||||
self.scales, self.biases = scales_biases
|
||||
else:
|
||||
(self.scales,) = scales_biases
|
||||
self.weight, self.scales, *biases = mx.quantize(
|
||||
weight, group_size, bits, mode=mode
|
||||
)
|
||||
self.biases = biases[0] if biases else None
|
||||
|
||||
# And bias if needed
|
||||
if bias:
|
||||
@@ -261,16 +256,13 @@ class QuantizedLinear(Module):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits, mode=mode)
|
||||
ql.weight, *scales_biases = mx.quantize(
|
||||
ql.weight, ql.scales, *biases = mx.quantize(
|
||||
linear_layer.weight,
|
||||
group_size,
|
||||
bits,
|
||||
mode=mode,
|
||||
)
|
||||
if mode == "affine":
|
||||
ql.scales, ql.biases = scales_biases
|
||||
else:
|
||||
(ql.scales,) = scales_biases
|
||||
ql.biases = biases[0] if biases else None
|
||||
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
|
||||
Reference in New Issue
Block a user