An initial quantized matmul implementation (#205)

* Add quantized matvec
* Add quantized matrix matrix with 2nd matrix transposed
* Add quantized matmul tests
* Add a slow cpu quantized matmul
* Add a slightly faster vectorized cpu version
This commit is contained in:
Angelos Katharopoulos
2023-12-18 23:18:57 -08:00
committed by GitHub
parent e6872a4149
commit dfa9f4bc58
18 changed files with 1029 additions and 10 deletions

View File

@@ -23,6 +23,16 @@ def none_or_list(x):
return [int(xi) for xi in x.split(",")]
def dtype_from_str(x):
if x == "":
return mx.float32
else:
dt = getattr(mx, x)
if not isinstance(dt, mx.Dtype):
raise ValueError(f"{x} is not an mlx dtype")
return dt
def bench(f, *args):
for i in range(10):
f(*args)
@@ -49,6 +59,15 @@ def matmul(x, y):
mx.eval(ys)
def quant_matmul(x, w, s, b):
groups = x.shape[-1] // s.shape[-1]
width = 32 // (x.shape[-1] // w.shape[0])
ys = []
for i in range(10):
ys.append(mx.quantized_matmul(x, w, s, b, groups=groups, width=width))
mx.eval(ys)
def conv1d(x, y):
ys = []
for i in range(10):
@@ -296,9 +315,7 @@ if __name__ == "__main__":
parser.add_argument(
"--fused", action="store_true", help="Use fused functions where possible"
)
parser.add_argument(
"--dtype", choices=["float32", "float16", "bfloat16"], default="float32"
)
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
args = parser.parse_args()
@@ -315,11 +332,15 @@ if __name__ == "__main__":
mx.set_default_device(mx.cpu)
else:
mx.set_default_device(mx.gpu)
dtype = dict(float32=mx.float32, float16=mx.float16, bfloat16=mx.bfloat16)[
args.dtype
]
types = args.dtype
if not types:
types = [mx.float32]
if len(types) < len(args.size):
types = types + [types[0]] * (len(args.size) - len(types))
xs = []
for size in args.size:
for size, dtype in zip(args.size, types):
xs.append(mx.random.normal(size).astype(dtype))
for i, t in enumerate(args.transpose):
if t is None:
@@ -335,6 +356,9 @@ if __name__ == "__main__":
elif args.benchmark == "matmul":
print(bench(matmul, *xs))
elif args.benchmark == "quant_matmul":
print(bench(quant_matmul, *xs))
elif args.benchmark == "linear":
print(bench(linear, *xs))

View File

@@ -22,6 +22,16 @@ def none_or_list(x):
return [int(xi) for xi in x.split(",")]
def dtype_from_str(x):
if x == "":
return torch.float32
else:
dt = getattr(torch, x)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{x} is not a torch dtype")
return dt
def bench(f, *args):
for i in range(10):
f(*args)
@@ -312,7 +322,7 @@ if __name__ == "__main__":
parser.add_argument(
"--fused", action="store_true", help="Use fused functions where possible"
)
parser.add_argument("--dtype", choices=["float32", "float16"], default="float32")
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
args = parser.parse_args()
@@ -327,9 +337,15 @@ if __name__ == "__main__":
torch.set_num_threads(1)
device = "cpu" if args.cpu else "mps"
dtype = dict(float32=torch.float32, float16=torch.float16)[args.dtype]
types = args.dtype
if not types:
types = [torch.float32]
if len(types) < len(args.size):
types = types + [types[0]] * (len(args.size) - len(types))
xs = []
for size in args.size:
for size, dtype in zip(args.size, types):
xs.append(torch.randn(*size).to(device).to(dtype))
for i, t in enumerate(args.transpose):
if t is None: