Compare commits

..

3 Commits

Author SHA1 Message Date
acsweet
323cc645ab
Merge 992eac905a into 5adf185f86 2025-06-21 11:11:23 +02:00
Angelos Katharopoulos
5adf185f86
Fix update_modules() when providing a subset (#2308) 2025-06-20 17:19:46 -07:00
Awni Hannun
c9a9180584
Cuda perf tuning (#2307)
* perf tuning

* fix adding inputs arrays in matmul / srot

* format

* fix
2025-06-20 14:50:57 -07:00
7 changed files with 69 additions and 25 deletions

View File

@ -3,6 +3,7 @@
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <fmt/format.h> #include <fmt/format.h>
@ -14,9 +15,11 @@ namespace mlx::core {
namespace cu { namespace cu {
constexpr int page_size = 16384;
CudaAllocator::CudaAllocator() CudaAllocator::CudaAllocator()
: buffer_cache_( : buffer_cache_(
getpagesize(), page_size,
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { [this](CudaBuffer* buf) {
cuda_free(buf->data); cuda_free(buf->data);
@ -31,7 +34,14 @@ CudaAllocator::CudaAllocator()
Buffer CudaAllocator::malloc(size_t size) { Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache. // Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_); std::unique_lock lock(mutex_);
if (size < page_size) {
size = next_power_of_2(size);
} else {
size = page_size * ((size + page_size - 1) / page_size);
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) { if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure or are over the maximum cache size,

View File

@ -24,7 +24,6 @@ void copy_gpu_inplace(
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) { if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out); copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
return; return;

View File

@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
#pragma unroll #pragma unroll
for (int i = NDIM - 1; i >= 0; --i) { for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc); return cuda::std::make_tuple(a_loc, b_loc);
@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
#pragma unroll #pragma unroll
for (int i = NDIM - 1; i >= 0; --i) { for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * c_strides[i]; c_loc += dim_idx * IdxT(c_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc, c_loc); return cuda::std::make_tuple(a_loc, b_loc, c_loc);
@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT b_loc = 0; IdxT b_loc = 0;
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc); return cuda::std::make_tuple(a_loc, b_loc);
@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
IdxT c_loc = 0; IdxT c_loc = 0;
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * c_strides[i]; c_loc += dim_idx * IdxT(c_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc, c_loc); return cuda::std::make_tuple(a_loc, b_loc, c_loc);

View File

@ -162,11 +162,15 @@ class MatMul {
} }
} }
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
array workspace( array workspace(
allocator::malloc(heuristic_.workspaceSize), allocator::malloc(heuristic_.workspaceSize),
{static_cast<int>(heuristic_.workspaceSize)}, {static_cast<int>(heuristic_.workspaceSize)},
int8); int8);
encoder.add_temporary(workspace); encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
}
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
CHECK_CUBLAS_ERROR(cublasLtMatmul( CHECK_CUBLAS_ERROR(cublasLtMatmul(
@ -183,8 +187,8 @@ class MatMul {
out, out,
out_desc_, out_desc_,
&heuristic_.algo, &heuristic_.algo,
workspace.data<void>(), workspace_ptr,
workspace.nbytes(), heuristic_.workspaceSize,
stream)); stream));
}); });
} }
@ -358,9 +362,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(), a_batch_strides.back(),
b_batch_strides.back()); b_batch_strides.back());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
return;
}
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) { for (size_t i = 0; i < nbatch; ++i) {
matmul.run( matmul.run(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
@ -444,10 +457,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
b_batch_strides.back(), b_batch_strides.back(),
c_batch_strides.back()); c_batch_strides.back());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(
encoder,
out.data<int8_t>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
alpha_,
beta_);
return;
}
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) { for (size_t i = 0; i < nbatch; ++i) {
matmul.run( matmul.run(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,

View File

@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array out = out_; array out = out_;
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (axis < 0) { if (axis < 0) {
axis += in.ndim(); axis += in.ndim();
} }
@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.flags()); in.flags());
} }
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
if constexpr (!std::is_same_v<CTYPE, complex64_t>) { if constexpr (!std::is_same_v<CTYPE, complex64_t>) {

View File

@ -413,7 +413,7 @@ class Module(dict):
f'Module does not have sub-module named "{k}".' f'Module does not have sub-module named "{k}".'
) )
elif isinstance(modules, list): elif isinstance(modules, list):
for i in range(len(dst)): for i in range(len(modules)):
current_value = dst[i] current_value = dst[i]
new_value = modules[i] new_value = modules[i]
if self.is_module(current_value) and self.is_module(new_value): if self.is_module(current_value) and self.is_module(new_value):

View File

@ -259,6 +259,11 @@ class TestBase(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]}) m = m.update_modules({"list": ["hi"]})
# Allow updating a strict subset
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
self.assertEqual(m.layers[1].weight.shape, (4, 3))
class TestLayers(mlx_tests.MLXTestCase): class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):