[CUDA] Reduce use of managed memory (#2725)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled

* Use async cuda malloc managed with cuda 13

* add pool threshold

* refactor for regular cuda malloc

* load eval gpu for cuda

* remove use of cuda pool, use cuda free async

* fix

* fix

* fix

* fix

* fix + comment
This commit is contained in:
Awni Hannun
2025-11-05 16:05:23 -08:00
committed by GitHub
parent 27778156dc
commit df58b4133a
79 changed files with 795 additions and 515 deletions

View File

@@ -26,10 +26,6 @@ runs:
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: pip install -e ".[dev]" -v
- name: Check if build actually worked
shell: bash
run: python -c "import mlx.core"
- name: Run Python tests - CPU
if: inputs.run-tests == 'true'
shell: bash

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_;
public:
Buffer(void* ptr) : ptr_(ptr) {};
explicit Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer
void* raw_ptr();

View File

@@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
cpy.array_desc_->offset = other.array_desc_->offset;
return cpy;
}
@@ -141,7 +141,7 @@ bool array::is_tracer() const {
void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->offset = 0;
array_desc_->data_size = size();
array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true;
@@ -156,7 +156,7 @@ void array::set_data(
Flags flags,
Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->offset = 0;
array_desc_->data_size = data_size;
array_desc_->strides = std::move(strides);
array_desc_->flags = flags;
@@ -172,9 +172,8 @@ void array::copy_shared_buffer(
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
array_desc_->offset =
sizeof(char) * itemsize() * offset + other.array_desc_->offset;
}
void array::copy_shared_buffer(const array& other) {

View File

@@ -354,15 +354,23 @@ class array {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
// Return a raw pointer to the arrays data. This function may do a copy if
// the underlying buffer is not accessible on the CPU. When accessing the
// data for GPU kernels, be sure to use the correct method / function for the
// given backend to access the GPU pointer.
template <typename T>
T* data() {
return static_cast<T*>(array_desc_->data_ptr);
return reinterpret_cast<T*>(
(static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));
}
template <typename T>
const T* data() const {
return static_cast<T*>(array_desc_->data_ptr);
return const_cast<array&>(*this).data<T>();
}
int64_t offset() const {
return array_desc_->offset;
}
enum Status {
@@ -466,8 +474,8 @@ class array {
// can share the underlying data buffer.
std::shared_ptr<Data> data;
// Properly offset data pointer
void* data_ptr{nullptr};
// Offset from beginning of data pointer
int64_t offset{0};
// The size in elements of the data buffer the array accesses
size_t data_size;

View File

@@ -38,20 +38,20 @@ inline void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt) {
BinaryOpType bopt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b_donatable) {
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc(b.data_size() * out.itemsize()),
mallocfn(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(a);
} else {
out.set_data(
allocator::malloc(a.data_size() * out.itemsize()),
mallocfn(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc(a.data_size() * out.itemsize()),
mallocfn(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
break;
}

View File

@@ -6,7 +6,7 @@ namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
Strides strides(out.ndim(), 0);

View File

@@ -114,7 +114,9 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant,
bool contiguous) {
bool contiguous,
const std::function<allocator::Buffer(size_t)>&
mallocfn /* = allocator::malloc */) {
if (contiguous) {
int o = 0;
Strides strides;
@@ -140,7 +142,7 @@ void compiled_allocate_outputs(
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc(data_size * outputs[o].itemsize()),
mallocfn(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
@@ -163,7 +165,7 @@ void compiled_allocate_outputs(
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
outputs[o].set_data(mallocfn(outputs[o].nbytes()));
}
}
}

View File

@@ -58,7 +58,9 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant,
bool contiguous);
bool contiguous,
const std::function<allocator::Buffer(size_t)>& mallocfn =
allocator::malloc);
// Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(

View File

@@ -22,7 +22,11 @@ enum class CopyType {
GeneralGeneral
};
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
inline bool set_copy_output_data(
const array& in,
array& out,
CopyType ctype,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
@@ -31,14 +35,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
return true;
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
mallocfn(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
return false;
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
return false;
}
}

View File

@@ -45,7 +45,7 @@ void slice(
const Shape& start_indices,
const Shape& strides) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}

View File

@@ -46,7 +46,8 @@ inline void set_ternary_op_output_data(
const array& b,
const array& c,
array& out,
TernaryOpType topt) {
TernaryOpType topt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
auto maybe_donate = [&out](const array& x) {
if (is_donatable(x, out)) {
out.copy_shared_buffer(x);
@@ -57,13 +58,12 @@ inline void set_ternary_op_output_data(
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
allocator::malloc(out.itemsize() * b.data_size()),
mallocfn(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
@@ -76,7 +76,7 @@ inline void set_ternary_op_output_data(
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
break;
}

View File

@@ -7,19 +7,22 @@
namespace mlx::core {
inline void set_unary_output_data(const array& in, array& out) {
inline void set_unary_output_data(
const array& in,
array& out,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
mallocfn(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
}

View File

@@ -333,7 +333,7 @@ void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
auto& in = inputs[0];
@@ -361,7 +361,7 @@ void DynamicSliceUpdate::eval_cpu(
const std::vector<array>& inputs,
array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
@@ -396,7 +396,7 @@ void DynamicSliceUpdate::eval_cpu(
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}

View File

@@ -32,6 +32,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp

View File

@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/utils.h"
@@ -67,6 +68,7 @@ CudaBuffer* SmallSizePool::malloc() {
next_free_ = next_free_->next;
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size;
b->buf.device = -1;
return &b->buf;
}
@@ -88,14 +90,40 @@ CudaAllocator::CudaAllocator()
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_;
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
int curr;
CHECK_CUDA_ERROR(cudaGetDevice(&curr));
for (int i = 0; i < device_count; ++i) {
CHECK_CUDA_ERROR(cudaSetDevice(i));
cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s);
}
CHECK_CUDA_ERROR(cudaSetDevice(curr));
}
Buffer CudaAllocator::malloc(size_t size) {
void copy_to_managed(CudaBuffer& buf) {
// TODO maybe make this async on a i/o stream to avoid synchronizing the
// device on malloc/and free
void* new_data;
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, buf.size));
buf.device = -1;
CHECK_CUDA_ERROR(cudaMemcpy(new_data, buf.data, buf.size, cudaMemcpyDefault));
CHECK_CUDA_ERROR(cudaFree(buf.data));
buf.data = new_data;
}
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
if (size == 0) {
return Buffer{new CudaBuffer{nullptr, 0, -1}};
}
// Find available buffer from cache.
std::unique_lock lock(mutex_);
if (size <= small_block_size) {
@@ -106,6 +134,11 @@ Buffer CudaAllocator::malloc(size_t size) {
size = page_size * ((size + page_size - 1) / page_size);
}
int device = -1;
if (size > small_block_size && stream != nullptr) {
CHECK_CUDA_ERROR(cudaStreamGetDevice(stream, &device));
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure try to reclaim memory from the cache.
@@ -121,8 +154,13 @@ Buffer CudaAllocator::malloc(size_t size) {
}
lock.unlock();
if (!buf) {
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
buf = new CudaBuffer{nullptr, size, device};
cudaError_t err;
if (device == -1) {
err = cudaMallocManaged(&buf->data, size);
} else {
err = cudaMallocAsync(&buf->data, size, stream);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
@@ -137,14 +175,30 @@ Buffer CudaAllocator::malloc(size_t size) {
if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
// Copy to managed here if the buffer is not on the right device
if (buf->device != device) {
copy_to_managed(*buf);
}
return Buffer{buf};
}
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
return malloc_impl(size, stream);
}
Buffer CudaAllocator::malloc(size_t size) {
return malloc_impl(size, nullptr);
}
void CudaAllocator::free(Buffer buffer) {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return;
}
if (buf->size == 0) {
delete buf;
return;
}
std::unique_lock lock(mutex_);
active_memory_ -= buf->size;
@@ -168,7 +222,11 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf);
} else {
cudaFree(buf->data);
if (buf->device >= 0) {
cudaFreeAsync(buf->data, free_streams_[buf->device]);
} else {
cudaFree(buf->data);
}
delete buf;
}
}
@@ -219,6 +277,16 @@ CudaAllocator& allocator() {
return *allocator_;
}
Buffer malloc_async(size_t size, cudaStream_t stream) {
auto buffer = allocator().malloc_async(size, stream);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_async] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace cu
namespace allocator {
@@ -231,7 +299,11 @@ void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<cu::CudaBuffer*>(ptr_)->data;
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
if (cbuf.device != -1) {
copy_to_managed(cbuf);
}
return cbuf.data;
}
} // namespace allocator

View File

@@ -4,7 +4,9 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include "mlx/backend/cuda/cuda_utils.h"
#include <cuda_runtime.h>
#include <mutex>
#include <set>
#include <utility>
@@ -17,6 +19,7 @@ using allocator::Buffer;
struct CudaBuffer {
void* data;
size_t size;
int device; // -1 for managed
};
class SmallSizePool {
@@ -45,6 +48,7 @@ class SmallSizePool {
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
Buffer malloc_async(size_t size, cudaStream_t stream);
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
@@ -58,6 +62,7 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache();
private:
Buffer malloc_impl(size_t size, cudaStream_t stream);
void cuda_free(CudaBuffer* buf);
CudaAllocator();
@@ -69,9 +74,12 @@ class CudaAllocator : public allocator::Allocator {
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::vector<cudaStream_t> free_streams_;
SmallSizePool scalar_pool_;
};
CudaAllocator& allocator();
Buffer malloc_async(size_t size, cudaStream_t stream);
} // namespace mlx::core::cu

View File

@@ -41,9 +41,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(stream());
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
encoder.set_output_array(out);
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
@@ -58,7 +57,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks,
block_dims,
0,
out.data<OutType>(),
gpu_ptr<OutType>(out),
out.data_size(),
static_cast<CTYPE>(start_),
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));

View File

@@ -140,8 +140,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgReduce::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
// Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_);
@@ -154,7 +156,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
int32_t ndim = shape.size();
// ArgReduce.
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
@@ -172,8 +173,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks,
block_dim(),
0,
in.data<T>(),
out.data<uint32_t>(),
gpu_ptr<T>(in),
gpu_ptr<uint32_t>(out),
out.size(),
const_param(shape),
const_param(in_strides),

View File

@@ -292,9 +292,9 @@ void binary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
@@ -310,9 +310,9 @@ void binary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
rest,
const_param(shape),
const_param(a_strides),
@@ -339,9 +339,9 @@ void binary_op_gpu_inplace(
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
out.data_size());
});
}
@@ -365,7 +365,11 @@ void binary_op_gpu(
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
binary_op_gpu_inplace<Op>(inputs, out, op, s);
}

View File

@@ -245,14 +245,18 @@ void binary_two_op_gpu_inplace(
auto& out_a = outputs[0];
auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
if (out_a.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
@@ -313,10 +317,10 @@ void binary_two_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
@@ -332,10 +336,10 @@ void binary_two_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
rest,
const_param(shape),
const_param(a_strides),
@@ -366,10 +370,10 @@ void binary_two_op_gpu_inplace(
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
out_a.data_size());
});
}

View File

@@ -293,8 +293,13 @@ void Compiled::eval_gpu(
}
}
auto& encoder = cu::get_command_encoder(s);
// Put outputs.
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
compiled_allocate_outputs(
inputs, outputs, is_constant_, contiguous, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
for (auto& x : outputs) {
args.append(x);
}
@@ -324,7 +329,6 @@ void Compiled::eval_gpu(
kernel_name += fmt::format(
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
}
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}

View File

@@ -270,17 +270,16 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
if (out_.size() == 0) {
return;
}
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
array out = out_;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
Dtype dtype = out.dtype();
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Search cache.
ConvCacheKey cache_key{
encoder.device().cuda_device(),

View File

@@ -86,7 +86,7 @@ array unfold_inputs_nd(
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
@@ -118,8 +118,8 @@ array unfold_inputs_nd(
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(unfolded),
filter_size,
out_pixels,
params);

View File

@@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd(
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
@@ -121,8 +121,8 @@ array grouped_unfold_transpose_inputs_nd(
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(unfolded),
filter_size,
out_pixels,
params);

View File

@@ -5,6 +5,22 @@
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
auto& encoder = cu::get_command_encoder(s);
bool donated = set_copy_output_data(in, out, ctype, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
@@ -87,11 +103,31 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
encoder.set_input_array(in);
encoder.set_output_array(out);
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
}
void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace mlx::core

View File

@@ -77,8 +77,8 @@ void copy_contiguous(
num_blocks,
block_dims,
0,
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,
gpu_ptr<InType>(in) + in_offset,
gpu_ptr<OutType>(out) + out_offset,
out.data_size());
});
});

View File

@@ -106,8 +106,8 @@ void copy_general(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
size_t data_size = 1;
for (auto& s : shape)

View File

@@ -69,8 +69,8 @@ void copy_general_dynamic(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
@@ -90,8 +90,8 @@ void copy_general_dynamic(
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in),
const_param<dims_constant()>(strides_out),
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
gpu_ptr<int64_t>(dynamic_offset_in),
gpu_ptr<int64_t>(dynamic_offset_out));
});
} else { // ndim >= 4
auto [num_blocks, block_dims] = get_launch_args(out, large());
@@ -107,8 +107,8 @@ void copy_general_dynamic(
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
gpu_ptr<int64_t>(dynamic_offset_in),
gpu_ptr<int64_t>(dynamic_offset_out));
}
});
});

