mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
3 Commits
106568caa1
...
9d86a4d5ba
Author | SHA1 | Date | |
---|---|---|---|
![]() |
9d86a4d5ba | ||
![]() |
5adf185f86 | ||
![]() |
c9a9180584 |
@ -3,6 +3,7 @@
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <fmt/format.h>
|
||||
@ -14,9 +15,11 @@ namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
constexpr int page_size = 16384;
|
||||
|
||||
CudaAllocator::CudaAllocator()
|
||||
: buffer_cache_(
|
||||
getpagesize(),
|
||||
page_size,
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) {
|
||||
cuda_free(buf->data);
|
||||
@ -31,7 +34,14 @@ CudaAllocator::CudaAllocator()
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
// Find available buffer from cache.
|
||||
auto orig_size = size;
|
||||
std::unique_lock lock(mutex_);
|
||||
if (size < page_size) {
|
||||
size = next_power_of_2(size);
|
||||
} else {
|
||||
size = page_size * ((size + page_size - 1) / page_size);
|
||||
}
|
||||
|
||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
if (!buf) {
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
|
@ -24,7 +24,6 @@ void copy_gpu_inplace(
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
||||
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
||||
return;
|
||||
|
@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
c_loc += dim_idx * c_strides[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||
@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||
IdxT b_loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
||||
IdxT c_loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
c_loc += dim_idx * c_strides[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||
|
@ -162,11 +162,15 @@ class MatMul {
|
||||
}
|
||||
}
|
||||
|
||||
array workspace(
|
||||
allocator::malloc(heuristic_.workspaceSize),
|
||||
{static_cast<int>(heuristic_.workspaceSize)},
|
||||
int8);
|
||||
encoder.add_temporary(workspace);
|
||||
void* workspace_ptr = nullptr;
|
||||
if (heuristic_.workspaceSize > 0) {
|
||||
array workspace(
|
||||
allocator::malloc(heuristic_.workspaceSize),
|
||||
{static_cast<int>(heuristic_.workspaceSize)},
|
||||
int8);
|
||||
encoder.add_temporary(workspace);
|
||||
workspace_ptr = workspace.data<void>();
|
||||
}
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||
@ -183,8 +187,8 @@ class MatMul {
|
||||
out,
|
||||
out_desc_,
|
||||
&heuristic_.algo,
|
||||
workspace.data<void>(),
|
||||
workspace.nbytes(),
|
||||
workspace_ptr,
|
||||
heuristic_.workspaceSize,
|
||||
stream));
|
||||
});
|
||||
}
|
||||
@ -358,9 +362,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
auto nbatch = batch_count / batch_shape.back();
|
||||
if (nbatch == 1) {
|
||||
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
matmul.run(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||
@ -444,10 +457,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
b_batch_strides.back(),
|
||||
c_batch_strides.back());
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto nbatch = batch_count / batch_shape.back();
|
||||
if (nbatch == 1) {
|
||||
matmul.run(
|
||||
encoder,
|
||||
out.data<int8_t>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
c.data<int8_t>(),
|
||||
alpha_,
|
||||
beta_);
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
matmul.run(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||
|
@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
array out = out_;
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
if (axis < 0) {
|
||||
axis += in.ndim();
|
||||
}
|
||||
@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
in.flags());
|
||||
}
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||
|
@ -413,7 +413,7 @@ class Module(dict):
|
||||
f'Module does not have sub-module named "{k}".'
|
||||
)
|
||||
elif isinstance(modules, list):
|
||||
for i in range(len(dst)):
|
||||
for i in range(len(modules)):
|
||||
current_value = dst[i]
|
||||
new_value = modules[i]
|
||||
if self.is_module(current_value) and self.is_module(new_value):
|
||||
|
@ -259,6 +259,11 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
m = m.update_modules({"list": ["hi"]})
|
||||
|
||||
# Allow updating a strict subset
|
||||
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
||||
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
|
||||
self.assertEqual(m.layers[1].weight.shape, (4, 3))
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
Loading…
Reference in New Issue
Block a user