mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Allow arbitrary first dimension in quantization kernels. (#458)
* Allow arbitrary first dim on qmm_t and qmv * Allow arbitrary first dim on qmm and qvm * Specialized aligned vs unaligned case * Add more checks for valid quantizations
This commit is contained in:

committed by
GitHub

parent
f44c132f4a
commit
c15fe3e61b
@@ -60,20 +60,48 @@ def matmul(x, y):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def _quant_matmul(x, w, s, b, group_size, bits):
|
||||
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.quantized_matmul(x, w, s, b, group_size=group_size, bits=bits))
|
||||
ys.append(
|
||||
mx.quantized_matmul(
|
||||
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
|
||||
)
|
||||
)
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
quant_matmul = {
|
||||
"quant_matmul_64_2": partial(_quant_matmul, group_size=64, bits=2),
|
||||
"quant_matmul_64_4": partial(_quant_matmul, group_size=64, bits=4),
|
||||
"quant_matmul_64_8": partial(_quant_matmul, group_size=64, bits=8),
|
||||
"quant_matmul_128_2": partial(_quant_matmul, group_size=128, bits=2),
|
||||
"quant_matmul_128_4": partial(_quant_matmul, group_size=128, bits=4),
|
||||
"quant_matmul_128_8": partial(_quant_matmul, group_size=128, bits=8),
|
||||
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
|
||||
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
|
||||
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
|
||||
"quant_matmul_128_2": partial(
|
||||
_quant_matmul, transpose=False, group_size=128, bits=2
|
||||
),
|
||||
"quant_matmul_128_4": partial(
|
||||
_quant_matmul, transpose=False, group_size=128, bits=4
|
||||
),
|
||||
"quant_matmul_128_8": partial(
|
||||
_quant_matmul, transpose=False, group_size=128, bits=8
|
||||
),
|
||||
"quant_matmul_t_64_2": partial(
|
||||
_quant_matmul, transpose=True, group_size=64, bits=2
|
||||
),
|
||||
"quant_matmul_t_64_4": partial(
|
||||
_quant_matmul, transpose=True, group_size=64, bits=4
|
||||
),
|
||||
"quant_matmul_t_64_8": partial(
|
||||
_quant_matmul, transpose=True, group_size=64, bits=8
|
||||
),
|
||||
"quant_matmul_t_128_2": partial(
|
||||
_quant_matmul, transpose=True, group_size=128, bits=2
|
||||
),
|
||||
"quant_matmul_t_128_4": partial(
|
||||
_quant_matmul, transpose=True, group_size=128, bits=4
|
||||
),
|
||||
"quant_matmul_t_128_8": partial(
|
||||
_quant_matmul, transpose=True, group_size=128, bits=8
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user