View File

@@ -92,8 +92,8 @@ void copy_general_input(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;

View File

@@ -0,0 +1,82 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_runtime.h>
namespace mlx::core {
// Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
// The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
// Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)>
class CudaHandle {
public:
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
assert(this != &other);
other.handle_ = nullptr;
}
~CudaHandle() {
reset();
}
CudaHandle(const CudaHandle&) = delete;
CudaHandle& operator=(const CudaHandle&) = delete;
CudaHandle& operator=(CudaHandle&& other) {
assert(this != &other);
reset();
std::swap(handle_, other.handle_);
return *this;
}
void reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(Destroy(handle_));
handle_ = nullptr;
}
}
operator Handle() const {
return handle_;
}
protected:
Handle handle_;
};
namespace cu {
class Device;
}; // namespace cu
// Wrappers of CUDA resources.
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
public:
using CudaHandle::CudaHandle;
explicit CudaGraph(cu::Device& device);
void end_capture(cudaStream_t stream);
};
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
public:
void instantiate(cudaGraph_t graph);
};
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
public:
explicit CudaStream(cu::Device& device);
};
} // namespace mlx::core

View File

@@ -132,14 +132,18 @@ bool prepare_cudnn_plan(
void** data_ptrs,
F&& execute) {
int workspace_size = plan.getWorkspaceSize();
array workspace(
workspace_size > 0 ? allocator::malloc(workspace_size)
: allocator::Buffer(nullptr),
{workspace_size},
uint8);
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder.stream()),
{workspace_size},
uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
}
auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setWorkspacePointer(workspace_ptr)
.setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids)
.build();
@@ -151,7 +155,6 @@ bool prepare_cudnn_plan(
return false;
}
encoder.add_temporary(workspace);
return true;
}

