Update GEMM (#424)

* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance 
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition 
* Update tests and benchmarks as needed
This commit is contained in:
Jagrit Digani
2024-01-17 12:42:39 -08:00
committed by GitHub
parent 556cdf0e06
commit 78102a47ad
30 changed files with 2361 additions and 646 deletions

View File

@@ -63,9 +63,10 @@ class Linear(Module):
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
def __call__(self, x: mx.array) -> mx.array:
x = x @ self.weight.T
if "bias" in self:
x = x + self.bias
x = mx.addmm(self.bias, x, self.weight.T)
else:
x = x @ self.weight.T
return x

View File

@@ -3476,4 +3476,34 @@ void init_ops(py::module_& m) {
Returns:
result (array): The tiled array.
)pbdoc");
m.def(
"addmm",
&addmm,
"c"_a,
"a"_a,
"b"_a,
py::pos_only(),
"alpha"_a = 1.0f,
"beta"_a = 1.0f,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
addmm(c: array, a: array, b: array, /, alpha: float = 1.0, beta: float = 1.0, *, stream: Union[None, Stream, Device] = None) -> array
Matrix multiplication with addition and optional scaling.
Perform the (possibly batched) matrix multiplication of two arrays and add to the result
with optional scaling factors.
Args:
c (array): Input array or scalar.
a (array): Input array or scalar.
b (array): Input array or scalar.
alpha (float, optional): Scaling factor for the
matrix product of ``a`` and ``b`` (default: ``1``)
beta (float, optional): Scaling factor for ``c`` (default: ``1``)
Returns:
array: ``alpha * (a @ b) + beta * c``
)pbdoc");
}

View File

@@ -74,6 +74,7 @@ class TestBlas(mlx_tests.MLXTestCase):
if mx.default_device() == mx.gpu:
shapes += [
(16, 768, 768, 128),
(1, 64, 64, 4096),
]
for dtype in self.dtypes:
@@ -444,3 +445,139 @@ class TestBlas(mlx_tests.MLXTestCase):
list(c_npy.shape), list(c_mlx.shape)
)
self.assertTrue(np.array_equal(c_mlx, c_npy))
def test_addmm(self):
np.random.seed(0)
# Batched matmul
alpha = 0.5
beta = 2.0
# Regular batched case
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Batched and transposed matmul
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (32, 1, 128), (1, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
b_np_t = np.transpose(b_npy, (0, 2, 1))
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# # Batched matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in (
(1,),
(32, 128),
):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Split K specializtion
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
def test_addmm_grad(self):
def make_ref_addmm(alpha, beta):
return lambda c, a, b: alpha * (a @ b) + beta * c
def make_addmm(alpha, beta):
return lambda c, a, b: mx.addmm(c, a, b, alpha, beta)
# B, M, N, K
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))
alpha = 2.0
beta = 0.5
f_test = make_addmm(alpha, beta)
f_ref = make_ref_addmm(alpha, beta)
for B, M, N, K in shapes:
cotan = mx.ones((B, M, N))
c = mx.random.normal((B, M, N))
a = mx.random.normal((B, M, K))
b = mx.random.normal((B, K, N))
out_ref, dout_ref = mx.vjp(
f_ref,
[c, a, b],
[
cotan,
],
)
out_test, dout_test = mx.vjp(
f_test,
[c, a, b],
[
cotan,
],
)
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item())
for r, t in zip(dout_ref, dout_test):
self.assertListEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-5).item())