mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Support for quantized matmul with w and w^T (#349)
* Add the metal qvm implementation * Add qmm_n * Add gradient wrt to input for quantized_matmul
This commit is contained in:

committed by
GitHub

parent
d7ac050f4b
commit
e7f5059fe4
@@ -4,6 +4,7 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -59,15 +60,23 @@ def matmul(x, y):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def quant_matmul(x, w, s, b):
|
||||
groups = x.shape[-1] // s.shape[-1]
|
||||
width = 32 // (x.shape[-1] // w.shape[0])
|
||||
def _quant_matmul(x, w, s, b, group_size, bits):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.quantized_matmul(x, w, s, b, groups=groups, width=width))
|
||||
ys.append(mx.quantized_matmul(x, w, s, b, group_size=group_size, bits=bits))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
quant_matmul = {
|
||||
"quant_matmul_64_2": partial(_quant_matmul, group_size=64, bits=2),
|
||||
"quant_matmul_64_4": partial(_quant_matmul, group_size=64, bits=4),
|
||||
"quant_matmul_64_8": partial(_quant_matmul, group_size=64, bits=8),
|
||||
"quant_matmul_128_2": partial(_quant_matmul, group_size=128, bits=2),
|
||||
"quant_matmul_128_4": partial(_quant_matmul, group_size=128, bits=4),
|
||||
"quant_matmul_128_8": partial(_quant_matmul, group_size=128, bits=8),
|
||||
}
|
||||
|
||||
|
||||
def conv1d(x, y):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
@@ -356,8 +365,8 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "matmul":
|
||||
print(bench(matmul, *xs))
|
||||
|
||||
elif args.benchmark == "quant_matmul":
|
||||
print(bench(quant_matmul, *xs))
|
||||
elif args.benchmark.startswith("quant_matmul"):
|
||||
print(bench(quant_matmul[args.benchmark], *xs))
|
||||
|
||||
elif args.benchmark == "linear":
|
||||
print(bench(linear, *xs))
|
||||
|
Reference in New Issue
Block a user