View File

@@ -3,6 +3,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/dtype_utils.h"
@@ -23,7 +24,7 @@ class CommandEncoder;
// Return pointer alignment of |x|'s data.
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
uintptr_t address = reinterpret_cast<uintptr_t>(gpu_ptr<void>(x));
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
@@ -56,7 +57,7 @@ inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
// Helpers used by get_data_ptrs to get pointers.
inline void* get_data_ptr(const array& arr) {
return const_cast<void*>(arr.data<void>());
return const_cast<void*>(gpu_ptr<void>(arr));
}
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>

View File

@@ -279,6 +279,7 @@ void CustomKernel::eval_gpu(
std::vector<array>& outputs) {
nvtx3::scoped_range r("CustomKernel::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
std::vector<array> copies;
@@ -288,7 +289,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
}
}
@@ -356,7 +357,6 @@ void CustomKernel::eval_gpu(
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
// Call the kernel
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : checked_inputs) {
encoder.set_input_array(in);
}

View File

@@ -3,6 +3,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h"

View File

@@ -15,8 +15,10 @@ void AllReduce::eval_gpu(
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto set_input_output =
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto set_input_output = [&](const array& in,
array& out) -> std::pair<array, array> {
if (!in.flags().row_contiguous) {
copy_gpu(in, out, CopyType::General, s);
return {out, out};
@@ -24,19 +26,17 @@ void AllReduce::eval_gpu(
out.copy_shared_buffer(in);
return {in, out};
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
return {in, out};
}
};
auto [input, output] = set_input_output(inputs[0], outputs[0]);
auto& encoder = cu::get_command_encoder(stream());
encoder.set_input_array(input);
encoder.set_output_array(output);
auto capture = encoder.capture_context();
auto& s = stream();
switch (reduce_type_) {
case Sum:
@@ -74,7 +74,7 @@ void AllGather::eval_gpu(
};
auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);
@@ -103,7 +103,7 @@ void ReduceScatter::eval_gpu(
};
auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);

View File

@@ -1,6 +1,8 @@
// Copyright © 2025 Apple Inc.
#include "mlx/fence.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
namespace mlx::core {
@@ -20,8 +22,24 @@ void Fence::wait(Stream s, const array&) {
fence->event.wait(fence->count);
}
void Fence::update(Stream s, const array&) {
void Fence::update(Stream s, const array& a, bool cross_device) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
if (cross_device) {
// Move to managed memory if there is a device switch
auto& cbuf =
*static_cast<cu::CudaBuffer*>(const_cast<array&>(a).buffer().ptr());
if (cbuf.device != -1) {
void* new_data;
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
cbuf.device = -1;
auto& encoder = cu::device(s.device).get_command_encoder(s);
encoder.commit();
CHECK_CUDA_ERROR(cudaMemcpyAsync(
new_data, cbuf.data, cbuf.size, cudaMemcpyDefault, encoder.stream()));
CHECK_CUDA_ERROR(cudaFreeAsync(cbuf.data, encoder.stream()));
cbuf.data = new_data;
}
}
fence->count++;
fence->event.signal(s, fence->count);
}

View File

@@ -241,7 +241,7 @@ void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
auto* bias_ptr = bias.data<void>();
auto* bias_ptr = gpu_ptr<void>(bias);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
@@ -278,9 +278,9 @@ void CublasGemm::run(
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
gpu_ptr<void>(out),
gpu_ptr<void>(a),
gpu_ptr<void>(b),
nullptr,
alpha);
}
@@ -321,10 +321,10 @@ void CublasGemm::run(
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c.data<void>(),
gpu_ptr<void>(out),
gpu_ptr<void>(a),
gpu_ptr<void>(b),
gpu_ptr<void>(c),
alpha,
beta);
}
@@ -370,11 +370,11 @@ void CublasGemm::execute(
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace(
allocator::malloc(nbytes),
cu::malloc_async(nbytes, encoder.stream()),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
workspace_ptr = gpu_ptr<void>(workspace);
}
auto capture = encoder.capture_context();

View File

@@ -25,9 +25,10 @@ void CublasGemm::run_batched(
for (size_t i = 0; i < nbatch; ++i) {
execute(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
gpu_ptr<int8_t>(out) +
out.itemsize() * i * batch_shape.back() * M_ * N_,
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
nullptr,
alpha);
a_it.step();
@@ -60,10 +61,11 @@ void CublasGemm::run_batched(
for (size_t i = 0; i < nbatch; ++i) {
execute(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
c.data<int8_t>() + c.itemsize() * c_it.loc,
gpu_ptr<int8_t>(out) +
out.itemsize() * i * batch_shape.back() * M_ * N_,
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
gpu_ptr<int8_t>(c) + c.itemsize() * c_it.loc,
alpha,
beta);
a_it.step();

View File

@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(void*) * 3),
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()),
{batch_count * 3},
uint64);
@@ -183,10 +183,10 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(out),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
@@ -200,10 +200,10 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(out),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
@@ -219,7 +219,7 @@ void CublasGemm::run_batched(
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto a_pointers = gpu_ptr<int8_t*>(pointers);
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
execute(
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()),
{batch_count * 4},
uint64);
@@ -271,11 +271,11 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(c),
gpu_ptr<int8_t>(out),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
@@ -290,11 +290,11 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(c),
gpu_ptr<int8_t>(out),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
@@ -312,7 +312,7 @@ void CublasGemm::run_batched(
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto a_pointers = gpu_ptr<int8_t*>(pointers);
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;

View File

@@ -149,13 +149,13 @@ void gemv(
auto vec_strides = const_param(b_batch_strides);
if (M == 1) {
mat = b.data<DataType>();
vec = a.data<DataType>();
mat = gpu_ptr<DataType>(b);
vec = gpu_ptr<DataType>(a);
rows = N;
std::swap(mat_strides, vec_strides);
} else {
mat = a.data<DataType>();
vec = b.data<DataType>();
mat = gpu_ptr<DataType>(a);
vec = gpu_ptr<DataType>(b);
rows = M;
}
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
@@ -177,7 +177,7 @@ void gemv(
0,
mat,
vec,
out.data<DataType>(),
gpu_ptr<DataType>(out),
rows,
cols);
} else {
@@ -189,7 +189,7 @@ void gemv(
0,
mat,
vec,
out.data<DataType>(),
gpu_ptr<DataType>(out),
rows,
cols,
const_param(batch_shape),

View File

@@ -31,7 +31,7 @@ void append_indices_arg(
int idx_ndim) {
SmallVector<const void*> indices(nidx);
for (int i = 0; i < nidx; ++i) {
indices[i] = inputs[i + 1].data<void>();
indices[i] = gpu_ptr<void>(inputs[i + 1]);
}
args.append(std::move(indices));
SmallVector<int32_t> indices_shape(nidx * idx_ndim);
@@ -59,7 +59,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() > 0);
const auto& src = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
if (out.size() == 0) {
return;
}
@@ -80,7 +82,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_string(idx_dtype),
nidx);
auto& s = stream();
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
std::vector<std::string> kernel_names;
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
@@ -121,7 +122,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_ndim,
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}
@@ -239,7 +239,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
const auto& src = inputs[0];
const auto& idx = inputs[1];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
if (out.size() == 0) {
return;
}
@@ -251,7 +253,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_string(out.dtype()),
dtype_to_string(idx.dtype()));
auto& s = stream();
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
std::vector<std::string> kernel_names;
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
@@ -312,7 +313,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
idx.flags().row_contiguous,
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}

View File

@@ -31,7 +31,7 @@ struct KernelArgs {
}
void append(const array& a) {
append(reinterpret_cast<CUdeviceptr>(a.data<void>()));
append(reinterpret_cast<CUdeviceptr>(gpu_ptr<void>(a)));
}
template <typename T>

View File

@@ -9,6 +9,7 @@
#include <type_traits>
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cuda.h>

View File

@@ -230,9 +230,10 @@ void LayerNorm::eval_gpu(
nvtx3::scoped_range r("LayerNorm::eval_gpu");
auto& s = stream();
auto& out = outputs[0];
auto& encoder = cu::get_command_encoder(s);
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
auto set_output = [&s, &out, &encoder](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
@@ -243,7 +244,7 @@ void LayerNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
x.data_size(),
x.strides(),
x.flags());
@@ -265,7 +266,6 @@ void LayerNorm::eval_gpu(
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(b);
@@ -280,10 +280,10 @@ void LayerNorm::eval_gpu(
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
b.data<DataType>(),
out.data<DataType>(),
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(b),
gpu_ptr<DataType>(out),
eps_,
axis_size,
w_stride,
@@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc(gx.nbytes()));
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
@@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu(
g_in_gw = true;
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
encoder.add_temporary(gw_temp);
}
}
@@ -393,11 +393,11 @@ void LayerNormVJP::eval_gpu(
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);

60
mlx/backend/cuda/load.cpp Normal file
View File

@@ -0,0 +1,60 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <utility>
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/primitives.h"
namespace {
template <const uint8_t scalar_size>
void swap_endianness(uint8_t* data_bytes, size_t N) {
struct Elem {
uint8_t bytes[scalar_size];
};
Elem* data = reinterpret_cast<Elem*>(data_bytes);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < (scalar_size / 2); j++) {
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
}
}
}
} // namespace
namespace mlx::core {
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(stream());
auto size = out.size();
auto nbytes = size * out.itemsize();
out.set_data(cu::malloc_async(nbytes, encoder.stream()));
auto out_ptr = malloc(nbytes);
reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);
if (swap_endianness_) {
switch (out.itemsize()) {
case 2:
swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 4:
swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 8:
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
}
}
CHECK_CUDA_ERROR(cudaMemcpyAsync(
gpu_ptr<void>(out),
out_ptr,
nbytes,
cudaMemcpyDefault,
encoder.stream()));
CHECK_CUDA_ERROR(cudaLaunchHostFunc(encoder.stream(), free, out_ptr));
}
} // namespace mlx::core

