diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index b70f61e3d..9ee470888 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -119,7 +119,6 @@ class MatMul { uint64_t b_rows, uint64_t b_cols, int64_t ldb, - bool c_transposed, int64_t ldc, int32_t batch_count, int64_t a_batch_stride, @@ -141,7 +140,7 @@ class MatMul { b_batch_stride) { auto type = dtype_to_cuda_type(dtype); c_desc_ = create_matrix_layout( - type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride); + type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); } ~MatMul() { @@ -404,9 +403,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; - auto& c_pre = inputs[2]; - - out.set_data(allocator::malloc(out.nbytes())); + auto c = inputs[2]; ///////////////////////////////////////////////////////////////////////////// // Init checks and prep @@ -419,7 +416,24 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // the arrays auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); - auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre); + + int64_t ldc; + { + auto stx = c.strides()[c.ndim() - 2]; + auto sty = c.strides()[c.ndim() - 1]; + if (sty == 1 && stx == c.shape(-1)) { + ldc = stx; + out.set_data(allocator::malloc(out.nbytes())); + } else if (sty == 1 && stx == 0) { + ldc = 0; + out.set_data(allocator::malloc(out.nbytes())); + } else { + // Copy C into out and set C to out + ldc = c.shape(-1); + copy_gpu(c, out, CopyType::General, s); + c = out; + } + } ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions @@ -457,7 +471,6 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { K, N, ldb, - c_transposed, ldc, batch_shape.back(), a_batch_strides.back(), diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 5e096d9c5..7cc39f06a 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + # Transposed c + a = mx.ones((10, 5)).T + b = mx.ones((5, 5)) + out = mx.addmm(a, b, a, beta=1.5, alpha=0.5) + expected = 1.5 * a + 0.5 * (b @ a) + self.assertTrue(mx.allclose(expected, out)) + + # Broadcast c + a = mx.ones((5, 5)) + b = mx.ones((5, 5)) + c = mx.ones((1, 5)) + out = mx.addmm(c, a, b, beta=1.5, alpha=0.5) + expected = 1.5 * c + 0.5 * (a @ b) + self.assertTrue(mx.allclose(expected, out)) + def test_addmm_grad(self): def make_ref_addmm(alpha, beta): return lambda c, a, b: alpha * (a @ b) + beta * c