mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-20 20:18:15 +08:00
[CUDA] Set bias as input when using bias epilogue (#2584)
This commit is contained in:
@@ -230,15 +230,20 @@ void CublasGemm::set_out(
|
|||||||
batch_stride);
|
batch_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasGemm::set_bias(void* bias) {
|
void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
|
||||||
|
encoder.set_input_array(bias);
|
||||||
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
matmul_desc_,
|
matmul_desc_,
|
||||||
CUBLASLT_MATMUL_DESC_EPILOGUE,
|
CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||||
&epilogue,
|
&epilogue,
|
||||||
sizeof(epilogue)));
|
sizeof(epilogue)));
|
||||||
|
auto* bias_ptr = bias.data<void>();
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
|
matmul_desc_,
|
||||||
|
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
||||||
|
&bias_ptr,
|
||||||
|
sizeof(bias_ptr)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasGemm::run(
|
void CublasGemm::run(
|
||||||
|
@@ -55,7 +55,7 @@ class CublasGemm {
|
|||||||
int32_t batch_count,
|
int32_t batch_count,
|
||||||
int64_t batch_stride);
|
int64_t batch_stride);
|
||||||
|
|
||||||
void set_bias(void* bias);
|
void set_bias(cu::CommandEncoder& encoder, const array& bias);
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
|
@@ -41,7 +41,7 @@ void gemm_and_bias(
|
|||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
void* bias = nullptr,
|
const std::optional<array>& bias = std::nullopt,
|
||||||
float alpha = 1.0f) {
|
float alpha = 1.0f) {
|
||||||
// Check and collapse batch dimensions
|
// Check and collapse batch dimensions
|
||||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||||
@@ -93,7 +93,7 @@ void gemm_and_bias(
|
|||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back());
|
b_batch_strides.back());
|
||||||
if (bias) {
|
if (bias) {
|
||||||
gemm.set_bias(bias);
|
gemm.set_bias(encoder, *bias);
|
||||||
}
|
}
|
||||||
gemm.run(
|
gemm.run(
|
||||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
|
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
|
||||||
@@ -171,7 +171,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out,
|
out,
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
c.data<void>(),
|
c,
|
||||||
alpha_);
|
alpha_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@@ -702,7 +702,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
b = mx.ones((5, 5))
|
b = mx.ones((5, 5))
|
||||||
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
|
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
|
||||||
expected = beta * a + alpha * (b @ a)
|
expected = beta * a + alpha * (b @ a)
|
||||||
self.assertTrue(mx.allclose(expected, out, atol=1e-5))
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
# Broadcast c
|
# Broadcast c
|
||||||
a = mx.ones((5, 5))
|
a = mx.ones((5, 5))
|
||||||
@@ -710,7 +710,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
c = mx.ones((1, 5))
|
c = mx.ones((1, 5))
|
||||||
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
|
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
|
||||||
expected = beta * c + alpha * (a @ b)
|
expected = beta * c + alpha * (a @ b)
|
||||||
self.assertTrue(mx.allclose(expected, out, atol=1e-5))
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
def test_addmm_grad(self):
|
def test_addmm_grad(self):
|
||||||
def make_ref_addmm(alpha, beta):
|
def make_ref_addmm(alpha, beta):
|
||||||
|
Reference in New Issue
Block a user