Compare commits

..

1 Commits

Author SHA1 Message Date
Nripesh Niketan
ff71ede14d
Merge cc4de6a607 into 76831ed83d 2025-06-20 10:27:09 +12:00
7 changed files with 25 additions and 69 deletions

View File

@ -3,7 +3,6 @@
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h"
#include <cuda_runtime.h>
#include <fmt/format.h>
@ -15,11 +14,9 @@ namespace mlx::core {
namespace cu {
constexpr int page_size = 16384;
CudaAllocator::CudaAllocator()
: buffer_cache_(
page_size,
getpagesize(),
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) {
cuda_free(buf->data);
@ -34,14 +31,7 @@ CudaAllocator::CudaAllocator()
Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size < page_size) {
size = next_power_of_2(size);
} else {
size = page_size * ((size + page_size - 1) / page_size);
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size,

View File

@ -24,6 +24,7 @@ void copy_gpu_inplace(
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
return;

View File

@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc);
@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * IdxT(c_strides[i]);
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
c_loc += dim_idx * c_strides[i];
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT b_loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc);
@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
IdxT c_loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * IdxT(c_strides[i]);
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
c_loc += dim_idx * c_strides[i];
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc, c_loc);

View File

@ -162,15 +162,11 @@ class MatMul {
}
}
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
array workspace(
allocator::malloc(heuristic_.workspaceSize),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
}
array workspace(
allocator::malloc(heuristic_.workspaceSize),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
encoder.launch_kernel([&](cudaStream_t stream) {
CHECK_CUBLAS_ERROR(cublasLtMatmul(
@ -187,8 +183,8 @@ class MatMul {
out,
out_desc_,
&heuristic_.algo,
workspace_ptr,
heuristic_.workspaceSize,
workspace.data<void>(),
workspace.nbytes(),
stream));
});
}
@ -362,18 +358,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(),
b_batch_strides.back());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
return;
}
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < nbatch; ++i) {
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
@ -457,28 +444,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
b_batch_strides.back(),
c_batch_strides.back());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(
encoder,
out.data<int8_t>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
alpha_,
beta_);
return;
}
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < nbatch; ++i) {
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,

View File

@ -79,6 +79,9 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array out = out_;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (axis < 0) {
axis += in.ndim();
}
@ -103,8 +106,6 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.flags());
}
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {

View File

@ -413,7 +413,7 @@ class Module(dict):
f'Module does not have sub-module named "{k}".'
)
elif isinstance(modules, list):
for i in range(len(modules)):
for i in range(len(dst)):
current_value = dst[i]
new_value = modules[i]
if self.is_module(current_value) and self.is_module(new_value):

View File

@ -259,11 +259,6 @@ class TestBase(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]})
# Allow updating a strict subset
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
self.assertEqual(m.layers[1].weight.shape, (4, 3))
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):