View File

@@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
@@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
cu::malloc_async(in.nbytes() / n, encoder.stream()),
in.data_size() / n,
std::move(strides),
flags);
@@ -151,8 +151,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
n_rows,
block_dim(),
0,
in.data<DataType>(),
out.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(out),
axis_size);
});
});

View File

@@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
@@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
c.data_size() == out.shape(-1)) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
gemm_and_bias(
encoder,
M,
@@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) {
ldc = stx;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
} else if (sty == 1 && stx == 0) {
ldc = 0;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
} else {
// Copy C into out and set C to out
ldc = c.shape(-1);

View File

@@ -28,7 +28,6 @@ NO_GPU(FFT)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)

View File

@@ -262,10 +262,10 @@ void affine_quantize(
num_blocks,
block_dims,
0,
w.data<T>(),
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
gpu_ptr<T>(w),
gpu_ptr<uint8_t>(wq),
gpu_ptr<T>(scales),
gpu_ptr<T>(biases),
w.size());
});
});
@@ -318,10 +318,10 @@ void affine_dequantize(
num_blocks,
block_dims,
0,
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.data<T>(),
gpu_ptr<uint8_t>(wq),
gpu_ptr<T>(scales),
gpu_ptr<T>(biases),
gpu_ptr<T>(w),
w.size());
});
});

View File

