Support for quantized matmul with w and w^T (#349)

* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
This commit is contained in:
Angelos Katharopoulos
2024-01-03 14:22:36 -08:00
committed by GitHub
parent d7ac050f4b
commit e7f5059fe4
12 changed files with 718 additions and 193 deletions

View File

@@ -4,6 +4,7 @@ import argparse
import math
import os
import time
from functools import partial
import mlx.core as mx
import mlx.nn as nn
@@ -59,15 +60,23 @@ def matmul(x, y):
mx.eval(ys)
def quant_matmul(x, w, s, b):
groups = x.shape[-1] // s.shape[-1]
width = 32 // (x.shape[-1] // w.shape[0])
def _quant_matmul(x, w, s, b, group_size, bits):
ys = []
for i in range(10):
ys.append(mx.quantized_matmul(x, w, s, b, groups=groups, width=width))
ys.append(mx.quantized_matmul(x, w, s, b, group_size=group_size, bits=bits))
mx.eval(ys)
quant_matmul = {
"quant_matmul_64_2": partial(_quant_matmul, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, group_size=64, bits=8),
"quant_matmul_128_2": partial(_quant_matmul, group_size=128, bits=2),
"quant_matmul_128_4": partial(_quant_matmul, group_size=128, bits=4),
"quant_matmul_128_8": partial(_quant_matmul, group_size=128, bits=8),
}
def conv1d(x, y):
ys = []
for i in range(10):
@@ -356,8 +365,8 @@ if __name__ == "__main__":
elif args.benchmark == "matmul":
print(bench(matmul, *xs))
elif args.benchmark == "quant_matmul":
print(bench(quant_matmul, *xs))
elif args.benchmark.startswith("quant_matmul"):
print(bench(quant_matmul[args.benchmark], *xs))
elif args.benchmark == "linear":
print(bench(linear, *xs))