[CUDA] Fix alpha not respected when using bias epilogue (#2578)

This commit is contained in:
Cheng
2025-09-10 09:08:01 +09:00
committed by GitHub
parent dde3682b69
commit 44cc5da4bc
6 changed files with 146 additions and 125 deletions

View File

@@ -248,11 +248,19 @@ void CublasGemm::run(
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides) { const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) { if (batch_count / batch_shape.back() > 1) {
run_batched( run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); encoder,
out,
a,
b,
batch_shape,
a_batch_strides,
b_batch_strides,
alpha);
return; return;
} }
@@ -260,7 +268,13 @@ void CublasGemm::run(
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr); execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
nullptr,
alpha);
} }
void CublasGemm::run( void CublasGemm::run(

View File

@@ -64,7 +64,8 @@ class CublasGemm {
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides); const Strides& b_batch_strides,
float alpha = 1.0f);
void run( void run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
@@ -87,7 +88,8 @@ class CublasGemm {
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides); const Strides& b_batch_strides,
float alpha);
void run_batched( void run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,

View File

@@ -13,7 +13,8 @@ void CublasGemm::run_batched(
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides) { const Strides& b_batch_strides,
float alpha) {
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
@@ -27,7 +28,8 @@ void CublasGemm::run_batched(
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc, a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc, b.data<int8_t>() + b.itemsize() * b_it.loc,
nullptr); nullptr,
alpha);
a_it.step(); a_it.step();
b_it.step(); b_it.step();
} }

View File

@@ -154,7 +154,8 @@ void CublasGemm::run_batched(
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides) { const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count); set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count); set_pointer_mode(b_desc_, batch_count);
@@ -226,7 +227,8 @@ void CublasGemm::run_batched(
reinterpret_cast<void*>(out_pointers), reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers), reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers), reinterpret_cast<void*>(b_pointers),
nullptr); nullptr,
alpha);
} }
void CublasGemm::run_batched( void CublasGemm::run_batched(

View File

@@ -41,7 +41,8 @@ void gemm_and_bias(
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
void* bias = nullptr) { void* bias = nullptr,
float alpha = 1.0f) {
// Check and collapse batch dimensions // Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
@@ -94,7 +95,8 @@ void gemm_and_bias(
if (bias) { if (bias) {
gemm.set_bias(bias); gemm.set_bias(bias);
} }
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); gemm.run(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
} }
} // namespace } // namespace
@@ -169,7 +171,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
out, out,
a, a,
b, b,
c.data<void>()); c.data<void>(),
alpha_);
return; return;
} }

View File

@@ -594,124 +594,123 @@ class TestBlas(mlx_tests.MLXTestCase):
np.random.seed(0) np.random.seed(0)
# Batched matmul # Batched matmul
alpha = 0.5 alpha = 0.5
beta = 2.0 for beta in (1.0, 2.0):
# c must broadcast to the output shape
with self.assertRaises(ValueError):
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
# c must broadcast to the output shape # Regular batched case
with self.assertRaises(ValueError): a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2))) b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
# Regular batched case a_mlx = mx.array(a_npy)
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) b_mlx = mx.array(b_npy)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
a_mlx = mx.array(a_npy) for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
b_mlx = mx.array(b_npy) c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)): d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
c_npy = np.ones(c_shape).astype(np.float32) d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) # Batched and transposed matmul
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_mlx = mx.array(b_npy)
# Batched and transposed matmul for c_shape in ((1,), (32, 1, 128), (1, 128)):
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) c_npy = np.ones(c_shape).astype(np.float32)
b_mlx = mx.array(b_npy) c_mlx = mx.array(c_npy)
for c_shape in ((1,), (32, 1, 128), (1, 128)): b_np_t = np.transpose(b_npy, (0, 2, 1))
c_npy = np.ones(c_shape).astype(np.float32) b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
c_mlx = mx.array(c_npy)
b_np_t = np.transpose(b_npy, (0, 2, 1)) d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
b_mx_t = mx.transpose(b_mlx, (0, 2, 1)) d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Batched matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) a_mlx = mx.array(a_npy)
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) b_mlx = mx.array(b_npy)
# Batched matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
a_mlx = mx.array(a_npy) for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
b_mlx = mx.array(b_npy) c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)): d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
c_npy = np.ones(c_shape).astype(np.float32) d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) for c_shape in ((1,), (128,), (32, 128)):
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) c_npy = np.ones(c_shape).astype(np.float32)
# Matmul with vector c_mlx = mx.array(c_npy)
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (128,), (32, 128)): d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
c_npy = np.ones(c_shape).astype(np.float32) d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) # Matmul with vector
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
# Matmul with vector for c_shape in ((1,), (32, 128)):
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) c_npy = np.ones(c_shape).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) c_mlx = mx.array(c_npy)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (32, 128)): d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
c_npy = np.ones(c_shape).astype(np.float32) d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) # Split K specializtion
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
# Split K specializtion a_mlx = mx.array(a_npy)
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32) b_mlx = mx.array(b_npy)
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
a_mlx = mx.array(a_npy) for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
b_mlx = mx.array(b_npy) c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)): d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
c_npy = np.ones(c_shape).astype(np.float32) d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) # Transposed c
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
expected = beta * a + alpha * (b @ a)
self.assertTrue(mx.allclose(expected, out))
# Transposed c # Broadcast c
a = mx.ones((10, 5)).T a = mx.ones((5, 5))
b = mx.ones((5, 5)) b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5) c = mx.ones((1, 5))
expected = 1.5 * a + 0.5 * (b @ a) out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
self.assertTrue(mx.allclose(expected, out)) expected = beta * c + alpha * (a @ b)
self.assertTrue(mx.allclose(expected, out))
# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
expected = 1.5 * c + 0.5 * (a @ b)
self.assertTrue(mx.allclose(expected, out))
def test_addmm_grad(self): def test_addmm_grad(self):
def make_ref_addmm(alpha, beta): def make_ref_addmm(alpha, beta):
@@ -724,33 +723,32 @@ class TestBlas(mlx_tests.MLXTestCase):
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47)) shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))
alpha = 2.0 alpha = 2.0
beta = 0.5 for beta in (1.0, 0.5):
f_test = make_addmm(alpha, beta)
f_ref = make_ref_addmm(alpha, beta)
f_test = make_addmm(alpha, beta) for B, M, N, K in shapes:
f_ref = make_ref_addmm(alpha, beta) cotan = mx.ones((B, M, N))
c = mx.random.normal((B, M, N))
a = mx.random.normal((B, M, K))
b = mx.random.normal((B, K, N))
for B, M, N, K in shapes: out_ref, dout_ref = mx.vjp(
cotan = mx.ones((B, M, N)) f_ref,
c = mx.random.normal((B, M, N)) [c, a, b],
a = mx.random.normal((B, M, K)) [cotan],
b = mx.random.normal((B, K, N)) )
out_test, dout_test = mx.vjp(
f_test,
[c, a, b],
[cotan],
)
out_ref, dout_ref = mx.vjp( self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
f_ref,
[c, a, b],
[cotan],
)
out_test, dout_test = mx.vjp(
f_test,
[c, a, b],
[cotan],
)
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item()) for r, t in zip(dout_ref, dout_test):
self.assertEqual(r.shape, t.shape)
for r, t in zip(dout_ref, dout_test): self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
def test_empty_matmul(self): def test_empty_matmul(self):
a = mx.array([[], []]).T a = mx.array([[], []]).T