@@ -156,9 +156,9 @@ void fp_quantize(
num_blocks,
block_dims,
0,
w.data<T>(),
wq.data<uint8_t>(),
scales.data<uint8_t>(),
gpu_ptr<T>(w),
gpu_ptr<uint8_t>(wq),
gpu_ptr<uint8_t>(scales),
w.size());
} else {
throw std::runtime_error(
@@ -202,9 +202,9 @@ void fp_dequantize(
num_blocks,
block_dims,
0,
wq.data<uint8_t>(),
scales.data<T>(),
w.data<T>(),
gpu_ptr<uint8_t>(wq),
gpu_ptr<uint8_t>(scales),
gpu_ptr<T>(w),
w.size());
} else {
throw std::runtime_error(

View File

@@ -59,7 +59,7 @@ void fast::Quantize::eval_gpu(
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes()));
w.set_data(cu::malloc_async(w.nbytes(), enc.stream()));
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[2], enc, s);
@@ -72,11 +72,11 @@ void fast::Quantize::eval_gpu(
auto& wq = outputs[0];
auto& scales = outputs[1];
wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
wq.set_data(cu::malloc_async(wq.nbytes(), enc.stream()));
scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream()));
if (mode_ == QuantizationMode::Affine) {
auto& biases = outputs[2];
biases.set_data(allocator::malloc(biases.nbytes()));
biases.set_data(cu::malloc_async(biases.nbytes(), enc.stream()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
} else {
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);

View File

@@ -143,7 +143,9 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
uint32_t elems_per_key = out.size() / num_keys;
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
if (out.size() == 0) {
return;
}
@@ -152,8 +154,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
uint32_t half_size = out_per_key / 2;
bool odd = out_per_key % 2;
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(keys);
encoder.set_output_array(out);
dim3 grid_dims{num_keys, half_size + odd};
@@ -171,8 +171,8 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
grid,
block,
0,
keys.data<uint32_t>(),
out.data<uint8_t>(),
gpu_ptr<uint32_t>(keys),
gpu_ptr<uint8_t>(out),
grid_dims,
odd,
bytes_per_key);
@@ -182,8 +182,8 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
grid,
block,
0,
keys.data<uint32_t>(),
out.data<uint8_t>(),
gpu_ptr<uint32_t>(keys),
gpu_ptr<uint8_t>(out),
grid_dims,
odd,
bytes_per_key,

View File

@@ -66,7 +66,7 @@ void all_reduce(
Reduce::ReduceType reduce_type) {
constexpr int N_READS = 8;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
auto get_args = [](size_t size, int N) {
int threads = std::min(512UL, (size + N - 1) / N);
@@ -100,14 +100,15 @@ void all_reduce(
Dtype dt = in.dtype();
// Cub doesn't like const pointers for load (sigh).
void* indata = const_cast<void*>(in.data<void>());
void* indata = const_cast<void*>(gpu_ptr<void>(in));
// Large array so allocate an intermediate and accumulate there
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
encoder.set_input_array(in);
if (blocks > 1) {
array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
encoder.add_temporary(intermediate);
encoder.set_output_array(intermediate);
dispatch_all_types(dt, [&](auto type_tag) {
@@ -122,14 +123,14 @@ void all_reduce(
threads,
0,
static_cast<T*>(indata),
intermediate.data<U>(),
gpu_ptr<U>(intermediate),
block_step,
insize);
});
});
// Set the input for the next step and recalculate the blocks
indata = intermediate.data<void>();
indata = gpu_ptr<void>(intermediate);
dt = intermediate.dtype();
insize = intermediate.size();
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
@@ -149,7 +150,7 @@ void all_reduce(
threads,
0,
static_cast<T*>(indata),
out.data<U>(),
gpu_ptr<U>(out),
block_step,
insize);
});

View File

@@ -250,7 +250,7 @@ void col_reduce_looped(
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
allocate_same_layout(out, in, axes, encoder);
encoder.set_input_array(in);
encoder.set_output_array(out);
@@ -261,7 +261,7 @@ void col_reduce_looped(
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
T* indata = const_cast<T*>(gpu_ptr<T>(in));
constexpr int N_READS = 4;
constexpr int BM = 32;
@@ -276,7 +276,7 @@ void col_reduce_looped(
blocks,
0,
indata,
out.data<U>(),
gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args));
});
});
@@ -293,7 +293,7 @@ void col_reduce_small(
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
allocate_same_layout(out, in, axes, encoder);
encoder.set_input_array(in);
encoder.set_output_array(out);
@@ -312,8 +312,8 @@ void col_reduce_small(
grid,
block,
0,
in.data<T>(),
out.data<U>(),
gpu_ptr<T>(in),
gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args),
out.size());
});

View File

@@ -28,7 +28,7 @@ void init_reduce(
Reduce::ReduceType reduce_type) {
// Allocate if needed
if (out.data_shared_ptr() == nullptr) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
}
encoder.set_output_array(out);
@@ -42,7 +42,7 @@ void init_reduce(
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
grid.x = (grid.x + 1023) / 1024;
encoder.add_kernel_node(
kernel, grid, block, 0, out.data<U>(), out.size());
kernel, grid, block, 0, gpu_ptr<U>(out), out.size());
});
});
}

View File

@@ -5,6 +5,7 @@
#include <numeric>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
@@ -92,9 +93,10 @@ block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
inline void allocate_same_layout(
array& out,
const array& in,
const std::vector<int>& axes) {
const std::vector<int>& axes,
cu::CommandEncoder& encoder) {
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
return;
}
@@ -133,7 +135,7 @@ inline void allocate_same_layout(
fl.col_contiguous = cc;
fl.contiguous = true;
out.set_data(
allocator::malloc(out.nbytes()),
cu::malloc_async(out.nbytes(), encoder.stream()),
data_size,
final_strides,
fl,

View File

@@ -238,7 +238,7 @@ void row_reduce_simple(
const ReductionPlan& plan) {
// Allocate data for the output using in's layout to avoid elem_to_loc in the
// kernel.
allocate_same_layout(out, in, axes);
allocate_same_layout(out, in, axes, encoder);
// TODO: If out.size() < 1024 which will be a common case then write this in
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
@@ -268,10 +268,10 @@ void row_reduce_simple(
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
}
T* indata = const_cast<T*>(in.data<T>());
T* indata = const_cast<T*>(gpu_ptr<T>(in));
int size = plan.shape.back();
encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
kernel, grid, block, 0, indata, gpu_ptr<U>(out), out.size(), size);
});
});
}
@@ -286,7 +286,7 @@ void row_reduce_looped(
cu::RowReduceArgs args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
allocate_same_layout(out, in, axes, encoder);
encoder.set_input_array(in);
encoder.set_output_array(out);
@@ -315,7 +315,7 @@ void row_reduce_looped(
});
encoder.add_kernel_node(
kernel, grid, block, 0, in.data<T>(), out.data<U>(), args);
kernel, grid, block, 0, gpu_ptr<T>(in), gpu_ptr<U>(out), args);
});
});
}

View File

@@ -176,9 +176,10 @@ void RMSNorm::eval_gpu(
nvtx3::scoped_range r("RMSNorm::eval_gpu");
auto& s = stream();
auto& out = outputs[0];
auto& encoder = cu::get_command_encoder(s);
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
auto set_output = [&s, &out, &encoder](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
@@ -189,7 +190,7 @@ void RMSNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
x.data_size(),
x.strides(),
x.flags());
@@ -209,7 +210,6 @@ void RMSNorm::eval_gpu(
int32_t n_rows = x.data_size() / axis_size;
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_output_array(out);
@@ -223,9 +223,9 @@ void RMSNorm::eval_gpu(
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
out.data<DataType>(),
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(out),
eps_,
axis_size,
w_stride);
@@ -274,7 +274,7 @@ void RMSNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc(gx.nbytes()));
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
@@ -292,7 +292,7 @@ void RMSNormVJP::eval_gpu(
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
encoder.add_temporary(gw_temp);
}
}
@@ -318,11 +318,11 @@ void RMSNormVJP::eval_gpu(
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);

