From 380aeb58ae159f38b03050acc860a71a76a470a0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 10 Oct 2025 09:50:54 -0700 Subject: [PATCH] enable admm low-precision cpu (#2661) --- mlx/backend/cpu/matmul.cpp | 4 ---- python/tests/test_blas.py | 9 +++++++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index 029e94aab..777b31c02 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -131,10 +131,6 @@ void Matmul::eval_cpu(const std::vector& inputs, array& out) { } void AddMM::eval_cpu(const std::vector& inputs, array& out) { - if (out.dtype() != float32) { - throw std::runtime_error( - "[AddMM::eval_cpu] Currently only supports float32."); - } if (out.size() == 0) { out.set_data(allocator::malloc(out.nbytes())); return; diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 67289ceef..82c63e3d8 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -712,6 +712,15 @@ class TestBlas(mlx_tests.MLXTestCase): expected = beta * c + alpha * (a @ b) self.assertTrue(mx.allclose(expected, out)) + # Test half precision + for t, tol in [(mx.float16, 1e-3), (mx.bfloat16, 1e-2)]: + c = mx.ones((32, 32)).astype(t) + a = mx.random.uniform(shape=(32, 32)).astype(t) + b = mx.random.uniform(shape=(32, 32)).astype(t) + out = mx.addmm(c, a, b) + expected = a @ b + c + self.assertTrue(mx.allclose(out, expected, rtol=tol, atol=tol)) + def test_addmm_grad(self): def make_ref_addmm(alpha, beta): return lambda c, a, b: alpha * (a @ b) + beta * c