fix addmm

This commit is contained in:
Awni Hannun
2025-07-18 08:41:22 -07:00
parent 508bd25e29
commit 8435c047e1
2 changed files with 35 additions and 7 deletions

View File

@@ -119,7 +119,6 @@ class MatMul {
uint64_t b_rows, uint64_t b_rows,
uint64_t b_cols, uint64_t b_cols,
int64_t ldb, int64_t ldb,
bool c_transposed,
int64_t ldc, int64_t ldc,
int32_t batch_count, int32_t batch_count,
int64_t a_batch_stride, int64_t a_batch_stride,
@@ -141,7 +140,7 @@ class MatMul {
b_batch_stride) { b_batch_stride) {
auto type = dtype_to_cuda_type(dtype); auto type = dtype_to_cuda_type(dtype);
c_desc_ = create_matrix_layout( c_desc_ = create_matrix_layout(
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride); type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
} }
~MatMul() { ~MatMul() {
@@ -404,9 +403,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3); assert(inputs.size() == 3);
auto& a_pre = inputs[0]; auto& a_pre = inputs[0];
auto& b_pre = inputs[1]; auto& b_pre = inputs[1];
auto& c_pre = inputs[2]; auto c = inputs[2];
out.set_data(allocator::malloc(out.nbytes()));
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Init checks and prep // Init checks and prep
@@ -419,7 +416,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// the arrays // the arrays
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
int64_t ldc;
{
auto stx = c.strides()[c.ndim() - 2];
auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) {
ldc = stx;
out.set_data(allocator::malloc(out.nbytes()));
} else if (sty == 1 && stx == 0) {
ldc = 0;
out.set_data(allocator::malloc(out.nbytes()));
} else {
// Copy C into out and set C to out
ldc = c.shape(-1);
copy_gpu(c, out, CopyType::General, s);
c = out;
}
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions // Check and collapse batch dimensions
@@ -457,7 +471,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
K, K,
N, N,
ldb, ldb,
c_transposed,
ldc, ldc,
batch_shape.back(), batch_shape.back(),
a_batch_strides.back(), a_batch_strides.back(),

View File

@@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Transposed c
a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
expected = 1.5 * a + 0.5 * (b @ a)
self.assertTrue(mx.allclose(expected, out))
# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
expected = 1.5 * c + 0.5 * (a @ b)
self.assertTrue(mx.allclose(expected, out))
def test_addmm_grad(self): def test_addmm_grad(self):
def make_ref_addmm(alpha, beta): def make_ref_addmm(alpha, beta):
return lambda c, a, b: alpha * (a @ b) + beta * c return lambda c, a, b: alpha * (a @ b) + beta * c