View File

@@ -250,6 +250,7 @@ void RoPE::eval_gpu(
nvtx3::scoped_range r("RoPE::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto& in = inputs[0];
auto& offset = inputs[1];
auto& out = outputs[0];
@@ -291,14 +292,14 @@ void RoPE::eval_gpu(
donated = true;
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
}
strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (dispatch_ndim == 3) {
// Handle non-contiguous 3D inputs
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
@@ -319,7 +320,6 @@ void RoPE::eval_gpu(
bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(donated ? out : in);
encoder.set_input_array(offset);
if (with_freqs) {
@@ -340,9 +340,9 @@ void RoPE::eval_gpu(
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
scale_,
std::log2(base_),
mat_size,
@@ -357,10 +357,10 @@ void RoPE::eval_gpu(
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
inputs[2].data<float>(),
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
gpu_ptr<float>(inputs[2]),
scale_,
mat_size,
dims,
@@ -381,10 +381,10 @@ void RoPE::eval_gpu(
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
inputs[2].data<float>(),
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
gpu_ptr<float>(inputs[2]),
scale_,
std::log2(base_),
strides,
@@ -408,9 +408,9 @@ void RoPE::eval_gpu(
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
scale_,
std::log2(base_),
strides,

View File

@@ -513,11 +513,11 @@ void sdpa_vector_1pass_fallback(
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
o.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
gpu_ptr<DataType>(q),
gpu_ptr<DataType>(k),
gpu_ptr<DataType>(v),
gpu_ptr<DataType>(o),
sinks ? gpu_ptr<DataType>(*sinks) : nullptr,
params);
});
});
@@ -565,9 +565,10 @@ void sdpa_vector_2pass_fallback(
array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
sums.set_data(allocator::malloc(sums.nbytes()));
maxs.set_data(allocator::malloc(maxs.nbytes()));
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream()));
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream()));
encoder.add_temporary(intermediate);
encoder.add_temporary(sums);
@@ -601,13 +602,13 @@ void sdpa_vector_2pass_fallback(
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
gpu_ptr<DataType>(q),
gpu_ptr<DataType>(k),
gpu_ptr<DataType>(v),
sinks ? gpu_ptr<DataType>(*sinks) : nullptr,
gpu_ptr<float>(intermediate),
gpu_ptr<float>(sums),
gpu_ptr<float>(maxs),
params);
}
@@ -628,10 +629,10 @@ void sdpa_vector_2pass_fallback(
grid_dim,
block_dim,
0,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
o.data<DataType>(),
gpu_ptr<float>(intermediate),
gpu_ptr<float>(sums),
gpu_ptr<float>(maxs),
gpu_ptr<DataType>(o),
params);
}
});
@@ -787,7 +788,7 @@ void ScaledDotProductAttention::eval_gpu(
};
o.set_data(
allocator::malloc(o.nbytes()),
cu::malloc_async(o.nbytes(), encoder.stream()),
o.size(),
{str_oB, str_oH, str_oL, str_oD},
flags);

View File

@@ -367,13 +367,14 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto in = inputs[0];
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
if (in.flags().contiguous && in.strides()[axis_] != 0) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
in.data_size(),
in.strides(),
in.flags());
@@ -387,7 +388,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
int32_t axis_size = in.shape(axis_);
bool contiguous = in.strides()[axis_] == 1;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
@@ -415,8 +415,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in.data_size() / axis_size,
block_dim,
0,
in.data<T>(),
out.data<U>(),
gpu_ptr<T>(in),
gpu_ptr<U>(out),
axis_size);
} else {
constexpr int BM = WARP_SIZE;
@@ -445,8 +445,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks,
block_dim,
0,
in.data<T>(),
out.data<U>(),
gpu_ptr<T>(in),
gpu_ptr<U>(out),
axis_size,
stride,
stride_blocks);

View File

@@ -23,14 +23,15 @@ void concatenate_gpu(
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
auto concurrent = cu::get_command_encoder(s).concurrent_context();
auto concurrent = encoder.concurrent_context();
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis] * sizes[i];
@@ -80,6 +81,7 @@ array compute_dynamic_offset(
return std::make_tuple(false, std::move(source), std::vector{kernel_name});
});
auto& encoder = cu::get_command_encoder(s);
// Prepare output.
array offset({1}, int64, nullptr, {});
bool donate = indices.is_donatable() &&
@@ -87,10 +89,9 @@ array compute_dynamic_offset(
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc(offset.itemsize()));
offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream()));
}
auto& encoder = cu::get_command_encoder(s);
encoder.add_temporary(offset);
encoder.set_input_array(indices);
encoder.set_output_array(offset);

View File

@@ -109,15 +109,16 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Softmax::eval_gpu");
assert(inputs.size() == 1);
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
auto set_output = [&s, &out, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
x.data_size(),
x.strides(),
x.flags());
@@ -136,7 +137,6 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
@@ -152,8 +152,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
n_rows,
block_dim(),
0,
in.data<DataType>(),
out.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(out),
axis_size);
});
});

View File

