diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index 029e94aab..777b31c02 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -131,10 +131,6 @@ void Matmul::eval_cpu(const std::vector& inputs, array& out) { } void AddMM::eval_cpu(const std::vector& inputs, array& out) { - if (out.dtype() != float32) { - throw std::runtime_error( - "[AddMM::eval_cpu] Currently only supports float32."); - } if (out.size() == 0) { out.set_data(allocator::malloc(out.nbytes())); return; diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 67289ceef..82c63e3d8 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -712,6 +712,15 @@ class TestBlas(mlx_tests.MLXTestCase): expected = beta * c + alpha * (a @ b) self.assertTrue(mx.allclose(expected, out)) + # Test half precision + for t, tol in [(mx.float16, 1e-3), (mx.bfloat16, 1e-2)]: + c = mx.ones((32, 32)).astype(t) + a = mx.random.uniform(shape=(32, 32)).astype(t) + b = mx.random.uniform(shape=(32, 32)).astype(t) + out = mx.addmm(c, a, b) + expected = a @ b + c + self.assertTrue(mx.allclose(out, expected, rtol=tol, atol=tol)) + def test_addmm_grad(self): def make_ref_addmm(alpha, beta): return lambda c, a, b: alpha * (a @ b) + beta * c