mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
An initial quantized matmul implementation (#205)
* Add quantized matvec * Add quantized matrix matrix with 2nd matrix transposed * Add quantized matmul tests * Add a slow cpu quantized matmul * Add a slightly faster vectorized cpu version
This commit is contained in:

committed by
GitHub

parent
e6872a4149
commit
dfa9f4bc58
@@ -23,6 +23,16 @@ def none_or_list(x):
|
||||
return [int(xi) for xi in x.split(",")]
|
||||
|
||||
|
||||
def dtype_from_str(x):
|
||||
if x == "":
|
||||
return mx.float32
|
||||
else:
|
||||
dt = getattr(mx, x)
|
||||
if not isinstance(dt, mx.Dtype):
|
||||
raise ValueError(f"{x} is not an mlx dtype")
|
||||
return dt
|
||||
|
||||
|
||||
def bench(f, *args):
|
||||
for i in range(10):
|
||||
f(*args)
|
||||
@@ -49,6 +59,15 @@ def matmul(x, y):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def quant_matmul(x, w, s, b):
|
||||
groups = x.shape[-1] // s.shape[-1]
|
||||
width = 32 // (x.shape[-1] // w.shape[0])
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.quantized_matmul(x, w, s, b, groups=groups, width=width))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def conv1d(x, y):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
@@ -296,9 +315,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--fused", action="store_true", help="Use fused functions where possible"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", choices=["float32", "float16", "bfloat16"], default="float32"
|
||||
)
|
||||
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -315,11 +332,15 @@ if __name__ == "__main__":
|
||||
mx.set_default_device(mx.cpu)
|
||||
else:
|
||||
mx.set_default_device(mx.gpu)
|
||||
dtype = dict(float32=mx.float32, float16=mx.float16, bfloat16=mx.bfloat16)[
|
||||
args.dtype
|
||||
]
|
||||
|
||||
types = args.dtype
|
||||
if not types:
|
||||
types = [mx.float32]
|
||||
if len(types) < len(args.size):
|
||||
types = types + [types[0]] * (len(args.size) - len(types))
|
||||
|
||||
xs = []
|
||||
for size in args.size:
|
||||
for size, dtype in zip(args.size, types):
|
||||
xs.append(mx.random.normal(size).astype(dtype))
|
||||
for i, t in enumerate(args.transpose):
|
||||
if t is None:
|
||||
@@ -335,6 +356,9 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "matmul":
|
||||
print(bench(matmul, *xs))
|
||||
|
||||
elif args.benchmark == "quant_matmul":
|
||||
print(bench(quant_matmul, *xs))
|
||||
|
||||
elif args.benchmark == "linear":
|
||||
print(bench(linear, *xs))
|
||||
|
||||
|
Reference in New Issue
Block a user