@@ -49,11 +49,14 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array trans = swapaxes_in_eval(in, axis, last_dim);
in = contiguous_copy_gpu(trans, s);
encoder.add_temporary(in);
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
out = array(
cu::malloc_async(out.nbytes(), encoder.stream()),
in.shape(),
out.dtype());
encoder.add_temporary(out);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
in.data_size(),
in.strides(),
in.flags());
@@ -70,22 +73,28 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
thrust::make_counting_iterator(0), OffsetTransform{nsort});
if (argsort) {
// Indices in the sorted dimension.
array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
array indices(
cu::malloc_async(out.nbytes(), encoder.stream()),
in.shape(),
out.dtype());
encoder.add_temporary(indices);
// In argsort though we don't need the result of sorted values, the
// API requires us to provide an array to store it.
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
array discard(
cu::malloc_async(in.nbytes(), encoder.stream()),
in.shape(),
in.dtype());
encoder.add_temporary(discard);
size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
nullptr,
size,
in.data<Type>(),
discard.data<Type>(),
indices.data<uint32_t>(),
out.data<uint32_t>(),
gpu_ptr<Type>(in),
gpu_ptr<Type>(discard),
gpu_ptr<uint32_t>(indices),
gpu_ptr<uint32_t>(out),
in.data_size(),
in.data_size() / nsort,
offsets,
@@ -94,7 +103,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
sizeof(Type) * 8,
stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
array temp(
cu::malloc_async(size, encoder.stream()),
{static_cast<int>(size)},
uint8);
encoder.add_temporary(temp);
// Start capturing after allocations
@@ -103,16 +115,16 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
cu::thrust_policy(stream),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(indices.data_size()),
thrust::device_pointer_cast(indices.data<uint32_t>()),
thrust::device_pointer_cast(gpu_ptr<uint32_t>(indices)),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
temp.data<void>(),
gpu_ptr<void>(temp),
size,
in.data<Type>(),
discard.data<Type>(),
indices.data<uint32_t>(),
out.data<uint32_t>(),
gpu_ptr<Type>(in),
gpu_ptr<Type>(discard),
gpu_ptr<uint32_t>(indices),
gpu_ptr<uint32_t>(out),
in.data_size(),
in.data_size() / nsort,
offsets,
@@ -125,8 +137,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
nullptr,
size,
in.data<Type>(),
out.data<Type>(),
gpu_ptr<Type>(in),
gpu_ptr<Type>(out),
in.data_size(),
in.data_size() / nsort,
offsets,
@@ -135,16 +147,19 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
sizeof(Type) * 8,
stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
array temp(
cu::malloc_async(size, encoder.stream()),
{static_cast<int>(size)},
uint8);
encoder.add_temporary(temp);
// Start capturing after allocations
auto capture = encoder.capture_context();
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
temp.data<void>(),
gpu_ptr<void>(temp),
size,
in.data<Type>(),
out.data<Type>(),
gpu_ptr<Type>(in),
gpu_ptr<Type>(out),
in.data_size(),
in.data_size() / nsort,
offsets,

View File

@@ -168,10 +168,10 @@ void ternary_op_gpu_inplace(
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
gpu_ptr<bool>(a),
gpu_ptr<DType>(b),
gpu_ptr<DType>(c),
gpu_ptr<DType>(out),
out.data_size());
});
} else {
@@ -211,10 +211,10 @@ void ternary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
gpu_ptr<bool>(a),
gpu_ptr<DType>(b),
gpu_ptr<DType>(c),
gpu_ptr<DType>(out),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
@@ -231,10 +231,10 @@ void ternary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
gpu_ptr<bool>(a),
gpu_ptr<DType>(b),
gpu_ptr<DType>(c),
gpu_ptr<DType>(out),
rest,
const_param(shape),
const_param(a_strides),
@@ -256,7 +256,10 @@ void ternary_op_gpu(
auto& b = inputs[1];
auto& c = inputs[2];
auto topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
auto& encoder = cu::get_command_encoder(s);
set_ternary_op_output_data(a, b, c, out, topt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
ternary_op_gpu_inplace<Op>(inputs, out, s);
}

View File

@@ -158,8 +158,8 @@ void unary_op_gpu_inplace(
num_blocks,
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(in),
gpu_ptr<OutType>(out),
out.data_size());
} else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
@@ -182,8 +182,8 @@ void unary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(in),
gpu_ptr<OutType>(out),
rest,
const_param(shape),
const_param(strides),
@@ -207,7 +207,10 @@ void unary_op_gpu(
array& out,
const char* op,
const Stream& s) {
set_unary_output_data(inputs[0], out);
auto& encoder = cu::get_command_encoder(s);
set_unary_output_data(inputs[0], out, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
unary_op_gpu_inplace<Op>(inputs, out, op, s);
}

View File

@@ -4,89 +4,12 @@
#pragma once
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/cuda_utils.h"
namespace mlx::core {
namespace cu {
class Device;
}
struct Dtype;
// Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
// The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
// Convert Dtype to CUDA C++ types.
const char* dtype_to_cuda_type(const Dtype& dtype);
// Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)>
class CudaHandle {
public:
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
assert(this != &other);
other.handle_ = nullptr;
}
~CudaHandle() {
reset();
}
CudaHandle(const CudaHandle&) = delete;
CudaHandle& operator=(const CudaHandle&) = delete;
CudaHandle& operator=(CudaHandle&& other) {
assert(this != &other);
reset();
std::swap(handle_, other.handle_);
return *this;
}
void reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(Destroy(handle_));
handle_ = nullptr;
}
}
operator Handle() const {
return handle_;
}
protected:
Handle handle_;
};
// Wrappers of CUDA resources.
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
public:
using CudaHandle::CudaHandle;
explicit CudaGraph(cu::Device& device);
void end_capture(cudaStream_t stream);
};
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
public:
void instantiate(cudaGraph_t graph);
};
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
public:
explicit CudaStream(cu::Device& device);
};
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
@@ -100,4 +23,22 @@ inline uint max_occupancy_block_dim(T kernel) {
return block_dim;
}
template <typename T>
inline T* gpu_ptr(array& arr) {
return reinterpret_cast<T*>(
static_cast<char*>(
static_cast<cu::CudaBuffer*>(arr.buffer().ptr())->data) +
arr.offset());
}
template <typename T>
inline const T* gpu_ptr(const array& arr) {
return gpu_ptr<T>(const_cast<array&>(arr));
}
struct Dtype;
// Convert Dtype to CUDA C++ types.
const char* dtype_to_cuda_type(const Dtype& dtype);
} // namespace mlx::core

View File

