mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
d6977f2a57
...
v0.29.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee18e1cbf0 | ||
|
|
af120c2bc0 | ||
|
|
6a3acf2301 |
@@ -230,15 +230,20 @@ void CublasGemm::set_out(
|
|||||||
batch_stride);
|
batch_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasGemm::set_bias(void* bias) {
|
void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
|
||||||
|
encoder.set_input_array(bias);
|
||||||
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
matmul_desc_,
|
matmul_desc_,
|
||||||
CUBLASLT_MATMUL_DESC_EPILOGUE,
|
CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||||
&epilogue,
|
&epilogue,
|
||||||
sizeof(epilogue)));
|
sizeof(epilogue)));
|
||||||
|
auto* bias_ptr = bias.data<void>();
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
|
matmul_desc_,
|
||||||
|
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
||||||
|
&bias_ptr,
|
||||||
|
sizeof(bias_ptr)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasGemm::run(
|
void CublasGemm::run(
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class CublasGemm {
|
|||||||
int32_t batch_count,
|
int32_t batch_count,
|
||||||
int64_t batch_stride);
|
int64_t batch_stride);
|
||||||
|
|
||||||
void set_bias(void* bias);
|
void set_bias(cu::CommandEncoder& encoder, const array& bias);
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ void gemm_and_bias(
|
|||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
void* bias = nullptr,
|
const std::optional<array>& bias = std::nullopt,
|
||||||
float alpha = 1.0f) {
|
float alpha = 1.0f) {
|
||||||
// Check and collapse batch dimensions
|
// Check and collapse batch dimensions
|
||||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||||
@@ -93,7 +93,7 @@ void gemm_and_bias(
|
|||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back());
|
b_batch_strides.back());
|
||||||
if (bias) {
|
if (bias) {
|
||||||
gemm.set_bias(bias);
|
gemm.set_bias(encoder, *bias);
|
||||||
}
|
}
|
||||||
gemm.run(
|
gemm.run(
|
||||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
|
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
|
||||||
@@ -171,7 +171,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out,
|
out,
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
c.data<void>(),
|
c,
|
||||||
alpha_);
|
alpha_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ file(
|
|||||||
"${CMAKE_CURRENT_BINARY_DIR}/nccl.h")
|
"${CMAKE_CURRENT_BINARY_DIR}/nccl.h")
|
||||||
|
|
||||||
add_library(nccl SHARED nccl_stubs.cpp)
|
add_library(nccl SHARED nccl_stubs.cpp)
|
||||||
|
set_target_properties(nccl PROPERTIES SOVERSION 2)
|
||||||
find_package(CUDAToolkit REQUIRED)
|
find_package(CUDAToolkit REQUIRED)
|
||||||
target_include_directories(nccl PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
target_include_directories(nccl PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||||
target_include_directories(nccl PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
target_include_directories(nccl PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#define MLX_VERSION_MAJOR 0
|
#define MLX_VERSION_MAJOR 0
|
||||||
#define MLX_VERSION_MINOR 29
|
#define MLX_VERSION_MINOR 29
|
||||||
#define MLX_VERSION_PATCH 0
|
#define MLX_VERSION_PATCH 1
|
||||||
#define MLX_VERSION_NUMERIC \
|
#define MLX_VERSION_NUMERIC \
|
||||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||||
|
|
||||||
|
|||||||
@@ -702,7 +702,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
b = mx.ones((5, 5))
|
b = mx.ones((5, 5))
|
||||||
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
|
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
|
||||||
expected = beta * a + alpha * (b @ a)
|
expected = beta * a + alpha * (b @ a)
|
||||||
self.assertTrue(mx.allclose(expected, out, atol=1e-5))
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
# Broadcast c
|
# Broadcast c
|
||||||
a = mx.ones((5, 5))
|
a = mx.ones((5, 5))
|
||||||
@@ -710,7 +710,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
c = mx.ones((1, 5))
|
c = mx.ones((1, 5))
|
||||||
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
|
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
|
||||||
expected = beta * c + alpha * (a @ b)
|
expected = beta * c + alpha * (a @ b)
|
||||||
self.assertTrue(mx.allclose(expected, out, atol=1e-5))
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
def test_addmm_grad(self):
|
def test_addmm_grad(self):
|
||||||
def make_ref_addmm(alpha, beta):
|
def make_ref_addmm(alpha, beta):
|
||||||
|
|||||||
Reference in New Issue
Block a user