mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix addmm
This commit is contained in:
@@ -119,7 +119,6 @@ class MatMul {
|
|||||||
uint64_t b_rows,
|
uint64_t b_rows,
|
||||||
uint64_t b_cols,
|
uint64_t b_cols,
|
||||||
int64_t ldb,
|
int64_t ldb,
|
||||||
bool c_transposed,
|
|
||||||
int64_t ldc,
|
int64_t ldc,
|
||||||
int32_t batch_count,
|
int32_t batch_count,
|
||||||
int64_t a_batch_stride,
|
int64_t a_batch_stride,
|
||||||
@@ -141,7 +140,7 @@ class MatMul {
|
|||||||
b_batch_stride) {
|
b_batch_stride) {
|
||||||
auto type = dtype_to_cuda_type(dtype);
|
auto type = dtype_to_cuda_type(dtype);
|
||||||
c_desc_ = create_matrix_layout(
|
c_desc_ = create_matrix_layout(
|
||||||
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
|
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
~MatMul() {
|
~MatMul() {
|
||||||
@@ -404,9 +403,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 3);
|
assert(inputs.size() == 3);
|
||||||
auto& a_pre = inputs[0];
|
auto& a_pre = inputs[0];
|
||||||
auto& b_pre = inputs[1];
|
auto& b_pre = inputs[1];
|
||||||
auto& c_pre = inputs[2];
|
auto c = inputs[2];
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Init checks and prep
|
// Init checks and prep
|
||||||
@@ -419,7 +416,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// the arrays
|
// the arrays
|
||||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||||
auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
|
|
||||||
|
int64_t ldc;
|
||||||
|
{
|
||||||
|
auto stx = c.strides()[c.ndim() - 2];
|
||||||
|
auto sty = c.strides()[c.ndim() - 1];
|
||||||
|
if (sty == 1 && stx == c.shape(-1)) {
|
||||||
|
ldc = stx;
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
} else if (sty == 1 && stx == 0) {
|
||||||
|
ldc = 0;
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
} else {
|
||||||
|
// Copy C into out and set C to out
|
||||||
|
ldc = c.shape(-1);
|
||||||
|
copy_gpu(c, out, CopyType::General, s);
|
||||||
|
c = out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Check and collapse batch dimensions
|
// Check and collapse batch dimensions
|
||||||
@@ -457,7 +471,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
K,
|
K,
|
||||||
N,
|
N,
|
||||||
ldb,
|
ldb,
|
||||||
c_transposed,
|
|
||||||
ldc,
|
ldc,
|
||||||
batch_shape.back(),
|
batch_shape.back(),
|
||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
|
|||||||
@@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
|
# Transposed c
|
||||||
|
a = mx.ones((10, 5)).T
|
||||||
|
b = mx.ones((5, 5))
|
||||||
|
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
|
||||||
|
expected = 1.5 * a + 0.5 * (b @ a)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
# Broadcast c
|
||||||
|
a = mx.ones((5, 5))
|
||||||
|
b = mx.ones((5, 5))
|
||||||
|
c = mx.ones((1, 5))
|
||||||
|
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
|
||||||
|
expected = 1.5 * c + 0.5 * (a @ b)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
def test_addmm_grad(self):
|
def test_addmm_grad(self):
|
||||||
def make_ref_addmm(alpha, beta):
|
def make_ref_addmm(alpha, beta):
|
||||||
return lambda c, a, b: alpha * (a @ b) + beta * c
|
return lambda c, a, b: alpha * (a @ b) + beta * c
|
||||||
|
|||||||
Reference in New Issue
Block a user