mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/` * Update gemm elements for better performance * Add split-K specialization for gemm * Add `addmm` primitive, op and bindings for fused matmul and bias addition * Update tests and benchmarks as needed
This commit is contained in:
@@ -63,9 +63,10 @@ class Linear(Module):
|
||||
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = x @ self.weight.T
|
||||
if "bias" in self:
|
||||
x = x + self.bias
|
||||
x = mx.addmm(self.bias, x, self.weight.T)
|
||||
else:
|
||||
x = x @ self.weight.T
|
||||
return x
|
||||
|
||||
|
||||
|
@@ -3476,4 +3476,34 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
result (array): The tiled array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"addmm",
|
||||
&addmm,
|
||||
"c"_a,
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
py::pos_only(),
|
||||
"alpha"_a = 1.0f,
|
||||
"beta"_a = 1.0f,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
addmm(c: array, a: array, b: array, /, alpha: float = 1.0, beta: float = 1.0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Matrix multiplication with addition and optional scaling.
|
||||
|
||||
Perform the (possibly batched) matrix multiplication of two arrays and add to the result
|
||||
with optional scaling factors.
|
||||
|
||||
Args:
|
||||
c (array): Input array or scalar.
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
alpha (float, optional): Scaling factor for the
|
||||
matrix product of ``a`` and ``b`` (default: ``1``)
|
||||
beta (float, optional): Scaling factor for ``c`` (default: ``1``)
|
||||
|
||||
Returns:
|
||||
array: ``alpha * (a @ b) + beta * c``
|
||||
)pbdoc");
|
||||
}
|
||||
|
@@ -74,6 +74,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
if mx.default_device() == mx.gpu:
|
||||
shapes += [
|
||||
(16, 768, 768, 128),
|
||||
(1, 64, 64, 4096),
|
||||
]
|
||||
|
||||
for dtype in self.dtypes:
|
||||
@@ -444,3 +445,139 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
list(c_npy.shape), list(c_mlx.shape)
|
||||
)
|
||||
self.assertTrue(np.array_equal(c_mlx, c_npy))
|
||||
|
||||
def test_addmm(self):
|
||||
np.random.seed(0)
|
||||
# Batched matmul
|
||||
alpha = 0.5
|
||||
beta = 2.0
|
||||
|
||||
# Regular batched case
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# Batched and transposed matmul
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
for c_shape in ((1,), (32, 1, 128), (1, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
b_np_t = np.transpose(b_npy, (0, 2, 1))
|
||||
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
|
||||
|
||||
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# # Batched matmul with simple broadcast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
for c_shape in (
|
||||
(1,),
|
||||
(32, 128),
|
||||
):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# Split K specializtion
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
def test_addmm_grad(self):
|
||||
def make_ref_addmm(alpha, beta):
|
||||
return lambda c, a, b: alpha * (a @ b) + beta * c
|
||||
|
||||
def make_addmm(alpha, beta):
|
||||
return lambda c, a, b: mx.addmm(c, a, b, alpha, beta)
|
||||
|
||||
# B, M, N, K
|
||||
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))
|
||||
|
||||
alpha = 2.0
|
||||
beta = 0.5
|
||||
|
||||
f_test = make_addmm(alpha, beta)
|
||||
f_ref = make_ref_addmm(alpha, beta)
|
||||
|
||||
for B, M, N, K in shapes:
|
||||
cotan = mx.ones((B, M, N))
|
||||
c = mx.random.normal((B, M, N))
|
||||
a = mx.random.normal((B, M, K))
|
||||
b = mx.random.normal((B, K, N))
|
||||
|
||||
out_ref, dout_ref = mx.vjp(
|
||||
f_ref,
|
||||
[c, a, b],
|
||||
[
|
||||
cotan,
|
||||
],
|
||||
)
|
||||
out_test, dout_test = mx.vjp(
|
||||
f_test,
|
||||
[c, a, b],
|
||||
[
|
||||
cotan,
|
||||
],
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item())
|
||||
|
||||
for r, t in zip(dout_ref, dout_test):
|
||||
self.assertListEqual(r.shape, t.shape)
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-5).item())
|
||||
|
Reference in New Issue
Block a user