Compare commits

...

3 Commits

Author SHA1 Message Date
Awni Hannun
ee18e1cbf0 patch bump (#2588) 2025-09-11 17:10:09 -07:00
Awni Hannun
af120c2bc0 set nccl ABI version (#2587) 2025-09-11 16:55:53 -07:00
Cheng
6a3acf2301 [CUDA] Set bias as input when using bias epilogue (#2584) 2025-09-11 15:31:09 +09:00
6 changed files with 15 additions and 9 deletions

View File

@@ -230,15 +230,20 @@ void CublasGemm::set_out(
batch_stride); batch_stride);
} }
void CublasGemm::set_bias(void* bias) { void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
encoder.set_input_array(bias);
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, matmul_desc_,
CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, &epilogue,
sizeof(epilogue))); sizeof(epilogue)));
auto* bias_ptr = bias.data<void>();
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias))); matmul_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr,
sizeof(bias_ptr)));
} }
void CublasGemm::run( void CublasGemm::run(

View File

@@ -55,7 +55,7 @@ class CublasGemm {
int32_t batch_count, int32_t batch_count,
int64_t batch_stride); int64_t batch_stride);
void set_bias(void* bias); void set_bias(cu::CommandEncoder& encoder, const array& bias);
void run( void run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,

View File

@@ -41,7 +41,7 @@ void gemm_and_bias(
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
void* bias = nullptr, const std::optional<array>& bias = std::nullopt,
float alpha = 1.0f) { float alpha = 1.0f) {
// Check and collapse batch dimensions // Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
@@ -93,7 +93,7 @@ void gemm_and_bias(
a_batch_strides.back(), a_batch_strides.back(),
b_batch_strides.back()); b_batch_strides.back());
if (bias) { if (bias) {
gemm.set_bias(bias); gemm.set_bias(encoder, *bias);
} }
gemm.run( gemm.run(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha); encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
@@ -171,7 +171,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
out, out,
a, a,
b, b,
c.data<void>(), c,
alpha_); alpha_);
return; return;
} }

View File

@@ -8,6 +8,7 @@ file(
"${CMAKE_CURRENT_BINARY_DIR}/nccl.h") "${CMAKE_CURRENT_BINARY_DIR}/nccl.h")
add_library(nccl SHARED nccl_stubs.cpp) add_library(nccl SHARED nccl_stubs.cpp)
set_target_properties(nccl PROPERTIES SOVERSION 2)
find_package(CUDAToolkit REQUIRED) find_package(CUDAToolkit REQUIRED)
target_include_directories(nccl PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) target_include_directories(nccl PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
target_include_directories(nccl PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) target_include_directories(nccl PRIVATE ${CMAKE_CURRENT_BINARY_DIR})

View File

@@ -4,7 +4,7 @@
#define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 29 #define MLX_VERSION_MINOR 29
#define MLX_VERSION_PATCH 0 #define MLX_VERSION_PATCH 1
#define MLX_VERSION_NUMERIC \ #define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -702,7 +702,7 @@ class TestBlas(mlx_tests.MLXTestCase):
b = mx.ones((5, 5)) b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=beta, alpha=alpha) out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
expected = beta * a + alpha * (b @ a) expected = beta * a + alpha * (b @ a)
self.assertTrue(mx.allclose(expected, out, atol=1e-5)) self.assertTrue(mx.allclose(expected, out))
# Broadcast c # Broadcast c
a = mx.ones((5, 5)) a = mx.ones((5, 5))
@@ -710,7 +710,7 @@ class TestBlas(mlx_tests.MLXTestCase):
c = mx.ones((1, 5)) c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=beta, alpha=alpha) out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
expected = beta * c + alpha * (a @ b) expected = beta * c + alpha * (a @ b)
self.assertTrue(mx.allclose(expected, out, atol=1e-5)) self.assertTrue(mx.allclose(expected, out))
def test_addmm_grad(self): def test_addmm_grad(self):
def make_ref_addmm(alpha, beta): def make_ref_addmm(alpha, beta):