@@ -7,18 +7,7 @@
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
@@ -52,25 +41,6 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) {
return arr_copy;
}
void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
int ndim = x.ndim();
if (start_axis < 0) {

View File

@@ -83,7 +83,7 @@ void Depends::eval_gpu(
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("DynamicSlice::eval_gpu");
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
@@ -112,7 +112,7 @@ void DynamicSliceUpdate::eval_gpu(
array& out) {
MLX_PROFILER_RANGE("DynamicSliceUpdate::eval_gpu");
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
@@ -209,7 +209,7 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Slice::eval_gpu");
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
@@ -220,7 +220,7 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}

View File

@@ -10,6 +10,19 @@ namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
@@ -201,4 +214,23 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace mlx::core

View File

@@ -261,10 +261,7 @@ void CommandEncoder::set_input_array(
needs_barrier_ =
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc_->setBuffer(a_buf, base_offset, idx);
enc_->setBuffer(a_buf, a.offset() + offset, idx);
}
void CommandEncoder::set_output_array(
@@ -448,10 +445,8 @@ void Device::end_encoding(int index) {
auto& enc = *stream.encoder;
// Remove temporaries from inputs and outputs
for (auto& t : stream.temporaries) {
if (t.data<void>() != nullptr) {
enc.outputs().erase(t.buffer().ptr());
enc.inputs().erase(t.buffer().ptr());
}
enc.outputs().erase(t.buffer().ptr());
enc.inputs().erase(t.buffer().ptr());
}
// Keep references to the fences we waited on and put them

View File

@@ -31,7 +31,7 @@ struct FenceImpl {
auto p = metal::new_scoped_memory_pool();
static_cast<MTL::SharedEvent*>(fence)->release();
} else {
allocator::free(static_cast<MTL::Buffer*>(fence));
allocator::free(allocator::Buffer{static_cast<MTL::Buffer*>(fence)});
}
}
bool use_fast{false};
@@ -99,7 +99,7 @@ void Fence::wait(Stream stream, const array& x) {
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
}
void Fence::update(Stream stream, const array& x) {
void Fence::update(Stream stream, const array& x, bool cross_device) {
auto& f = *static_cast<FenceImpl*>(fence_.get());
f.count++;
@@ -130,21 +130,23 @@ void Fence::update(Stream stream, const array& x) {
// Launch input visibility kernels
auto& compute_encoder = d.get_command_encoder(idx);
auto kernel = d.get_kernel("input_coherent");
uint32_t nthreads =
(x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / sizeof(uint32_t);
MTL::Size group_dims = MTL::Size(1024, 1, 1);
MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_bytes(nthreads, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
if (cross_device) {
auto kernel = d.get_kernel("input_coherent");
uint32_t nthreads = (x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) /
sizeof(uint32_t);
MTL::Size group_dims = MTL::Size(1024, 1, 1);
MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_bytes(nthreads, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
// Barrier on previous kernels
compute_encoder.barrier();
// Launch value update kernel
kernel = d.get_kernel("fence_update");
auto kernel = d.get_kernel("fence_update");
MTL::Size kernel_dims = MTL::Size(1, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);

View File

@@ -768,9 +768,12 @@ void nd_fft_op(
const Stream& s) {
// Perform ND FFT on GPU as a series of 1D FFTs
auto temp_shape = inverse ? in.shape() : out.shape();
array temp1(temp_shape, complex64, nullptr, {});
array temp2(temp_shape, complex64, nullptr, {});
std::vector<array> temp_arrs = {temp1, temp2};
std::vector<array> temp_arrs;
temp_arrs.emplace_back(temp_shape, complex64, nullptr, std::vector<array>{});
if (axes.size() > 2) {
temp_arrs.emplace_back(
temp_shape, complex64, nullptr, std::vector<array>{});
}
for (int i = axes.size() - 1; i >= 0; i--) {
int reverse_index = axes.size() - i - 1;
// For 5D and above, we don't want to reallocate our two temporary arrays
@@ -781,8 +784,8 @@ void nd_fft_op(
// Mirror np.fft.(i)rfftn and perform a real transform
// only on the final axis.
bool step_real = (real && index == axes.size() - 1);
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[i % 2];
array& out_arr = i == 0 ? out : temp_arrs[1 - i % 2];
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
}

View File

@@ -36,7 +36,7 @@ void Fence::wait(Stream stream, const array&) {
}
}
void Fence::update(Stream stream, const array&) {
void Fence::update(Stream stream, const array&, bool) {
auto& f = *static_cast<FenceImpl*>(fence_.get());
f.count++;
if (stream.device == Device::cpu) {

View File

@@ -300,8 +300,8 @@ class NCCLGroup : public GroupImpl {
using T = typename decltype(type_tag)::type;
auto& encoder = cu::get_command_encoder(stream);
CHECK_NCCL(ncclAllGather(
input.data<T>(),
output.data<T>(),
gpu_ptr<T>(input),
gpu_ptr<T>(output),
input.size(),
dt,
comm_,
@@ -348,8 +348,8 @@ class NCCLGroup : public GroupImpl {
auto& encoder = cu::get_command_encoder(stream);
CHECK_NCCL(ncclAllReduce(
input.data<T>(),
output.data<T>(),
gpu_ptr<T>(input),
gpu_ptr<T>(output),
input.size(),
dt,
op,
@@ -367,8 +367,8 @@ class NCCLGroup : public GroupImpl {
auto& encoder = cu::get_command_encoder(stream);
CHECK_NCCL(ncclReduceScatter(
input.data<T>(),
output.data<T>(),
gpu_ptr<T>(input),
gpu_ptr<T>(output),
output.size(),
dt,
op,

View File

@@ -15,7 +15,7 @@ namespace mlx::core {
* `wait` returns. The array passed to `wait` will not be read until all
* previous calls to `update` have completed.
*
* Note, calls to `update` should always from the same thread or explicitly
* Note, calls to `update` should always be from the same thread or explicitly
* synchronized so that they occur in sequence. Calls to `wait` can be on any
* thread.
*
@@ -29,7 +29,7 @@ class Fence {
Fence() {};
explicit Fence(Stream stream);
void update(Stream stream, const array& x);
void update(Stream stream, const array& x, bool cross_device);
void wait(Stream stream, const array& x);
private:

View File

@@ -13,6 +13,7 @@
#include <windows.h>
#endif // _WIN32
#include "mlx/backend/cuda/cuda.h"
#include "mlx/io/load.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
@@ -226,10 +227,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
throw std::runtime_error("[load] Failed to open " + in_stream->label());
}
auto stream = to_stream(s, Device::cpu);
if (stream.device != Device::cpu) {
throw std::runtime_error("[load] Must run on a CPU stream.");
}
auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu);
////////////////////////////////////////////////////////
// Read header and prepare array details

View File

@@ -4,6 +4,7 @@
#include <memory>
#include <stack>
#include "mlx/backend/cuda/cuda.h"
#include "mlx/io.h"
#include "mlx/io/load.h"
#include "mlx/ops.h"
@@ -113,10 +114,7 @@ SafetensorsLoad load_safetensors(
"[load_safetensors] Failed to open " + in_stream->label());
}
auto stream = to_stream(s, Device::cpu);
if (stream.device != Device::cpu) {
throw std::runtime_error("[load_safetensors] Must run on a CPU stream.");
}
auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu);
uint64_t jsonHeaderLength = 0;
// This is the same limit as in the original Rust Safetensors code.

View File

@@ -62,7 +62,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
// Map of array id that needs fence and stream it's computed on
std::unordered_map<uintptr_t, uint32_t> needs_fence;
std::unordered_map<uintptr_t, std::pair<uint32_t, bool>> needs_fence;
auto synchronizer = array(
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
@@ -114,7 +114,14 @@ array eval_impl(std::vector<array> outputs, bool async) {
"https://github.com/ml-explore/mlx/issues.");
}
if (a.primitive().stream() != in.primitive().stream()) {
needs_fence.emplace(in.id(), in.primitive().stream().index);
bool device_switch =
a.primitive().stream().device != in.primitive().stream().device;
auto [it, inserted] = needs_fence.emplace(
in.id(),
std::make_pair(in.primitive().stream().index, device_switch));
if (!inserted) {
it->second.second |= device_switch;
}
}
}
@@ -190,7 +197,6 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
std::unordered_set<int> open_streams;
while (!tape.empty()) {
auto arr = std::move(tape.back());
tape.pop_back();
@@ -216,7 +222,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
// Use fence to wait within a single eval
// Get the input array's stream fence and wait on the
// output arrays stream
fences[it->second].wait(stream, in);
fences[it->second.first].wait(stream, in);
} else if (in.event().valid()) {
if (in.event().is_signaled()) {
in.detach_event();
@@ -251,12 +257,12 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) {
if (needs_fence.find(a.id()) != needs_fence.end()) {
if (auto nf = needs_fence.find(a.id()); nf != needs_fence.end()) {
auto it = fences.find(stream.index);
if (it == fences.end()) {
it = fences.emplace(stream.index, Fence{stream}).first;
}
it->second.update(stream, a);
it->second.update(stream, a, nf->second.second);
}
};